import os import gc import re import time import torch import gradio as gr import numpy as np from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList from transformers import pipeline from collections import defaultdict from datetime import datetime, timedelta import tempfile # ============================= # Configuration # ============================= MODEL_PATH = r"Muhammadidrees/JayConverstionalModel" WHISPER_MODEL = "openai/whisper-small" # Change to "openai/whisper-base" for faster, or "openai/whisper-medium" for better accuracy TTS_MODEL = "suno/bark-small" # Alternative: "facebook/mms-tts-eng" for faster TTS MAX_NEW_TOKENS = 200 TEMPERATURE = 0.5 TOP_K = 50 REPETITION_PENALTY = 1.1 MAX_HISTORY_TURNS = 5 device = "cuda" if torch.cuda.is_available() else "cpu" print(f"šŸš€ Loading models on {device}...") # ============================= # Rate Limiting # ============================= rate_limit_store = defaultdict(list) MAX_REQUESTS_PER_MINUTE = 10 def check_rate_limit(session_id): """Simple rate limiting to prevent abuse""" now = datetime.now() rate_limit_store[session_id] = [ timestamp for timestamp in rate_limit_store[session_id] if now - timestamp < timedelta(minutes=1) ] if len(rate_limit_store[session_id]) >= MAX_REQUESTS_PER_MINUTE: return False rate_limit_store[session_id].append(now) return True # ========================== # Load Models # ============================= try: # Load ChatDoctor Model print("Loading ChatDoctor model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True ) print("āœ… ChatDoctor model loaded!") # Load Whisper (Speech-to-Text) print("Loading Whisper ASR model...") whisper_pipe = pipeline( "automatic-speech-recognition", model=WHISPER_MODEL, device=0 if torch.cuda.is_available() else -1 ) print("āœ… Whisper model loaded!") # Load TTS Model print("Loading TTS model...") try: tts_pipe = pipeline( "text-to-speech", model=TTS_MODEL, device=0 if torch.cuda.is_available() else -1 ) print("āœ… TTS model loaded!") TTS_AVAILABLE = True except Exception as e: print(f"āš ļø TTS model not available: {e}") TTS_AVAILABLE = False except Exception as e: print(f"āŒ Error loading models: {e}") raise # ============================= # Stop Criteria # ============================= class StopOnTokens(StoppingCriteria): def __init__(self, stop_ids): self.stop_ids = stop_ids def __call__(self, input_ids, scores, **kwargs): for stop_id_seq in self.stop_ids: if len(stop_id_seq) == 1: if input_ids[0][-1] == stop_id_seq[0]: return True else: if len(input_ids[0]) >= len(stop_id_seq): if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: return True return False # ============================= # Medical Keywords and Validation # ============================= MEDICAL_KEYWORDS = [ "pain", "ache", "symptom", "hurt", "sore", "discomfort", "fever", "cough", "flu", "infection", "allergy", "diabetes", "pressure", "asthma", "migraine", "vomit", "stomach", "head", "chest", "throat", "heart", "lung", "liver", "kidney", "brain", "doctor", "hospital", "medicine", "treatment", "therapy", "surgery", "disease", "illness", "blood", "test", "scan", "health", "diet", "nutrition", "stress", "sleep", "weight", "vitamin", "fatigue", "anxiety", "depression", "nausea", "dizziness", "rash", "swelling", "injury", "bruise", "cold", "sneeze", "tired", "weak" ] EMERGENCY_KEYWORDS = [ "suicide", "kill myself", "end my life", "chest pain", "can't breathe", "severe bleeding", "overdose", "poisoning", "unconscious", "seizure", "stroke", "heart attack", "choking" ] CASUAL_PATTERNS = [ r"^(hey|hi|hello|sup|yo|wassup|hiya)\s*[\?\!\.]*$", r"^good\s+(morning|evening|afternoon|night)\s*[\?\!\.]*$", r"^how\s+are\s+you\s*[\?\!\.]*$", r"^what'?s\s+up\s*[\?\!\.]*$", ] DANGEROUS_PATTERNS = [ r"take\s+\d+\s+(pills|tablets|capsules)", r"inject\s+(yourself|myself)", r"(don't|do not)\s+go\s+to\s+(hospital|doctor|emergency)", r"ignore\s+(doctor|medical|professional)", ] def is_emergency_query(message): message_lower = message.lower() return any(keyword in message_lower for keyword in EMERGENCY_KEYWORDS) def is_medical_query(message): message_lower = message.lower() for keyword in MEDICAL_KEYWORDS: if keyword in message_lower: return True question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does", "could", "would"] words = message_lower.split() has_question = any(q in words[:4] for q in question_words) if has_question and len(words) > 5: return True return False def is_only_greeting(message): message_clean = message.lower().strip() message_clean = re.sub(r'[!?.]+$', '', message_clean) for pattern in CASUAL_PATTERNS: if re.match(pattern, message_clean): return True return False def contains_dangerous_advice(response): response_lower = response.lower() for pattern in DANGEROUS_PATTERNS: if re.search(pattern, response_lower): return True return False # ============================= # Speech Processing Functions # ============================= def transcribe_audio(audio): """Convert speech to text using Whisper""" if audio is None: return "" try: # Handle different audio input formats if isinstance(audio, tuple): sample_rate, audio_data = audio else: audio_data = audio # Ensure audio is in the right format if isinstance(audio_data, np.ndarray): if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max # Transcribe result = whisper_pipe(audio_data) transcription = result["text"].strip() return transcription except Exception as e: print(f"Error in transcription: {e}") return "" def text_to_speech(text): """Convert text to speech""" if not TTS_AVAILABLE or not text: return None try: # Limit text length for TTS (to avoid timeout) if len(text) > 500: text = text[:500] + "..." # Generate speech speech = tts_pipe(text) # Extract audio data audio_data = speech["audio"] sampling_rate = speech["sampling_rate"] return (sampling_rate, audio_data) except Exception as e: print(f"Error in TTS: {e}") return None # ============================= # Get Response # ============================= def get_response(user_input, history_context, session_id="default"): """Generate response with enhanced safety and quality checks""" if not check_rate_limit(session_id): return "ā° You've made too many requests. Please wait a minute before trying again." if is_emergency_query(user_input): return ( "🚨 **EMERGENCY DETECTED** 🚨\n\n" "If you are experiencing a medical emergency, please:\n" "• Call emergency services immediately (911 in US, 999 in UK, 112 in EU)\n" "• Go to the nearest emergency room\n" "• Contact your local emergency hotline\n\n" "This AI cannot provide emergency medical care. Please seek immediate professional help." ) if is_only_greeting(user_input): return "šŸ‘‹ Hello! I'm ChatDoctor — your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss." if not is_medical_query(user_input): return ( "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n" "I can help you with:\n" "• Symptoms and medical conditions\n" "• Treatment and prevention advice\n" "• Fitness, diet, and mental health tips\n\n" "Please describe your health concern in detail to get started." ) human_prefix = "Patient:" doctor_prefix = "ChatDoctor:" system_instruction = ( "You are ChatDoctor, a professional medical AI assistant. " "You provide accurate, concise, and empathetic responses to health-related questions only.\n" "Always recommend consulting a healthcare professional for serious conditions.\n" "Never provide dosage instructions or tell patients to avoid seeking professional help.\n\n" ) limited_history = history_context[-MAX_HISTORY_TURNS:] if len(history_context) > MAX_HISTORY_TURNS else history_context history_text = [system_instruction] for human, assistant in limited_history: if human: history_text.append(f"{human_prefix} {human}") if assistant: history_text.append(f"{doctor_prefix} {assistant}") history_text.append(f"{human_prefix} {user_input}") prompt = "\n".join(history_text) + f"\n{doctor_prefix} " try: input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip() for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]: if stop_word in response: response = response.split(stop_word)[0].strip() break response = response.strip() if contains_dangerous_advice(response): response = ( "I apologize, but I cannot provide that specific medical advice. " "Please consult with a qualified healthcare professional who can properly evaluate your situation." ) if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud", "sorry, i don't have"]): response = ( "I apologize for the confusion. I'm ChatDoctor, trained to assist with medical and health-related topics. " "Please tell me more about your symptoms or health concerns so I can help you better." ) serious_conditions = ["cancer", "tumor", "heart disease", "stroke", "diabetes complications"] if any(condition in response.lower() for condition in serious_conditions): response += "\n\nāš ļø **Important:** Please consult a healthcare professional for proper diagnosis and treatment." del input_ids, output_ids gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return response except Exception as e: print(f"Error generating response: {e}") return "I apologize, but I encountered an error processing your request. Please try rephrasing your question or try again later." # ============================= # Gradio Interface # ============================= custom_css = """ #header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 25px; border-radius: 12px; margin-bottom: 20px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } #header h1 { margin: 0; font-size: 2.5em; font-weight: 700; } #header p { margin: 5px 0 0; font-size: 1.1em; opacity: 0.95; } .disclaimer { background-color: #fff3cd; border-left: 4px solid #ffc107; border-radius: 8px; padding: 18px; margin: 20px 0; color: #856404; } .disclaimer h3 { margin-top: 0; color: #d39e00; } .emergency-warning { background-color: #f8d7da; border-left: 4px solid #dc3545; border-radius: 8px; padding: 15px; margin: 15px 0; color: #721c24; } .voice-section { background: linear-gradient(135deg, #e0c3fc 0%, #8ec5fc 100%); border-radius: 10px; padding: 20px; margin: 15px 0; } footer { margin-top: 30px; padding: 15px; text-align: center; color: #6c757d; font-size: 0.9em; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: session_state = gr.State(value=str(time.time())) gr.HTML(""" """) gr.HTML("""

āš ļø Medical Disclaimer

This AI assistant is for informational purposes only. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of your physician or qualified health provider with any questions you may have regarding a medical condition.

""") gr.HTML("""

🚨 In Case of Emergency

If you are experiencing a medical emergency, call emergency services immediately (911 in US, 999 in UK, 112 in EU) or go to the nearest emergency room.

""") with gr.Tab("šŸ’¬ Text Chat"): chatbot = gr.Chatbot( height=500, placeholder="

šŸ‘‹ Welcome to ChatDoctor!

Describe your symptoms or ask a health-related question to begin.

", show_label=False, avatar_images=(None, "šŸ¤–"), ) with gr.Row(): msg = gr.Textbox( placeholder="Type your medical concern here...", show_label=False, scale=9, container=False, lines=1 ) send_btn = gr.Button("Send šŸ“¤", scale=1, variant="primary") with gr.Row(): clear_btn = gr.Button("šŸ—‘ļø Clear Chat", scale=1) retry_btn = gr.Button("šŸ”„ Retry", scale=1) with gr.Tab("šŸŽ¤ Voice Chat"): gr.HTML('

šŸŽ™ļø Voice Interaction

Record your medical question and get voice responses!

') voice_chatbot = gr.Chatbot( height=400, placeholder="

šŸŽ¤ Voice Chat Mode

Click the microphone to record your question

", show_label=False, avatar_images=(None, "šŸ¤–"), ) with gr.Row(): audio_input = gr.Audio( sources=["microphone"], type="numpy", label="šŸŽ¤ Record Your Question", scale=8 ) voice_send_btn = gr.Button("Send Voice šŸŽ™ļø", scale=2, variant="primary") audio_output = gr.Audio( label="šŸ”Š Voice Response", autoplay=True, visible=TTS_AVAILABLE ) transcribed_text = gr.Textbox( label="šŸ“ Transcribed Text", interactive=False, visible=True ) with gr.Row(): voice_clear_btn = gr.Button("šŸ—‘ļø Clear Voice Chat", scale=1) if not TTS_AVAILABLE: gr.Warning("āš ļø TTS model not available. Voice responses disabled. Text responses will still work.") with gr.Accordion("āš™ļø Advanced Settings", open=False): temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature (Lower = More Focused)") max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens") top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K Sampling") # ============================= # Text Chat Functions # ============================= def user_message(user_msg, history): if not user_msg.strip(): return "", history return "", history + [[user_msg, None]] def bot_response(history, temp, max_tok, topk, session_id): if not history or history[-1][1] is not None: return history global TEMPERATURE, MAX_NEW_TOKENS, TOP_K TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk) user_msg = history[-1][0] bot_msg = get_response(user_msg, history[:-1], session_id) history[-1][1] = bot_msg return history def retry_last(history, temp, max_tok, topk, session_id): if not history: return history user_msg = history[-1][0] bot_msg = get_response(user_msg, history[:-1], session_id) history[-1][1] = bot_msg return history # ============================= # Voice Chat Functions # ============================= def text_to_speech(text): # Convert text to speech using Bark from transformers import AutoProcessor, BarkModel import numpy as np processor = AutoProcessor.from_pretrained("suno/bark-small") model = BarkModel.from_pretrained("suno/bark-small") inputs = processor(text, voice_preset="v2/en_speaker_6", return_tensors="pt") speech = model.generate(**inputs) # āœ… Extract and normalize audio data audio_data = speech["audio"] sampling_rate = speech["sampling_rate"] # šŸ”Š Normalize & clip Bark audio output to avoid struct.error if isinstance(audio_data, np.ndarray): audio_data = np.clip(audio_data, -1.0, 1.0).astype(np.float32) else: audio_data = np.array(audio_data, dtype=np.float32) audio_data = np.clip(audio_data, -1.0, 1.0) return (sampling_rate, audio_data) def process_voice_input(audio, history, temp, max_tok, topk, session_id): """Process voice input: transcribe, get response, convert to speech""" if audio is None: return history, "", None # Transcribe audio to text transcribed = transcribe_audio(audio) if not transcribed: return history, "āš ļø Could not transcribe audio. Please try again.", None # Add to history history = history + [[transcribed, None]] # Get bot response global TEMPERATURE, MAX_NEW_TOKENS, TOP_K TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk) bot_msg = get_response(transcribed, history[:-1], session_id) history[-1][1] = bot_msg # Convert response to speech audio_response = text_to_speech(bot_msg) if TTS_AVAILABLE else None return history, transcribed, audio_response # Text Chat Events msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot ) send_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot ) clear_btn.click(lambda: None, None, chatbot, queue=False) retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot) # Voice Chat Events voice_send_btn.click( process_voice_input, [audio_input, voice_chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], [voice_chatbot, transcribed_text, audio_output] ) voice_clear_btn.click(lambda: (None, "", None), None, [voice_chatbot, transcribed_text, audio_output], queue=False) gr.HTML(f""" """) # ============================= # Launch App # ============================= if __name__ == "__main__": print("\nšŸ’” Launching ChatDoctor with Voice Support...") print(f"šŸ“Š Configuration:") print(f" - Device: {device.upper()}") print(f" - Whisper Model: {WHISPER_MODEL}") print(f" - TTS Available: {TTS_AVAILABLE}") print(f" - Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute") demo.queue() demo.launch(server_name="0.0.0.0", server_port=7860, share=False)