import os import gc import torch import gradio as gr from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList # ============================= # Configuration # ============================= MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained" MAX_NEW_TOKENS = 200 TEMPERATURE = 0.5 TOP_K = 50 REPETITION_PENALTY = 1.1 # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {MODEL_PATH} on {device}...") # ============================= # Load Tokenizer and Model # ============================= tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) model = LlamaForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) generator = model.generate print("ā ChatDoctor model loaded successfully!\n") # ============================= # Stopping Criteria # ============================= class StopOnTokens(StoppingCriteria): def __init__(self, stop_ids): self.stop_ids = stop_ids def __call__(self, input_ids, scores, **kwargs): for stop_id_seq in self.stop_ids: if len(stop_id_seq) == 1: if input_ids[0][-1] == stop_id_seq[0]: return True else: if len(input_ids[0]) >= len(stop_id_seq): if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: return True return False # ============================= # Chat History (Global) # ============================= conversation_history = [] # ============================= # Get Response Function # ============================= def get_response(user_input, history_context): """Generate response from ChatDoctor model""" human_invitation = "Patient: " doctor_invitation = "ChatDoctor: " # Build conversation from history history_text = [] for human, assistant in history_context: if human: history_text.append(human_invitation + human) if assistant: history_text.append(doctor_invitation + assistant) # Add current user input history_text.append(human_invitation + user_input) # Build conversation prompt prompt = "\n".join(history_text) + "\n" + doctor_invitation input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Define stop words and their token IDs stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) # Generate model response with torch.no_grad(): output_ids = generator( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode and clean response full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) response = full_output[len(prompt):].strip() # Remove any "Patient:" that might have slipped through for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]: if stop_word in response: response = response.split(stop_word)[0].strip() break response = response.strip() # Free memory del input_ids, output_ids gc.collect() torch.cuda.empty_cache() return response # ============================= # Gradio Chat Function # ============================= def chat_function(message, history): """Gradio chat interface function""" if not message.strip(): return "" try: response = get_response(message, history) return response except Exception as e: return f"Error: {str(e)}" # ============================= # Custom CSS # ============================= custom_css = """ #header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px; } #header h1 { margin: 0; font-size: 2.5em; } #header p { margin: 10px 0 0 0; font-size: 1.1em; opacity: 0.9; } .disclaimer { background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 8px; padding: 15px; margin: 20px 0; color: #856404; } .disclaimer h3 { margin-top: 0; color: #856404; } footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } """ # ============================= # Gradio Interface # ============================= with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: # Header gr.HTML("""
Your AI-powered medical conversation partner
Important: This AI assistant is for informational and educational purposes only. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition. Never disregard professional medical advice or delay in seeking it because of something you have read here.
I'm here to discuss your health concerns. How can I assist you today?