icd10-api / app.py
StudioIlios's picture
Update app.py
ac332d0 verified
Raw
History Blame Contribute Delete
1.67 kB
import os
import re
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
hf_token = os.environ.get("HF_TOKEN")
MODEL_ID = "StudioIlios/icd10-model"
BASE_MODEL = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
token=hf_token
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
token=hf_token,
device_map="auto"
)
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(
base_model,
MODEL_ID,
token=hf_token
)
model.eval()
print("Model loaded successfully!")
def predict(clinical_text):
prompt = f"""
Patient has {clinical_text}.
What is the ICD10 code?
"""
inputs = tokenizer(
prompt,
return_tensors="pt"
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
full_output = tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
print("INPUT:", clinical_text)
print("OUTPUT:", full_output)
icd_match = re.search(
r'\b[A-Z][0-9][A-Z0-9](?:\.[A-Z0-9]{1,4})?\b',
full_output
)
if icd_match:
return icd_match.group(0)
return full_output
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(
label="Clinical Text",
lines=4
),
outputs=gr.Textbox(
label="Predicted ICD-10 Code"
),
title="ICD-10 Predictor"
)
demo.launch()