Spaces:
Sleeping
Sleeping
Expose LLM tuning params (temp/tokens) via Gradio API
Browse files
app.py
CHANGED
|
@@ -27,7 +27,7 @@ if ADAPTER_ID:
|
|
| 27 |
model.load_adapter(ADAPTER_ID)
|
| 28 |
|
| 29 |
# 2. Main API Function called by our Next App
|
| 30 |
-
def diagnose_api(history: str, examination: str, image: Image.Image = None, audio_path: str = None):
|
| 31 |
try:
|
| 32 |
if image is None:
|
| 33 |
# Fallback if no image is passed
|
|
@@ -58,7 +58,7 @@ def diagnose_api(history: str, examination: str, image: Image.Image = None, audi
|
|
| 58 |
).to(device)
|
| 59 |
|
| 60 |
with torch.no_grad():
|
| 61 |
-
generated_ids = model.generate(**inputs, max_new_tokens=
|
| 62 |
|
| 63 |
generated_ids_trimmed = [
|
| 64 |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
@@ -80,7 +80,9 @@ demo = gr.Interface(
|
|
| 80 |
gr.Textbox(lines=5, label="Patient History (String)", placeholder="Age, symptoms, past medical history..."),
|
| 81 |
gr.Textbox(lines=5, label="Examination Findings (String)", placeholder="Vitals, systemic exam..."),
|
| 82 |
gr.Image(type="pil", label="Diagnostic Scan / Image"),
|
| 83 |
-
gr.Audio(type="filepath", label="Optional Dictation Audio", visible=False)
|
|
|
|
|
|
|
| 84 |
],
|
| 85 |
outputs=gr.Markdown(label="Clinical Report Output"),
|
| 86 |
title="Multi-Modal Diagnostic Co-Pilot API (Trained via Kaggle)",
|
|
|
|
| 27 |
model.load_adapter(ADAPTER_ID)
|
| 28 |
|
| 29 |
# 2. Main API Function called by our Next App
|
| 30 |
+
def diagnose_api(history: str, examination: str, image: Image.Image = None, audio_path: str = None, temp: float = 0.2, max_tokens: int = 1500):
|
| 31 |
try:
|
| 32 |
if image is None:
|
| 33 |
# Fallback if no image is passed
|
|
|
|
| 58 |
).to(device)
|
| 59 |
|
| 60 |
with torch.no_grad():
|
| 61 |
+
generated_ids = model.generate(**inputs, max_new_tokens=int(max_tokens), temperature=float(temp), top_p=0.9, do_sample=True)
|
| 62 |
|
| 63 |
generated_ids_trimmed = [
|
| 64 |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
|
|
| 80 |
gr.Textbox(lines=5, label="Patient History (String)", placeholder="Age, symptoms, past medical history..."),
|
| 81 |
gr.Textbox(lines=5, label="Examination Findings (String)", placeholder="Vitals, systemic exam..."),
|
| 82 |
gr.Image(type="pil", label="Diagnostic Scan / Image"),
|
| 83 |
+
gr.Audio(type="filepath", label="Optional Dictation Audio", visible=False),
|
| 84 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, label="Temperature (Creativity)"),
|
| 85 |
+
gr.Slider(minimum=256, maximum=4096, value=1500, step=256, label="Max Output Tokens")
|
| 86 |
],
|
| 87 |
outputs=gr.Markdown(label="Clinical Report Output"),
|
| 88 |
title="Multi-Modal Diagnostic Co-Pilot API (Trained via Kaggle)",
|