Mina Parham commited on
Commit Β·
38366a7
1
Parent(s): 55aae7a
Initial commit π
Browse files- app.py +136 -63
- requirements.txt +54 -1
app.py
CHANGED
|
@@ -1,64 +1,137 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import requests
|
| 3 |
+
import asyncio
|
| 4 |
+
import aiohttp
|
| 5 |
+
|
| 6 |
+
# Models setup
|
| 7 |
+
models = {
|
| 8 |
+
"Mistral-7B-Instruct": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
|
| 9 |
+
"DeepSeek-7B-Instruct": "https://api-inference.huggingface.co/models/deepseek-ai/deepseek-llm-7b-instruct",
|
| 10 |
+
"Qwen-7B-Chat": "https://api-inference.huggingface.co/models/Qwen/Qwen-7B-Chat"
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
# Judge model (Mixtral-8x7B)
|
| 14 |
+
judge_model_url = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
| 15 |
+
|
| 16 |
+
# Your Hugging Face API Token
|
| 17 |
+
API_TOKEN = "YOUR_HUGGINGFACE_API_TOKEN"
|
| 18 |
+
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
|
| 19 |
+
|
| 20 |
+
# Async function to call a model
|
| 21 |
+
async def query_model(session, model_name, question):
|
| 22 |
+
payload = {"inputs": question, "parameters": {"max_new_tokens": 300}}
|
| 23 |
+
try:
|
| 24 |
+
async with session.post(models[model_name], headers=HEADERS, json=payload, timeout=60) as response:
|
| 25 |
+
result = await response.json()
|
| 26 |
+
if isinstance(result, list) and len(result) > 0:
|
| 27 |
+
return model_name, result[0]["generated_text"]
|
| 28 |
+
elif isinstance(result, dict) and "generated_text" in result:
|
| 29 |
+
return model_name, result["generated_text"]
|
| 30 |
+
else:
|
| 31 |
+
return model_name, str(result)
|
| 32 |
+
except Exception as e:
|
| 33 |
+
return model_name, f"Error: {str(e)}"
|
| 34 |
+
|
| 35 |
+
# Async function to call all models
|
| 36 |
+
async def gather_model_answers(question):
|
| 37 |
+
async with aiohttp.ClientSession() as session:
|
| 38 |
+
tasks = [query_model(session, model_name, question) for model_name in models]
|
| 39 |
+
results = await asyncio.gather(*tasks)
|
| 40 |
+
return dict(results)
|
| 41 |
+
|
| 42 |
+
# Function to ask the judge
|
| 43 |
+
def judge_best_answer(question, answers):
|
| 44 |
+
# Format the prompt for the Judge
|
| 45 |
+
judge_prompt = f"""
|
| 46 |
+
You are a wise AI Judge. A user asked the following question:
|
| 47 |
+
|
| 48 |
+
Question:
|
| 49 |
+
{question}
|
| 50 |
+
|
| 51 |
+
Here are the answers provided by different models:
|
| 52 |
+
|
| 53 |
+
Answer 1 (Mistral-7B-Instruct):
|
| 54 |
+
{answers['Mistral-7B-Instruct']}
|
| 55 |
+
|
| 56 |
+
Answer 2 (DeepSeek-7B-Instruct):
|
| 57 |
+
{answers['DeepSeek-7B-Instruct']}
|
| 58 |
+
|
| 59 |
+
Answer 3 (Qwen-7B-Chat):
|
| 60 |
+
{answers['Qwen-7B-Chat']}
|
| 61 |
+
|
| 62 |
+
Please carefully read all three answers. Your job:
|
| 63 |
+
- Pick the best answer (Answer 1, Answer 2, or Answer 3).
|
| 64 |
+
- Explain briefly why you chose that answer.
|
| 65 |
+
|
| 66 |
+
Respond in this JSON format:
|
| 67 |
+
{{"best_answer": "Answer X", "reason": "Your reasoning here"}}
|
| 68 |
+
""".strip()
|
| 69 |
+
|
| 70 |
+
payload = {"inputs": judge_prompt, "parameters": {"max_new_tokens": 300}}
|
| 71 |
+
response = requests.post(judge_model_url, headers=HEADERS, json=payload)
|
| 72 |
+
|
| 73 |
+
if response.status_code == 200:
|
| 74 |
+
result = response.json()
|
| 75 |
+
# Try to extract JSON from response
|
| 76 |
+
import json
|
| 77 |
+
import re
|
| 78 |
+
|
| 79 |
+
# Attempt to extract JSON block
|
| 80 |
+
match = re.search(r"\{.*\}", str(result))
|
| 81 |
+
if match:
|
| 82 |
+
try:
|
| 83 |
+
judge_decision = json.loads(match.group(0))
|
| 84 |
+
return judge_decision
|
| 85 |
+
except json.JSONDecodeError:
|
| 86 |
+
return {"best_answer": "Unknown", "reason": "Failed to parse judge output."}
|
| 87 |
+
else:
|
| 88 |
+
return {"best_answer": "Unknown", "reason": "No JSON found in judge output."}
|
| 89 |
+
else:
|
| 90 |
+
return {"best_answer": "Unknown", "reason": f"Judge API error: {response.status_code}"}
|
| 91 |
+
|
| 92 |
+
# Final app logic
|
| 93 |
+
def multi_model_qa(question):
|
| 94 |
+
answers = asyncio.run(gather_model_answers(question))
|
| 95 |
+
judge_decision = judge_best_answer(question, answers)
|
| 96 |
+
|
| 97 |
+
# Find the selected best answer
|
| 98 |
+
best_answer_key = judge_decision.get("best_answer", "")
|
| 99 |
+
best_answer_text = ""
|
| 100 |
+
if "1" in best_answer_key:
|
| 101 |
+
best_answer_text = answers["Mistral-7B-Instruct"]
|
| 102 |
+
elif "2" in best_answer_key:
|
| 103 |
+
best_answer_text = answers["DeepSeek-7B-Instruct"]
|
| 104 |
+
elif "3" in best_answer_key:
|
| 105 |
+
best_answer_text = answers["Qwen-7B-Chat"]
|
| 106 |
+
else:
|
| 107 |
+
best_answer_text = "Could not determine best answer."
|
| 108 |
+
|
| 109 |
+
return (
|
| 110 |
+
answers["Mistral-7B-Instruct"],
|
| 111 |
+
answers["DeepSeek-7B-Instruct"],
|
| 112 |
+
answers["Qwen-7B-Chat"],
|
| 113 |
+
best_answer_text,
|
| 114 |
+
judge_decision.get("reason", "No reasoning provided.")
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Gradio UI
|
| 118 |
+
with gr.Blocks() as demo:
|
| 119 |
+
gr.Markdown("# π§ Multi-Model Answer Aggregator")
|
| 120 |
+
gr.Markdown("Ask any question. The system queries multiple models and the AI Judge selects the best answer.")
|
| 121 |
+
|
| 122 |
+
question_input = gr.Textbox(label="Enter your question", placeholder="Ask me anything...", lines=2)
|
| 123 |
+
submit_btn = gr.Button("Get Best Answer")
|
| 124 |
+
|
| 125 |
+
mistral_output = gr.Textbox(label="Mistral-7B-Instruct Answer")
|
| 126 |
+
deepseek_output = gr.Textbox(label="DeepSeek-7B-Instruct Answer")
|
| 127 |
+
qwen_output = gr.Textbox(label="Qwen-7B-Chat Answer")
|
| 128 |
+
best_answer_output = gr.Textbox(label="π Best Answer Selected")
|
| 129 |
+
judge_reasoning_output = gr.Textbox(label="βοΈ Judge's Reasoning")
|
| 130 |
+
|
| 131 |
+
submit_btn.click(
|
| 132 |
+
multi_model_qa,
|
| 133 |
+
inputs=[question_input],
|
| 134 |
+
outputs=[mistral_output, deepseek_output, qwen_output, best_answer_output, judge_reasoning_output]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
demo.launch()
|
requirements.txt
CHANGED
|
@@ -1 +1,54 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==24.1.0
|
| 2 |
+
annotated-types==0.7.0
|
| 3 |
+
anyio==4.9.0
|
| 4 |
+
audioop-lts==0.2.1
|
| 5 |
+
certifi==2025.4.26
|
| 6 |
+
charset-normalizer==3.4.1
|
| 7 |
+
click==8.1.8
|
| 8 |
+
fastapi==0.115.12
|
| 9 |
+
ffmpy==0.5.0
|
| 10 |
+
filelock==3.18.0
|
| 11 |
+
fsspec==2025.3.2
|
| 12 |
+
gradio==5.27.0
|
| 13 |
+
gradio_client==1.9.0
|
| 14 |
+
groovy==0.1.2
|
| 15 |
+
h11==0.16.0
|
| 16 |
+
httpcore==1.0.9
|
| 17 |
+
httpx==0.28.1
|
| 18 |
+
huggingface-hub==0.30.2
|
| 19 |
+
idna==3.10
|
| 20 |
+
Jinja2==3.1.6
|
| 21 |
+
markdown-it-py==3.0.0
|
| 22 |
+
MarkupSafe==3.0.2
|
| 23 |
+
mdurl==0.1.2
|
| 24 |
+
numpy==2.2.5
|
| 25 |
+
orjson==3.10.16
|
| 26 |
+
packaging==25.0
|
| 27 |
+
pandas==2.2.3
|
| 28 |
+
pillow==11.2.1
|
| 29 |
+
pydantic==2.11.3
|
| 30 |
+
pydantic_core==2.33.1
|
| 31 |
+
pydub==0.25.1
|
| 32 |
+
Pygments==2.19.1
|
| 33 |
+
python-dateutil==2.9.0.post0
|
| 34 |
+
python-multipart==0.0.20
|
| 35 |
+
pytz==2025.2
|
| 36 |
+
PyYAML==6.0.2
|
| 37 |
+
requests==2.32.3
|
| 38 |
+
rich==14.0.0
|
| 39 |
+
ruff==0.11.7
|
| 40 |
+
safehttpx==0.1.6
|
| 41 |
+
semantic-version==2.10.0
|
| 42 |
+
shellingham==1.5.4
|
| 43 |
+
six==1.17.0
|
| 44 |
+
sniffio==1.3.1
|
| 45 |
+
starlette==0.46.2
|
| 46 |
+
tomlkit==0.13.2
|
| 47 |
+
tqdm==4.67.1
|
| 48 |
+
typer==0.15.2
|
| 49 |
+
typing-inspection==0.4.0
|
| 50 |
+
typing_extensions==4.13.2
|
| 51 |
+
tzdata==2025.2
|
| 52 |
+
urllib3==2.4.0
|
| 53 |
+
uvicorn==0.34.2
|
| 54 |
+
websockets==15.0.1
|