MedQA: Fine-Tuning a Clinical AI on AMD ROCm — No CUDA Required

Team Article Published May 8, 2026

A complete walkthrough of LoRA fine-tuning Qwen3-1.7B on MedMCQA using AMD MI300X, built for the AMD Developer Hackathon on lablab.ai.


The Idea

Medical question answering is one of those tasks where the stakes are genuinely high. A model that confidently picks the wrong answer on a clinical MCQ isn't just wrong — it's dangerous. At the same time, most open-source medical AI work assumes you have an NVIDIA GPU. CUDA is the default. Everything else is an afterthought.

This project challenges that assumption.

MedQA is a LoRA fine-tuned clinical question-answering model built entirely on AMD hardware using ROCm. It takes a multiple-choice medical question and returns both the correct answer letter and a clinical explanation of the reasoning. The entire training pipeline — from data loading to adapter export — runs on an AMD Instinct MI300X without a single CUDA dependency.


Why AMD ROCm?

The AMD Instinct MI300X is a remarkable piece of hardware: 192 GB of HBM3 memory in a single device. For LLM fine-tuning, VRAM is often the binding constraint — it dictates batch size, sequence length, and whether you need to quantize at all. With 192 GB available, we trained Qwen3-1.7B with LoRA in full fp16 without any 4-bit or 8-bit quantization hacks.

More importantly, the goal was to prove that the HuggingFace ecosystem — Transformers, PEFT, TRL, Accelerate — works seamlessly on ROCm. It does. The same training code that runs on CUDA runs on ROCm with three environment variables set:

os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"

That's it. No code changes. No custom kernels. No CUDA compatibility shims.


The Dataset: MedMCQA

MedMCQA is a large-scale multiple-choice question dataset derived from Indian medical entrance exams (AIIMS, USMLE-style). Each example contains:

  • A clinical question
  • Four answer options (A–D)
  • The correct answer index
  • An optional free-text explanation (exp field)

For this project we used 2,000 training samples — a deliberately small slice to demonstrate that meaningful fine-tuning is achievable quickly. Training took approximately 5 minutes on the MI300X.


Model: Qwen3-1.7B

The base model is Qwen/Qwen3-1.7B — Alibaba's latest small-scale language model. At 1.7 billion parameters it's compact enough to fine-tune cheaply but capable enough to produce coherent clinical reasoning. It supports trust_remote_code=True and loads cleanly with HuggingFace Transformers.


The Prompt Format

Consistency in prompt formatting is critical for instruction fine-tuning. Every training example and every inference call uses the same template:

### Question:
{question}

### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}

### Answer:
{answer_letter}) {answer_text}

### Explanation:
{explanation}

During training the model sees the full sequence including the answer and explanation. During inference we provide everything up to ### Answer:\n and let the model complete from there.


Training with LoRA

Rather than fine-tuning all 1.5 billion parameters, we use LoRA (Low-Rank Adaptation) via the PEFT library. LoRA injects small trainable rank-decomposition matrices into the attention layers, leaving the base weights frozen.

LoRA Configuration

from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443

Only ~2.2 million of the model's 1.5 billion parameters are trained. This keeps memory usage low and training fast.

Training Arguments

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="./outputs",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,     # effective batch size = 16
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    optim="adamw_torch",
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    report_to="none",
)

A few things worth noting:

  • fp16=True, bf16=False — We use standard fp16. In early experiments with bfloat16 we encountered NaN loss; switching to fp16 resolved it entirely.
  • gradient_checkpointing=True — Trades compute for memory. Not strictly necessary on MI300X given the 192 GB VRAM, but good practice for reproducibility on smaller GPUs.
  • gradient_accumulation_steps=4 — Effective batch size of 16 with a physical batch of 4.
  • Cosine LR schedule with warmup — Smoother convergence than a flat schedule for short training runs.

The Full Training Loop

from transformers import DataCollatorForSeq2Seq, Trainer

collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True,
    pad_to_multiple_of=8,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator,
)

trainer.train()

# Save adapter + tokenizer
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")

After training, ./outputs contains the LoRA adapter weights — a few MB of files rather than a full multi-GB model checkpoint.


Inference

At inference time we load the base model, attach the LoRA adapter, and optionally merge the weights:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()

Generation uses greedy decoding (do_sample=False) with a repetition penalty to prevent the model from looping:

def generate(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=False,
            temperature=1.0,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )

    new_tokens = output[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)

Sample Output

Question: Which of the following is the first-line treatment for hypertensive emergency?

A) Oral amlodipine
B) IV labetalol or IV nitroprusside
C) Sublingual nifedipine
D) IM hydralazine

Model Output:

B) IV labetalol or IV nitroprusside

Explanation:
Intravenous labetalol (beta-blocker) or nitroprusside rapidly reduces blood
pressure in emergency settings. Oral agents act too slowly for hypertensive
emergencies requiring immediate BP control to prevent end-organ damage.

The model doesn't just output a letter — it explains why, which is what makes it clinically useful.


Loading from HuggingFace Hub

The fine-tuned adapter is publicly available. You can load it directly without cloning the repo:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-1.7B", trust_remote_code=True
)

base = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()

Challenges and Fixes

No AMD ROCm project is complete without a war story section. Here's what we ran into:

Challenge Root Cause Fix
NaN loss Mixed precision instability Switched from bfloat16 → fp16
GPU not detected Missing ROCm env variables Set ROCR_VISIBLE_DEVICES, HIP_VISIBLE_DEVICES, HSA_OVERRIDE_GFX_VERSION
bitsandbytes unsupported No ROCm build of bitsandbytes Dropped quantization entirely — MI300X has enough VRAM
Garbage inference output Tokenizer padding misconfigured Set pad_token = eos_token and fixed padding_side
Trainer eval errors Transformers version mismatch Pinned transformers>=4.40.0

The bitsandbytes issue deserves a note: on NVIDIA hardware, 4-bit quantization is often required to fit a model in memory. On MI300X with 192 GB HBM3, it's simply unnecessary. This is a genuine hardware advantage — cleaner training, no quantization artifacts.


Results

Metric Value
Trainable parameters ~2.2M (0.15% of total)
Training time on MI300X ~5 minutes
Dataset size used 2,000 samples
Baseline MedMCQA accuracy ~45%
Framework PyTorch + ROCm 6.1

Try It Yourself

No GPU? No problem. The live Gradio demo runs on HuggingFace Spaces (CPU inference):

👉 Live Demo on HuggingFace Spaces

Have AMD hardware? Clone the repo and run it natively:

git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py   # ~5 minutes
python infer.py   # run sample questions
python app.py     # launch Gradio UI

What's Next

This project proves the pipeline works. The next steps are about scaling and hardening it:

  • Larger dataset — Train on the full MedMCQA corpus (~180k questions) and add PubMedQA
  • Confidence scoring — Add calibrated confidence estimates alongside answers
  • RAG integration — Ground answers in real-time medical literature retrieval
  • Evaluation harness — Proper held-out accuracy benchmarking beyond the training split

Conclusion

MedQA shows that building a capable, explainable medical AI on open-source AMD hardware is not only possible — it's straightforward. The HuggingFace ecosystem's ROCm compatibility is genuinely good. The MI300X's memory headroom removes an entire category of engineering problems. And LoRA makes fine-tuning a 1.7B model a 5-minute job.

If you're building on AMD ROCm and hitting walls, the fixes above should save you hours. And if you're building medical AI, the emphasis on explanation over bare accuracy is worth taking seriously.


Built for the AMD Developer Hackathon on lablab.ai · Powered by AMD ROCm + HuggingFace ecosystem

*— Harikrishna Sivanand Iyer and Srijan Sivaram A

Screenshot From 2026-05-07 14-26-07

Community

Sign up or log in to comment