Muhammadidrees's picture
Update app.py
93dd230 verified
raw
history blame
8.31 kB
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PaitentVoiceToText import record_and_transcribe
from DocVoice import text_to_speech # Your TTS function
# -------------------
# 1️⃣ Load Model & Processor
# -------------------
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load directly from Hugging Face
processor = AutoProcessor.from_pretrained("Muhammadidrees/RaiyaChatDoc", trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
"Muhammadidrees/RaiyaChatDoc",
torch_dtype=dtype,
device_map="auto" # automatically assigns to GPU if available
)
model.to(device)
return processor, model, device
processor, model, device = load_model()
# -------------------
# 2️⃣ Chat Logic Functions
# -------------------
def process_message(message, history, question_count):
if not message.strip():
return history, history, question_count
history.append([message, None])
question_count += 1
should_analyze = (
question_count >= 6 or
any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"])
)
if should_analyze:
system_prompt = (
"You are a medical doctor. Based on the patient's responses, provide a comprehensive analysis "
"of potential causes for their symptoms. Start with 'Based on the information provided by the patient, "
"potential causes of [symptoms] could include:' and list 3-4 possible diagnoses with brief explanations. "
"Format as numbered list with diagnosis name and short explanation."
)
else:
system_prompt = (
"You are a medical doctor conducting a patient interview. Ask ONE specific, direct medical question "
"to gather important diagnostic information. Keep it brief - just ask the question without explanations. "
"Focus on key areas like: age, medical history, medications, lifestyle, family history, or symptom details."
)
dialogue = []
for user_msg, bot_msg in history[:-1]:
if user_msg:
dialogue.append(f"Patient: {user_msg}")
if bot_msg:
dialogue.append(f"Doctor: {bot_msg}")
dialogue.append(f"Patient: {message}")
conversation = "\n".join(dialogue)
prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:"
inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
max_tokens = 1000 if should_analyze else 25
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=processor.tokenizer.eos_token_id,
)
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_length:]
response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
if response.lower().startswith("doctor:"):
response = response[7:].strip()
if not should_analyze:
sentences = response.split('?')
if len(sentences) > 1:
response = sentences[0].strip() + '?'
cleanup_starts = [
"I need to ask",
"Let me ask",
"I would like to know",
"Can you tell me",
"It would help if",
]
for phrase in cleanup_starts:
if response.startswith(phrase):
parts = response.split(',', 1)
if len(parts) > 1:
response = parts[1].strip()
if not response.endswith('?'):
response += '?'
history[-1][1] = response
if should_analyze:
question_count = 0
return history, history, question_count
def force_analysis(history, question_count):
return history, 10
def clear_chat():
return [], [], 0
# -------------------
# 3️⃣ TTS Helper
# -------------------
def play_assistant_audio(response_text):
if response_text:
text_to_speech(response_text)
return None
# -------------------
# 4️⃣ Gradio Interface
# -------------------
with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo:
question_count_state = gr.State(0)
assistant_responses_state = gr.State([])
gr.Markdown(
"""
# 🩺 Chat with ChatDOC
Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better.
"""
)
chatbot = gr.Chatbot(
value=[],
height=400,
show_label=False,
avatar_images=(
r"C:\Users\JAY\Downloads\model\user_msg.png",
r"C:\Users\JAY\Downloads\model\bot_msg.jpg"
),
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="Describe your symptoms...",
scale=4,
container=False,
show_label=False
)
send_btn = gr.Button("Send", variant="primary", scale=1)
mic_btn = gr.Button("🎤 Speak", variant="secondary", scale=1)
with gr.Row():
analysis_btn = gr.Button("Request Analysis", variant="secondary")
clear_btn = gr.Button("Clear Chat", variant="stop")
play_audio_btn = gr.Button("🔊 Play Assistant Response", variant="secondary")
# -------------------
# Update assistant responses
# -------------------
def update_assistant_responses(history, assistant_responses):
if history and history[-1][1]:
assistant_responses.append(history[-1][1])
return assistant_responses
# -------------------
# Submit handlers
# -------------------
def user_submit(message, history, question_count, assistant_responses):
history, updated_history, question_count = process_message(message, history, question_count)
assistant_responses = update_assistant_responses(history, assistant_responses)
return updated_history, updated_history, question_count, assistant_responses
def mic_submit(history, question_count, assistant_responses):
user_text = record_and_transcribe(duration=5)
history.append([user_text, None])
history, updated_history, question_count = process_message(user_text, history, question_count)
assistant_responses = update_assistant_responses(history, assistant_responses)
return updated_history, updated_history, question_count, assistant_responses
def clear_input():
return ""
# -------------------
# Connect buttons
# -------------------
send_btn.click(
user_submit,
inputs=[msg, chatbot, question_count_state, assistant_responses_state],
outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
).then(clear_input, outputs=[msg])
msg.submit(
user_submit,
inputs=[msg, chatbot, question_count_state, assistant_responses_state],
outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
).then(clear_input, outputs=[msg])
mic_btn.click(
mic_submit,
inputs=[chatbot, question_count_state, assistant_responses_state],
outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
)
analysis_btn.click(
force_analysis,
inputs=[chatbot, question_count_state],
outputs=[chatbot, question_count_state]
)
clear_btn.click(
clear_chat,
outputs=[chatbot, chatbot, question_count_state]
)
play_audio_btn.click(
lambda assistant_responses: play_assistant_audio(assistant_responses[-1]) if assistant_responses else None,
inputs=[assistant_responses_state],
outputs=[]
)
# -------------------
# 5️⃣ Launch
# -------------------
if __name__ == "__main__":
demo.launch(
server_name="127.0.0.1",
server_port=7860,
share=False,
debug=True
)