Muhammadidrees's picture
Update app.py
26a6bd7 verified
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
# -------------------
# 1️⃣ Load Model & Processor (Now from Hugging Face)
# -------------------
def load_model():
model_id = "Muhammadidrees/RaiyaChatDoc"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=dtype,
device_map="auto" # Let HF handle device placement
)
return processor, model, device
# Load model once at startup
processor, model, device = load_model()
# -------------------
# 2️⃣ Chat Logic Functions
# -------------------
def process_message(message, history, question_count):
"""Process user message and generate doctor response"""
if not message.strip():
return history, history, question_count
history.append([message, None])
question_count += 1
should_analyze = (
question_count >= 6 or
any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"])
)
if should_analyze:
system_prompt = (
"You are a highly experienced medical expert who combines the roles of a medical doctor, specialist, nutritionist, and medical teacher.\n"
"Based only on the patient's provided information, give a clear and structured analysis:\n\n"
"1. Possible health issues or conditions the patient might have (3–4 points).\n"
"2. Dietary and lifestyle recommendations specific to the patient’s situation.\n"
"3. Guidance on which type of doctor or specialist the patient should consult.\n\n"
"Be concise, professional, and easy to understand for a non-medical person. "
"If you mention complex medical terms, briefly explain them in simple language."
)
else:
system_prompt = (
"You are a medical expert conducting a patient interview. Follow these rules:\n"
"1. If the user simply shares symptoms or health info, ask ONE direct and specific medical question "
"to gather diagnostic details (e.g., age, medical history, medications, lifestyle, family history, or symptoms). "
"Do not explain, just ask the question.\n"
"2. If the user explicitly asks for a diet plan, provide a complete, practical diet plan. "
"Avoid unnecessary disclaimers, but keep it safe and balanced.\n"
"3. If the user asks about a complex medical term, give a clear and simple explanation.\n\n"
"Always keep responses brief, clear, and professional."
)
dialogue = []
for user_msg, bot_msg in history[:-1]:
if user_msg:
dialogue.append(f"Patient: {user_msg}")
if bot_msg:
dialogue.append(f"Doctor: {bot_msg}")
dialogue.append(f"Patient: {message}")
conversation = "\n".join(dialogue)
prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:"
inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
max_tokens = 400 if should_analyze else 25
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=processor.tokenizer.eos_token_id,
)
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_length:]
response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
if response.lower().startswith("doctor:"):
response = response[7:].strip()
if not should_analyze:
sentences = response.split('?')
if len(sentences) > 1:
response = sentences[0].strip() + '?'
cleanup_starts = [
"I need to ask",
"Let me ask",
"I would like to know",
"Can you tell me",
"It would help if",
]
for phrase in cleanup_starts:
if response.startswith(phrase):
parts = response.split(',', 1)
if len(parts) > 1:
response = parts[1].strip()
if not response.endswith('?'):
response += '?'
history[-1][1] = response
if should_analyze:
question_count = 0
return history, history, question_count
def force_analysis(history, question_count):
return history, 10
def clear_chat():
return [], [], 0
# -------------------
# 3️⃣ Gradio Interface
# -------------------
with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo:
question_count_state = gr.State(0)
gr.Markdown(
"""
# 🩺 Chat with ChatDOC
Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better.
"""
)
chatbot = gr.Chatbot(
value=[],
height=400,
show_label=False,
avatar_images=(
r"user_msg.png",
r"bot_msg.jpg"
),
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="Describe your symptoms...",
scale=4,
container=False,
show_label=False
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
analysis_btn = gr.Button("Request Analysis", variant="secondary")
clear_btn = gr.Button("Clear Chat", variant="stop")
def user_submit(message, history, question_count):
return process_message(message, history, question_count)
def clear_input():
return ""
send_event = send_btn.click(
user_submit,
inputs=[msg, chatbot, question_count_state],
outputs=[chatbot, chatbot, question_count_state]
).then(
clear_input,
outputs=[msg]
)
msg.submit(
user_submit,
inputs=[msg, chatbot, question_count_state],
outputs=[chatbot, chatbot, question_count_state]
).then(
clear_input,
outputs=[msg]
)
analysis_btn.click(
force_analysis,
inputs=[chatbot, question_count_state],
outputs=[chatbot, question_count_state]
)
clear_btn.click(
clear_chat,
outputs=[chatbot, chatbot, question_count_state]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)