MedGemma-4B ICD-10 Diagnosis Coding β€” QLoRA Adapter (v2, Epoch 5)

A QLoRA fine-tuned adapter for google/medgemma-4b-it that predicts ICD-10-CM diagnosis codes from clinical notes. Focused on Chapter 6: Diseases of the Nervous System (G00-G99) β€” 665 billable codes.

Model Description

This is a LoRA adapter (not a full model) that must be loaded on top of the MedGemma-4B-IT base model. It was trained on 3,325 synthetic clinical notes generated by MedGemma itself (self-distillation), covering diverse documentation styles: SOAP notes, H&P exams, progress notes, consultation reports, and brief assessments.

Property Value
Base Model google/medgemma-4b-it (4B params)
Adapter Type LoRA (PEFT v0.18.1)
Adapter Size 250 MB
Task ICD-10-CM diagnosis code prediction from clinical notes
Domain Nervous System (G00-G99), 665 billable codes
Training Data 3,325 LLM-generated clinical notes (5 per code)
Training Epochs 5 (3 initial + 2 resumed with lower LR)
Hardware Single NVIDIA RTX 5070 (12GB VRAM)

Evaluation Results

Evaluated on 250 held-out clinical notes across 50 ICD-10 codes:

Metric Baseline (No FT) After Fine-Tuning
Exact Code Match 0% 0.4–0.8%
Category Match (3-char prefix) ~10% 73.2%
Produces Valid ICD-10 Code ~20% 100%

Note on category match: The 73.2% was measured on a challenging 250-example eval set with diverse, LLM-generated clinical notes. When combined with BM25 retrieval and trie-based constrained decoding (see inference pipeline below), accuracy improves substantially.

An earlier V1 model trained on simpler template-based data achieved 38% exact match and 88% category match on a smaller 50-example eval set β€” demonstrating that evaluation difficulty scales with dataset diversity.

QLoRA Configuration

{
  "peft_type": "LORA",
  "r": 32,
  "lora_alpha": 64,
  "lora_dropout": 0.05,
  "bias": "none",
  "task_type": "CAUSAL_LM",
  "target_modules": [
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj"
  ]
}

Quantization: 4-bit NF4 with double quantization, bfloat16 compute dtype.

Training Hyperparameters

Parameter Epochs 1–3 Epochs 4–5 (resumed)
Learning rate 1e-4 5e-5
Batch size 2 2
Gradient accumulation 4 (eff. batch = 8) 4 (eff. batch = 8)
Max sequence length 768 tokens 768 tokens
Optimizer AdamW (wd=0.01) AdamW (wd=0.01)
LR scheduler Cosine (5% warmup) Cosine (10% warmup)
Gradient clipping max_norm=1.0 max_norm=1.0
Mixed precision bfloat16 bfloat16

How to Use

Basic Usage (Direct Inference)

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# Load base model with 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    "google/medgemma-4b-it",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained("google/medgemma-4b-it")

# Load the LoRA adapter
model = PeftModel.from_pretrained(base_model, "YOUR_USERNAME/medgemma-icd10-lora-v2")
model.eval()

# Predict ICD-10 code from a clinical note
clinical_note = """
68-year-old male presenting with 2-year history of progressive right-hand
resting tremor. Reports difficulty with fine motor tasks. Examination reveals
4-5 Hz pill-rolling tremor, cogwheel rigidity bilateral upper extremities,
bradykinesia on finger tapping. Gait shows reduced arm swing and mild shuffling.
"""

messages = [
    {"role": "user", "content": f"Given the following clinical note, predict the ICD-10-CM diagnosis code:\n\n{clinical_note}"}
]

input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
input_ids = input_ids.to(model.device)

with torch.no_grad():
    output = model.generate(input_ids, max_new_tokens=50, temperature=0.1, do_sample=True)

response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
print(response)
# Example output: "ICD-10-CM: G20.A1 - Parkinson disease, without fluctuations"

Recommended: With BM25 RAG + Constrained Decoding

For production use, we recommend the full inference pipeline with:

  1. BM25 retrieval to narrow candidates to top-15 codes
  2. Trie-based constrained decoding to guarantee valid ICD-10 output

See the Gradio app for the complete implementation.

Training Data

Trained on the MedGemma ICD-10 Clinical Notes Dataset β€” 3,325 synthetic clinical notes generated by MedGemma-4B-IT (self-distillation).

Key characteristics:

  • 665 ICD-10 codes (G00-G99, billable only)
  • 5 notes per code with varied styles and demographics
  • 10 prompt templates for training input diversity
  • No data leakage β€” clinical notes describe presentations without naming codes or diagnoses directly
  • Average note length: 265 words (range: 152–445)

Intended Use

  • Medical coding assistance: Suggest ICD-10 codes from clinical documentation
  • Research: Benchmarking clinical NLP models on structured code prediction
  • Education: Demonstrating QLoRA fine-tuning for domain-specific medical AI tasks

Limitations

  • Nervous System only β€” trained on G00-G99 codes; will not predict codes from other ICD-10 chapters
  • Single diagnosis β€” predicts one code per note; real encounters often require multiple codes
  • Synthetic training data β€” not trained on real clinical records
  • Not clinically validated β€” has not been evaluated by certified medical coders against production data
  • English only

Ethical Considerations

  • This model is for research and educational purposes only
  • ICD-10 coding in production requires certified medical coders and validated, regulated systems
  • Incorrect diagnosis codes can lead to claim denials, billing errors, and patient safety issues
  • Always have a human expert review model predictions before clinical or billing use

Technical Details

  • Framework: Pure PyTorch training loop (no HF Trainer dependency)
  • Environment: Python 3.14, CUDA 12.8, PyTorch 2.10+cu128
  • Training time: ~6 hours for 5 epochs on RTX 5070

Citation

@misc{medgemma_icd10_finetuning,
  title={Fine-Tuning MedGemma-4B for ICD-10 Diagnosis Coding},
  author={singhak-abbvie},
  year={2026},
  url={https://github.com/singhak-abbvie/medgemma_finetuning_ICD_10}
}

Links

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for singhankit16/Medgemma-fine-tuned-ICD10-code-prediction

Adapter
(91)
this model

Dataset used to train singhankit16/Medgemma-fine-tuned-ICD10-code-prediction