mharkey commited on
Commit
4b3ecc2
Β·
verified Β·
1 Parent(s): 7aab66d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -2,18 +2,19 @@ import gradio as gr
2
  from transformers import pipeline
3
  from datasets import load_dataset
4
 
5
- # Load smaller model to fit memory
6
  model_name = "Qwen/Qwen2.5-3B"
7
  pipe = pipeline("text-generation", model=model_name, device=0)
8
 
9
- # Load the GTA dataset from Hugging Face
10
- gta = load_dataset("Jize1/GTA", split="test")
11
 
 
12
  def run_model(input_text, use_gta_idx):
13
  if use_gta_idx:
14
  try:
15
  idx = int(input_text)
16
- question = gta[idx]["dialogs"][0]["content"].strip()
17
  except Exception as e:
18
  return f"❌ Invalid index (0–{len(gta)-1}): {e}"
19
  else:
@@ -22,9 +23,10 @@ def run_model(input_text, use_gta_idx):
22
  output = pipe(question, max_new_tokens=256, do_sample=True)
23
  return f"**Question:** {question}\n\n**Response:**\n{output[0]['generated_text']}"
24
 
 
25
  with gr.Blocks() as demo:
26
  gr.Markdown("# πŸ€– GTA Reasoning Demo (Qwen2.5‑3B + GTA Dataset)")
27
- gr.Markdown("Enter your own question or a GTA index (0–228).")
28
  with gr.Row():
29
  input_text = gr.Textbox(label="Your input or GTA index")
30
  use_index = gr.Checkbox(label="Treat input as GTA index", value=False)
 
2
  from transformers import pipeline
3
  from datasets import load_dataset
4
 
5
+ # βœ… Load the smaller model (fits in 16GB)
6
  model_name = "Qwen/Qwen2.5-3B"
7
  pipe = pipeline("text-generation", model=model_name, device=0)
8
 
9
+ # βœ… Load the GTA dataset (correct split is 'train')
10
+ gta = load_dataset("Jize1/GTA", split="train")
11
 
12
+ # βœ… Inference function
13
  def run_model(input_text, use_gta_idx):
14
  if use_gta_idx:
15
  try:
16
  idx = int(input_text)
17
+ question = gta[idx]["dialogs"][0]["content"]
18
  except Exception as e:
19
  return f"❌ Invalid index (0–{len(gta)-1}): {e}"
20
  else:
 
23
  output = pipe(question, max_new_tokens=256, do_sample=True)
24
  return f"**Question:** {question}\n\n**Response:**\n{output[0]['generated_text']}"
25
 
26
+ # βœ… Gradio UI
27
  with gr.Blocks() as demo:
28
  gr.Markdown("# πŸ€– GTA Reasoning Demo (Qwen2.5‑3B + GTA Dataset)")
29
+ gr.Markdown("Enter a custom question or choose a sample from the GTA dataset (index 0–228).")
30
  with gr.Row():
31
  input_text = gr.Textbox(label="Your input or GTA index")
32
  use_index = gr.Checkbox(label="Treat input as GTA index", value=False)