File size: 6,882 Bytes
72f0197
c33a7b6
6068b3b
72f0197
 
4570e73
72f0197
 
4570e73
72f0197
 
 
4570e73
72f0197
4570e73
c33a7b6
4570e73
72f0197
 
 
4570e73
72f0197
 
 
4570e73
72f0197
 
4570e73
72f0197
 
 
 
 
4570e73
 
 
 
6068b3b
 
4570e73
 
26a6bd7
 
d477b36
 
 
 
 
4570e73
 
 
d477b36
 
 
 
 
 
 
 
4570e73
d477b36
4570e73
72f0197
 
4570e73
 
 
 
 
72f0197
4570e73
 
72f0197
 
6068b3b
4570e73
72f0197
 
 
 
 
 
 
 
4570e73
72f0197
 
 
4570e73
 
 
72f0197
 
4570e73
72f0197
4570e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72f0197
 
4570e73
 
 
 
72f0197
 
 
 
 
 
 
 
 
6068b3b
72f0197
4570e73
72f0197
 
4570e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72f0197
 
4570e73
 
 
 
 
 
72f0197
 
 
 
 
 
4570e73
 
 
 
 
 
6068b3b
4570e73
 
 
 
 
 
 
72f0197
 
4570e73
 
 
 
 
 
 
 
 
 
 
 
 
72f0197
4570e73
 
 
 
72f0197
 
4570e73
7fdf8f5
4570e73
7fdf8f5
4570e73
7fdf8f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch

# -------------------
# 1️⃣ Load Model & Processor (Now from Hugging Face)
# -------------------
def load_model():
    model_id = "Muhammadidrees/RaiyaChatDoc"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32

    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForVision2Seq.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map="auto"  # Let HF handle device placement
    )
    return processor, model, device

# Load model once at startup
processor, model, device = load_model()

# -------------------
# 2️⃣ Chat Logic Functions
# -------------------
def process_message(message, history, question_count):
    """Process user message and generate doctor response"""
    if not message.strip():
        return history, history, question_count
    
    history.append([message, None])
    question_count += 1
    
    should_analyze = (
        question_count >= 6 or 
        any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"])
    )

    if should_analyze:
        system_prompt = (
            "You are a highly experienced medical expert who combines the roles of a medical doctor, specialist, nutritionist, and medical teacher.\n"
            "Based only on the patient's provided information, give a clear and structured analysis:\n\n"
            "1. Possible health issues or conditions the patient might have (3–4 points).\n"
            "2. Dietary and lifestyle recommendations specific to the patient’s situation.\n"
            "3. Guidance on which type of doctor or specialist the patient should consult.\n\n"
            "Be concise, professional, and easy to understand for a non-medical person. "
            "If you mention complex medical terms, briefly explain them in simple language."
        )
    else:
        system_prompt = (
            "You are a medical expert conducting a patient interview. Follow these rules:\n"
            "1. If the user simply shares symptoms or health info, ask ONE direct and specific medical question "
            "to gather diagnostic details (e.g., age, medical history, medications, lifestyle, family history, or symptoms). "
            "Do not explain, just ask the question.\n"
            "2. If the user explicitly asks for a diet plan, provide a complete, practical diet plan. "
            "Avoid unnecessary disclaimers, but keep it safe and balanced.\n"
            "3. If the user asks about a complex medical term, give a clear and simple explanation.\n\n"
            "Always keep responses brief, clear, and professional."
        )

    
    dialogue = []
    for user_msg, bot_msg in history[:-1]:
        if user_msg:
            dialogue.append(f"Patient: {user_msg}")
        if bot_msg:
            dialogue.append(f"Doctor: {bot_msg}")
    
    dialogue.append(f"Patient: {message}")
    conversation = "\n".join(dialogue)
    prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:"

    inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
    max_tokens = 400 if should_analyze else 25
    
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            repetition_penalty=1.1,
            pad_token_id=processor.tokenizer.eos_token_id,
        )

    input_length = inputs["input_ids"].shape[1]
    generated_tokens = outputs[:, input_length:]
    response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()

    if response.lower().startswith("doctor:"):
        response = response[7:].strip()
    
    if not should_analyze:
        sentences = response.split('?')
        if len(sentences) > 1:
            response = sentences[0].strip() + '?'
        
        cleanup_starts = [
            "I need to ask",
            "Let me ask",
            "I would like to know",
            "Can you tell me",
            "It would help if",
        ]
        
        for phrase in cleanup_starts:
            if response.startswith(phrase):
                parts = response.split(',', 1)
                if len(parts) > 1:
                    response = parts[1].strip()
                    if not response.endswith('?'):
                        response += '?'

    history[-1][1] = response
    
    if should_analyze:
        question_count = 0
    
    return history, history, question_count

def force_analysis(history, question_count):
    return history, 10

def clear_chat():
    return [], [], 0

# -------------------
# 3️⃣ Gradio Interface
# -------------------
with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo:
    question_count_state = gr.State(0)
    
    gr.Markdown(
        """
        # 🩺 Chat with ChatDOC
        Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better.
        """
    )
    
    chatbot = gr.Chatbot(
        value=[],
        height=400,
        show_label=False,
         avatar_images=(
            r"user_msg.png",
            r"bot_msg.jpg"
        ),
        bubble_full_width=False
    )
    
    with gr.Row():
        msg = gr.Textbox(
            placeholder="Describe your symptoms...",
            scale=4,
            container=False,
            show_label=False
        )
        send_btn = gr.Button("Send", variant="primary", scale=1)
    
    with gr.Row():
        analysis_btn = gr.Button("Request Analysis", variant="secondary")
        clear_btn = gr.Button("Clear Chat", variant="stop")
    
    def user_submit(message, history, question_count):
        return process_message(message, history, question_count)
    
    def clear_input():
        return ""
    
    send_event = send_btn.click(
        user_submit,
        inputs=[msg, chatbot, question_count_state],
        outputs=[chatbot, chatbot, question_count_state]
    ).then(
        clear_input,
        outputs=[msg]
    )
    
    msg.submit(
        user_submit,
        inputs=[msg, chatbot, question_count_state],
        outputs=[chatbot, chatbot, question_count_state]
    ).then(
        clear_input,
        outputs=[msg]
    )
    
    analysis_btn.click(
        force_analysis,
        inputs=[chatbot, question_count_state],
        outputs=[chatbot, question_count_state]
    )
    
    clear_btn.click(
        clear_chat,
        outputs=[chatbot, chatbot, question_count_state]
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )