hssling commited on
Commit
e83f1b8
·
1 Parent(s): 2567232

Expose LLM tuning params (temp/tokens) via Gradio API

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -27,7 +27,7 @@ if ADAPTER_ID:
27
  model.load_adapter(ADAPTER_ID)
28
 
29
  # 2. Main API Function called by our Next App
30
- def diagnose_api(history: str, examination: str, image: Image.Image = None, audio_path: str = None):
31
  try:
32
  if image is None:
33
  # Fallback if no image is passed
@@ -58,7 +58,7 @@ def diagnose_api(history: str, examination: str, image: Image.Image = None, audi
58
  ).to(device)
59
 
60
  with torch.no_grad():
61
- generated_ids = model.generate(**inputs, max_new_tokens=1024, temperature=0.2)
62
 
63
  generated_ids_trimmed = [
64
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
@@ -80,7 +80,9 @@ demo = gr.Interface(
80
  gr.Textbox(lines=5, label="Patient History (String)", placeholder="Age, symptoms, past medical history..."),
81
  gr.Textbox(lines=5, label="Examination Findings (String)", placeholder="Vitals, systemic exam..."),
82
  gr.Image(type="pil", label="Diagnostic Scan / Image"),
83
- gr.Audio(type="filepath", label="Optional Dictation Audio", visible=False) # Qwen-VL does not naturally support audio, handling externally or ignoring
 
 
84
  ],
85
  outputs=gr.Markdown(label="Clinical Report Output"),
86
  title="Multi-Modal Diagnostic Co-Pilot API (Trained via Kaggle)",
 
27
  model.load_adapter(ADAPTER_ID)
28
 
29
  # 2. Main API Function called by our Next App
30
+ def diagnose_api(history: str, examination: str, image: Image.Image = None, audio_path: str = None, temp: float = 0.2, max_tokens: int = 1500):
31
  try:
32
  if image is None:
33
  # Fallback if no image is passed
 
58
  ).to(device)
59
 
60
  with torch.no_grad():
61
+ generated_ids = model.generate(**inputs, max_new_tokens=int(max_tokens), temperature=float(temp), top_p=0.9, do_sample=True)
62
 
63
  generated_ids_trimmed = [
64
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
80
  gr.Textbox(lines=5, label="Patient History (String)", placeholder="Age, symptoms, past medical history..."),
81
  gr.Textbox(lines=5, label="Examination Findings (String)", placeholder="Vitals, systemic exam..."),
82
  gr.Image(type="pil", label="Diagnostic Scan / Image"),
83
+ gr.Audio(type="filepath", label="Optional Dictation Audio", visible=False),
84
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, label="Temperature (Creativity)"),
85
+ gr.Slider(minimum=256, maximum=4096, value=1500, step=256, label="Max Output Tokens")
86
  ],
87
  outputs=gr.Markdown(label="Clinical Report Output"),
88
  title="Multi-Modal Diagnostic Co-Pilot API (Trained via Kaggle)",