Update app.py
Browse files
app.py
CHANGED
|
@@ -1,38 +1,48 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
| 3 |
from datasets import load_dataset
|
|
|
|
| 4 |
|
| 5 |
-
#
|
| 6 |
-
model_name = "Qwen/Qwen2.5-3B"
|
| 7 |
-
pipe = pipeline("text-generation", model=model_name, device=0)
|
| 8 |
-
|
| 9 |
-
# β
Load the GTA dataset (correct split is 'train')
|
| 10 |
gta = load_dataset("Jize1/GTA", split="train")
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
with gr.Blocks() as demo:
|
| 28 |
-
gr.Markdown("#
|
| 29 |
-
gr.
|
| 30 |
-
|
| 31 |
-
input_text = gr.Textbox(label="Your input or GTA index")
|
| 32 |
-
use_index = gr.Checkbox(label="Treat input as GTA index", value=False)
|
| 33 |
-
run_btn = gr.Button("Generate")
|
| 34 |
output_md = gr.Markdown()
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
| 3 |
from datasets import load_dataset
|
| 4 |
+
import torch
|
| 5 |
|
| 6 |
+
# Load GTA dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
gta = load_dataset("Jize1/GTA", split="train")
|
| 8 |
|
| 9 |
+
def evaluate_model(model_name, num_samples):
|
| 10 |
+
try:
|
| 11 |
+
pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)
|
| 12 |
+
|
| 13 |
+
correct = 0
|
| 14 |
+
total = 0
|
| 15 |
+
log = []
|
| 16 |
+
|
| 17 |
+
for i in range(min(num_samples, len(gta))):
|
| 18 |
+
query = gta[i]["dialogs"][0]["content"]
|
| 19 |
+
gt_answers = gta[i]["gt_answer"].get("whitelist", [])
|
| 20 |
+
flat_gt = {ans.strip().lower() for group in gt_answers for ans in group if isinstance(ans, str)}
|
| 21 |
+
|
| 22 |
+
# Generate model output
|
| 23 |
+
out = pipe(query, max_new_tokens=128, do_sample=False)[0]["generated_text"].strip().lower()
|
| 24 |
+
|
| 25 |
+
# Match: exact substring match with any whitelist answer
|
| 26 |
+
matched = any(gt in out for gt in flat_gt)
|
| 27 |
+
|
| 28 |
+
log.append(f"### Query {i}\n**Input**: {query}\n**Prediction**: {out}\n**GT**: {flat_gt}\n**βοΈ Correct**: {matched}\n")
|
| 29 |
+
correct += int(matched)
|
| 30 |
+
total += 1
|
| 31 |
+
|
| 32 |
+
acc = round((correct / total) * 100, 2)
|
| 33 |
+
summary = f"### π GTA Answer Accuracy (AnsAcc) for `{model_name}`: **{acc}%** on {total} queries\n\n---\n"
|
| 34 |
+
return summary + "\n".join(log)
|
| 35 |
+
|
| 36 |
+
except Exception as e:
|
| 37 |
+
return f"β Evaluation failed: {e}"
|
| 38 |
+
|
| 39 |
with gr.Blocks() as demo:
|
| 40 |
+
gr.Markdown("# π§ͺ Real GTA Evaluation (Answer Accuracy Only)")
|
| 41 |
+
model_input = gr.Textbox(label="Enter Hugging Face Model Name", value="Qwen/Qwen2.5-3B")
|
| 42 |
+
sample_count = gr.Slider(label="Number of GTA samples to evaluate", minimum=1, maximum=229, value=10, step=1)
|
|
|
|
|
|
|
|
|
|
| 43 |
output_md = gr.Markdown()
|
| 44 |
+
|
| 45 |
+
model_input.change(fn=evaluate_model, inputs=[model_input, sample_count], outputs=output_md)
|
| 46 |
+
sample_count.change(fn=evaluate_model, inputs=[model_input, sample_count], outputs=output_md)
|
| 47 |
|
| 48 |
demo.launch()
|