Muhammadidrees's picture
Update app.py
b100d07 verified
raw
history blame
22 kB
import os
import gc
import re
import time
import torch
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from transformers import pipeline
from collections import defaultdict
from datetime import datetime, timedelta
import tempfile
# =============================
# Configuration
# =============================
MODEL_PATH = r"Muhammadidrees/JayConverstionalModel"
WHISPER_MODEL = "openai/whisper-small" # Change to "openai/whisper-base" for faster, or "openai/whisper-medium" for better accuracy
TTS_MODEL = "suno/bark-small" # Alternative: "facebook/mms-tts-eng" for faster TTS
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.5
TOP_K = 50
REPETITION_PENALTY = 1.1
MAX_HISTORY_TURNS = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸš€ Loading models on {device}...")
# =============================
# Rate Limiting
# =============================
rate_limit_store = defaultdict(list)
MAX_REQUESTS_PER_MINUTE = 10
def check_rate_limit(session_id):
"""Simple rate limiting to prevent abuse"""
now = datetime.now()
rate_limit_store[session_id] = [
timestamp for timestamp in rate_limit_store[session_id]
if now - timestamp < timedelta(minutes=1)
]
if len(rate_limit_store[session_id]) >= MAX_REQUESTS_PER_MINUTE:
return False
rate_limit_store[session_id].append(now)
return True
# ==========================
# Load Models
# =============================
try:
# Load ChatDoctor Model
print("Loading ChatDoctor model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
)
print("βœ… ChatDoctor model loaded!")
# Load Whisper (Speech-to-Text)
print("Loading Whisper ASR model...")
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=WHISPER_MODEL,
device=0 if torch.cuda.is_available() else -1
)
print("βœ… Whisper model loaded!")
# Load TTS Model
print("Loading TTS model...")
try:
tts_pipe = pipeline(
"text-to-speech",
model=TTS_MODEL,
device=0 if torch.cuda.is_available() else -1
)
print("βœ… TTS model loaded!")
TTS_AVAILABLE = True
except Exception as e:
print(f"⚠️ TTS model not available: {e}")
TTS_AVAILABLE = False
except Exception as e:
print(f"❌ Error loading models: {e}")
raise
# =============================
# Stop Criteria
# =============================
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_ids):
self.stop_ids = stop_ids
def __call__(self, input_ids, scores, **kwargs):
for stop_id_seq in self.stop_ids:
if len(stop_id_seq) == 1:
if input_ids[0][-1] == stop_id_seq[0]:
return True
else:
if len(input_ids[0]) >= len(stop_id_seq):
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
return True
return False
# =============================
# Medical Keywords and Validation
# =============================
MEDICAL_KEYWORDS = [
"pain", "ache", "symptom", "hurt", "sore", "discomfort", "fever", "cough", "flu",
"infection", "allergy", "diabetes", "pressure", "asthma", "migraine", "vomit",
"stomach", "head", "chest", "throat", "heart", "lung", "liver", "kidney", "brain",
"doctor", "hospital", "medicine", "treatment", "therapy", "surgery", "disease",
"illness", "blood", "test", "scan", "health", "diet", "nutrition", "stress", "sleep",
"weight", "vitamin", "fatigue", "anxiety", "depression", "nausea", "dizziness",
"rash", "swelling", "injury", "bruise", "cold", "sneeze", "tired", "weak"
]
EMERGENCY_KEYWORDS = [
"suicide", "kill myself", "end my life", "chest pain", "can't breathe",
"severe bleeding", "overdose", "poisoning", "unconscious", "seizure",
"stroke", "heart attack", "choking"
]
CASUAL_PATTERNS = [
r"^(hey|hi|hello|sup|yo|wassup|hiya)\s*[\?\!\.]*$",
r"^good\s+(morning|evening|afternoon|night)\s*[\?\!\.]*$",
r"^how\s+are\s+you\s*[\?\!\.]*$",
r"^what'?s\s+up\s*[\?\!\.]*$",
]
DANGEROUS_PATTERNS = [
r"take\s+\d+\s+(pills|tablets|capsules)",
r"inject\s+(yourself|myself)",
r"(don't|do not)\s+go\s+to\s+(hospital|doctor|emergency)",
r"ignore\s+(doctor|medical|professional)",
]
def is_emergency_query(message):
message_lower = message.lower()
return any(keyword in message_lower for keyword in EMERGENCY_KEYWORDS)
def is_medical_query(message):
message_lower = message.lower()
for keyword in MEDICAL_KEYWORDS:
if keyword in message_lower:
return True
question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does", "could", "would"]
words = message_lower.split()
has_question = any(q in words[:4] for q in question_words)
if has_question and len(words) > 5:
return True
return False
def is_only_greeting(message):
message_clean = message.lower().strip()
message_clean = re.sub(r'[!?.]+$', '', message_clean)
for pattern in CASUAL_PATTERNS:
if re.match(pattern, message_clean):
return True
return False
def contains_dangerous_advice(response):
response_lower = response.lower()
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, response_lower):
return True
return False
# =============================
# Speech Processing Functions
# =============================
def transcribe_audio(audio):
"""Convert speech to text using Whisper"""
if audio is None:
return ""
try:
# Handle different audio input formats
if isinstance(audio, tuple):
sample_rate, audio_data = audio
else:
audio_data = audio
# Ensure audio is in the right format
if isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
# Transcribe
result = whisper_pipe(audio_data)
transcription = result["text"].strip()
return transcription
except Exception as e:
print(f"Error in transcription: {e}")
return ""
def text_to_speech(text):
"""Convert text to speech"""
if not TTS_AVAILABLE or not text:
return None
try:
# Limit text length for TTS (to avoid timeout)
if len(text) > 500:
text = text[:500] + "..."
# Generate speech
speech = tts_pipe(text)
# Extract audio data
audio_data = speech["audio"]
sampling_rate = speech["sampling_rate"]
return (sampling_rate, audio_data)
except Exception as e:
print(f"Error in TTS: {e}")
return None
# =============================
# Get Response
# =============================
def get_response(user_input, history_context, session_id="default"):
"""Generate response with enhanced safety and quality checks"""
if not check_rate_limit(session_id):
return "⏰ You've made too many requests. Please wait a minute before trying again."
if is_emergency_query(user_input):
return (
"🚨 **EMERGENCY DETECTED** 🚨\n\n"
"If you are experiencing a medical emergency, please:\n"
"β€’ Call emergency services immediately (911 in US, 999 in UK, 112 in EU)\n"
"β€’ Go to the nearest emergency room\n"
"β€’ Contact your local emergency hotline\n\n"
"This AI cannot provide emergency medical care. Please seek immediate professional help."
)
if is_only_greeting(user_input):
return "πŸ‘‹ Hello! I'm ChatDoctor β€” your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss."
if not is_medical_query(user_input):
return (
"Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n"
"I can help you with:\n"
"β€’ Symptoms and medical conditions\n"
"β€’ Treatment and prevention advice\n"
"β€’ Fitness, diet, and mental health tips\n\n"
"Please describe your health concern in detail to get started."
)
human_prefix = "Patient:"
doctor_prefix = "ChatDoctor:"
system_instruction = (
"You are ChatDoctor, a professional medical AI assistant. "
"You provide accurate, concise, and empathetic responses to health-related questions only.\n"
"Always recommend consulting a healthcare professional for serious conditions.\n"
"Never provide dosage instructions or tell patients to avoid seeking professional help.\n\n"
)
limited_history = history_context[-MAX_HISTORY_TURNS:] if len(history_context) > MAX_HISTORY_TURNS else history_context
history_text = [system_instruction]
for human, assistant in limited_history:
if human:
history_text.append(f"{human_prefix} {human}")
if assistant:
history_text.append(f"{doctor_prefix} {assistant}")
history_text.append(f"{human_prefix} {user_input}")
prompt = "\n".join(history_text) + f"\n{doctor_prefix} "
try:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_k=TOP_K,
repetition_penalty=REPETITION_PENALTY,
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip()
for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]:
if stop_word in response:
response = response.split(stop_word)[0].strip()
break
response = response.strip()
if contains_dangerous_advice(response):
response = (
"I apologize, but I cannot provide that specific medical advice. "
"Please consult with a qualified healthcare professional who can properly evaluate your situation."
)
if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud", "sorry, i don't have"]):
response = (
"I apologize for the confusion. I'm ChatDoctor, trained to assist with medical and health-related topics. "
"Please tell me more about your symptoms or health concerns so I can help you better."
)
serious_conditions = ["cancer", "tumor", "heart disease", "stroke", "diabetes complications"]
if any(condition in response.lower() for condition in serious_conditions):
response += "\n\n⚠️ **Important:** Please consult a healthcare professional for proper diagnosis and treatment."
del input_ids, output_ids
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return response
except Exception as e:
print(f"Error generating response: {e}")
return "I apologize, but I encountered an error processing your request. Please try rephrasing your question or try again later."
# =============================
# Gradio Interface
# =============================
custom_css = """
#header {
text-align: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 25px;
border-radius: 12px;
margin-bottom: 20px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
}
#header h1 { margin: 0; font-size: 2.5em; font-weight: 700; }
#header p { margin: 5px 0 0; font-size: 1.1em; opacity: 0.95; }
.disclaimer {
background-color: #fff3cd;
border-left: 4px solid #ffc107;
border-radius: 8px;
padding: 18px;
margin: 20px 0;
color: #856404;
}
.disclaimer h3 { margin-top: 0; color: #d39e00; }
.emergency-warning {
background-color: #f8d7da;
border-left: 4px solid #dc3545;
border-radius: 8px;
padding: 15px;
margin: 15px 0;
color: #721c24;
}
.voice-section {
background: linear-gradient(135deg, #e0c3fc 0%, #8ec5fc 100%);
border-radius: 10px;
padding: 20px;
margin: 15px 0;
}
footer {
margin-top: 30px;
padding: 15px;
text-align: center;
color: #6c757d;
font-size: 0.9em;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
session_state = gr.State(value=str(time.time()))
gr.HTML("""
<div id="header">
<h1>🩺 ChatDoctor AI Assistant</h1>
<p>🎀 Voice-Enabled Medical Consultation Partner</p>
</div>
""")
gr.HTML("""
<div class="disclaimer">
<h3>⚠️ Medical Disclaimer</h3>
<p><strong>This AI assistant is for informational purposes only.</strong>
It is NOT a substitute for professional medical advice, diagnosis, or treatment.
Always seek the advice of your physician or qualified health provider with any questions
you may have regarding a medical condition.</p>
</div>
""")
gr.HTML("""
<div class="emergency-warning">
<h4>🚨 In Case of Emergency</h4>
<p>If you are experiencing a medical emergency, call emergency services immediately
(911 in US, 999 in UK, 112 in EU) or go to the nearest emergency room.</p>
</div>
""")
with gr.Tab("πŸ’¬ Text Chat"):
chatbot = gr.Chatbot(
height=500,
placeholder="<div style='text-align:center;padding:50px;'><h3>πŸ‘‹ Welcome to ChatDoctor!</h3><p style='color:#6c757d;'>Describe your symptoms or ask a health-related question to begin.</p></div>",
show_label=False,
avatar_images=(None, "πŸ€–"),
)
with gr.Row():
msg = gr.Textbox(
placeholder="Type your medical concern here...",
show_label=False,
scale=9,
container=False,
lines=1
)
send_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
retry_btn = gr.Button("πŸ”„ Retry", scale=1)
with gr.Tab("🎀 Voice Chat"):
gr.HTML('<div class="voice-section"><h3>πŸŽ™οΈ Voice Interaction</h3><p>Record your medical question and get voice responses!</p></div>')
voice_chatbot = gr.Chatbot(
height=400,
placeholder="<div style='text-align:center;padding:40px;'><h3>🎀 Voice Chat Mode</h3><p>Click the microphone to record your question</p></div>",
show_label=False,
avatar_images=(None, "πŸ€–"),
)
with gr.Row():
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
label="🎀 Record Your Question",
scale=8
)
voice_send_btn = gr.Button("Send Voice πŸŽ™οΈ", scale=2, variant="primary")
audio_output = gr.Audio(
label="πŸ”Š Voice Response",
autoplay=True,
visible=TTS_AVAILABLE
)
transcribed_text = gr.Textbox(
label="πŸ“ Transcribed Text",
interactive=False,
visible=True
)
with gr.Row():
voice_clear_btn = gr.Button("πŸ—‘οΈ Clear Voice Chat", scale=1)
if not TTS_AVAILABLE:
gr.Warning("⚠️ TTS model not available. Voice responses disabled. Text responses will still work.")
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature (Lower = More Focused)")
max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens")
top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K Sampling")
# =============================
# Text Chat Functions
# =============================
def user_message(user_msg, history):
if not user_msg.strip():
return "", history
return "", history + [[user_msg, None]]
def bot_response(history, temp, max_tok, topk, session_id):
if not history or history[-1][1] is not None:
return history
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
user_msg = history[-1][0]
bot_msg = get_response(user_msg, history[:-1], session_id)
history[-1][1] = bot_msg
return history
def retry_last(history, temp, max_tok, topk, session_id):
if not history:
return history
user_msg = history[-1][0]
bot_msg = get_response(user_msg, history[:-1], session_id)
history[-1][1] = bot_msg
return history
# =============================
# Voice Chat Functions
# =============================
def text_to_speech(text):
# Convert text to speech using Bark
from transformers import AutoProcessor, BarkModel
import numpy as np
processor = AutoProcessor.from_pretrained("suno/bark-small")
model = BarkModel.from_pretrained("suno/bark-small")
inputs = processor(text, voice_preset="v2/en_speaker_6", return_tensors="pt")
speech = model.generate(**inputs)
# βœ… Extract and normalize audio data
audio_data = speech["audio"]
sampling_rate = speech["sampling_rate"]
# πŸ”Š Normalize & clip Bark audio output to avoid struct.error
if isinstance(audio_data, np.ndarray):
audio_data = np.clip(audio_data, -1.0, 1.0).astype(np.float32)
else:
audio_data = np.array(audio_data, dtype=np.float32)
audio_data = np.clip(audio_data, -1.0, 1.0)
return (sampling_rate, audio_data)
def process_voice_input(audio, history, temp, max_tok, topk, session_id):
"""Process voice input: transcribe, get response, convert to speech"""
if audio is None:
return history, "", None
# Transcribe audio to text
transcribed = transcribe_audio(audio)
if not transcribed:
return history, "⚠️ Could not transcribe audio. Please try again.", None
# Add to history
history = history + [[transcribed, None]]
# Get bot response
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
bot_msg = get_response(transcribed, history[:-1], session_id)
history[-1][1] = bot_msg
# Convert response to speech
audio_response = text_to_speech(bot_msg) if TTS_AVAILABLE else None
return history, transcribed, audio_response
# Text Chat Events
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot
)
send_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot)
# Voice Chat Events
voice_send_btn.click(
process_voice_input,
[audio_input, voice_chatbot, temp_slider, max_tok_slider, top_k_slider, session_state],
[voice_chatbot, transcribed_text, audio_output]
)
voice_clear_btn.click(lambda: (None, "", None), None, [voice_chatbot, transcribed_text, audio_output], queue=False)
gr.HTML(f"""
<footer>
<p><strong>🧠 Powered by LLaMA + Whisper + TTS</strong></p>
<p>Device: {device.upper()} | Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute</p>
<p>🎀 Voice: Whisper ASR | πŸ”Š TTS: {"Enabled" if TTS_AVAILABLE else "Disabled"}</p>
<p style='font-size:0.85em;margin-top:10px;'>
This AI provides general health information only. Always consult healthcare professionals for medical advice.
</p>
</footer>
""")
# =============================
# Launch App
# =============================
if __name__ == "__main__":
print("\nπŸ’‘ Launching ChatDoctor with Voice Support...")
print(f"πŸ“Š Configuration:")
print(f" - Device: {device.upper()}")
print(f" - Whisper Model: {WHISPER_MODEL}")
print(f" - TTS Available: {TTS_AVAILABLE}")
print(f" - Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute")
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)