singhankit16 commited on
Commit
d88facc
·
verified ·
1 Parent(s): db31955

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,200 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: gemma
5
+ library_name: peft
6
+ base_model: google/medgemma-4b-it
7
+ tags:
8
+ - medical
9
+ - icd-10
10
+ - clinical-coding
11
+ - lora
12
+ - qlora
13
+ - medgemma
14
+ - healthcare
15
+ - neurology
16
+ - peft
17
+ pipeline_tag: text-generation
18
+ datasets:
19
+ - YOUR_USERNAME/medgemma-icd10-clinical-notes
20
+ ---
21
+
22
+ # MedGemma-4B ICD-10 Diagnosis Coding — QLoRA Adapter (v2, Epoch 5)
23
+
24
+ A **QLoRA fine-tuned adapter** for [google/medgemma-4b-it](https://huggingface.co/google/medgemma-4b-it) that predicts ICD-10-CM diagnosis codes from clinical notes. Focused on **Chapter 6: Diseases of the Nervous System (G00-G99)** — 665 billable codes.
25
+
26
+ ## Model Description
27
+
28
+ This is a **LoRA adapter** (not a full model) that must be loaded on top of the MedGemma-4B-IT base model. It was trained on 3,325 synthetic clinical notes generated by MedGemma itself (self-distillation), covering diverse documentation styles: SOAP notes, H&P exams, progress notes, consultation reports, and brief assessments.
29
+
30
+ | Property | Value |
31
+ |----------|-------|
32
+ | **Base Model** | [google/medgemma-4b-it](https://huggingface.co/google/medgemma-4b-it) (4B params) |
33
+ | **Adapter Type** | LoRA (PEFT v0.18.1) |
34
+ | **Adapter Size** | 250 MB |
35
+ | **Task** | ICD-10-CM diagnosis code prediction from clinical notes |
36
+ | **Domain** | Nervous System (G00-G99), 665 billable codes |
37
+ | **Training Data** | 3,325 LLM-generated clinical notes (5 per code) |
38
+ | **Training Epochs** | 5 (3 initial + 2 resumed with lower LR) |
39
+ | **Hardware** | Single NVIDIA RTX 5070 (12GB VRAM) |
40
+
41
+ ## Evaluation Results
42
+
43
+ Evaluated on 250 held-out clinical notes across 50 ICD-10 codes:
44
+
45
+ | Metric | Baseline (No FT) | After Fine-Tuning |
46
+ |--------|-------------------|-------------------|
47
+ | **Exact Code Match** | 0% | 0.4–0.8% |
48
+ | **Category Match (3-char prefix)** | ~10% | **73.2%** |
49
+ | **Produces Valid ICD-10 Code** | ~20% | **100%** |
50
+
51
+ **Note on category match**: The 73.2% was measured on a challenging 250-example eval set with diverse, LLM-generated clinical notes. When combined with BM25 retrieval and trie-based constrained decoding (see inference pipeline below), accuracy improves substantially.
52
+
53
+ An earlier V1 model trained on simpler template-based data achieved 38% exact match and 88% category match on a smaller 50-example eval set — demonstrating that evaluation difficulty scales with dataset diversity.
54
+
55
+ ## QLoRA Configuration
56
+
57
+ ```json
58
+ {
59
+ "peft_type": "LORA",
60
+ "r": 32,
61
+ "lora_alpha": 64,
62
+ "lora_dropout": 0.05,
63
+ "bias": "none",
64
+ "task_type": "CAUSAL_LM",
65
+ "target_modules": [
66
+ "q_proj", "k_proj", "v_proj", "o_proj",
67
+ "gate_proj", "up_proj", "down_proj"
68
+ ]
69
+ }
70
+ ```
71
+
72
+ **Quantization**: 4-bit NF4 with double quantization, bfloat16 compute dtype.
73
+
74
+ ## Training Hyperparameters
75
+
76
+ | Parameter | Epochs 1–3 | Epochs 4–5 (resumed) |
77
+ |-----------|------------|----------------------|
78
+ | Learning rate | 1e-4 | 5e-5 |
79
+ | Batch size | 2 | 2 |
80
+ | Gradient accumulation | 4 (eff. batch = 8) | 4 (eff. batch = 8) |
81
+ | Max sequence length | 768 tokens | 768 tokens |
82
+ | Optimizer | AdamW (wd=0.01) | AdamW (wd=0.01) |
83
+ | LR scheduler | Cosine (5% warmup) | Cosine (10% warmup) |
84
+ | Gradient clipping | max_norm=1.0 | max_norm=1.0 |
85
+ | Mixed precision | bfloat16 | bfloat16 |
86
+
87
+ ## How to Use
88
+
89
+ ### Basic Usage (Direct Inference)
90
+
91
+ ```python
92
+ import torch
93
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
94
+ from peft import PeftModel
95
+
96
+ # Load base model with 4-bit quantization
97
+ bnb_config = BitsAndBytesConfig(
98
+ load_in_4bit=True,
99
+ bnb_4bit_quant_type="nf4",
100
+ bnb_4bit_use_double_quant=True,
101
+ bnb_4bit_compute_dtype=torch.bfloat16,
102
+ )
103
+
104
+ base_model = AutoModelForCausalLM.from_pretrained(
105
+ "google/medgemma-4b-it",
106
+ quantization_config=bnb_config,
107
+ device_map="auto",
108
+ torch_dtype=torch.bfloat16,
109
+ )
110
+
111
+ tokenizer = AutoTokenizer.from_pretrained("google/medgemma-4b-it")
112
+
113
+ # Load the LoRA adapter
114
+ model = PeftModel.from_pretrained(base_model, "YOUR_USERNAME/medgemma-icd10-lora-v2")
115
+ model.eval()
116
+
117
+ # Predict ICD-10 code from a clinical note
118
+ clinical_note = """
119
+ 68-year-old male presenting with 2-year history of progressive right-hand
120
+ resting tremor. Reports difficulty with fine motor tasks. Examination reveals
121
+ 4-5 Hz pill-rolling tremor, cogwheel rigidity bilateral upper extremities,
122
+ bradykinesia on finger tapping. Gait shows reduced arm swing and mild shuffling.
123
+ """
124
+
125
+ messages = [
126
+ {"role": "user", "content": f"Given the following clinical note, predict the ICD-10-CM diagnosis code:\n\n{clinical_note}"}
127
+ ]
128
+
129
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
130
+ input_ids = input_ids.to(model.device)
131
+
132
+ with torch.no_grad():
133
+ output = model.generate(input_ids, max_new_tokens=50, temperature=0.1, do_sample=True)
134
+
135
+ response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
136
+ print(response)
137
+ # Example output: "ICD-10-CM: G20.A1 - Parkinson disease, without fluctuations"
138
+ ```
139
+
140
+ ### Recommended: With BM25 RAG + Constrained Decoding
141
+
142
+ For production use, we recommend the full inference pipeline with:
143
+ 1. **BM25 retrieval** to narrow candidates to top-15 codes
144
+ 2. **Trie-based constrained decoding** to guarantee valid ICD-10 output
145
+
146
+ See the [Gradio app](https://github.com/singhak-abbvie/medgemma_finetuning_ICD_10/blob/main/app_icd10.py) for the complete implementation.
147
+
148
+ ## Training Data
149
+
150
+ Trained on the [MedGemma ICD-10 Clinical Notes Dataset](https://huggingface.co/datasets/YOUR_USERNAME/medgemma-icd10-clinical-notes) — 3,325 synthetic clinical notes generated by MedGemma-4B-IT (self-distillation).
151
+
152
+ Key characteristics:
153
+ - **665 ICD-10 codes** (G00-G99, billable only)
154
+ - **5 notes per code** with varied styles and demographics
155
+ - **10 prompt templates** for training input diversity
156
+ - **No data leakage** — clinical notes describe presentations without naming codes or diagnoses directly
157
+ - **Average note length**: 265 words (range: 152–445)
158
+
159
+ ## Intended Use
160
+
161
+ - **Medical coding assistance**: Suggest ICD-10 codes from clinical documentation
162
+ - **Research**: Benchmarking clinical NLP models on structured code prediction
163
+ - **Education**: Demonstrating QLoRA fine-tuning for domain-specific medical AI tasks
164
+
165
+ ## Limitations
166
+
167
+ - **Nervous System only** — trained on G00-G99 codes; will not predict codes from other ICD-10 chapters
168
+ - **Single diagnosis** — predicts one code per note; real encounters often require multiple codes
169
+ - **Synthetic training data** — not trained on real clinical records
170
+ - **Not clinically validated** — has not been evaluated by certified medical coders against production data
171
+ - **English only**
172
+
173
+ ## Ethical Considerations
174
+
175
+ - This model is for **research and educational purposes only**
176
+ - ICD-10 coding in production requires certified medical coders and validated, regulated systems
177
+ - Incorrect diagnosis codes can lead to claim denials, billing errors, and patient safety issues
178
+ - Always have a human expert review model predictions before clinical or billing use
179
+
180
+ ## Technical Details
181
+
182
+ - **Framework**: Pure PyTorch training loop (no HF Trainer dependency)
183
+ - **Environment**: Python 3.14, CUDA 12.8, PyTorch 2.10+cu128
184
+ - **Training time**: ~6 hours for 5 epochs on RTX 5070
185
+
186
+ ## Citation
187
+
188
+ ```bibtex
189
+ @misc{medgemma_icd10_finetuning,
190
+ title={Fine-Tuning MedGemma-4B for ICD-10 Diagnosis Coding},
191
+ author={singhak-abbvie},
192
+ year={2026},
193
+ url={https://github.com/singhak-abbvie/medgemma_finetuning_ICD_10}
194
+ }
195
+ ```
196
+
197
+ ## Links
198
+
199
+ - **GitHub**: [singhak-abbvie/medgemma_finetuning_ICD_10](https://github.com/singhak-abbvie/medgemma_finetuning_ICD_10)
200
+ - **Dataset**: [HF Dataset](https://huggingface.co/datasets/YOUR_USERNAME/medgemma-icd10-clinical-notes)
adapter_config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alora_invocation_tokens": null,
3
+ "alpha_pattern": {},
4
+ "arrow_config": null,
5
+ "auto_mapping": null,
6
+ "base_model_name_or_path": "google/medgemma-4b-it",
7
+ "bias": "none",
8
+ "corda_config": null,
9
+ "ensure_weight_tying": false,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": true,
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 64,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.05,
22
+ "megatron_config": null,
23
+ "megatron_core": "megatron.core",
24
+ "modules_to_save": null,
25
+ "peft_type": "LORA",
26
+ "peft_version": "0.18.1",
27
+ "qalora_group_size": 16,
28
+ "r": 32,
29
+ "rank_pattern": {},
30
+ "revision": null,
31
+ "target_modules": [
32
+ "down_proj",
33
+ "gate_proj",
34
+ "o_proj",
35
+ "k_proj",
36
+ "v_proj",
37
+ "q_proj",
38
+ "up_proj"
39
+ ],
40
+ "target_parameters": null,
41
+ "task_type": "CAUSAL_LM",
42
+ "trainable_token_indices": null,
43
+ "use_dora": false,
44
+ "use_qalora": false,
45
+ "use_rslora": false
46
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfebe700a30cd45ce829cc63d18ea20b97a44b2616aefd8eecd583ea23b9aa04
3
+ size 262406656
chat_template.jinja ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{ bos_token }}
2
+ {%- if messages[0]['role'] == 'system' -%}
3
+ {%- if messages[0]['content'] is string -%}
4
+ {%- set first_user_prefix = messages[0]['content'] + '
5
+
6
+ ' -%}
7
+ {%- else -%}
8
+ {%- set first_user_prefix = messages[0]['content'][0]['text'] + '
9
+
10
+ ' -%}
11
+ {%- endif -%}
12
+ {%- set loop_messages = messages[1:] -%}
13
+ {%- else -%}
14
+ {%- set first_user_prefix = "" -%}
15
+ {%- set loop_messages = messages -%}
16
+ {%- endif -%}
17
+ {%- for message in loop_messages -%}
18
+ {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
19
+ {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
20
+ {%- endif -%}
21
+ {%- if (message['role'] == 'assistant') -%}
22
+ {%- set role = "model" -%}
23
+ {%- else -%}
24
+ {%- set role = message['role'] -%}
25
+ {%- endif -%}
26
+ {{ '<start_of_turn>' + role + '
27
+ ' + (first_user_prefix if loop.first else "") }}
28
+ {%- if message['content'] is string -%}
29
+ {{ message['content'] | trim }}
30
+ {%- elif message['content'] is iterable -%}
31
+ {%- for item in message['content'] -%}
32
+ {%- if item['type'] == 'image' -%}
33
+ {{ '<start_of_image>' }}
34
+ {%- elif item['type'] == 'text' -%}
35
+ {{ item['text'] | trim }}
36
+ {%- endif -%}
37
+ {%- endfor -%}
38
+ {%- else -%}
39
+ {{ raise_exception("Invalid content type") }}
40
+ {%- endif -%}
41
+ {{ '<end_of_turn>
42
+ ' }}
43
+ {%- endfor -%}
44
+ {%- if add_generation_prompt -%}
45
+ {{'<start_of_turn>model
46
+ '}}
47
+ {%- endif -%}
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96e462e2ee46f6a9af0bef56724cb91559bf1c618d7c2900174643bcda3cf563
3
+ size 33384831
tokenizer_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "boi_token": "<start_of_image>",
4
+ "bos_token": "<bos>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eoi_token": "<end_of_image>",
7
+ "eos_token": "<eos>",
8
+ "image_token": "<image_soft_token>",
9
+ "is_local": true,
10
+ "mask_token": "<mask>",
11
+ "max_length": 768,
12
+ "model_max_length": 1000000000000000019884624838656,
13
+ "model_specific_special_tokens": {
14
+ "boi_token": "<start_of_image>",
15
+ "eoi_token": "<end_of_image>",
16
+ "image_token": "<image_soft_token>"
17
+ },
18
+ "pad_to_multiple_of": null,
19
+ "pad_token": "<pad>",
20
+ "pad_token_type_id": 0,
21
+ "padding_side": "left",
22
+ "processor_class": "Gemma3Processor",
23
+ "sp_model_kwargs": null,
24
+ "spaces_between_special_tokens": false,
25
+ "stride": 0,
26
+ "tokenizer_class": "GemmaTokenizer",
27
+ "truncation_side": "right",
28
+ "truncation_strategy": "longest_first",
29
+ "unk_token": "<unk>",
30
+ "use_default_system_prompt": false
31
+ }