Spaces:
Sleeping
Sleeping
| # # # ============================================================ | |
| # # # ๐ง Medical Dialogue โ SOAP Note Generator (Fine-tuned Phi-3) | |
| # # # ============================================================ | |
| # # import gradio as gr | |
| # # import torch | |
| # # import transformers | |
| # # from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| # # from peft import PeftModel, PeftConfig | |
| # # import os | |
| # # import gc | |
| # # if hasattr(transformers, "DynamicCache"): | |
| # # transformers.DynamicCache.seen_tokens = property(lambda self: None) | |
| # # # ------------------------------------------------------------ | |
| # # # โ๏ธ Environment setup for memory management | |
| # # # ------------------------------------------------------------ | |
| # # os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" | |
| # # os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # # # ------------------------------------------------------------ | |
| # # # ๐๏ธ Model identifiers | |
| # # # ------------------------------------------------------------ | |
| # # base_model_name = "microsoft/phi-3-mini-4k-instruct" | |
| # # fine_tuned_model_name = "raselmeya2194/med_dialogue2soap" | |
| # # # ------------------------------------------------------------ | |
| # # # ๐งฎ Quantization (4-bit for memory efficiency if GPU is present) | |
| # # # ------------------------------------------------------------ | |
| # # bnb_config = BitsAndBytesConfig( | |
| # # load_in_4bit=True, | |
| # # bnb_4bit_use_double_quant=True, | |
| # # bnb_4bit_quant_type="nf4", | |
| # # bnb_4bit_compute_dtype=torch.bfloat16 | |
| # # ) if torch.cuda.is_available() else None | |
| # # # ------------------------------------------------------------ | |
| # # # ๐ง Load Base Model (with automatic device mapping) | |
| # # # ------------------------------------------------------------ | |
| # # print("๐น Loading base model...") | |
| # # try: | |
| # # fine_tuned_model = AutoModelForCausalLM.from_pretrained( | |
| # # base_model_name, | |
| # # quantization_config=bnb_config, | |
| # # device_map="auto" if torch.cuda.is_available() else None, | |
| # # trust_remote_code=True, | |
| # # torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| # # ) | |
| # # print("โ Base model loaded successfully.") | |
| # # except Exception as e: | |
| # # print(f"โ Error loading base model: {e}") | |
| # # raise | |
| # # # ------------------------------------------------------------ | |
| # # # ๐ค Load Tokenizer | |
| # # # ------------------------------------------------------------ | |
| # # print("๐น Loading tokenizer...") | |
| # # try: | |
| # # tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) | |
| # # if tokenizer.pad_token is None: | |
| # # tokenizer.pad_token = tokenizer.eos_token | |
| # # print("โ Tokenizer loaded successfully.") | |
| # # except Exception as e: | |
| # # print(f"โ Error loading tokenizer: {e}") | |
| # # raise | |
| # # # ------------------------------------------------------------ | |
| # # # ๐งฉ Load LoRA Adapter (PEFT fine-tuned weights) | |
| # # # ------------------------------------------------------------ | |
| # # print("๐น Loading LoRA adapter...") | |
| # # offload_dir = "./offload_folder" | |
| # # os.makedirs(offload_dir, exist_ok=True) | |
| # # try: | |
| # # peft_config = PeftConfig.from_pretrained(fine_tuned_model_name) | |
| # # fine_tuned_model = PeftModel.from_pretrained( | |
| # # fine_tuned_model, | |
| # # fine_tuned_model_name, | |
| # # offload_folder=offload_dir, | |
| # # config=peft_config | |
| # # ) | |
| # # print("โ LoRA adapter loaded successfully.") | |
| # # except Exception as e: | |
| # # print(f"โ Error loading LoRA adapter: {e}") | |
| # # raise | |
| # # # ------------------------------------------------------------ | |
| # # # ๐ Prepare model for inference | |
| # # # ------------------------------------------------------------ | |
| # # fine_tuned_model.eval() | |
| # # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # # fine_tuned_model.to(device) | |
| # # if torch.cuda.device_count() > 1: | |
| # # print(f"๐ Using {torch.cuda.device_count()} GPUs via DataParallel!") | |
| # # fine_tuned_model = torch.nn.DataParallel(fine_tuned_model) | |
| # # if torch.cuda.is_available(): | |
| # # torch.cuda.empty_cache() | |
| # # gc.collect() | |
| # # # ------------------------------------------------------------ | |
| # # # ๐งพ SOAP Note Generation Function | |
| # # # ------------------------------------------------------------ | |
| # # def generate_soap(input_text: str, temperature: float = 0.7, max_new_tokens: int = 512): | |
| # # """ | |
| # # Generates a SOAP note from a doctor-patient dialogue using the fine-tuned model. | |
| # # """ | |
| # # try: | |
| # # # ๐งฉ Format input | |
| # # prompt = ( | |
| # # f"<|user|>\n" | |
| # # f"Generate a structured SOAP note based on the following doctor-patient dialogue:\n\n" | |
| # # f"{input_text.strip()}\n" | |
| # # f"<|end|>\n<|assistant|>SOAP Notes:\n" | |
| # # ) | |
| # # # ๐ Tokenize input | |
| # # inputs = tokenizer( | |
| # # prompt, | |
| # # return_tensors="pt", | |
| # # truncation=True, | |
| # # max_length=4096, | |
| # # padding=True | |
| # # ).to(device) | |
| # # autocast_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # # # โก Inference | |
| # # with torch.no_grad(): | |
| # # with torch.amp.autocast(autocast_device): | |
| # # model = fine_tuned_model.module if isinstance(fine_tuned_model, torch.nn.DataParallel) else fine_tuned_model | |
| # # outputs = model.generate( | |
| # # input_ids=inputs["input_ids"], | |
| # # attention_mask=inputs.get("attention_mask"), | |
| # # max_new_tokens=max_new_tokens, | |
| # # temperature=temperature, | |
| # # top_p=0.9, | |
| # # top_k=50, | |
| # # pad_token_id=tokenizer.pad_token_id, | |
| # # eos_token_id=tokenizer.eos_token_id, | |
| # # do_sample=True, | |
| # # no_repeat_ngram_size=3 | |
| # # ) | |
| # # # ๐ Extract and decode generated tokens | |
| # # generated_ids = outputs[0][inputs["input_ids"].shape[-1]:] | |
| # # generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| # # # ๐งผ Clean output | |
| # # if "SOAP Notes:" in generated_text: | |
| # # generated_text = generated_text.split("SOAP Notes:")[-1].strip() | |
| # # return generated_text | |
| # # except Exception as e: | |
| # # return f"โ Error generating SOAP note: {str(e)}" | |
| # # # ------------------------------------------------------------ | |
| # # # ๐จ Gradio Interface | |
| # # # ------------------------------------------------------------ | |
| # # ============================================================ | |
| # # ๐ง Medical Dialogue โ SOAP Note Generator (Fine-tuned Phi-3) | |
| # # ============================================================ | |
| # import gradio as gr | |
| # import torch | |
| # from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # from peft import PeftModel, PeftConfig | |
| # import os | |
| # import gc | |
| # # ------------------------------------------------------------ | |
| # # โ๏ธ Environment setup for memory management | |
| # # ------------------------------------------------------------ | |
| # os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # # ------------------------------------------------------------ | |
| # # ๐๏ธ Model identifiers | |
| # # ------------------------------------------------------------ | |
| # base_model_name = "microsoft/phi-3-mini-4k-instruct" | |
| # fine_tuned_model_name = "raselmeya2194/med_dialogue2soap" | |
| # # ------------------------------------------------------------ | |
| # # ๐ง Load Base Model (CPU mode) | |
| # # ------------------------------------------------------------ | |
| # print("๐น Loading base model on CPU...") | |
| # try: | |
| # fine_tuned_model = AutoModelForCausalLM.from_pretrained( | |
| # base_model_name, | |
| # device_map=None, # Force CPU | |
| # trust_remote_code=True, | |
| # torch_dtype=torch.float32 | |
| # ) | |
| # print("โ Base model loaded successfully.") | |
| # except Exception as e: | |
| # print(f"โ Error loading base model: {e}") | |
| # raise | |
| # # ------------------------------------------------------------ | |
| # # ๐ค Load Tokenizer | |
| # # ------------------------------------------------------------ | |
| # print("๐น Loading tokenizer...") | |
| # try: | |
| # tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) | |
| # if tokenizer.pad_token is None: | |
| # tokenizer.pad_token = tokenizer.eos_token | |
| # print("โ Tokenizer loaded successfully.") | |
| # except Exception as e: | |
| # print(f"โ Error loading tokenizer: {e}") | |
| # raise | |
| # # ------------------------------------------------------------ | |
| # # ๐งฉ Load LoRA Adapter (PEFT fine-tuned weights) | |
| # # ------------------------------------------------------------ | |
| # print("๐น Loading LoRA adapter...") | |
| # try: | |
| # peft_config = PeftConfig.from_pretrained(fine_tuned_model_name) | |
| # fine_tuned_model = PeftModel.from_pretrained( | |
| # fine_tuned_model, | |
| # fine_tuned_model_name, | |
| # config=peft_config | |
| # ) | |
| # print("โ LoRA adapter loaded successfully.") | |
| # except Exception as e: | |
| # print(f"โ Error loading LoRA adapter: {e}") | |
| # raise | |
| # # ------------------------------------------------------------ | |
| # # ๐ Prepare model for inference | |
| # # ------------------------------------------------------------ | |
| # fine_tuned_model.eval() | |
| # device = torch.device("cpu") | |
| # fine_tuned_model.to(device) | |
| # gc.collect() | |
| # torch.cuda.empty_cache() | |
| # # ------------------------------------------------------------ | |
| # # ๐งพ SOAP Note Generation Function | |
| # # ------------------------------------------------------------ | |
| # def generate_soap(input_text: str, temperature: float = 0.7, max_new_tokens: int = 512): | |
| # """ | |
| # Generates a SOAP note from a doctor-patient dialogue using the fine-tuned model (CPU). | |
| # """ | |
| # try: | |
| # # Format input | |
| # prompt = ( | |
| # f"<|user|>\n" | |
| # f"Generate a structured SOAP note based on the following doctor-patient dialogue:\n\n" | |
| # f"{input_text.strip()}\n" | |
| # f"<|end|>\n<|assistant|>SOAP Notes:\n" | |
| # ) | |
| # # Tokenize input | |
| # inputs = tokenizer( | |
| # prompt, | |
| # return_tensors="pt", | |
| # truncation=True, | |
| # max_length=2048, | |
| # padding=True | |
| # ).to(device) | |
| # # Run model (no autocast or GPU) | |
| # with torch.no_grad(): | |
| # outputs = fine_tuned_model.generate( | |
| # input_ids=inputs["input_ids"], | |
| # attention_mask=inputs.get("attention_mask"), | |
| # max_new_tokens=max_new_tokens, | |
| # temperature=temperature, | |
| # top_p=0.9, | |
| # top_k=50, | |
| # pad_token_id=tokenizer.pad_token_id, | |
| # eos_token_id=tokenizer.eos_token_id, | |
| # do_sample=True, | |
| # no_repeat_ngram_size=3, | |
| # use_cache=False | |
| # ) | |
| # # Extract and decode generated tokens | |
| # generated_ids = outputs[0][inputs["input_ids"].shape[-1]:] | |
| # generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| # # Clean up final output | |
| # if "SOAP Notes:" in generated_text: | |
| # generated_text = generated_text.split("SOAP Notes:")[-1].strip() | |
| # return generated_text | |
| # except Exception as e: | |
| # return f"โ Error generating SOAP note: {str(e)}" | |
| # # ------------------------------------------------------------ | |
| # # ๐จ Gradio Interface | |
| # # ------------------------------------------------------------ | |
| # custom_css = """ | |
| # /* Import Google Fonts */ | |
| # @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
| # /* Global Styles */ | |
| # body { | |
| # background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| # font-family: 'Inter', sans-serif; | |
| # margin: 0; | |
| # padding: 0; | |
| # min-height: 100vh; | |
| # } | |
| # .gradio-container { | |
| # max-width: 1200px !important; | |
| # margin: 0 auto !important; | |
| # padding: 2rem !important; | |
| # } | |
| # /* Main Container Styling */ | |
| # .contain { | |
| # background: rgba(255, 255, 255, 0.95) !important; | |
| # backdrop-filter: blur(10px); | |
| # border-radius: 24px !important; | |
| # box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3) !important; | |
| # padding: 3rem !important; | |
| # animation: fadeIn 0.6s ease-in-out; | |
| # } | |
| # @keyframes fadeIn { | |
| # from { | |
| # opacity: 0; | |
| # transform: translateY(20px); | |
| # } | |
| # to { | |
| # opacity: 1; | |
| # transform: translateY(0); | |
| # } | |
| # } | |
| # /* Title Styles */ | |
| # h1 { | |
| # background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| # -webkit-background-clip: text; | |
| # -webkit-text-fill-color: transparent; | |
| # background-clip: text; | |
| # font-size: 3rem !important; | |
| # font-weight: 700 !important; | |
| # text-align: center !important; | |
| # margin-bottom: 1rem !important; | |
| # letter-spacing: -0.5px; | |
| # } | |
| # /* Description Styles */ | |
| # .description { | |
| # background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
| # color: #2c3e50 !important; | |
| # font-size: 1.1rem !important; | |
| # padding: 1.5rem 2rem !important; | |
| # border-radius: 16px !important; | |
| # text-align: center !important; | |
| # margin: 0 auto 2.5rem !important; | |
| # border-left: 4px solid #667eea; | |
| # box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); | |
| # line-height: 1.6; | |
| # } | |
| # /* Label Styles */ | |
| # label { | |
| # font-weight: 600 !important; | |
| # color: #2c3e50 !important; | |
| # font-size: 0.95rem !important; | |
| # margin-bottom: 0.5rem !important; | |
| # display: block !important; | |
| # } | |
| # /* Input/Textarea Styles */ | |
| # textarea, input[type="text"] { | |
| # background: #ffffff !important; | |
| # border: 2px solid #e0e7ff !important; | |
| # border-radius: 12px !important; | |
| # padding: 1rem !important; | |
| # font-size: 1rem !important; | |
| # font-family: 'Inter', sans-serif !important; | |
| # transition: all 0.3s ease !important; | |
| # color: #2c3e50 !important; | |
| # text-align: left !important; | |
| # } | |
| # textarea:focus, input[type="text"]:focus { | |
| # border-color: #667eea !important; | |
| # outline: none !important; | |
| # box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; | |
| # transform: translateY(-2px); | |
| # } | |
| # textarea::placeholder, input::placeholder { | |
| # color: #94a3b8 !important; | |
| # font-style: italic; | |
| # } | |
| # /* Slider Container */ | |
| # .slider-container { | |
| # margin: 1.5rem 0 !important; | |
| # } | |
| # /* Slider Styles */ | |
| # input[type="range"] { | |
| # width: 100% !important; | |
| # height: 8px !important; | |
| # border-radius: 5px !important; | |
| # background: linear-gradient(to right, #667eea 0%, #764ba2 100%) !important; | |
| # outline: none !important; | |
| # opacity: 0.9 !important; | |
| # transition: opacity 0.2s !important; | |
| # } | |
| # input[type="range"]:hover { | |
| # opacity: 1 !important; | |
| # } | |
| # input[type="range"]::-webkit-slider-thumb { | |
| # width: 20px !important; | |
| # height: 20px !important; | |
| # border-radius: 50% !important; | |
| # background: #ffffff !important; | |
| # cursor: pointer !important; | |
| # box-shadow: 0 2px 6px rgba(102, 126, 234, 0.4) !important; | |
| # border: 3px solid #667eea !important; | |
| # } | |
| # input[type="range"]::-moz-range-thumb { | |
| # width: 20px !important; | |
| # height: 20px !important; | |
| # border-radius: 50% !important; | |
| # background: #ffffff !important; | |
| # cursor: pointer !important; | |
| # box-shadow: 0 2px 6px rgba(102, 126, 234, 0.4) !important; | |
| # border: 3px solid #667eea !important; | |
| # } | |
| # /* Info Text for Sliders */ | |
| # .info { | |
| # color: #64748b !important; | |
| # font-size: 0.875rem !important; | |
| # margin-top: 0.5rem !important; | |
| # font-style: italic; | |
| # } | |
| # /* Button Styles */ | |
| # button { | |
| # background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| # color: white !important; | |
| # border: none !important; | |
| # padding: 1rem 2.5rem !important; | |
| # border-radius: 12px !important; | |
| # cursor: pointer !important; | |
| # font-size: 1.1rem !important; | |
| # font-weight: 600 !important; | |
| # margin-top: 1.5rem !important; | |
| # width: 100% !important; | |
| # transition: all 0.3s ease !important; | |
| # box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; | |
| # letter-spacing: 0.5px; | |
| # } | |
| # button:hover { | |
| # transform: translateY(-2px) !important; | |
| # box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; | |
| # } | |
| # button:active { | |
| # transform: translateY(0) !important; | |
| # } | |
| # /* Output Container */ | |
| # .output-textbox { | |
| # background: #f8fafc !important; | |
| # border: 2px solid #e0e7ff !important; | |
| # border-radius: 12px !important; | |
| # padding: 1.5rem !important; | |
| # font-size: 1rem !important; | |
| # margin-top: 1.5rem !important; | |
| # line-height: 1.8 !important; | |
| # box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.06) !important; | |
| # text-align: left !important; | |
| # } | |
| # /* Example Section Styling */ | |
| # .examples { | |
| # background: #f8fafc !important; | |
| # border-radius: 16px !important; | |
| # padding: 1.5rem !important; | |
| # margin-top: 2rem !important; | |
| # border: 2px dashed #cbd5e1 !important; | |
| # } | |
| # .examples h4 { | |
| # color: #475569 !important; | |
| # font-weight: 600 !important; | |
| # margin-bottom: 1rem !important; | |
| # } | |
| # /* Loading Animation */ | |
| # .loading { | |
| # border: 3px solid #f3f4f6; | |
| # border-top: 3px solid #667eea; | |
| # border-radius: 50%; | |
| # width: 40px; | |
| # height: 40px; | |
| # animation: spin 1s linear infinite; | |
| # margin: 2rem auto; | |
| # } | |
| # @keyframes spin { | |
| # 0% { transform: rotate(0deg); } | |
| # 100% { transform: rotate(360deg); } | |
| # } | |
| # /* Card-like sections */ | |
| # .input-group { | |
| # background: #ffffff; | |
| # padding: 1.5rem; | |
| # border-radius: 12px; | |
| # margin-bottom: 1.5rem; | |
| # box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); | |
| # border: 1px solid #f1f5f9; | |
| # } | |
| # /* Responsive Design */ | |
| # @media screen and (max-width: 768px) { | |
| # h1 { | |
| # font-size: 2rem !important; | |
| # } | |
| # .description { | |
| # font-size: 1rem !important; | |
| # padding: 1.25rem !important; | |
| # } | |
| # .contain { | |
| # padding: 1.5rem !important; | |
| # } | |
| # button { | |
| # font-size: 1rem !important; | |
| # padding: 0.875rem 2rem !important; | |
| # } | |
| # } | |
| # /* Smooth transitions for all interactive elements */ | |
| # * { | |
| # transition: all 0.2s ease; | |
| # } | |
| # /* Custom scrollbar */ | |
| # ::-webkit-scrollbar { | |
| # width: 10px; | |
| # } | |
| # ::-webkit-scrollbar-track { | |
| # background: #f1f5f9; | |
| # border-radius: 10px; | |
| # } | |
| # ::-webkit-scrollbar-thumb { | |
| # background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| # border-radius: 10px; | |
| # } | |
| # ::-webkit-scrollbar-thumb:hover { | |
| # background: linear-gradient(135deg, #764ba2 0%, #667eea 100%); | |
| # } | |
| # """ | |
| # # Gradio interface with Generate button | |
| # iface = gr.Interface( | |
| # fn=generate_soap, | |
| # title="๐ฉบ SOAP Note Generator", | |
| # description=( | |
| # "Transform doctor-patient dialogues into professional, structured SOAP notes instantly. " | |
| # "Powered by advanced AI to ensure accuracy and medical formatting standards." | |
| # ), | |
| # inputs=[ | |
| # gr.Textbox( | |
| # label="๐ Doctor-Patient Dialogue", | |
| # placeholder="Paste the complete conversation between doctor and patient here...\n\nExample:\nDoctor: Hello, what brings you in today?\nPatient: I've been having chest pain for the past week...", | |
| # lines=10, | |
| # max_lines=20, | |
| # show_label=True, | |
| # interactive=True, | |
| # elem_classes="input-group" | |
| # ), | |
| # gr.Slider( | |
| # minimum=0, | |
| # maximum=1, | |
| # step=0.05, | |
| # value=0.7, | |
| # label="๐จ Temperature (Creativity Level)", | |
| # info="Lower values = More focused and consistent | Higher values = More creative and varied output" | |
| # ), | |
| # gr.Slider( | |
| # minimum=128, | |
| # maximum=4096, | |
| # step=128, | |
| # value=512, | |
| # label="๐ Max Length (Tokens)", | |
| # info="Controls the maximum length of the generated SOAP note (1 token โ 0.75 words)" | |
| # ), | |
| # ], | |
| # outputs=[ | |
| # gr.Textbox( | |
| # label="๐ Generated SOAP Note", | |
| # placeholder="Your professionally formatted SOAP note will appear here...\n\nโ Subjective findings\nโ Objective observations\nโ Assessment\nโ Plan of care", | |
| # lines=18, | |
| # max_lines=25, | |
| # interactive=False, | |
| # show_label=True, | |
| # show_copy_button=True | |
| # ) | |
| # ], | |
| # allow_flagging="never", | |
| # live=False, | |
| # cache_examples=False, | |
| # examples=[ | |
| # ["""Doctor: Hello, can you please tell me about your past medical history? | |
| # Patient: Hi, I don't have any past medical history. | |
| # Doctor: Okay. What brings you in today? | |
| # Patient: I've been experiencing painless blurry vision in my right eye for a week now. I've also had intermittent fevers, headache, body aches, and a nonpruritic maculopapular rash on my lower legs for the past 6 months. | |
| # Doctor: Thank you for sharing that. Have you had any other symptoms such as neck stiffness, nausea, vomiting, Raynaud's phenomenon, oral ulcerations, chest pain, shortness of breath, abdominal pain, or photosensitivity? | |
| # Patient: No, only an isolated episode of left knee swelling and testicular swelling in the past. | |
| # Doctor: Do you work with any toxic substances or have any habits like smoking, drinking, or illicit drug use? | |
| # Patient: No, I work as a flooring installer and I don't have any toxic habits. | |
| # Doctor: Alright. We checked your vital signs and they were normal. During the physical exam, we found bilateral papilledema and optic nerve erythema in your right eye, which was greater than in your left eye. You also have a right inferior nasal quadrant visual field defect and a right afferent pupillary defect. Your muscle strength and reflexes were normal, and your sensation to light touch, pinprick, vibration, and proprioception was intact. We also noticed the maculopapular rash on your bilateral lower extremities. | |
| # Patient: Oh, I see. | |
| # Doctor: Your admitting labs showed some abnormal results. You have microcytic anemia with a hemoglobin of 11.6 gm/dL, hematocrit of 35.3%, and mean corpuscular volume of 76.9 fL. You also have hyponatremia with a sodium level of 133 mmol/L. Your erythrocyte sedimentation rate (ESR) is elevated at 33 mm/hr, and your C-reactive protein (CRP) is also elevated at 13.3 mg/L. Your urinalysis did not show any protein or blood. | |
| # Patient: Okay. What does that mean? | |
| # Doctor: These results could indicate an underlying inflammatory or infectious process. We also performed a lumbar puncture, which showed clear and colorless fluid, 2 red blood cells per microliter, and 56 white blood cells per microliter. | |
| # Patient: So, what's the next step? | |
| # Doctor: We need to investigate further to determine the cause of your symptoms. We'll run additional tests and consult with a specialist to get a clearer understanding of your condition. In the meantime, we'll monitor your symptoms and provide supportive care. We'll keep you informed about any new findings and discuss the best course of treatment. | |
| # Patient: Alright, thank you, Doctor.""", | |
| # 0.7, 512] | |
| # ], | |
| # theme=gr.themes.Soft( | |
| # primary_hue="indigo", | |
| # secondary_hue="purple", | |
| # neutral_hue="slate", | |
| # ), | |
| # css=custom_css | |
| # ) | |
| # # Print message to confirm interface launch | |
| # print("๐ Launching Enhanced Gradio Interface...") | |
| # print("โจ New features: Modern gradient design, smooth animations, better UX") | |
| # # Launch the Gradio interface | |
| # if __name__ == "__main__": | |
| # try: | |
| # iface.launch( | |
| # server_name="0.0.0.0", | |
| # server_port=7860, | |
| # debug=True | |
| # ) | |
| # except Exception as e: | |
| # print(f"โ Error launching Gradio: {e}") | |
| # raise | |
| # ============================================================ | |
| # ๐ง Medical Dialogue โ SOAP Note Generator (Fine-tuned Phi-3) | |
| # ============================================================ | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel, PeftConfig | |
| import os | |
| import gc | |
| # ------------------------------------------------------------ | |
| # โ๏ธ Environment setup for memory management | |
| # ------------------------------------------------------------ | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # ------------------------------------------------------------ | |
| # ๐๏ธ Model identifiers | |
| # ------------------------------------------------------------ | |
| base_model_name = "microsoft/phi-3-mini-4k-instruct" | |
| fine_tuned_model_name = "raselmeya2194/med_dialogue2soap" | |
| # ------------------------------------------------------------ | |
| # ๐ง Load Base Model (CPU mode) | |
| # ------------------------------------------------------------ | |
| print("๐น Loading base model on CPU...") | |
| try: | |
| fine_tuned_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| device_map=None, # Force CPU | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32 | |
| ) | |
| print("โ Base model loaded successfully.") | |
| except Exception as e: | |
| print(f"โ Error loading base model: {e}") | |
| raise | |
| # ------------------------------------------------------------ | |
| # ๐ค Load Tokenizer | |
| # ------------------------------------------------------------ | |
| print("๐น Loading tokenizer...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("โ Tokenizer loaded successfully.") | |
| except Exception as e: | |
| print(f"โ Error loading tokenizer: {e}") | |
| raise | |
| # ------------------------------------------------------------ | |
| # ๐งฉ Load LoRA Adapter (PEFT fine-tuned weights) | |
| # ------------------------------------------------------------ | |
| print("๐น Loading LoRA adapter...") | |
| try: | |
| peft_config = PeftConfig.from_pretrained(fine_tuned_model_name) | |
| fine_tuned_model = PeftModel.from_pretrained( | |
| fine_tuned_model, | |
| fine_tuned_model_name, | |
| config=peft_config | |
| ) | |
| print("โ LoRA adapter loaded successfully.") | |
| except Exception as e: | |
| print(f"โ Error loading LoRA adapter: {e}") | |
| raise | |
| # ------------------------------------------------------------ | |
| # ๐ Prepare model for inference | |
| # ------------------------------------------------------------ | |
| fine_tuned_model.eval() | |
| device = torch.device("cpu") | |
| fine_tuned_model.to(device) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # ------------------------------------------------------------ | |
| # ๐งพ SOAP Note Generation Function | |
| # ------------------------------------------------------------ | |
| def generate_soap(input_text: str, temperature: float = 0.7, max_new_tokens: int = 512): | |
| """ | |
| Generates a SOAP note from a doctor-patient dialogue using the fine-tuned model (CPU). | |
| """ | |
| try: | |
| # Format input | |
| prompt = ( | |
| f"<|user|>\n" | |
| f"Generate a structured SOAP note based on the following doctor-patient dialogue:\n\n" | |
| f"{input_text.strip()}\n" | |
| f"<|end|>\n<|assistant|>SOAP Notes:\n" | |
| ) | |
| # Tokenize input | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048, | |
| padding=True | |
| ).to(device) | |
| # Run model (no autocast or GPU) | |
| with torch.no_grad(): | |
| outputs = fine_tuned_model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs.get("attention_mask"), | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| top_k=50, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| no_repeat_ngram_size=3, | |
| use_cache=False | |
| ) | |
| # Extract and decode generated tokens | |
| generated_ids = outputs[0][inputs["input_ids"].shape[-1]:] | |
| generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| # Clean up final output | |
| if "SOAP Notes:" in generated_text: | |
| generated_text = generated_text.split("SOAP Notes:")[-1].strip() | |
| return generated_text | |
| except Exception as e: | |
| return f"โ Error generating SOAP note: {str(e)}" | |
| # ------------------------------------------------------------ | |
| # ๐จ Custom CSS | |
| # ------------------------------------------------------------ | |
| custom_css = """ | |
| /* Import Google Fonts */ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
| /* Global Styles */ | |
| body { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| font-family: 'Inter', sans-serif; | |
| margin: 0; | |
| padding: 0; | |
| min-height: 100vh; | |
| } | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| padding: 2rem !important; | |
| } | |
| /* Main Container Styling */ | |
| .contain { | |
| background: rgba(255, 255, 255, 0.95) !important; | |
| backdrop-filter: blur(10px); | |
| border-radius: 24px !important; | |
| box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3) !important; | |
| padding: 3rem !important; | |
| animation: fadeIn 0.6s ease-in-out; | |
| } | |
| @keyframes fadeIn { | |
| from { | |
| opacity: 0; | |
| transform: translateY(20px); | |
| } | |
| to { | |
| opacity: 1; | |
| transform: translateY(0); | |
| } | |
| } | |
| /* Title Styles */ | |
| h1 { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| font-size: 3rem !important; | |
| font-weight: 700 !important; | |
| text-align: center !important; | |
| margin-bottom: 1rem !important; | |
| letter-spacing: -0.5px; | |
| } | |
| /* Description Styles */ | |
| .description { | |
| background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
| color: #2c3e50 !important; | |
| font-size: 1.1rem !important; | |
| padding: 1.5rem 2rem !important; | |
| border-radius: 16px !important; | |
| text-align: center !important; | |
| margin: 0 auto 2.5rem !important; | |
| border-left: 4px solid #667eea; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); | |
| line-height: 1.6; | |
| } | |
| /* Label Styles */ | |
| label { | |
| font-weight: 600 !important; | |
| color: #2c3e50 !important; | |
| font-size: 0.95rem !important; | |
| margin-bottom: 0.5rem !important; | |
| display: block !important; | |
| } | |
| /* Input/Textarea Styles */ | |
| textarea, input[type="text"] { | |
| background: #ffffff !important; | |
| border: 2px solid #e0e7ff !important; | |
| border-radius: 12px !important; | |
| padding: 1rem !important; | |
| font-size: 1rem !important; | |
| font-family: 'Inter', sans-serif !important; | |
| transition: all 0.3s ease !important; | |
| color: #2c3e50 !important; | |
| text-align: left !important; | |
| } | |
| textarea:focus, input[type="text"]:focus { | |
| border-color: #667eea !important; | |
| outline: none !important; | |
| box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; | |
| transform: translateY(-2px); | |
| } | |
| textarea::placeholder, input::placeholder { | |
| color: #94a3b8 !important; | |
| font-style: italic; | |
| } | |
| /* Slider Container */ | |
| .slider-container { | |
| margin: 1.5rem 0 !important; | |
| } | |
| /* Slider Styles */ | |
| input[type="range"] { | |
| width: 100% !important; | |
| height: 8px !important; | |
| border-radius: 5px !important; | |
| background: linear-gradient(to right, #667eea 0%, #764ba2 100%) !important; | |
| outline: none !important; | |
| opacity: 0.9 !important; | |
| transition: opacity 0.2s !important; | |
| } | |
| input[type="range"]:hover { | |
| opacity: 1 !important; | |
| } | |
| input[type="range"]::-webkit-slider-thumb { | |
| width: 20px !important; | |
| height: 20px !important; | |
| border-radius: 50% !important; | |
| background: #ffffff !important; | |
| cursor: pointer !important; | |
| box-shadow: 0 2px 6px rgba(102, 126, 234, 0.4) !important; | |
| border: 3px solid #667eea !important; | |
| } | |
| input[type="range"]::-moz-range-thumb { | |
| width: 20px !important; | |
| height: 20px !important; | |
| border-radius: 50% !important; | |
| background: #ffffff !important; | |
| cursor: pointer !important; | |
| box-shadow: 0 2px 6px rgba(102, 126, 234, 0.4) !important; | |
| border: 3px solid #667eea !important; | |
| } | |
| /* Info Text for Sliders */ | |
| .info { | |
| color: #64748b !important; | |
| font-size: 0.875rem !important; | |
| margin-top: 0.5rem !important; | |
| font-style: italic; | |
| } | |
| /* Button Styles */ | |
| button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| padding: 1rem 2.5rem !important; | |
| border-radius: 12px !important; | |
| cursor: pointer !important; | |
| font-size: 1.1rem !important; | |
| font-weight: 600 !important; | |
| margin-top: 1.5rem !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; | |
| letter-spacing: 0.5px; | |
| } | |
| button:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; | |
| } | |
| button:active { | |
| transform: translateY(0) !important; | |
| } | |
| /* Output Container */ | |
| .output-textbox { | |
| background: #f8fafc !important; | |
| border: 2px solid #e0e7ff !important; | |
| border-radius: 12px !important; | |
| padding: 1.5rem !important; | |
| font-size: 1rem !important; | |
| margin-top: 1.5rem !important; | |
| line-height: 1.8 !important; | |
| box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.06) !important; | |
| text-align: left !important; | |
| } | |
| /* Example Section Styling */ | |
| .examples { | |
| background: #f8fafc !important; | |
| border-radius: 16px !important; | |
| padding: 1.5rem !important; | |
| margin-top: 2rem !important; | |
| border: 2px dashed #cbd5e1 !important; | |
| } | |
| .examples h4 { | |
| color: #475569 !important; | |
| font-weight: 600 !important; | |
| margin-bottom: 1rem !important; | |
| } | |
| /* Loading Animation */ | |
| .loading { | |
| border: 3px solid #f3f4f6; | |
| border-top: 3px solid #667eea; | |
| border-radius: 50%; | |
| width: 40px; | |
| height: 40px; | |
| animation: spin 1s linear infinite; | |
| margin: 2rem auto; | |
| } | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| /* Card-like sections */ | |
| .input-group { | |
| background: #ffffff; | |
| padding: 1.5rem; | |
| border-radius: 12px; | |
| margin-bottom: 1.5rem; | |
| box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); | |
| border: 1px solid #f1f5f9; | |
| } | |
| /* Responsive Design */ | |
| @media screen and (max-width: 768px) { | |
| h1 { | |
| font-size: 2rem !important; | |
| } | |
| .description { | |
| font-size: 1rem !important; | |
| padding: 1.25rem !important; | |
| } | |
| .contain { | |
| padding: 1.5rem !important; | |
| } | |
| button { | |
| font-size: 1rem !important; | |
| padding: 0.875rem 2rem !important; | |
| } | |
| } | |
| /* Smooth transitions for all interactive elements */ | |
| * { | |
| transition: all 0.2s ease; | |
| } | |
| /* Custom scrollbar */ | |
| ::-webkit-scrollbar { | |
| width: 10px; | |
| } | |
| ::-webkit-scrollbar-track { | |
| background: #f1f5f9; | |
| border-radius: 10px; | |
| } | |
| ::-webkit-scrollbar-thumb { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border-radius: 10px; | |
| } | |
| ::-webkit-scrollbar-thumb:hover { | |
| background: linear-gradient(135deg, #764ba2 0%, #667eea 100%); | |
| } | |
| """ | |
| # ------------------------------------------------------------ | |
| # ๐จ Gradio Blocks Interface (Better Control) | |
| # ------------------------------------------------------------ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="purple", | |
| neutral_hue="slate", | |
| ), | |
| css=custom_css, | |
| title="๐ฉบ SOAP Note Generator" | |
| ) as demo: | |
| # Header | |
| gr.Markdown("# ๐ฉบ SOAP Note Generator") | |
| gr.Markdown( | |
| """<div class='description'> | |
| Transform doctor-patient dialogues into professional, structured SOAP notes instantly. | |
| Powered by advanced AI to ensure accuracy and medical formatting standards. | |
| </div>""", | |
| elem_classes="description" | |
| ) | |
| # Input Section | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_dialogue = gr.Textbox( | |
| label="๐ Doctor-Patient Dialogue", | |
| placeholder="Paste the complete conversation between doctor and patient here...\n\nExample:\nDoctor: Hello, what brings you in today?\nPatient: I've been having chest pain for the past week...", | |
| lines=10, | |
| max_lines=20, | |
| show_label=True, | |
| interactive=True, | |
| elem_classes="input-group" | |
| ) | |
| with gr.Row(): | |
| temperature_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| step=0.05, | |
| value=0.7, | |
| label="๐จ Temperature (Creativity Level)", | |
| info="Lower values = More focused and consistent | Higher values = More creative and varied output" | |
| ) | |
| max_length_slider = gr.Slider( | |
| minimum=128, | |
| maximum=4096, | |
| step=128, | |
| value=512, | |
| label="๐ Max Length (Tokens)", | |
| info="Controls the maximum length of the generated SOAP note (1 token โ 0.75 words)" | |
| ) | |
| generate_btn = gr.Button("๐ Generate SOAP Note", variant="primary", size="lg") | |
| # Output Section | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_soap = gr.Textbox( | |
| label="๐ Generated SOAP Note", | |
| placeholder="Your professionally formatted SOAP note will appear here...\n\nโ Subjective findings\nโ Objective observations\nโ Assessment\nโ Plan of care", | |
| lines=18, | |
| max_lines=25, | |
| interactive=False, | |
| show_label=True, | |
| show_copy_button=True | |
| ) | |
| # Examples Section | |
| gr.Examples( | |
| examples=[ | |
| ["""Doctor: Hello, can you please tell me about your past medical history? | |
| Patient: Hi, I don't have any past medical history. | |
| Doctor: Okay. What brings you in today? | |
| Patient: I've been experiencing painless blurry vision in my right eye for a week now. I've also had intermittent fevers, headache, body aches, and a nonpruritic maculopapular rash on my lower legs for the past 6 months. | |
| Doctor: Thank you for sharing that. Have you had any other symptoms such as neck stiffness, nausea, vomiting, Raynaud's phenomenon, oral ulcerations, chest pain, shortness of breath, abdominal pain, or photosensitivity? | |
| Patient: No, only an isolated episode of left knee swelling and testicular swelling in the past. | |
| Doctor: Do you work with any toxic substances or have any habits like smoking, drinking, or illicit drug use? | |
| Patient: No, I work as a flooring installer and I don't have any toxic habits. | |
| Doctor: Alright. We checked your vital signs and they were normal. During the physical exam, we found bilateral papilledema and optic nerve erythema in your right eye, which was greater than in your left eye. You also have a right inferior nasal quadrant visual field defect and a right afferent pupillary defect. Your muscle strength and reflexes were normal, and your sensation to light touch, pinprick, vibration, and proprioception was intact. We also noticed the maculopapular rash on your bilateral lower extremities. | |
| Patient: Oh, I see. | |
| Doctor: Your admitting labs showed some abnormal results. You have microcytic anemia with a hemoglobin of 11.6 gm/dL, hematocrit of 35.3%, and mean corpuscular volume of 76.9 fL. You also have hyponatremia with a sodium level of 133 mmol/L. Your erythrocyte sedimentation rate (ESR) is elevated at 33 mm/hr, and your C-reactive protein (CRP) is also elevated at 13.3 mg/L. Your urinalysis did not show any protein or blood. | |
| Patient: Okay. What does that mean? | |
| Doctor: These results could indicate an underlying inflammatory or infectious process. We also performed a lumbar puncture, which showed clear and colorless fluid, 2 red blood cells per microliter, and 56 white blood cells per microliter. | |
| Patient: So, what's the next step? | |
| Doctor: We need to investigate further to determine the cause of your symptoms. We'll run additional tests and consult with a specialist to get a clearer understanding of your condition. In the meantime, we'll monitor your symptoms and provide supportive care. We'll keep you informed about any new findings and discuss the best course of treatment. | |
| Patient: Alright, thank you, Doctor.""", 0.7, 512] | |
| ], | |
| inputs=[input_dialogue, temperature_slider, max_length_slider], | |
| label="๐ก Try this example", | |
| examples_per_page=1 | |
| ) | |
| # Button click event - ONLY triggers on button click | |
| generate_btn.click( | |
| fn=generate_soap, | |
| inputs=[input_dialogue, temperature_slider, max_length_slider], | |
| outputs=output_soap | |
| ) | |
| # Print message to confirm interface launch | |
| print("๐ Launching Enhanced Gradio Interface...") | |
| print("โจ New features: Modern gradient design, smooth animations, better UX") | |
| print("๐ Manual trigger: Generation only starts when you click the button") | |
| # Launch the Gradio interface | |
| if __name__ == "__main__": | |
| try: | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True | |
| ) | |
| except Exception as e: | |
| print(f"โ Error launching Gradio: {e}") | |
| raise |