File size: 4,208 Bytes
f7d54db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import re
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
from threading import Thread

# ==========================================
# 1. SETUP & AUTHENTICATION
# ==========================================
# Mistral requires an HF token to download the base model. 
# You must add your token in the HF Spaces "Settings" -> "Variables and secrets" as 'HF_TOKEN'
hf_token = os.environ.get("HF_TOKEN")

BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
ADAPTER_REPO = "your-username/Medical-Mistral-7B-LoRA" # <--- CHANGE THIS TO YOUR REPO

print("Booting up Clinical AI Server...")

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=hf_token)
tokenizer.pad_token = tokenizer.eos_token

# Load Base Model in 4-bit precision
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_compute_dtype=torch.float16
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL, 
    quantization_config=bnb_config, 
    device_map="auto",
    token=hf_token
)

# Merge Base Model with your Trained LoRA Adapter
print("Downloading and attaching LoRA Adapter from Hub...")
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO, token=hf_token)
model.eval()
print("✅ Model Ready!")

# ==========================================
# 2. FORMATTING FUNCTION
# ==========================================
def format_and_clean(text):
    end_triggers = ["regards", "hope this", "hope i", "let me know", "dr.", "thanks for"]
    lower_text = text.lower()
    cutoff = len(text)
    
    for trigger in end_triggers:
        idx = lower_text.find(trigger)
        if idx != -1 and idx < cutoff:
            cutoff = idx
            
    clean = text[:cutoff].strip().replace('*', '')
    
    sentences = clean.split(". ")
    capitalized = [s.capitalize() for s in sentences if s]
    clean = ". ".join(capitalized)

    clean = re.sub(r'(?i)assessment\s*:', '**Assessment:** ', clean)
    clean = re.sub(r'(?i)analysis\s*:', '\n\n**Analysis:** ', clean)
    clean = re.sub(r'(?i)recommended action\s*:', '\n\n**Recommended Action:** ', clean)
    
    return clean

# ==========================================
# 3. CHAT ENGINE & UI
# ==========================================
def clinical_chat(message, history):
    prompt = f"""<s>[INST] You are a highly intelligent clinical AI. The user says: "{message}"
    
    If the user provides symptoms, diagnose them. If the user asks about a known condition, provide treatment advice.
    
    You MUST format your response exactly like this:
    **Assessment:** [Name of the condition]
    **Analysis:** [Brief explanation]
    **Recommended Action:** [Medical advice]
    
    Do not include greetings, sign-offs, or links. [/INST]"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10.0)
    
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=300,
        temperature=0.2, 
        repetition_penalty=1.15, 
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    partial_text = ""
    for new_token in streamer:
        partial_text += new_token
        yield format_and_clean(partial_text)

# Launch UI inside Docker container
theme = gr.themes.Soft(primary_hue="blue", neutral_hue="slate")

demo = gr.ChatInterface(
    fn=clinical_chat,
    title="⚕️ Clinical AI Diagnostic Assistant",
    description="**Enter your symptoms or medical queries below for a professional analysis.**",
    theme=theme,
    examples=[
        "I am having severe headache, body pain and strain in my neck.",
        "What medicine should I take if I have high cholesterol?",
        "I have been sneezing and coughing for 3 days."
    ]
)

if __name__ == "__main__":
    # 0.0.0.0 and port 7860 are strictly required for Docker on Hugging Face Spaces
    demo.launch(server_name="0.0.0.0", server_port=7860)