Medical_Chatbot / app.py
Amrender's picture
Create app.py
f7d54db verified
import os
import torch
import re
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
from threading import Thread
# ==========================================
# 1. SETUP & AUTHENTICATION
# ==========================================
# Mistral requires an HF token to download the base model.
# You must add your token in the HF Spaces "Settings" -> "Variables and secrets" as 'HF_TOKEN'
hf_token = os.environ.get("HF_TOKEN")
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
ADAPTER_REPO = "your-username/Medical-Mistral-7B-LoRA" # <--- CHANGE THIS TO YOUR REPO
print("Booting up Clinical AI Server...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=hf_token)
tokenizer.pad_token = tokenizer.eos_token
# Load Base Model in 4-bit precision
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
token=hf_token
)
# Merge Base Model with your Trained LoRA Adapter
print("Downloading and attaching LoRA Adapter from Hub...")
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO, token=hf_token)
model.eval()
print("✅ Model Ready!")
# ==========================================
# 2. FORMATTING FUNCTION
# ==========================================
def format_and_clean(text):
end_triggers = ["regards", "hope this", "hope i", "let me know", "dr.", "thanks for"]
lower_text = text.lower()
cutoff = len(text)
for trigger in end_triggers:
idx = lower_text.find(trigger)
if idx != -1 and idx < cutoff:
cutoff = idx
clean = text[:cutoff].strip().replace('*', '')
sentences = clean.split(". ")
capitalized = [s.capitalize() for s in sentences if s]
clean = ". ".join(capitalized)
clean = re.sub(r'(?i)assessment\s*:', '**Assessment:** ', clean)
clean = re.sub(r'(?i)analysis\s*:', '\n\n**Analysis:** ', clean)
clean = re.sub(r'(?i)recommended action\s*:', '\n\n**Recommended Action:** ', clean)
return clean
# ==========================================
# 3. CHAT ENGINE & UI
# ==========================================
def clinical_chat(message, history):
prompt = f"""<s>[INST] You are a highly intelligent clinical AI. The user says: "{message}"
If the user provides symptoms, diagnose them. If the user asks about a known condition, provide treatment advice.
You MUST format your response exactly like this:
**Assessment:** [Name of the condition]
**Analysis:** [Brief explanation]
**Recommended Action:** [Medical advice]
Do not include greetings, sign-offs, or links. [/INST]"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10.0)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=300,
temperature=0.2,
repetition_penalty=1.15,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_text = ""
for new_token in streamer:
partial_text += new_token
yield format_and_clean(partial_text)
# Launch UI inside Docker container
theme = gr.themes.Soft(primary_hue="blue", neutral_hue="slate")
demo = gr.ChatInterface(
fn=clinical_chat,
title="⚕️ Clinical AI Diagnostic Assistant",
description="**Enter your symptoms or medical queries below for a professional analysis.**",
theme=theme,
examples=[
"I am having severe headache, body pain and strain in my neck.",
"What medicine should I take if I have high cholesterol?",
"I have been sneezing and coughing for 3 days."
]
)
if __name__ == "__main__":
# 0.0.0.0 and port 7860 are strictly required for Docker on Hugging Face Spaces
demo.launch(server_name="0.0.0.0", server_port=7860)