Diabetes_readmission / models /bart_explainer.py
Parishri07's picture
Upload 11 files
dda3dc2 verified
raw
history blame contribute delete
788 Bytes
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_NAME = "dmacres/bart-large-mimiciii-v2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
def generate_explanation(note: str, risk_score: float) -> str:
prompt = f"""
Discharge summary:
{note}
Predicted readmission risk: {risk_score:.2f}
Explain the key clinical reasons for readmission risk.
"""
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024
)
outputs = model.generate(
**inputs,
max_length=200,
num_beams=4,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)