Spaces:
Runtime error
Runtime error
| import os | |
| import ast | |
| import spaces | |
| import gradio as gr | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import login as hf_login | |
| import xgrammar as xgr | |
| from pydantic import BaseModel | |
| hf_login(token=os.getenv("HF_TOKEN")) | |
| model_name = "gregorlied/Llama-3.2-1B-Instruct-Medical-Report-Summarization-FP32" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| attn_implementation='eager', | |
| trust_remote_code=True, | |
| ) | |
| class Person(BaseModel): | |
| life_style: str | |
| family_history: str | |
| social_history: str | |
| medical_surgical_history: str | |
| signs_symptoms: str | |
| comorbidities: str | |
| diagnostic_techniques_procedures: str | |
| diagnosis: str | |
| laboratory_values: str | |
| pathology: str | |
| pharmacological_therapy: str | |
| interventional_therapy: str | |
| patient_outcome_assessment: str | |
| age: str | |
| gender: str | |
| config = AutoConfig.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer_info = xgr.TokenizerInfo.from_huggingface( | |
| tokenizer, vocab_size=len(tokenizer) | |
| ) | |
| grammar_compiler = xgr.GrammarCompiler(tokenizer_info) | |
| compiled_grammar = grammar_compiler.compile_json_schema(Person) | |
| default_value = "A 57-year-old male presented with fever (38.9Β°C), chest pain, cough, and progressive dyspnea. The patient exhibited tachypnea (34 breaths/min) and tachycardia (134 bpm). Auscultation revealed decreased breath sounds in both lung bases, with crackles on the left. A chest X-ray revealed bilateral pleural opacities and enlargement of the cardiac silhouette ( A). Echocardiography showed moderate pericardial effusion affecting the entire cardiac silhouette. Pericardiocentesis yielded 250 mL of exudative fluid. A CT scan of the chest showed pneumonia in the left lower lobe, bilateral pleural effusion, and moderate pericardial effusion ( B). Thoracentesis was performed and yielded 1,050 mL of exudative fluid. Laboratory tests yielded the following data: white blood cell count, 11.78 Γ 109 cells/L (84.3% neutrophils, 4.3% lymphocytes, and 9.1% monocytes); platelet count, 512 Γ 109/L; serum C-reactive protein, 31.27 mg/dL; serum creatinine, 0.94 mg/dL; serum sodium, 133 mEq/L; and serum potassium, 3.72 mEq/L. Examination of the pleural fluid showed a pH of 7.16, a glucose level of 4.5 mg/dL, proteins at 49.1 g/L, and an LDH content of 1,385 U/L. A urinary pneumococcal antigen test was positive. Pleural fluid culture was positive for S. pneumoniae. The patient was treated for four weeks with amoxicillin-clavulanate (2.2 g/8 h, i.v.) plus levofloxacin (500 mg twice a day), together with a nonsteroidal anti-inflammatory drug (ibuprofen, 800 mg/day), after which there was nearly complete resolution of the alterations seen on the chest X-ray and CT scan." | |
| prompt = """You are a text extraction system for clinical reports. | |
| Please extract relevant clinical information from the report. | |
| ### Instructions | |
| - Use the JSON Schema given below. | |
| - Return only a valid JSON object β no markdown, no comments. | |
| - If no relevant facts are given for a field, set its value to "N/A". | |
| - If multile relevant facts are given for a field, separate them with "; ". | |
| ### JSON Schema | |
| { | |
| 'life_style': '', | |
| 'family_history': '', | |
| 'social_history': '', | |
| 'medical_surgical_history': '', | |
| 'signs_symptoms': '', | |
| 'comorbidities': '', | |
| 'diagnostic_techniques_procedures': '', | |
| 'diagnosis': '', | |
| 'laboratory_values': '', | |
| 'pathology': '', | |
| 'pharmacological_therapy': '', | |
| 'interventional_therapy': '', | |
| 'patient_outcome_assessment': '', | |
| 'age': '', | |
| 'gender': '', | |
| } | |
| ### Clinical Report | |
| """ | |
| def generate_html_tables(data, selected_fields): | |
| key_label_map = { | |
| 'age': 'Age', | |
| 'gender': 'Gender', | |
| 'life_style': 'Lifestyle', | |
| 'social_history': 'Social Background', | |
| 'medical_surgical_history': 'Personal', | |
| 'family_history': 'Family Members', | |
| 'signs_symptoms': 'Symptoms', | |
| 'comorbidities': 'Comorbid Conditions', | |
| 'diagnostic_techniques_procedures': 'Diagnostic Procedures', | |
| 'laboratory_values': 'Laboratory Results', | |
| 'pathology': 'Pathology Report', | |
| 'diagnosis': 'Diagnosis', | |
| 'interventional_therapy': 'Interventional Therapy', | |
| 'pharmacological_therapy': 'Pharmacological Therapy', | |
| 'patient_outcome_assessment': 'Patient Outcome', | |
| } | |
| label_key_map = {v: k for k, v in key_label_map.items()} | |
| categories = { | |
| "Personal Information": ["Age", "Gender", "Lifestyle", "Social Background"], | |
| "Medical History": ["Personal", "Family Members"], | |
| "Clinical Presentation": ["Symptoms", "Comorbid Conditions"], | |
| "Medical Assessment": ["Diagnostic Procedures", "Laboratory Results", "Pathology Report"], | |
| "Diagnosis": ["Diagnosis"], | |
| "Treatment": ["Interventional Therapy", "Pharmacological Therapy"], | |
| "Patient Outcome": ["Patient Outcome"], | |
| } | |
| def format_bullets(value): | |
| items = [item.strip() for item in value.split(";") if item.strip()] | |
| if not items: | |
| return "<i>Not Available</i>" | |
| if len(items) == 1: | |
| return items[0] | |
| return "<ul style='margin: 0; padding-left: 1em'>" + "".join(f"<li>{item}</li>" for item in items) + "</ul>" | |
| table_style = ( | |
| "width: 100%;" | |
| "height: 100%;" | |
| "table-layout: fixed;" | |
| ) | |
| th_td_style = ( | |
| "padding: 8px;" | |
| "border: 1px solid #ccc;" | |
| "vertical-align: top;" | |
| "text-align: left;" | |
| ) | |
| html_tables = [] | |
| for section, labels in categories.items(): | |
| section_fields = [label for label in labels if label in selected_fields] | |
| if section_fields: | |
| table_html = f"<h3 style='margin-bottom: 0.5em;'>{section}</h3>" | |
| table_html += f"<table style='{table_style}'>" | |
| table_html += f"<tr><th style='height: 30px; {th_td_style}; width: 150px;'>Field</th><th style='height: 30px; {th_td_style};'>Details</th></tr>" | |
| for label in section_fields: | |
| key = label_key_map[label] | |
| value = data.get(key, "N/A") | |
| details = "<i>Not Available</i>" if value == "N/A" else format_bullets(value) | |
| table_html += f"<tr><td style='{th_td_style}; width: 150px;'><b>{label}</b></td><td style='{th_td_style}'>{details}</td></tr>" | |
| table_html += "</table>" | |
| html_tables.append(table_html) | |
| i = 0 | |
| grouped_html = "" | |
| while i < len(html_tables): | |
| num_per_row = 2 if i < 4 else 3 | |
| row_tables = html_tables[i:i+num_per_row] | |
| grouped_html += ( | |
| "<div style='display: flex; gap: 1em; margin-bottom: 2em;'>" | |
| ) | |
| for table in row_tables: | |
| grouped_html += ( | |
| "<div style='display: flex; flex-direction: column;'>" | |
| f"{table}" | |
| "</div>" | |
| ) | |
| grouped_html += "</div>" | |
| i += num_per_row | |
| return f"<div style='font-family: sans-serif;'>{grouped_html}</div>" | |
| def summarize( | |
| text, | |
| personal_info, | |
| medical_history, | |
| clinical_presentation, | |
| medical_assessment, | |
| diagnosis, | |
| treatment, | |
| patient_outcome, | |
| ): | |
| if not text.strip(): | |
| return "Please enter some text to summarize." | |
| if text == default_value: | |
| response = ['{"life_style": "N/A", "family_history": "N/A", "social_history": "N/A", "medical_surgical_history": "N/A", "signs_symptoms": "Fever; Chest pain; Cough; Progressive dyspnea; Tachypnea; Tachycardia; Decreased breath sounds in both lung bases; Crackles on the left", "comorbidities": "N/A", "diagnostic_techniques_procedures": "Chest X-ray; Echocardiography; Thoracentesis; Laboratory tests; Pleural fluid analysis; Urinary pneumococcal antigen test; Pleural fluid culture", "diagnosis": "Pneumonia; Pericardial effusion; S. pneumoniae infection", "laboratory_values": "White blood cell count: 11.78 \\u00d7 10^9 cells/L (84.3% neutrophils, 4.3% lymphocytes, 9.1% monocytes); Platelet count: 512 \\u00d7 10^9/L; Serum C-reactive protein: 31.27 mg/dL; Serum creatinine: 0.94 mg/dL; Serum sodium: 133 mEq/L; Serum potassium: 3.72 mEq/L; Pleural fluid pH: 7.16; Pleural fluid glucose: 4.5 mg/dL; Pleural fluid proteins: 49.1 g/L; Pleural fluid LDH: 1,385 U/L", "pathology": "N/A", "pharmacological_therapy": "Amoxicillin-clavulanate (2.2 g/8 h, i.v.); Levofloxacin (500 mg twice a day); Ibuprofen (800 mg/day)", "interventional_therapy": "Pericardiocentesis; Thoracentesis", "patient_outcome_assessment": "Nearly complete resolution of alterations on chest X-ray and CT scan", "age": "57 year", "gender": "Male"}'] | |
| else: | |
| messages = [ | |
| {"role": "system", "content": prompt.strip()}, | |
| {"role": "user", "content": text.strip()}, | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, # only relevant for qwen | |
| ) | |
| # We cannot reset here because __call__ is not invoked when stop token is sampled. | |
| # Therefore, each `generate()` call needs to instantiate an LogitsProcessor. | |
| xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| generated_ids = model.generate( | |
| input_ids=model_inputs["input_ids"], | |
| attention_mask = model_inputs["attention_mask"], | |
| max_new_tokens=2048, | |
| logits_processor=[xgr_logits_processor] | |
| ) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
| try: | |
| data = ast.literal_eval(response[0]) | |
| except: | |
| data = { | |
| 'life_style': 'N/A', | |
| 'family_history': 'N/A', | |
| 'social_history': 'N/A', | |
| 'medical_surgical_history': 'N/A', | |
| 'signs_symptoms': 'N/A', | |
| 'comorbidities': 'N/A', | |
| 'diagnostic_techniques_procedures': 'N/A', | |
| 'diagnosis': 'N/A', | |
| 'laboratory_values': 'N/A', | |
| 'pathology': '', | |
| 'pharmacological_therapy': 'N/A', | |
| 'interventional_therapy': 'N/A', | |
| 'patient_outcome_assessment': 'N/A', | |
| 'age': 'N/A', | |
| 'gender': 'N/A', | |
| } | |
| selected_fields = [] | |
| selected_fields += personal_info | |
| selected_fields += medical_history | |
| selected_fields += clinical_presentation | |
| selected_fields += medical_assessment | |
| selected_fields += diagnosis | |
| selected_fields += treatment | |
| selected_fields += patient_outcome | |
| return generate_html_tables(data, selected_fields) | |
| with gr.Blocks() as demo: | |
| # need to be combined with `hf_oauth: true` in README.md | |
| # button = gr.LoginButton("Sign in") | |
| with gr.Column(): | |
| gr.HTML(""" | |
| <div align="center"> | |
| <img src="https://huggingface.co/spaces/gregorlied/medical-text-summarization/resolve/main/assets/LlamaMD-logo.png" alt="LlamaMD Logo" width="120" style="margin-bottom: 10px;"> | |
| <h2><strong>LlamaMD</strong></h2> | |
| <p><em>Structured Information Extraction from Clinical Reports</em></p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("LLamaMD"): | |
| with gr.Row(): | |
| input_text = gr.Textbox( | |
| label="Clinical Report", | |
| autoscroll=False, | |
| lines=15, | |
| max_lines=15, | |
| placeholder="Paste your clinical report here...", | |
| value=default_value, | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| personal_info = gr.CheckboxGroup( | |
| label="Personal Information", | |
| choices=[ | |
| "Age", | |
| "Gender", | |
| "Lifestyle", | |
| "Social Background", | |
| ], | |
| value=[ | |
| "Age", | |
| "Gender", | |
| "Lifestyle", | |
| "Social Background", | |
| ], | |
| ) | |
| with gr.Column(): | |
| medical_history = gr.CheckboxGroup( | |
| label="Medical History", | |
| choices=[ | |
| "Personal", | |
| "Family Members", | |
| ], | |
| value=[ | |
| "Personal", | |
| "Family Members", | |
| ], | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| clinical_presentation = gr.CheckboxGroup( | |
| label="Clinical Presentation", | |
| choices=[ | |
| "Symptoms", | |
| "Comorbid Conditions", | |
| ], | |
| value=[ | |
| "Symptoms", | |
| "Comorbid Conditions", | |
| ], | |
| ) | |
| with gr.Column(): | |
| medical_assessment = gr.CheckboxGroup( | |
| label="Medical Assessment", | |
| choices=[ | |
| "Diagnostic Procedures", | |
| "Laboratory Results", | |
| "Pathology Report", | |
| ], | |
| value=[ | |
| "Diagnostic Procedures", | |
| "Laboratory Results", | |
| "Pathology Report", | |
| ], | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| diagnosis = gr.CheckboxGroup( | |
| label="Diagnosis", | |
| choices=[ | |
| "Diagnosis", | |
| ], | |
| value=[ | |
| "Diagnosis", | |
| ], | |
| ) | |
| with gr.Column(): | |
| treatment = gr.CheckboxGroup( | |
| label="Treatment", | |
| choices=[ | |
| "Interventional Therapy", | |
| "Pharmacological Therapy", | |
| ], | |
| value=[ | |
| "Interventional Therapy", | |
| "Pharmacological Therapy", | |
| ], | |
| ) | |
| with gr.Column(): | |
| patient_outcome = gr.CheckboxGroup( | |
| label="Patient Outcome", | |
| choices=[ | |
| "Patient Outcome", | |
| ], | |
| value=[ | |
| "Patient Outcome", | |
| ], | |
| ) | |
| with gr.Row(): | |
| summarize_btn = gr.Button("Extract") | |
| with gr.Row(): | |
| output_text = gr.HTML() | |
| summarize_btn.click( | |
| fn=summarize, | |
| inputs=[input_text, personal_info, medical_history, clinical_presentation, medical_assessment, diagnosis, treatment, patient_outcome], | |
| outputs=output_text, | |
| show_progress=True, | |
| ) | |
| with gr.Tab("Help"): | |
| gr.Markdown("""## Help | |
| ### Personal Information | |
| **Age**: Age of the patient.<br> | |
| **Gender**: Gender of the patient.<br> | |
| **Lifestyle**: Daily habits and activities of the patient (e.g. alcohol consumption, diet, smoking status).<br> | |
| **Social Background**: Social factors of the patient (e.g. housing situation, marital status).<br> | |
| ### Medical History | |
| **Personal**: Past medical conditions, previous surgeries or treatments of the patient.<br> | |
| **Family Members**: Relevant medical conditions or genetic disorders in the patientβs family (e.g. cancer, heart disease).<br> | |
| ### Clinical Presentation | |
| **Symptoms**: Current symptoms of the patient.<br> | |
| **Comorbid Conditions**: Other medical conditions of the patient that may influence the treatment.<br> | |
| ### Medical Assessment | |
| **Diagnostic Procedures**: Description of the diagnostic tests or procedures performed (e.g. X-rays, MRIs)<br> | |
| **Laboratory Results**: Results foom laboratory test (e.g. blood counts, electrolyte levels)<br> | |
| **Pathology Report**: Findings from pathological examinations (e.g. biopsy results)<br> | |
| ### Diagnosis | |
| **Diagnosis**: All levels of diagnosis mentioned in the report.<br> | |
| ### Treatment | |
| **Interventional Therapy**: Medications prescribed to the patient.<br> | |
| **Pharmacological Therapy**: Information on surgical or non-surgical interventions performed.<br> | |
| ### Patient Outcome | |
| **Patient Outcome**: Evaluation of the patientβs health status at the end of treatment.<br> | |
| """) | |
| with gr.Tab("About"): | |
| gr.Markdown("""## About | |
| LlamaMD is a project developed as part of the "NLP for Social Good" course at TU Berlin. | |
| The goal of this project is to perform structured information extraction from clinical reports, helping doctors to have more time for their patients. | |
| The system is based on `meta-llama/Llama-3.2-1B-Instruct`, which has been fine-tuned on the ELMTEX dataset. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |