mharkey commited on
Commit
25b3bcb
Β·
verified Β·
1 Parent(s): 4b3ecc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -28
app.py CHANGED
@@ -1,38 +1,48 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
  from datasets import load_dataset
 
4
 
5
- # βœ… Load the smaller model (fits in 16GB)
6
- model_name = "Qwen/Qwen2.5-3B"
7
- pipe = pipeline("text-generation", model=model_name, device=0)
8
-
9
- # βœ… Load the GTA dataset (correct split is 'train')
10
  gta = load_dataset("Jize1/GTA", split="train")
11
 
12
- # βœ… Inference function
13
- def run_model(input_text, use_gta_idx):
14
- if use_gta_idx:
15
- try:
16
- idx = int(input_text)
17
- question = gta[idx]["dialogs"][0]["content"]
18
- except Exception as e:
19
- return f"❌ Invalid index (0–{len(gta)-1}): {e}"
20
- else:
21
- question = input_text.strip()
22
-
23
- output = pipe(question, max_new_tokens=256, do_sample=True)
24
- return f"**Question:** {question}\n\n**Response:**\n{output[0]['generated_text']}"
25
-
26
- # βœ… Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  with gr.Blocks() as demo:
28
- gr.Markdown("# πŸ€– GTA Reasoning Demo (Qwen2.5‑3B + GTA Dataset)")
29
- gr.Markdown("Enter a custom question or choose a sample from the GTA dataset (index 0–228).")
30
- with gr.Row():
31
- input_text = gr.Textbox(label="Your input or GTA index")
32
- use_index = gr.Checkbox(label="Treat input as GTA index", value=False)
33
- run_btn = gr.Button("Generate")
34
  output_md = gr.Markdown()
35
-
36
- run_btn.click(fn=run_model, inputs=[input_text, use_index], outputs=[output_md])
 
37
 
38
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  from datasets import load_dataset
4
+ import torch
5
 
6
+ # Load GTA dataset
 
 
 
 
7
  gta = load_dataset("Jize1/GTA", split="train")
8
 
9
+ def evaluate_model(model_name, num_samples):
10
+ try:
11
+ pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)
12
+
13
+ correct = 0
14
+ total = 0
15
+ log = []
16
+
17
+ for i in range(min(num_samples, len(gta))):
18
+ query = gta[i]["dialogs"][0]["content"]
19
+ gt_answers = gta[i]["gt_answer"].get("whitelist", [])
20
+ flat_gt = {ans.strip().lower() for group in gt_answers for ans in group if isinstance(ans, str)}
21
+
22
+ # Generate model output
23
+ out = pipe(query, max_new_tokens=128, do_sample=False)[0]["generated_text"].strip().lower()
24
+
25
+ # Match: exact substring match with any whitelist answer
26
+ matched = any(gt in out for gt in flat_gt)
27
+
28
+ log.append(f"### Query {i}\n**Input**: {query}\n**Prediction**: {out}\n**GT**: {flat_gt}\n**βœ”οΈ Correct**: {matched}\n")
29
+ correct += int(matched)
30
+ total += 1
31
+
32
+ acc = round((correct / total) * 100, 2)
33
+ summary = f"### πŸ” GTA Answer Accuracy (AnsAcc) for `{model_name}`: **{acc}%** on {total} queries\n\n---\n"
34
+ return summary + "\n".join(log)
35
+
36
+ except Exception as e:
37
+ return f"❌ Evaluation failed: {e}"
38
+
39
  with gr.Blocks() as demo:
40
+ gr.Markdown("# πŸ§ͺ Real GTA Evaluation (Answer Accuracy Only)")
41
+ model_input = gr.Textbox(label="Enter Hugging Face Model Name", value="Qwen/Qwen2.5-3B")
42
+ sample_count = gr.Slider(label="Number of GTA samples to evaluate", minimum=1, maximum=229, value=10, step=1)
 
 
 
43
  output_md = gr.Markdown()
44
+
45
+ model_input.change(fn=evaluate_model, inputs=[model_input, sample_count], outputs=output_md)
46
+ sample_count.change(fn=evaluate_model, inputs=[model_input, sample_count], outputs=output_md)
47
 
48
  demo.launch()