Model Card for MedGemma-4b-ICD

This model is a fine-tuned version of google/medgemma-4b-it. It has been trained using TRL.

Quick start

model = AutoModelForCausalLM.from_pretrained("abnuel/MedGemma-4b-ICD", torch_dtype=torch.bfloat16, device_map="auto")

tokenizer = AutoTokenizer.from_pretrained("abnuel/MedGemma-4b-ICD")

SYSTEM_PROMPT = "You are an expert medical coder. Your task is to analyze the clinical description provided and output only the single, most appropriate ICD-10-CM code. Do not include any text, justification other than the code itself."

def generate_response(clinical_note):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"Code the following: {clinical_note}"},
    ]

    # 3. Apply chat template and tokenize

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device)

    input_len = inputs["input_ids"].shape[-1]

    # 4. Generate the response
    with torch.inference_mode():
        generation = model.generate(
            **inputs,
            max_new_tokens=200,      # Max length of the generated ICD codes
            do_sample=False,         # Use greedy decoding for predictable output
            temperature=0.0,         # Zero temperature for deterministic results
        )

    # 5. Decode the output
    # Extract only the newly generated tokens
    generation = generation[0][input_len:]
    decoded_output = tokenizer.decode(generation, skip_special_tokens=True)

    return decoded_output.strip()

# --- Example Usage ---
test_note = "Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."


print(f"Clinical Note: {test_note}")
response = generate_response(test_note)
print(f"Generated ICD Codes: {response}")

Training procedure

This model was trained with SFT.

Framework versions

  • TRL: 0.24.0
  • Transformers: 4.57.1
  • Pytorch: 2.6.0+cu124
  • Datasets: 4.1.1
  • Tokenizers: 0.22.1

Citations

Cite TRL as:

@misc{vonwerra2022trl,
    title        = {{TRL: Transformer Reinforcement Learning}},
    author       = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
    year         = 2020,
    journal      = {GitHub repository},
    publisher    = {GitHub},
    howpublished = {\url{https://github.com/huggingface/trl}}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for abnuel/MedGemma-4b-ICD

Finetuned
(479)
this model

Space using abnuel/MedGemma-4b-ICD 1