MedQA: Fine-Tuning a Clinical AI on AMD ROCm — No CUDA Required
The Idea
Medical question answering is one of those tasks where the stakes are genuinely high. A model that confidently picks the wrong answer on a clinical MCQ isn't just wrong — it's dangerous. At the same time, most open-source medical AI work assumes you have an NVIDIA GPU. CUDA is the default. Everything else is an afterthought.
This project challenges that assumption.
MedQA is a LoRA fine-tuned clinical question-answering model built entirely on AMD hardware using ROCm. It takes a multiple-choice medical question and returns both the correct answer letter and a clinical explanation of the reasoning. The entire training pipeline — from data loading to adapter export — runs on an AMD Instinct MI300X without a single CUDA dependency.
- 🤗 Model on HuggingFace Hub: HK2184/medqa-qwen3-lora
- 🚀 Live Demo: HuggingFace Spaces
- 💻 GitHub: MedQA-Medical-AI-on-AMD-ROCm
Why AMD ROCm?
The AMD Instinct MI300X is a remarkable piece of hardware: 192 GB of HBM3 memory in a single device. For LLM fine-tuning, VRAM is often the binding constraint — it dictates batch size, sequence length, and whether you need to quantize at all. With 192 GB available, we trained Qwen3-1.7B with LoRA in full fp16 without any 4-bit or 8-bit quantization hacks.
More importantly, the goal was to prove that the HuggingFace ecosystem — Transformers, PEFT, TRL, Accelerate — works seamlessly on ROCm. It does. The same training code that runs on CUDA runs on ROCm with three environment variables set:
os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"
That's it. No code changes. No custom kernels. No CUDA compatibility shims.
The Dataset: MedMCQA
MedMCQA is a large-scale multiple-choice question dataset derived from Indian medical entrance exams (AIIMS, USMLE-style). Each example contains:
- A clinical question
- Four answer options (A–D)
- The correct answer index
- An optional free-text explanation (
expfield)
For this project we used 2,000 training samples — a deliberately small slice to demonstrate that meaningful fine-tuning is achievable quickly. Training took approximately 5 minutes on the MI300X.
Model: Qwen3-1.7B
The base model is Qwen/Qwen3-1.7B — Alibaba's latest small-scale language model. At 1.7 billion parameters it's compact enough to fine-tune cheaply but capable enough to produce coherent clinical reasoning. It supports trust_remote_code=True and loads cleanly with HuggingFace Transformers.
The Prompt Format
Consistency in prompt formatting is critical for instruction fine-tuning. Every training example and every inference call uses the same template:
### Question:
{question}
### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}
### Answer:
{answer_letter}) {answer_text}
### Explanation:
{explanation}
During training the model sees the full sequence including the answer and explanation. During inference we provide everything up to ### Answer:\n and let the model complete from there.
Training with LoRA
Rather than fine-tuning all 1.5 billion parameters, we use LoRA (Low-Rank Adaptation) via the PEFT library. LoRA injects small trainable rank-decomposition matrices into the attention layers, leaving the base weights frozen.
LoRA Configuration
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443
Only ~2.2 million of the model's 1.5 billion parameters are trained. This keeps memory usage low and training fast.
Training Arguments
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./outputs",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # effective batch size = 16
learning_rate=2e-4,
fp16=True,
bf16=False,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
report_to="none",
)
A few things worth noting:
fp16=True, bf16=False— We use standard fp16. In early experiments with bfloat16 we encountered NaN loss; switching to fp16 resolved it entirely.gradient_checkpointing=True— Trades compute for memory. Not strictly necessary on MI300X given the 192 GB VRAM, but good practice for reproducibility on smaller GPUs.gradient_accumulation_steps=4— Effective batch size of 16 with a physical batch of 4.- Cosine LR schedule with warmup — Smoother convergence than a flat schedule for short training runs.
The Full Training Loop
from transformers import DataCollatorForSeq2Seq, Trainer
collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding=True,
pad_to_multiple_of=8,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collator,
)
trainer.train()
# Save adapter + tokenizer
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")
After training, ./outputs contains the LoRA adapter weights — a few MB of files rather than a full multi-GB model checkpoint.
Inference
At inference time we load the base model, attach the LoRA adapter, and optionally merge the weights:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()
Generation uses greedy decoding (do_sample=False) with a repetition penalty to prevent the model from looping:
def generate(prompt, model, tokenizer):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False,
temperature=1.0,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
Sample Output
Question: Which of the following is the first-line treatment for hypertensive emergency?
A) Oral amlodipine
B) IV labetalol or IV nitroprusside
C) Sublingual nifedipine
D) IM hydralazine
Model Output:
B) IV labetalol or IV nitroprusside
Explanation:
Intravenous labetalol (beta-blocker) or nitroprusside rapidly reduces blood
pressure in emergency settings. Oral agents act too slowly for hypertensive
emergencies requiring immediate BP control to prevent end-organ damage.
The model doesn't just output a letter — it explains why, which is what makes it clinically useful.
Loading from HuggingFace Hub
The fine-tuned adapter is publicly available. You can load it directly without cloning the repo:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-1.7B", trust_remote_code=True
)
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()
Challenges and Fixes
No AMD ROCm project is complete without a war story section. Here's what we ran into:
| Challenge | Root Cause | Fix |
|---|---|---|
| NaN loss | Mixed precision instability | Switched from bfloat16 → fp16 |
| GPU not detected | Missing ROCm env variables | Set ROCR_VISIBLE_DEVICES, HIP_VISIBLE_DEVICES, HSA_OVERRIDE_GFX_VERSION |
| bitsandbytes unsupported | No ROCm build of bitsandbytes | Dropped quantization entirely — MI300X has enough VRAM |
| Garbage inference output | Tokenizer padding misconfigured | Set pad_token = eos_token and fixed padding_side |
| Trainer eval errors | Transformers version mismatch | Pinned transformers>=4.40.0 |
The bitsandbytes issue deserves a note: on NVIDIA hardware, 4-bit quantization is often required to fit a model in memory. On MI300X with 192 GB HBM3, it's simply unnecessary. This is a genuine hardware advantage — cleaner training, no quantization artifacts.
Results
| Metric | Value |
|---|---|
| Trainable parameters | ~2.2M (0.15% of total) |
| Training time on MI300X | ~5 minutes |
| Dataset size used | 2,000 samples |
| Baseline MedMCQA accuracy | ~45% |
| Framework | PyTorch + ROCm 6.1 |
Try It Yourself
No GPU? No problem. The live Gradio demo runs on HuggingFace Spaces (CPU inference):
👉 Live Demo on HuggingFace Spaces
Have AMD hardware? Clone the repo and run it natively:
git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py # ~5 minutes
python infer.py # run sample questions
python app.py # launch Gradio UI
What's Next
This project proves the pipeline works. The next steps are about scaling and hardening it:
- Larger dataset — Train on the full MedMCQA corpus (~180k questions) and add PubMedQA
- Confidence scoring — Add calibrated confidence estimates alongside answers
- RAG integration — Ground answers in real-time medical literature retrieval
- Evaluation harness — Proper held-out accuracy benchmarking beyond the training split
Conclusion
MedQA shows that building a capable, explainable medical AI on open-source AMD hardware is not only possible — it's straightforward. The HuggingFace ecosystem's ROCm compatibility is genuinely good. The MI300X's memory headroom removes an entire category of engineering problems. And LoRA makes fine-tuning a 1.7B model a 5-minute job.
If you're building on AMD ROCm and hitting walls, the fixes above should save you hours. And if you're building medical AI, the emphasis on explanation over bare accuracy is worth taking seriously.
Built for the AMD Developer Hackathon on lablab.ai · Powered by AMD ROCm + HuggingFace ecosystem
*— Harikrishna Sivanand Iyer and Srijan Sivaram A
