import os import gc import torch import gradio as gr from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList # ============================= # Configuration # ============================= MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained" MAX_NEW_TOKENS = 200 TEMPERATURE = 0.5 TOP_K = 50 REPETITION_PENALTY = 1.1 # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {MODEL_PATH} on {device}...") # ============================= # Load Tokenizer and Model # ============================= tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) model = LlamaForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) generator = model.generate print("āœ… ChatDoctor model loaded successfully!\n") # ============================= # Stopping Criteria # ============================= class StopOnTokens(StoppingCriteria): def __init__(self, stop_ids): self.stop_ids = stop_ids def __call__(self, input_ids, scores, **kwargs): for stop_id_seq in self.stop_ids: if len(stop_id_seq) == 1: if input_ids[0][-1] == stop_id_seq[0]: return True else: if len(input_ids[0]) >= len(stop_id_seq): if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: return True return False # ============================= # Get Response Function # ============================= def get_response(user_input, history_context): """Generate response from ChatDoctor model""" human_invitation = "Patient: " doctor_invitation = "ChatDoctor: " # Build conversation from history history_text = [] for human, assistant in history_context: if human: history_text.append(human_invitation + human) if assistant: history_text.append(doctor_invitation + assistant) # Add current user input history_text.append(human_invitation + user_input) # Build conversation prompt prompt = "\n".join(history_text) + "\n" + doctor_invitation input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Define stop words and their token IDs stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) # Generate model response with torch.no_grad(): output_ids = generator( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode and clean response full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) response = full_output[len(prompt):].strip() # Remove any "Patient:" that might have slipped through for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]: if stop_word in response: response = response.split(stop_word)[0].strip() break response = response.strip() # Free memory del input_ids, output_ids gc.collect() torch.cuda.empty_cache() return response # ============================= # Gradio Chat Function # ============================= def chat_function(message, history): """Gradio chat interface function""" if not message.strip(): return "" try: response = get_response(message, history) return response except Exception as e: return f"Error: {str(e)}" # ============================= # Text-to-Speech Function # ============================= def text_to_speech(text): """Convert text response to speech""" try: from gtts import gTTS import tempfile if not text or text.startswith("Error:"): return None # Create speech tts = gTTS(text=text, lang='en', slow=False) # Save to temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') tts.save(temp_file.name) return temp_file.name except Exception as e: print(f"TTS Error: {e}") return None # ============================= # Custom CSS # ============================= custom_css = """ #header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px; } #header h1 { margin: 0; font-size: 2.5em; } #header p { margin: 10px 0 0 0; font-size: 1.1em; opacity: 0.9; } .disclaimer { background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 8px; padding: 15px; margin: 20px 0; color: #856404; } .disclaimer h3 { margin-top: 0; color: #856404; } .voice-section { background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 20px; border-radius: 10px; margin: 20px 0; } footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } """ # ============================= # Gradio Interface # ============================= with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: # Header gr.HTML(""" """) # Disclaimer gr.HTML("""

āš ļø Medical Disclaimer

Important: This AI assistant is for informational and educational purposes only. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition. Never disregard professional medical advice or delay in seeking it because of something you have read here.

""") with gr.Row(): with gr.Column(scale=7): # Chatbot Interface chatbot = gr.Chatbot( height=500, placeholder="

šŸ‘‹ Welcome to ChatDoctor!

I'm here to discuss your health concerns. Type or speak your question!

", show_label=False, avatar_images=(None, "šŸ¤–"), ) with gr.Row(): msg = gr.Textbox( placeholder="Type your message here... (e.g., 'I have a headache')", show_label=False, scale=9, container=False ) submit_btn = gr.Button("Send šŸ“¤", scale=1, variant="primary") with gr.Row(): clear_btn = gr.Button("šŸ—‘ļø Clear Chat", scale=1) retry_btn = gr.Button("šŸ”„ Retry", scale=1) with gr.Column(scale=3): # Voice Input Section gr.HTML("

šŸŽ¤ Voice Features

") audio_input = gr.Audio( sources=["microphone"], type="filepath", label="šŸŽ™ļø Speak Your Question", show_download_button=False ) transcribed_text = gr.Textbox( label="šŸ“ Transcribed Text", placeholder="Your speech will appear here...", interactive=False, lines=3 ) send_voice_btn = gr.Button("Send Voice Message šŸ”Š", variant="primary") gr.Markdown("---") # Voice Output tts_enabled = gr.Checkbox( label="šŸ”Š Enable Text-to-Speech for responses", value=True, info="Hear the doctor's response" ) audio_output = gr.Audio( label="šŸ”ˆ AI Response Audio", autoplay=False, visible=True ) # Examples gr.Examples( examples=[ "I have a persistent headache for 3 days. What should I do?", "What are the symptoms of diabetes?", "How can I improve my sleep quality?", "I have a fever and sore throat. Should I be concerned?", "What are some natural ways to reduce stress?", ], inputs=msg, label="šŸ’” Example Questions" ) # Settings (collapsed by default) with gr.Accordion("āš™ļø Advanced Settings", open=False): temperature_slider = gr.Slider( minimum=0.1, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature (Creativity)", info="Higher values make responses more creative but less focused" ) max_tokens_slider = gr.Slider( minimum=50, maximum=500, value=MAX_NEW_TOKENS, step=50, label="Max Response Length", info="Maximum number of tokens in response" ) top_k_slider = gr.Slider( minimum=1, maximum=100, value=TOP_K, step=1, label="Top K", info="Limits vocabulary selection" ) # Footer gr.HTML(""" """) # ============================= # Event Handlers # ============================= def user_message(user_msg, history): return "", history + [[user_msg, None]], None def bot_response(history, temp, max_tok, top_k_val, tts_enabled_val): global TEMPERATURE, MAX_NEW_TOKENS, TOP_K TEMPERATURE = temp MAX_NEW_TOKENS = int(max_tok) TOP_K = int(top_k_val) user_msg = history[-1][0] bot_msg = chat_function(user_msg, history[:-1]) history[-1][1] = bot_msg # Generate audio if TTS is enabled audio_file = None if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"): audio_file = text_to_speech(bot_msg) return history, audio_file def transcribe_audio(audio_file): """Transcribe audio to text using Whisper""" if audio_file is None: return "" try: import whisper model = whisper.load_model("base") result = model.transcribe(audio_file) return result["text"] except ImportError: return "Error: Please install whisper: pip install openai-whisper" except Exception as e: return f"Transcription error: {str(e)}" def process_voice_input(audio_file, history, temp, max_tok, top_k_val, tts_enabled_val): """Process voice input: transcribe -> send -> get response""" if audio_file is None: return history, "", None, None # Transcribe transcribed = transcribe_audio(audio_file) if transcribed.startswith("Error:"): return history, transcribed, None, None # Add to chat history = history + [[transcribed, None]] # Get response global TEMPERATURE, MAX_NEW_TOKENS, TOP_K TEMPERATURE = temp MAX_NEW_TOKENS = int(max_tok) TOP_K = int(top_k_val) bot_msg = chat_function(transcribed, history[:-1]) history[-1][1] = bot_msg # Generate audio if TTS is enabled audio_file = None if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"): audio_file = text_to_speech(bot_msg) return history, transcribed, None, audio_file # Text input events msg.submit( user_message, [msg, chatbot], [msg, chatbot, audio_output], queue=False ).then( bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled], [chatbot, audio_output] ) submit_btn.click( user_message, [msg, chatbot], [msg, chatbot, audio_output], queue=False ).then( bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled], [chatbot, audio_output] ) # Voice input events audio_input.change( transcribe_audio, [audio_input], [transcribed_text] ) send_voice_btn.click( process_voice_input, [audio_input, chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled], [chatbot, transcribed_text, audio_input, audio_output] ) # Clear and retry clear_btn.click(lambda: (None, None, None), None, [chatbot, audio_output, transcribed_text], queue=False) retry_btn.click(lambda: None, None, chatbot, queue=False) # ============================= # Launch Interface # ============================= if __name__ == "__main__": print("\nšŸš€ Launching ChatDoctor Gradio Interface with Voice Support...") print("\nšŸ“¦ Required packages:") print(" pip install gradio gTTS openai-whisper") print("\nNote: Whisper will download models on first use (~100MB for base model)\n") demo.queue() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )