MedGemma-4B ICD-10 Diagnosis Coding β QLoRA Adapter (v2, Epoch 5)
A QLoRA fine-tuned adapter for google/medgemma-4b-it that predicts ICD-10-CM diagnosis codes from clinical notes. Focused on Chapter 6: Diseases of the Nervous System (G00-G99) β 665 billable codes.
Model Description
This is a LoRA adapter (not a full model) that must be loaded on top of the MedGemma-4B-IT base model. It was trained on 3,325 synthetic clinical notes generated by MedGemma itself (self-distillation), covering diverse documentation styles: SOAP notes, H&P exams, progress notes, consultation reports, and brief assessments.
| Property | Value |
|---|---|
| Base Model | google/medgemma-4b-it (4B params) |
| Adapter Type | LoRA (PEFT v0.18.1) |
| Adapter Size | 250 MB |
| Task | ICD-10-CM diagnosis code prediction from clinical notes |
| Domain | Nervous System (G00-G99), 665 billable codes |
| Training Data | 3,325 LLM-generated clinical notes (5 per code) |
| Training Epochs | 5 (3 initial + 2 resumed with lower LR) |
| Hardware | Single NVIDIA RTX 5070 (12GB VRAM) |
Evaluation Results
Evaluated on 250 held-out clinical notes across 50 ICD-10 codes:
| Metric | Baseline (No FT) | After Fine-Tuning |
|---|---|---|
| Exact Code Match | 0% | 0.4β0.8% |
| Category Match (3-char prefix) | ~10% | 73.2% |
| Produces Valid ICD-10 Code | ~20% | 100% |
Note on category match: The 73.2% was measured on a challenging 250-example eval set with diverse, LLM-generated clinical notes. When combined with BM25 retrieval and trie-based constrained decoding (see inference pipeline below), accuracy improves substantially.
An earlier V1 model trained on simpler template-based data achieved 38% exact match and 88% category match on a smaller 50-example eval set β demonstrating that evaluation difficulty scales with dataset diversity.
QLoRA Configuration
{
"peft_type": "LORA",
"r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"bias": "none",
"task_type": "CAUSAL_LM",
"target_modules": [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
]
}
Quantization: 4-bit NF4 with double quantization, bfloat16 compute dtype.
Training Hyperparameters
| Parameter | Epochs 1β3 | Epochs 4β5 (resumed) |
|---|---|---|
| Learning rate | 1e-4 | 5e-5 |
| Batch size | 2 | 2 |
| Gradient accumulation | 4 (eff. batch = 8) | 4 (eff. batch = 8) |
| Max sequence length | 768 tokens | 768 tokens |
| Optimizer | AdamW (wd=0.01) | AdamW (wd=0.01) |
| LR scheduler | Cosine (5% warmup) | Cosine (10% warmup) |
| Gradient clipping | max_norm=1.0 | max_norm=1.0 |
| Mixed precision | bfloat16 | bfloat16 |
How to Use
Basic Usage (Direct Inference)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# Load base model with 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
"google/medgemma-4b-it",
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("google/medgemma-4b-it")
# Load the LoRA adapter
model = PeftModel.from_pretrained(base_model, "YOUR_USERNAME/medgemma-icd10-lora-v2")
model.eval()
# Predict ICD-10 code from a clinical note
clinical_note = """
68-year-old male presenting with 2-year history of progressive right-hand
resting tremor. Reports difficulty with fine motor tasks. Examination reveals
4-5 Hz pill-rolling tremor, cogwheel rigidity bilateral upper extremities,
bradykinesia on finger tapping. Gait shows reduced arm swing and mild shuffling.
"""
messages = [
{"role": "user", "content": f"Given the following clinical note, predict the ICD-10-CM diagnosis code:\n\n{clinical_note}"}
]
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
input_ids = input_ids.to(model.device)
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=50, temperature=0.1, do_sample=True)
response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
print(response)
# Example output: "ICD-10-CM: G20.A1 - Parkinson disease, without fluctuations"
Recommended: With BM25 RAG + Constrained Decoding
For production use, we recommend the full inference pipeline with:
- BM25 retrieval to narrow candidates to top-15 codes
- Trie-based constrained decoding to guarantee valid ICD-10 output
See the Gradio app for the complete implementation.
Training Data
Trained on the MedGemma ICD-10 Clinical Notes Dataset β 3,325 synthetic clinical notes generated by MedGemma-4B-IT (self-distillation).
Key characteristics:
- 665 ICD-10 codes (G00-G99, billable only)
- 5 notes per code with varied styles and demographics
- 10 prompt templates for training input diversity
- No data leakage β clinical notes describe presentations without naming codes or diagnoses directly
- Average note length: 265 words (range: 152β445)
Intended Use
- Medical coding assistance: Suggest ICD-10 codes from clinical documentation
- Research: Benchmarking clinical NLP models on structured code prediction
- Education: Demonstrating QLoRA fine-tuning for domain-specific medical AI tasks
Limitations
- Nervous System only β trained on G00-G99 codes; will not predict codes from other ICD-10 chapters
- Single diagnosis β predicts one code per note; real encounters often require multiple codes
- Synthetic training data β not trained on real clinical records
- Not clinically validated β has not been evaluated by certified medical coders against production data
- English only
Ethical Considerations
- This model is for research and educational purposes only
- ICD-10 coding in production requires certified medical coders and validated, regulated systems
- Incorrect diagnosis codes can lead to claim denials, billing errors, and patient safety issues
- Always have a human expert review model predictions before clinical or billing use
Technical Details
- Framework: Pure PyTorch training loop (no HF Trainer dependency)
- Environment: Python 3.14, CUDA 12.8, PyTorch 2.10+cu128
- Training time: ~6 hours for 5 epochs on RTX 5070
Citation
@misc{medgemma_icd10_finetuning,
title={Fine-Tuning MedGemma-4B for ICD-10 Diagnosis Coding},
author={singhak-abbvie},
year={2026},
url={https://github.com/singhak-abbvie/medgemma_finetuning_ICD_10}
}
Links
- GitHub: singhak-abbvie/medgemma_finetuning_ICD_10
- Dataset: HF Dataset
- Downloads last month
- -