import os import gc import torch import gradio as gr from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList # ============================= # Configuration # ============================= MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained" MAX_NEW_TOKENS = 200 TEMPERATURE = 0.5 TOP_K = 50 REPETITION_PENALTY = 1.1 # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {MODEL_PATH} on {device}...") # ============================= # Load Tokenizer and Model # ============================= tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) model = LlamaForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) generator = model.generate print("āœ… ChatDoctor model loaded successfully!\n") # ============================= # System Prompt # ============================= SYSTEM_PROMPT = """ You are ChatDoctor — a friendly, professional, and caring virtual doctor. Whenever a patient describes their symptoms: 1. Always include a recommendation for diet, fluids, and proteins appropriate for recovery. - Fruits: citrus (orange, lemon), kiwi, papaya - Vegetables: leafy greens, carrots, spinach - Fluids: warm soups, herbal teas, coconut water - Proteins: boiled eggs, lentils, fish, chicken soup - Extras: garlic, ginger, turmeric 2. Recommend safe over-the-counter medicines if applicable (e.g., paracetamol for fever). 3. Ask follow-up questions if needed to understand the patient's condition better. 4. Always encourage the patient to see a real doctor if symptoms persist, worsen, or are serious. 5. Provide clear, warm, and empathetic advice. 6. Make your response structured and easy to understand. 7. Even if the patient only mentions a symptom, always include diet, fluids, protein, and care suggestions automatically. """ # ============================= # Stopping Criteria # ============================= class StopOnTokens(StoppingCriteria): def __init__(self, stop_ids): self.stop_ids = stop_ids def __call__(self, input_ids, scores, **kwargs): for stop_id_seq in self.stop_ids: if len(stop_id_seq) == 1: if input_ids[0][-1] == stop_id_seq[0]: return True else: if len(input_ids[0]) >= len(stop_id_seq): if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: return True return False # ============================= # Chat History (Global) # ============================= conversation_history = [] # ============================= # Get Response Function # ============================= def get_response(user_input, history_context): """Generate response from ChatDoctor model""" # Build conversation from history history_text = [] for human, assistant in history_context: if human: history_text.append("Patient: " + human) if assistant: history_text.append("ChatDoctor: " + assistant) # Add current user input history_text.append("Patient: " + user_input) # Build full prompt including system instructions prompt = SYSTEM_PROMPT + "\n\nConversation so far:\n" + "\n".join(history_text) + "\nChatDoctor:" input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Define stop words and their token IDs stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) # Generate model response with torch.no_grad(): output_ids = generator( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode and clean response full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) response = full_output[len(prompt):].strip() # Remove any "Patient:" that might have slipped through for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]: if stop_word in response: response = response.split(stop_word)[0].strip() break # Free memory del input_ids, output_ids gc.collect() torch.cuda.empty_cache() return response # ============================= # Gradio Chat Function # ============================= def chat_function(message, history): """Gradio chat interface function""" if not message.strip(): return "" try: response = get_response(message, history) return response except Exception as e: return f"Error: {str(e)}" # ============================= # Custom CSS # ============================= custom_css = """ #header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px; } #header h1 { margin: 0; font-size: 2.5em; } #header p { margin: 10px 0 0 0; font-size: 1.1em; opacity: 0.9; } .disclaimer { background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 8px; padding: 15px; margin: 20px 0; color: #856404; } .disclaimer h3 { margin-top: 0; color: #856404; } footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } """ # ============================= # Gradio Interface # ============================= with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: # Header gr.HTML(""" """) # Disclaimer gr.HTML("""

āš ļø Medical Disclaimer

Important: This AI assistant is for informational and educational purposes only. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition. Never disregard professional medical advice or delay in seeking it because of something you have read here.

""") # Chatbot Interface chatbot = gr.Chatbot( height=500, placeholder="

šŸ‘‹ Welcome to ChatDoctor!

I'm here to discuss your health concerns. How can I assist you today?

", show_label=False, avatar_images=(None, "šŸ¤–"), ) with gr.Row(): msg = gr.Textbox( placeholder="Type your message here... (e.g., 'I have a headache')", show_label=False, scale=9, container=False ) submit_btn = gr.Button("Send šŸ“¤", scale=1, variant="primary") with gr.Row(): clear_btn = gr.Button("šŸ—‘ļø Clear Chat", scale=1) retry_btn = gr.Button("šŸ”„ Retry", scale=1) # Examples gr.Examples( examples=[ "I have a persistent headache for 3 days. What should I do?", "What are the symptoms of diabetes?", "How can I improve my sleep quality?", "I have a fever and sore throat. Should I be concerned?", "What are some natural ways to reduce stress?", ], inputs=msg, label="šŸ’” Example Questions" ) # Settings (collapsed by default) with gr.Accordion("āš™ļø Advanced Settings", open=False): temperature_slider = gr.Slider( minimum=0.1, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature (Creativity)", info="Higher values make responses more creative but less focused" ) max_tokens_slider = gr.Slider( minimum=50, maximum=500, value=MAX_NEW_TOKENS, step=50, label="Max Response Length", info="Maximum number of tokens in response" ) top_k_slider = gr.Slider( minimum=1, maximum=100, value=TOP_K, step=1, label="Top K", info="Limits vocabulary selection" ) # Footer gr.HTML(f""" """) # Event handlers def user_message(user_msg, history): return "", history + [[user_msg, None]] def bot_response(history, temp, max_tok, top_k_val): global TEMPERATURE, MAX_NEW_TOKENS, TOP_K TEMPERATURE = temp MAX_NEW_TOKENS = int(max_tok) TOP_K = int(top_k_val) user_msg = history[-1][0] bot_msg = chat_function(user_msg, history[:-1]) history[-1][1] = bot_msg return history # Connect events msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot ) submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot ) clear_btn.click(lambda: None, None, chatbot, queue=False) def retry_last(): return None retry_btn.click(retry_last, None, chatbot, queue=False) # ============================= # Launch Interface # ============================= if __name__ == "__main__": print("\nšŸš€ Launching ChatDoctor Gradio Interface...") demo.queue() demo.launch( server_name="0.0.0.0", # Accessible from network server_port=7860, share=False, # Set to True to create public link show_error=True )