Spaces:
Sleeping
Sleeping
File size: 4,326 Bytes
6068b3b 72f0197 c33a7b6 6068b3b 72f0197 6068b3b 72f0197 6068b3b c33a7b6 72f0197 c33a7b6 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b 72f0197 6068b3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# app.py
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
# -------------------
# 1️⃣ Load Model
# -------------------
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load model and processor from Hugging Face
processor = AutoProcessor.from_pretrained("Muhammadidrees/RaiyaChatDoc", trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
"Muhammadidrees/RaiyaChatDoc",
torch_dtype=dtype,
device_map="auto" # automatically assigns to GPU if available
)
model.to(device)
return processor, model, device
processor, model, device = load_model()
# -------------------
# 2️⃣ Chat Logic
# -------------------
def process_message(message, history, question_count):
if not message.strip():
return history, history, question_count
history.append([message, None])
question_count += 1
# Decide if analysis is needed
should_analyze = question_count >= 6 or any(
word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]
)
# System prompt
system_prompt = (
"You are a medical doctor. "
"Provide a comprehensive analysis of potential causes for symptoms."
if should_analyze else
"You are a medical doctor conducting a patient interview. Ask ONE specific question."
)
# Build conversation context
dialogue = []
for user_msg, bot_msg in history[:-1]:
if user_msg: dialogue.append(f"Patient: {user_msg}")
if bot_msg: dialogue.append(f"Doctor: {bot_msg}")
dialogue.append(f"Patient: {message}")
prompt = f"{system_prompt}\n\nConversation:\n" + "\n".join(dialogue) + "\nDoctor:"
# Prepare input
inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
max_tokens = 400 if should_analyze else 25
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=processor.tokenizer.eos_token_id
)
# Decode response
input_length = inputs["input_ids"].shape[1]
response = processor.batch_decode(outputs[:, input_length:], skip_special_tokens=True)[0].strip()
if response.lower().startswith("doctor:"):
response = response[7:].strip()
# Concise question formatting
if not should_analyze:
response = response.split('?')[0].strip() + '?'
history[-1][1] = response
if should_analyze: question_count = 0
return history, history, question_count
def force_analysis(history, question_count):
return history, 10
def clear_chat():
return [], [], 0
# -------------------
# 3️⃣ Gradio Interface
# -------------------
with gr.Blocks(title="ChatDOC") as demo:
question_count_state = gr.State(0)
gr.Markdown("# 🩺 Chat with ChatDOC\nDescribe your symptoms and get guidance.")
chatbot = gr.Chatbot(value=[], height=400, show_label=False)
with gr.Row():
msg = gr.Textbox(placeholder="Describe your symptoms...", scale=4, container=False, show_label=False)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
analysis_btn = gr.Button("Request Analysis", variant="secondary")
clear_btn = gr.Button("Clear Chat", variant="stop")
send_event = send_btn.click(
process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state]
).then(lambda: "", outputs=[msg])
msg.submit(
process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state]
).then(lambda: "", outputs=[msg])
analysis_btn.click(force_analysis, inputs=[chatbot, question_count_state], outputs=[chatbot, question_count_state])
clear_btn.click(clear_chat, outputs=[chatbot, chatbot, question_count_state])
# -------------------
# 4️⃣ Launch
# -------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
|