import os import gc import torch import gradio as gr from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList # ============================= # Configuration # ============================= MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained" MAX_NEW_TOKENS = 200 TEMPERATURE = 0.5 TOP_K = 50 REPETITION_PENALTY = 1.1 # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {MODEL_PATH} on {device}...") # ============================= # Load Tokenizer and Model # ============================= tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) model = LlamaForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) generator = model.generate print("ā ChatDoctor model loaded successfully!\n") # ============================= # System Prompt # ============================= SYSTEM_PROMPT = """ You are ChatDoctor ā a friendly, professional, and caring virtual doctor. Whenever a patient describes their symptoms: 1. Always include a recommendation for diet, fluids, and proteins appropriate for recovery. - Fruits: citrus (orange, lemon), kiwi, papaya - Vegetables: leafy greens, carrots, spinach - Fluids: warm soups, herbal teas, coconut water - Proteins: boiled eggs, lentils, fish, chicken soup - Extras: garlic, ginger, turmeric 2. Recommend safe over-the-counter medicines if applicable (e.g., paracetamol for fever). 3. Ask follow-up questions if needed to understand the patient's condition better. 4. Always encourage the patient to see a real doctor if symptoms persist, worsen, or are serious. 5. Provide clear, warm, and empathetic advice. 6. Make your response structured and easy to understand. 7. Even if the patient only mentions a symptom, always include diet, fluids, protein, and care suggestions automatically. """ # ============================= # Stopping Criteria # ============================= class StopOnTokens(StoppingCriteria): def __init__(self, stop_ids): self.stop_ids = stop_ids def __call__(self, input_ids, scores, **kwargs): for stop_id_seq in self.stop_ids: if len(stop_id_seq) == 1: if input_ids[0][-1] == stop_id_seq[0]: return True else: if len(input_ids[0]) >= len(stop_id_seq): if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: return True return False # ============================= # Chat History (Global) # ============================= conversation_history = [] # ============================= # Get Response Function # ============================= def get_response(user_input, history_context): """Generate response from ChatDoctor model""" # Build conversation from history history_text = [] for human, assistant in history_context: if human: history_text.append("Patient: " + human) if assistant: history_text.append("ChatDoctor: " + assistant) # Add current user input history_text.append("Patient: " + user_input) # Build full prompt including system instructions prompt = SYSTEM_PROMPT + "\n\nConversation so far:\n" + "\n".join(history_text) + "\nChatDoctor:" input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Define stop words and their token IDs stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) # Generate model response with torch.no_grad(): output_ids = generator( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode and clean response full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) response = full_output[len(prompt):].strip() # Remove any "Patient:" that might have slipped through for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]: if stop_word in response: response = response.split(stop_word)[0].strip() break # Free memory del input_ids, output_ids gc.collect() torch.cuda.empty_cache() return response # ============================= # Gradio Chat Function # ============================= def chat_function(message, history): """Gradio chat interface function""" if not message.strip(): return "" try: response = get_response(message, history) return response except Exception as e: return f"Error: {str(e)}" # ============================= # Custom CSS # ============================= custom_css = """ #header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px; } #header h1 { margin: 0; font-size: 2.5em; } #header p { margin: 10px 0 0 0; font-size: 1.1em; opacity: 0.9; } .disclaimer { background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 8px; padding: 15px; margin: 20px 0; color: #856404; } .disclaimer h3 { margin-top: 0; color: #856404; } footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } """ # ============================= # Gradio Interface # ============================= with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: # Header gr.HTML("""
Your AI-powered medical conversation partner
Important: This AI assistant is for informational and educational purposes only. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition. Never disregard professional medical advice or delay in seeking it because of something you have read here.
I'm here to discuss your health concerns. How can I assist you today?