Muhammadidrees's picture
Update app.py
4570e73 verified
raw
history blame
6.23 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
# -------------------
# 1️⃣ Load Model & Processor (Now from Hugging Face)
# -------------------
def load_model():
model_id = "Muhammadidrees/RaiyaChatDoc"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=dtype,
device_map="auto" # Let HF handle device placement
)
model.to(device)
return processor, model, device
# Load model once at startup
processor, model, device = load_model()
# -------------------
# 2️⃣ Chat Logic Functions
# -------------------
def process_message(message, history, question_count):
"""Process user message and generate doctor response"""
if not message.strip():
return history, history, question_count
history.append([message, None])
question_count += 1
should_analyze = (
question_count >= 6 or
any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"])
)
if should_analyze:
system_prompt = (
"You are a medical doctor. Based on the patient's responses, provide a comprehensive analysis "
"of potential causes for their symptoms. Start with 'Based on the information provided by the patient, "
"potential causes of [symptoms] could include:' and list 3-4 possible diagnoses with brief explanations. "
"Format as numbered list with diagnosis name and short explanation."
)
else:
system_prompt = (
"You are a medical doctor conducting a patient interview. Ask ONE specific, direct medical question "
"to gather important diagnostic information. Keep it brief - just ask the question without explanations. "
"Focus on key areas like: age, medical history, medications, lifestyle, family history, or symptom details."
)
dialogue = []
for user_msg, bot_msg in history[:-1]:
if user_msg:
dialogue.append(f"Patient: {user_msg}")
if bot_msg:
dialogue.append(f"Doctor: {bot_msg}")
dialogue.append(f"Patient: {message}")
conversation = "\n".join(dialogue)
prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:"
inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
max_tokens = 400 if should_analyze else 25
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=processor.tokenizer.eos_token_id,
)
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_length:]
response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
if response.lower().startswith("doctor:"):
response = response[7:].strip()
if not should_analyze:
sentences = response.split('?')
if len(sentences) > 1:
response = sentences[0].strip() + '?'
cleanup_starts = [
"I need to ask",
"Let me ask",
"I would like to know",
"Can you tell me",
"It would help if",
]
for phrase in cleanup_starts:
if response.startswith(phrase):
parts = response.split(',', 1)
if len(parts) > 1:
response = parts[1].strip()
if not response.endswith('?'):
response += '?'
history[-1][1] = response
if should_analyze:
question_count = 0
return history, history, question_count
def force_analysis(history, question_count):
return history, 10
def clear_chat():
return [], [], 0
# -------------------
# 3️⃣ Gradio Interface
# -------------------
with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo:
question_count_state = gr.State(0)
gr.Markdown(
"""
# 🩺 Chat with ChatDOC
Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better.
"""
)
chatbot = gr.Chatbot(
value=[],
height=400,
show_label=False,
avatar_images=(
r"user_msg.png",
r"bot_msg.jpg"
),
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="Describe your symptoms...",
scale=4,
container=False,
show_label=False
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
analysis_btn = gr.Button("Request Analysis", variant="secondary")
clear_btn = gr.Button("Clear Chat", variant="stop")
def user_submit(message, history, question_count):
return process_message(message, history, question_count)
def clear_input():
return ""
send_event = send_btn.click(
user_submit,
inputs=[msg, chatbot, question_count_state],
outputs=[chatbot, chatbot, question_count_state]
).then(
clear_input,
outputs=[msg]
)
msg.submit(
user_submit,
inputs=[msg, chatbot, question_count_state],
outputs=[chatbot, chatbot, question_count_state]
).then(
clear_input,
outputs=[msg]
)
analysis_btn.click(
force_analysis,
inputs=[chatbot, question_count_state],
outputs=[chatbot, question_count_state]
)
clear_btn.click(
clear_chat,
outputs=[chatbot, chatbot, question_count_state]
)
if __name__ == "__main__":
demo.launch(
server_name="127.0.0.1",
server_port=7860,
share=False,
debug=True
)