File size: 5,851 Bytes
8aceec0
 
 
 
 
 
 
 
 
 
 
8c2c9df
8aceec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06ceb89
8aceec0
06ceb89
8aceec0
8c2c9df
8aceec0
8c2c9df
8aceec0
 
 
 
8c2c9df
06ceb89
 
 
 
561281d
06ceb89
 
 
 
 
 
561281d
06ceb89
561281d
06ceb89
 
 
8c2c9df
 
 
 
8aceec0
 
 
 
561281d
8aceec0
 
 
 
8c2c9df
 
 
8aceec0
8c2c9df
 
 
8aceec0
8c2c9df
561281d
8c2c9df
06ceb89
 
8aceec0
8c2c9df
 
 
561281d
8aceec0
561281d
8aceec0
 
 
8c2c9df
8aceec0
 
 
561281d
8aceec0
 
 
 
 
8c2c9df
8aceec0
 
 
 
 
 
 
 
 
8c2c9df
8aceec0
 
 
 
 
 
 
8c2c9df
8aceec0
 
 
8c2c9df
8aceec0
 
 
 
 
 
 
 
 
 
8c2c9df
561281d
 
 
8aceec0
561281d
 
8aceec0
8c2c9df
561281d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aceec0
 
 
 
561281d
 
 
 
8aceec0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import time

# =======================================================
# Load Model
# =======================================================
model_name = "augtoma/qCammel-13"
print("Loading tokenizer and model...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    low_cpu_mem_usage=True
)
model.eval()

print("Model loaded successfully!")
print(f"Device map: {model.hf_device_map}")
print(f"Model device: {next(model.parameters()).device}")


# =======================================================
# Generate Doctor Response (Stateless + Clean Replies)
# =======================================================
def generate_doctor_response(history):
    user_message = history[-1]["content"]

    if not user_message.strip():
        history.append({"role": "assistant", "content": "⚠️ Please describe your symptoms or ask a question."})
        yield history
        return

    # 🩺 New Prompt (no 'Patient:' or 'Doctor:' lines)
    prompt = f"""
You are a compassionate and professional medical expert.
Your role is to help users by providing clear, empathetic, and accurate medical information.

Guidelines:
1. Do NOT include words like 'Doctor:' or 'Patient:' in your replies.
2. Respond naturally and directly to the user's concern.
3. Keep answers short, clear, and medically sound.
4. Add a disclaimer when appropriate: 
   ⚕️ *This is AI-generated information and not a substitute for professional medical advice.*

Now, please respond to the user's message below:

User: {user_message}
Assistant:
"""

    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    gen_config = GenerationConfig(
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        max_new_tokens=500,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.2
    )

    input_len = inputs["input_ids"].shape[1]

    with torch.no_grad():
        output_ids = model.generate(**inputs, generation_config=gen_config)

    generated_ids = output_ids[0][input_len:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    # Keep concise output
    response = ". ".join(response.split(". ")[:3]).strip()
    if response.lower().startswith("assistant:"):
        response = response[10:].strip()
    if len(response) < 10:
        response = "I understand your concern. Could you please provide more details about your symptoms?"

    # Stream response token by token
    history.append({"role": "assistant", "content": ""})
    for i in range(0, len(response), 4):
        chunk = response[:i + 4]
        history[-1]["content"] = chunk + "▌"
        yield history.copy()
        time.sleep(0.015)

    history[-1]["content"] = response
    yield history


# =======================================================
# Gradio Interface
# =======================================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🩺 AI Doctor Chat Assistant")

    chatbot = gr.Chatbot(
        label="💬 Doctor Consultation",
        type='messages',
        avatar_images=(
            "https://cdn-icons-png.flaticon.com/512/706/706830.png",  # Patient
            "https://cdn-icons-png.flaticon.com/512/3774/3774299.png"   # Doctor
        ),
        height=500
    )

    with gr.Row():
        user_input = gr.Textbox(
            placeholder="Type your symptoms or question here...",
            label="🧍 Your Message",
            lines=2,
            scale=4
        )

    with gr.Row():
        send_btn = gr.Button("💬 Send", variant="primary", scale=1)
        clear_btn = gr.Button("🧹 Clear Chat", scale=1)

    gr.Examples(
        examples=[
            "I have a fever of 102°F since yesterday",
            "I've been having headaches for the past week",
            "I feel very tired all the time",
            "I have a sore throat and body aches",
        ],
        inputs=user_input,
        label="💡 Example Questions"
    )

    # =======================================================
    # Respond Function — Model forgets, Chat UI remembers
    # =======================================================
    def respond(message, history):
        user_message = message.strip()
        if not user_message:
            return "", history

        # Show user message in chat
        history.append({"role": "user", "content": user_message})

        # Model sees only current message (no memory)
        temp_history = [{"role": "user", "content": user_message}]

        for updated_history in generate_doctor_response(temp_history):
            if len(history) == 0 or history[-1]["role"] != "assistant":
                history.append({"role": "assistant", "content": updated_history[-1]["content"]})
            else:
                history[-1]["content"] = updated_history[-1]["content"]
            yield "", history

    # =======================================================
    # Button & Input Bindings
    # =======================================================
    send_btn.click(respond, [user_input, chatbot], [user_input, chatbot])
    user_input.submit(respond, [user_input, chatbot], [user_input, chatbot])
    clear_btn.click(lambda: [], None, chatbot, queue=False)


# =======================================================
# Launch App
# =======================================================
if __name__ == "__main__":
    demo.queue()
    demo.launch(share=True)