Spaces:
Sleeping
Sleeping
File size: 6,882 Bytes
72f0197 c33a7b6 6068b3b 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 c33a7b6 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 6068b3b 4570e73 26a6bd7 d477b36 4570e73 d477b36 4570e73 d477b36 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 6068b3b 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 6068b3b 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 6068b3b 4570e73 72f0197 4570e73 72f0197 4570e73 72f0197 4570e73 7fdf8f5 4570e73 7fdf8f5 4570e73 7fdf8f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
# -------------------
# 1️⃣ Load Model & Processor (Now from Hugging Face)
# -------------------
def load_model():
model_id = "Muhammadidrees/RaiyaChatDoc"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=dtype,
device_map="auto" # Let HF handle device placement
)
return processor, model, device
# Load model once at startup
processor, model, device = load_model()
# -------------------
# 2️⃣ Chat Logic Functions
# -------------------
def process_message(message, history, question_count):
"""Process user message and generate doctor response"""
if not message.strip():
return history, history, question_count
history.append([message, None])
question_count += 1
should_analyze = (
question_count >= 6 or
any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"])
)
if should_analyze:
system_prompt = (
"You are a highly experienced medical expert who combines the roles of a medical doctor, specialist, nutritionist, and medical teacher.\n"
"Based only on the patient's provided information, give a clear and structured analysis:\n\n"
"1. Possible health issues or conditions the patient might have (3–4 points).\n"
"2. Dietary and lifestyle recommendations specific to the patient’s situation.\n"
"3. Guidance on which type of doctor or specialist the patient should consult.\n\n"
"Be concise, professional, and easy to understand for a non-medical person. "
"If you mention complex medical terms, briefly explain them in simple language."
)
else:
system_prompt = (
"You are a medical expert conducting a patient interview. Follow these rules:\n"
"1. If the user simply shares symptoms or health info, ask ONE direct and specific medical question "
"to gather diagnostic details (e.g., age, medical history, medications, lifestyle, family history, or symptoms). "
"Do not explain, just ask the question.\n"
"2. If the user explicitly asks for a diet plan, provide a complete, practical diet plan. "
"Avoid unnecessary disclaimers, but keep it safe and balanced.\n"
"3. If the user asks about a complex medical term, give a clear and simple explanation.\n\n"
"Always keep responses brief, clear, and professional."
)
dialogue = []
for user_msg, bot_msg in history[:-1]:
if user_msg:
dialogue.append(f"Patient: {user_msg}")
if bot_msg:
dialogue.append(f"Doctor: {bot_msg}")
dialogue.append(f"Patient: {message}")
conversation = "\n".join(dialogue)
prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:"
inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
max_tokens = 400 if should_analyze else 25
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=processor.tokenizer.eos_token_id,
)
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_length:]
response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
if response.lower().startswith("doctor:"):
response = response[7:].strip()
if not should_analyze:
sentences = response.split('?')
if len(sentences) > 1:
response = sentences[0].strip() + '?'
cleanup_starts = [
"I need to ask",
"Let me ask",
"I would like to know",
"Can you tell me",
"It would help if",
]
for phrase in cleanup_starts:
if response.startswith(phrase):
parts = response.split(',', 1)
if len(parts) > 1:
response = parts[1].strip()
if not response.endswith('?'):
response += '?'
history[-1][1] = response
if should_analyze:
question_count = 0
return history, history, question_count
def force_analysis(history, question_count):
return history, 10
def clear_chat():
return [], [], 0
# -------------------
# 3️⃣ Gradio Interface
# -------------------
with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo:
question_count_state = gr.State(0)
gr.Markdown(
"""
# 🩺 Chat with ChatDOC
Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better.
"""
)
chatbot = gr.Chatbot(
value=[],
height=400,
show_label=False,
avatar_images=(
r"user_msg.png",
r"bot_msg.jpg"
),
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="Describe your symptoms...",
scale=4,
container=False,
show_label=False
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
analysis_btn = gr.Button("Request Analysis", variant="secondary")
clear_btn = gr.Button("Clear Chat", variant="stop")
def user_submit(message, history, question_count):
return process_message(message, history, question_count)
def clear_input():
return ""
send_event = send_btn.click(
user_submit,
inputs=[msg, chatbot, question_count_state],
outputs=[chatbot, chatbot, question_count_state]
).then(
clear_input,
outputs=[msg]
)
msg.submit(
user_submit,
inputs=[msg, chatbot, question_count_state],
outputs=[chatbot, chatbot, question_count_state]
).then(
clear_input,
outputs=[msg]
)
analysis_btn.click(
force_analysis,
inputs=[chatbot, question_count_state],
outputs=[chatbot, question_count_state]
)
clear_btn.click(
clear_chat,
outputs=[chatbot, chatbot, question_count_state]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
|