Spaces:
Sleeping
Sleeping
File size: 2,173 Bytes
7d6a969 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | # app.py
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ================= CONFIG ================= #
MODEL_DIR = "finetuned_model"
MAX_INPUT_LEN = 1024
MAX_OUTPUT_LEN = 256
NUM_BEAMS = 4
PROMPT = (
"Generate a structured SOAP clinical summary with clearly separated "
"Subjective (S), Objective (O), Assessment (A), and Plan (P) sections "
"from the following medical dialogue:\n"
)
# ================= LOAD MODEL ================= #
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
model.eval()
# ================= INFERENCE FUNCTION ================= #
def generate_soap(dialogue):
if not dialogue or len(dialogue.strip()) == 0:
return "Please enter a medical dialogue."
inputs = tokenizer(
PROMPT + dialogue,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_LEN
)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_length=MAX_OUTPUT_LEN,
num_beams=NUM_BEAMS,
no_repeat_ngram_size=3,
repetition_penalty=1.3,
length_penalty=1.0,
early_stopping=True
)
return tokenizer.decode(
output_ids[0],
skip_special_tokens=True
)
# ================= GRADIO UI ================= #
iface = gr.Interface(
fn=generate_soap,
inputs=gr.Textbox(
lines=10,
placeholder="Enter doctor–patient medical dialogue here...",
label="Medical Dialogue"
),
outputs=gr.Textbox(
lines=12,
label="Generated SOAP Clinical Summary"
),
title="SOAP Clinical Summary Generator",
description="Fine-tuned FLAN-T5 model for SOAP note generation.",
examples=[
["Patient reports fever and cough for three days. No history of chronic illness."],
["Patient complains of chest pain and shortness of breath during exertion."]
],
)
if __name__ == "__main__":
iface.launch()
|