Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ================= CONFIG ================= # | |
| MODEL_DIR = "finetuned_model" | |
| MAX_INPUT_LEN = 1024 | |
| MAX_OUTPUT_LEN = 256 | |
| NUM_BEAMS = 4 | |
| PROMPT = ( | |
| "Generate a structured SOAP clinical summary with clearly separated " | |
| "Subjective (S), Objective (O), Assessment (A), and Plan (P) sections " | |
| "from the following medical dialogue:\n" | |
| ) | |
| # ================= LOAD MODEL ================= # | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR) | |
| model.eval() | |
| # ================= INFERENCE FUNCTION ================= # | |
| def generate_soap(dialogue): | |
| if not dialogue or len(dialogue.strip()) == 0: | |
| return "Please enter a medical dialogue." | |
| inputs = tokenizer( | |
| PROMPT + dialogue, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_INPUT_LEN | |
| ) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_length=MAX_OUTPUT_LEN, | |
| num_beams=NUM_BEAMS, | |
| no_repeat_ngram_size=3, | |
| repetition_penalty=1.3, | |
| length_penalty=1.0, | |
| early_stopping=True | |
| ) | |
| return tokenizer.decode( | |
| output_ids[0], | |
| skip_special_tokens=True | |
| ) | |
| # ================= GRADIO UI ================= # | |
| iface = gr.Interface( | |
| fn=generate_soap, | |
| inputs=gr.Textbox( | |
| lines=10, | |
| placeholder="Enter doctor–patient medical dialogue here...", | |
| label="Medical Dialogue" | |
| ), | |
| outputs=gr.Textbox( | |
| lines=12, | |
| label="Generated SOAP Clinical Summary" | |
| ), | |
| title="SOAP Clinical Summary Generator", | |
| description="Fine-tuned FLAN-T5 model for SOAP note generation.", | |
| examples=[ | |
| ["Patient reports fever and cough for three days. No history of chronic illness."], | |
| ["Patient complains of chest pain and shortness of breath during exertion."] | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |