Update app.py
Browse files
app.py
CHANGED
|
@@ -3,46 +3,74 @@ from transformers import pipeline
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
# Load GTA dataset
|
| 7 |
gta = load_dataset("Jize1/GTA", split="train")
|
| 8 |
|
| 9 |
def evaluate_model(model_name, num_samples):
|
| 10 |
try:
|
| 11 |
pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
log = []
|
| 16 |
|
| 17 |
for i in range(min(num_samples, len(gta))):
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
-
out = pipe(query, max_new_tokens=128, do_sample=False)[0]["generated_text"].strip().lower()
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
except Exception as e:
|
| 37 |
-
return f"❌
|
| 38 |
|
|
|
|
| 39 |
with gr.Blocks() as demo:
|
| 40 |
-
gr.Markdown("#
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
output_md = gr.Markdown()
|
| 44 |
|
| 45 |
-
|
| 46 |
-
sample_count.change(fn=evaluate_model, inputs=[model_input, sample_count], outputs=output_md)
|
| 47 |
|
| 48 |
demo.launch()
|
|
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
import torch
|
| 5 |
|
|
|
|
| 6 |
gta = load_dataset("Jize1/GTA", split="train")
|
| 7 |
|
| 8 |
def evaluate_model(model_name, num_samples):
|
| 9 |
try:
|
| 10 |
pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)
|
| 11 |
|
| 12 |
+
inst_correct, tool_correct, summ_correct, ans_correct = 0, 0, 0, 0
|
| 13 |
+
logs = []
|
|
|
|
| 14 |
|
| 15 |
for i in range(min(num_samples, len(gta))):
|
| 16 |
+
sample = gta[i]
|
| 17 |
+
query = sample["dialogs"][0]["content"]
|
| 18 |
+
tools_used = [step["function"]["name"].lower() for step in sample["dialogs"] if "function" in step.get("function", {})]
|
| 19 |
|
| 20 |
+
prediction = pipe(query, max_new_tokens=256, do_sample=False)[0]["generated_text"].strip().lower()
|
|
|
|
| 21 |
|
| 22 |
+
# Instruction following: if answer is long enough and not hallucinated
|
| 23 |
+
inst_pass = len(prediction) > 10 and any(w in prediction for w in ["use", "calculate", "looks like", "means", "based on"])
|
| 24 |
+
inst_correct += inst_pass
|
| 25 |
|
| 26 |
+
# ToolAcc: if any known tool name is mentioned
|
| 27 |
+
tool_pass = any(tool in prediction for tool in tools_used)
|
| 28 |
+
tool_correct += tool_pass
|
| 29 |
|
| 30 |
+
# SummAcc: if answer includes concluding phrases or numbers (as proxy)
|
| 31 |
+
summ_pass = any(x in prediction for x in ["so", "therefore", "the answer is", "equals", "you will need", "hence"])
|
| 32 |
+
summ_correct += summ_pass
|
| 33 |
+
|
| 34 |
+
# AnsAcc: match whitelist phrase
|
| 35 |
+
gt_phrases = sample["gt_answer"].get("whitelist", [])
|
| 36 |
+
flat_gt = {s.strip().lower() for group in gt_phrases for s in group if isinstance(s, str)}
|
| 37 |
+
ans_pass = any(g in prediction for g in flat_gt)
|
| 38 |
+
ans_correct += ans_pass
|
| 39 |
+
|
| 40 |
+
logs.append(f"""
|
| 41 |
+
### Query {i}
|
| 42 |
+
**Input**: {query}
|
| 43 |
+
**Prediction**: {prediction}
|
| 44 |
+
**GT**: {flat_gt}
|
| 45 |
+
**Instruction✔️**: {inst_pass}
|
| 46 |
+
**Tool✔️**: {tool_pass}
|
| 47 |
+
**Summary✔️**: {summ_pass}
|
| 48 |
+
**Answer✔️**: {ans_pass}
|
| 49 |
+
---""")
|
| 50 |
+
|
| 51 |
+
total = min(num_samples, len(gta))
|
| 52 |
+
results = {
|
| 53 |
+
"InstAcc": round((inst_correct / total) * 100, 2),
|
| 54 |
+
"ToolAcc": round((tool_correct / total) * 100, 2),
|
| 55 |
+
"SummAcc": round((summ_correct / total) * 100, 2),
|
| 56 |
+
"AnsAcc": round((ans_correct / total) * 100, 2),
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
summary = "\n".join([f"**{k}**: {v}%" for k, v in results.items()])
|
| 60 |
+
return f"## 🔬 GTA Evaluation for `{model_name}` on {total} queries\n\n{summary}\n\n---\n" + "\n".join(logs)
|
| 61 |
|
| 62 |
except Exception as e:
|
| 63 |
+
return f"❌ Error: {e}"
|
| 64 |
|
| 65 |
+
# Gradio UI
|
| 66 |
with gr.Blocks() as demo:
|
| 67 |
+
gr.Markdown("# 🧠 GTA Tool Use Evaluation (Real Metrics, Real Queries)")
|
| 68 |
+
with gr.Row():
|
| 69 |
+
model_input = gr.Textbox(label="Model Name", value="Qwen/Qwen2.5-3B")
|
| 70 |
+
sample_slider = gr.Slider(label="Number of GTA samples", minimum=1, maximum=229, value=10, step=1)
|
| 71 |
+
run_btn = gr.Button("Run Evaluation")
|
| 72 |
output_md = gr.Markdown()
|
| 73 |
|
| 74 |
+
run_btn.click(fn=evaluate_model, inputs=[model_input, sample_slider], outputs=output_md)
|
|
|
|
| 75 |
|
| 76 |
demo.launch()
|