Muhammadidrees's picture
Update app.py
6068b3b verified
raw
history blame
4.33 kB
# app.py
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
# -------------------
# 1️⃣ Load Model
# -------------------
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load model and processor from Hugging Face
processor = AutoProcessor.from_pretrained("Muhammadidrees/RaiyaChatDoc", trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
"Muhammadidrees/RaiyaChatDoc",
torch_dtype=dtype,
device_map="auto" # automatically assigns to GPU if available
)
model.to(device)
return processor, model, device
processor, model, device = load_model()
# -------------------
# 2️⃣ Chat Logic
# -------------------
def process_message(message, history, question_count):
if not message.strip():
return history, history, question_count
history.append([message, None])
question_count += 1
# Decide if analysis is needed
should_analyze = question_count >= 6 or any(
word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]
)
# System prompt
system_prompt = (
"You are a medical doctor. "
"Provide a comprehensive analysis of potential causes for symptoms."
if should_analyze else
"You are a medical doctor conducting a patient interview. Ask ONE specific question."
)
# Build conversation context
dialogue = []
for user_msg, bot_msg in history[:-1]:
if user_msg: dialogue.append(f"Patient: {user_msg}")
if bot_msg: dialogue.append(f"Doctor: {bot_msg}")
dialogue.append(f"Patient: {message}")
prompt = f"{system_prompt}\n\nConversation:\n" + "\n".join(dialogue) + "\nDoctor:"
# Prepare input
inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
max_tokens = 400 if should_analyze else 25
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=processor.tokenizer.eos_token_id
)
# Decode response
input_length = inputs["input_ids"].shape[1]
response = processor.batch_decode(outputs[:, input_length:], skip_special_tokens=True)[0].strip()
if response.lower().startswith("doctor:"):
response = response[7:].strip()
# Concise question formatting
if not should_analyze:
response = response.split('?')[0].strip() + '?'
history[-1][1] = response
if should_analyze: question_count = 0
return history, history, question_count
def force_analysis(history, question_count):
return history, 10
def clear_chat():
return [], [], 0
# -------------------
# 3️⃣ Gradio Interface
# -------------------
with gr.Blocks(title="ChatDOC") as demo:
question_count_state = gr.State(0)
gr.Markdown("# 🩺 Chat with ChatDOC\nDescribe your symptoms and get guidance.")
chatbot = gr.Chatbot(value=[], height=400, show_label=False)
with gr.Row():
msg = gr.Textbox(placeholder="Describe your symptoms...", scale=4, container=False, show_label=False)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
analysis_btn = gr.Button("Request Analysis", variant="secondary")
clear_btn = gr.Button("Clear Chat", variant="stop")
send_event = send_btn.click(
process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state]
).then(lambda: "", outputs=[msg])
msg.submit(
process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state]
).then(lambda: "", outputs=[msg])
analysis_btn.click(force_analysis, inputs=[chatbot, question_count_state], outputs=[chatbot, question_count_state])
clear_btn.click(clear_chat, outputs=[chatbot, chatbot, question_count_state])
# -------------------
# 4️⃣ Launch
# -------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)