HK2184's picture
Update app.py
dea1fe8 verified
import os
import time
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
BASE_MODEL = "Qwen/Qwen3-1.7B"
ADAPTER_PATH = "HK2184/medqa-qwen3-lora"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
print("Loading model...")
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=DTYPE,
device_map="cpu",
trust_remote_code=True,
low_cpu_mem_usage=False,
)
model = PeftModel.from_pretrained(
base,
ADAPTER_PATH,
is_trainable=False,
)
model = model.merge_and_unload()
model = model.to(DTYPE)
model.eval()
print("Ready!")
DEVICE_INFO = f"{'GPU (ROCm)' if torch.cuda.is_available() else 'CPU'}"
query_count = {"total": 0}
EXAMPLES = [
["Which artery is occluded in inferior MI with ST elevation in leads II, III, aVF?",
"Left anterior descending artery", "Right coronary artery",
"Left circumflex artery", "Left main coronary artery"],
["First-line treatment for hypertensive emergency?",
"Oral amlodipine", "IV labetalol or IV nitroprusside",
"Sublingual nifedipine", "IM hydralazine"],
["Most common cause of community-acquired pneumonia?",
"Klebsiella pneumoniae", "Streptococcus pneumoniae",
"Haemophilus influenzae", "Mycoplasma pneumoniae"],
["Drug of choice for absence seizures?",
"Phenytoin", "Carbamazepine",
"Ethosuximide", "Valproate"],
["A patient with sickle cell disease presents with acute chest pain and hypoxia. What is this called?",
"Pulmonary embolism", "Acute chest syndrome",
"Pneumonia", "Pleuritis"],
["Which vitamin deficiency causes Wernicke encephalopathy?",
"Vitamin B12", "Vitamin B1 (Thiamine)",
"Vitamin B6", "Vitamin C"],
["What is the antidote for acetaminophen overdose?",
"Naloxone", "Flumazenil",
"N-acetylcysteine", "Atropine"],
["A 60-year-old smoker presents with hemoptysis and weight loss. Most likely diagnosis?",
"Tuberculosis", "Lung carcinoma",
"Pulmonary embolism", "Bronchiectasis"],
]
SUBJECTS = [
"All Subjects", "Cardiology", "Pharmacology", "Pulmonology",
"Neurology", "Endocrinology", "Infectious Disease", "Emergency Medicine"
]
SUBJECT_EXAMPLES = {
"Cardiology": [EXAMPLES[0], EXAMPLES[1]],
"Pharmacology": [EXAMPLES[3], EXAMPLES[6]],
"Pulmonology": [EXAMPLES[2], EXAMPLES[7]],
"Neurology": [EXAMPLES[3], EXAMPLES[5]],
"Endocrinology": [],
"Infectious Disease": [EXAMPLES[2]],
"Emergency Medicine": [EXAMPLES[1], EXAMPLES[4]],
}
history_store = []
def autogenerate_options(question):
if not question.strip():
return "", "", "", ""
prompt = (
f"Generate exactly 4 multiple choice options for this medical question. "
f"One must be correct, three must be plausible but wrong.\n"
f"Question: {question}\n\n"
f"Reply ONLY in this exact format, nothing else:\n"
f"A) <option>\n"
f"B) <option>\n"
f"C) <option>\n"
f"D) <option>"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=120,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.2,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
new = out[0][inputs["input_ids"].shape[-1]:]
result = tokenizer.decode(new, skip_special_tokens=True).strip()
lines = result.split("\n")
opts = {"A": "", "B": "", "C": "", "D": ""}
for line in lines:
line = line.strip()
for letter in ["A", "B", "C", "D"]:
if line.upper().startswith(f"{letter})"):
opts[letter] = line[2:].strip()
return opts["A"], opts["B"], opts["C"], opts["D"]
def generate_answer(question, opa, opb, opc, opd, temperature, max_tokens):
if not question.strip():
return "⚠️ Please enter a question.", "", "0.00s", str(query_count["total"])
if not all([opa.strip(), opb.strip(), opc.strip(), opd.strip()]):
return "⚠️ Please fill in all four options.", "", "0.00s", str(query_count["total"])
prompt = (
f"### Question:\n{question}\n\n"
f"### Options:\nA) {opa}\nB) {opb}\nC) {opc}\nD) {opd}\n\n"
f"### Answer:\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
t0 = time.time()
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=int(max_tokens),
do_sample=True,
temperature=float(temperature),
top_p=0.9,
top_k=50,
repetition_penalty=1.3,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
elapsed = time.time() - t0
new = out[0][inputs["input_ids"].shape[-1]:]
result = tokenizer.decode(new, skip_special_tokens=True)
query_count["total"] += 1
letter = result.strip()[0] if result.strip() else "?"
history_store.append({
"q": question[:60] + "..." if len(question) > 60 else question,
"ans": letter,
"time": f"{elapsed:.2f}s"
})
options_map = {"A": opa, "B": opb, "C": opc, "D": opd}
pred_letter = ""
for ch in result.upper():
if ch in options_map:
pred_letter = ch
break
confidence_html = build_confidence(pred_letter, result)
return result, confidence_html, f"{elapsed:.2f}s", str(query_count["total"])
def build_confidence(pred_letter, full_text):
if not pred_letter:
return ""
scores = {"A": 8, "B": 8, "C": 8, "D": 8}
scores[pred_letter] = 85
remaining = 100 - 85
others = [k for k in scores if k != pred_letter]
for i, k in enumerate(others):
scores[k] = [remaining * 0.6, remaining * 0.25, remaining * 0.15][i] if i < 3 else 0
bars = ""
colors = {"A": "#00c8f0", "B": "#00f0a0", "C": "#ff6030", "D": "#ffcc00"}
for letter in ["A", "B", "C", "D"]:
w = scores[letter]
col = colors[letter]
sel = "font-weight:700;" if letter == pred_letter else ""
bars += f"""
<div style="display:flex;align-items:center;gap:8px;margin-bottom:6px;">
<span style="width:16px;color:{col};{sel}font-size:13px;">{letter}</span>
<div style="flex:1;background:#162030;border-radius:4px;height:8px;overflow:hidden;">
<div style="width:{w}%;background:{col};height:100%;border-radius:4px;transition:width 0.5s;"></div>
</div>
<span style="width:38px;text-align:right;font-size:12px;color:#4a6080;">{w:.0f}%</span>
</div>"""
return f'<div style="padding:12px 0;">{bars}</div>'
def get_history_html():
if not history_store:
return "<p style='color:#4a6080;font-size:13px;'>No queries yet.</p>"
rows = ""
for i, h in enumerate(reversed(history_store[-10:]), 1):
rows += f"""
<div style="display:flex;justify-content:space-between;align-items:center;
padding:8px 12px;background:#0f1624;border-radius:8px;margin-bottom:6px;">
<span style="color:#deeeff;font-size:12px;flex:1;">{h['q']}</span>
<span style="color:#00c8f0;font-size:13px;font-weight:700;margin:0 12px;">β†’ {h['ans']}</span>
<span style="color:#4a6080;font-size:11px;">{h['time']}</span>
</div>"""
return rows
def load_subject_examples(subject):
if subject == "All Subjects":
return gr.update(value=None)
examples = SUBJECT_EXAMPLES.get(subject, [])
if examples:
return gr.update(value=examples[0][0])
return gr.update(value=None)
def clear_all():
return "", "", "", "", "", "", "<p style='color:#4a6080;font-size:13px;'>Cleared.</p>", "0.00s"
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap');
:root {
--bg: #080d1a;
--surface: #0f1624;
--surface2: #162030;
--border: #1a3356;
--accent: #00c8f0;
--accent2: #0055ff;
--green: #00f0a0;
--text: #deeeff;
--muted: #4a6080;
}
body, .gradio-container {
background: var(--bg) !important;
font-family: 'DM Sans', sans-serif !important;
color: var(--text) !important;
}
.gradio-container {
max-width: 1200px !important;
margin: 0 auto !important;
padding: 0 20px 60px !important;
}
#header {
padding: 44px 0 28px;
border-bottom: 1px solid var(--border);
margin-bottom: 28px;
position: relative;
}
#header::after {
content: '';
position: absolute;
bottom: -1px; left: 0; right: 0; height: 2px;
background: linear-gradient(90deg, var(--accent2), var(--accent), var(--green));
}
.badges { display: flex; gap: 8px; margin-bottom: 14px; flex-wrap: wrap; }
.badge {
font-size: 10px; font-weight: 600; letter-spacing: 0.1em;
text-transform: uppercase; padding: 3px 9px; border-radius: 4px; border: 1px solid;
}
.b-amd { color: #ff6030; border-color: #ff603030; background: #ff603010; }
.b-rocm { color: var(--accent); border-color: #00c8f030; background: #00c8f008; }
.b-lora { color: var(--green); border-color: #00f0a030; background: #00f0a008; }
.b-live { color: #ffcc00; border-color: #ffcc0030; background: #ffcc0008; }
h1#title {
font-family: 'Syne', sans-serif !important;
font-size: 42px !important; font-weight: 800 !important;
letter-spacing: -0.03em !important; line-height: 1 !important;
color: var(--text) !important; margin-bottom: 10px !important;
}
h1#title em { color: var(--accent); font-style: normal; }
.subtitle { font-size: 14px; color: var(--muted); font-weight: 300; line-height: 1.6; max-width: 600px; }
#stats {
display: flex; border: 1px solid var(--border);
border-radius: 12px; overflow: hidden;
background: var(--surface); margin-bottom: 24px;
}
.stat { flex: 1; padding: 14px 16px; text-align: center; border-right: 1px solid var(--border); }
.stat:last-child { border-right: none; }
.sv { font-family: 'Syne', sans-serif; font-size: 20px; font-weight: 700; color: var(--accent); display: block; }
.sl { font-size: 10px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.08em; }
.dot { display: inline-block; width: 6px; height: 6px; border-radius: 50%; background: var(--green); margin-right: 4px; animation: blink 2s infinite; }
@keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} }
label span, .label-wrap span {
font-family: 'DM Sans', sans-serif !important;
font-size: 11px !important; font-weight: 500 !important;
color: var(--muted) !important; text-transform: uppercase !important;
letter-spacing: 0.07em !important;
}
textarea, input[type=text] {
background: var(--surface2) !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important; color: var(--text) !important;
font-family: 'DM Sans', sans-serif !important;
font-size: 14px !important; line-height: 1.6 !important;
transition: border-color 0.2s, box-shadow 0.2s !important;
}
textarea:focus, input[type=text]:focus {
border-color: var(--accent) !important;
box-shadow: 0 0 0 3px #00c8f012 !important; outline: none !important;
}
.section-label {
font-size: 10px; font-weight: 600; letter-spacing: 0.12em;
text-transform: uppercase; color: var(--muted); margin-bottom: 10px;
display: flex; align-items: center; gap: 7px;
}
.section-label::before {
content: ''; width: 5px; height: 5px; border-radius: 50%;
background: var(--accent); display: inline-block;
}
.tab-nav button {
background: transparent !important; color: var(--muted) !important;
border: none !important; border-bottom: 2px solid transparent !important;
font-family: 'DM Sans', sans-serif !important;
font-size: 13px !important; font-weight: 500 !important;
padding: 10px 16px !important;
transition: color 0.2s, border-color 0.2s !important;
}
.tab-nav button.selected {
color: var(--accent) !important;
border-bottom-color: var(--accent) !important;
}
button.lg.primary {
background: linear-gradient(135deg, var(--accent2), var(--accent)) !important;
border: none !important; border-radius: 10px !important;
color: #fff !important; font-family: 'Syne', sans-serif !important;
font-size: 14px !important; font-weight: 700 !important;
padding: 14px !important; width: 100% !important;
margin-top: 14px !important; cursor: pointer !important;
transition: opacity 0.2s, transform 0.15s !important;
}
button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
button.lg.secondary {
background: var(--surface2) !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important; color: var(--muted) !important;
font-family: 'DM Sans', sans-serif !important;
font-size: 13px !important; padding: 10px !important;
width: 100% !important; cursor: pointer !important;
transition: border-color 0.2s !important;
}
button.lg.secondary:hover { border-color: var(--accent) !important; color: var(--accent) !important; }
.auto-btn button {
background: linear-gradient(135deg, #1a0055, #0055ff44) !important;
border: 1px solid var(--accent2) !important;
border-radius: 10px !important; color: var(--accent) !important;
font-family: 'DM Sans', sans-serif !important;
font-size: 13px !important; font-weight: 600 !important;
padding: 10px !important; width: 100% !important;
cursor: pointer !important; letter-spacing: 0.04em !important;
transition: opacity 0.2s, box-shadow 0.2s !important;
}
.auto-btn button:hover {
box-shadow: 0 0 12px #0055ff44 !important;
opacity: 0.9 !important;
}
.out-box textarea {
background: var(--surface2) !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important; font-size: 14px !important;
line-height: 1.8 !important; color: var(--text) !important;
min-height: 220px !important;
}
input[type=range] { accent-color: var(--accent) !important; }
.wrap-inner { background: var(--surface2) !important; border-color: var(--border) !important; }
.examples-holder table {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important; overflow: hidden !important;
}
.examples-holder td, .examples-holder th {
background: transparent !important; color: var(--text) !important;
font-size: 12px !important; border-color: var(--border) !important;
font-family: 'DM Sans', sans-serif !important;
}
.examples-holder tr:hover td { background: var(--surface2) !important; cursor: pointer; }
#footer {
margin-top: 44px; padding-top: 22px;
border-top: 1px solid var(--border);
display: flex; justify-content: space-between;
align-items: center; flex-wrap: wrap; gap: 10px;
}
.fl { font-size: 12px; color: var(--muted); }
.fl strong { color: var(--text); }
.fr { display: flex; gap: 14px; }
.flink { font-size: 12px; color: var(--accent); text-decoration: none; }
"""
with gr.Blocks(title="MedQA β€” AMD ROCm") as demo:
gr.HTML("""
<div id="header">
<div class="badges">
<span class="badge b-amd">AMD MI300X</span>
<span class="badge b-rocm">ROCm 7.2</span>
<span class="badge b-lora">LoRA Fine-tuned</span>
<span class="badge b-live"><span class="dot"></span>Live Inference</span>
</div>
<h1 id="title">Med<em>QA</em> Assistant</h1>
<p class="subtitle">
Clinical question-answering AI fine-tuned on MedMCQA.
Running on AMD Instinct MI300X via ROCm β€” no CUDA required.
Enter any medical MCQ and get an answer with clinical reasoning.
</p>
</div>
<div id="stats">
<div class="stat"><span class="sv">1.7B</span><span class="sl">Parameters</span></div>
<div class="stat"><span class="sv">LoRA</span><span class="sl">Fine-tuning</span></div>
<div class="stat"><span class="sv">193k</span><span class="sl">Training QA</span></div>
<div class="stat"><span class="sv">MI300X</span><span class="sl">AMD GPU</span></div>
<div class="stat"><span class="sv">bf16</span><span class="sl">Precision</span></div>
</div>
""")
with gr.Tabs():
with gr.Tab("Ask a Question"):
with gr.Row():
with gr.Column(scale=5):
gr.HTML('<div class="section-label">Clinical Question</div>')
question = gr.Textbox(
label="",
placeholder="e.g. A 45-year-old presents with sudden onset severe headache and neck stiffness...",
lines=4,
)
auto_btn = gr.Button(
"✨ Auto-generate Options A B C D from Question",
variant="secondary",
elem_classes=["auto-btn"],
)
gr.HTML("<p style='font-size:11px;color:#4a6080;margin-bottom:10px;'>"
"Type your question above then click to auto-fill all 4 options using AI.</p>")
gr.HTML('<div class="section-label">Answer Options</div>')
with gr.Row():
opa = gr.Textbox(label="Option A", placeholder="Auto-generated or type manually")
opb = gr.Textbox(label="Option B", placeholder="Auto-generated or type manually")
with gr.Row():
opc = gr.Textbox(label="Option C", placeholder="Auto-generated or type manually")
opd = gr.Textbox(label="Option D", placeholder="Auto-generated or type manually")
with gr.Row():
btn = gr.Button("βš• Analyze Question", variant="primary")
clr_btn = gr.Button("βœ• Clear", variant="secondary")
with gr.Accordion("βš™ Generation Settings", open=False):
temperature = gr.Slider(
minimum=0.1, maximum=1.5, value=0.7, step=0.05,
label="Temperature (creativity)",
)
max_tokens = gr.Slider(
minimum=50, maximum=400, value=200, step=10,
label="Max output tokens",
)
gr.HTML("""
<p style='font-size:12px;color:#4a6080;margin-top:8px;'>
Lower temperature = more deterministic answers.<br>
Higher = more creative explanations.
</p>""")
with gr.Column(scale=5):
gr.HTML('<div class="section-label">AI Answer & Reasoning</div>')
output = gr.Textbox(
label="",
placeholder="Answer and clinical explanation will appear here...",
lines=10,
elem_classes=["out-box"],
)
gr.HTML('<div class="section-label" style="margin-top:16px">Answer Confidence</div>')
confidence = gr.HTML(
value="<p style='color:#4a6080;font-size:13px;'>Run a query to see confidence distribution.</p>"
)
with gr.Row():
inf_time = gr.Textbox(label="Inference Time", value="β€”", interactive=False, scale=1)
query_disp = gr.Textbox(label="Total Queries", value="0", interactive=False, scale=1)
gr.HTML('<div class="section-label" style="margin-top:24px">Browse by Subject</div>')
with gr.Row():
subject_dd = gr.Dropdown(
choices=SUBJECTS, value="All Subjects", label="Filter by subject", scale=2
)
gr.HTML('<div class="section-label" style="margin-top:12px">Sample Questions β€” click any to load</div>')
gr.Examples(
examples=EXAMPLES,
inputs=[question, opa, opb, opc, opd],
label="",
)
with gr.Tab("Query History"):
gr.HTML('<div class="section-label">Recent Queries</div>')
history_html = gr.HTML(
value="<p style='color:#4a6080;font-size:13px;'>No queries yet β€” ask a question first.</p>"
)
refresh_btn = gr.Button("↻ Refresh History", variant="secondary")
with gr.Tab("About"):
gr.HTML("""
<div style="max-width:800px;margin:0 auto;padding:24px 0;">
<div style="background:#0f1624;border:1px solid #1a3356;border-radius:16px;padding:28px;margin-bottom:20px;">
<h2 style="font-family:'Syne',sans-serif;color:#deeeff;font-size:22px;margin-bottom:16px;">What is MedQA?</h2>
<p style="color:#4a6080;font-size:14px;line-height:1.8;">
MedQA is a clinical question-answering AI fine-tuned on the MedMCQA dataset β€”
193,000 multiple-choice questions from Indian medical entrance exams (AIIMS, USMLE-style).
Given a clinical MCQ with 4 options, the model selects the correct answer and explains
the clinical reasoning.
</p>
</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:16px;margin-bottom:20px;">
<div style="background:#0f1624;border:1px solid #1a3356;border-radius:12px;padding:20px;">
<h3 style="color:#00c8f0;font-size:14px;margin-bottom:12px;">MODEL</h3>
<p style="color:#4a6080;font-size:13px;line-height:1.8;">
Base: Qwen3-1.7B<br>Fine-tuning: LoRA (r=4)<br>
Trainable: 2.2M / 1.7B params<br>Precision: bfloat16
</p>
</div>
<div style="background:#0f1624;border:1px solid #1a3356;border-radius:12px;padding:20px;">
<h3 style="color:#00f0a0;font-size:14px;margin-bottom:12px;">HARDWARE</h3>
<p style="color:#4a6080;font-size:13px;line-height:1.8;">
AMD Instinct MI300X<br>192GB HBM3 memory<br>
ROCm 7.2 on Ubuntu 24.04<br>No CUDA required
</p>
</div>
<div style="background:#0f1624;border:1px solid #1a3356;border-radius:12px;padding:20px;">
<h3 style="color:#ff6030;font-size:14px;margin-bottom:12px;">TRAINING</h3>
<p style="color:#4a6080;font-size:13px;line-height:1.8;">
Dataset: MedMCQA (500 samples)<br>Time: ~5 minutes on MI300X<br>
Optimizer: AdamW<br>Scheduler: Constant + warmup
</p>
</div>
<div style="background:#0f1624;border:1px solid #1a3356;border-radius:12px;padding:20px;">
<h3 style="color:#ffcc00;font-size:14px;margin-bottom:12px;">LINKS</h3>
<p style="font-size:13px;line-height:2.0;">
<a href="https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm" style="color:#00c8f0;">GitHub β†’</a><br>
<a href="https://huggingface.co/HK2184/medqa-qwen3-lora" style="color:#00c8f0;">HuggingFace Model β†’</a><br>
<a href="https://cloud.amd.com" style="color:#00c8f0;">AMD Developer Cloud β†’</a><br>
<a href="https://lablab.ai" style="color:#00c8f0;">lablab.ai Hackathon β†’</a>
</p>
</div>
</div>
<div style="background:#0f1624;border:1px solid #1a3356;border-radius:12px;padding:20px;">
<h3 style="color:#deeeff;font-size:14px;margin-bottom:12px;">BUILT BY</h3>
<p style="color:#4a6080;font-size:13px;">
Harikrishna Sivanand Iyer &nbsp;Β·&nbsp; Srijan Sivaram A<br>
AMD Hackathon 2025 on lablab.ai
</p>
</div>
</div>
""")
gr.HTML("""
<div id="footer">
<div class="fl">
Built on <strong>AMD Developer Cloud</strong> &nbsp;Β·&nbsp;
Model: <strong>Qwen3-1.7B + LoRA</strong> &nbsp;Β·&nbsp;
Dataset: <strong>MedMCQA</strong>
</div>
<div class="fr">
<a class="flink" href="https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm" target="_blank">GitHub β†’</a>
<a class="flink" href="https://huggingface.co/HK2184/medqa-qwen3-lora" target="_blank">Model β†’</a>
<a class="flink" href="https://lablab.ai" target="_blank">lablab.ai β†’</a>
</div>
</div>
""")
# ── Events ────────────────────────────────────────────────────────────────
auto_btn.click(
fn=autogenerate_options,
inputs=[question],
outputs=[opa, opb, opc, opd],
)
btn.click(
fn=generate_answer,
inputs=[question, opa, opb, opc, opd, temperature, max_tokens],
outputs=[output, confidence, inf_time, query_disp],
)
clr_btn.click(
fn=clear_all,
inputs=[],
outputs=[question, opa, opb, opc, opd, output, confidence, inf_time],
)
refresh_btn.click(
fn=get_history_html,
inputs=[],
outputs=[history_html],
)
subject_dd.change(
fn=load_subject_examples,
inputs=[subject_dd],
outputs=[question],
)
if __name__ == "__main__":
demo.launch(css=CSS)