import os import gc import torch import gradio as gr from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList # ============================= # Configuration # ============================= MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained" MAX_NEW_TOKENS = 200 TEMPERATURE = 0.5 TOP_K = 50 REPETITION_PENALTY = 1.1 # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {MODEL_PATH} on {device}...") # ============================= # Load Tokenizer and Model # ============================= tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) model = LlamaForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) generator = model.generate print("āœ… ChatDoctor model loaded successfully!\n") # ============================= # Stopping Criteria # ============================= class StopOnTokens(StoppingCriteria): def __init__(self, stop_ids): self.stop_ids = stop_ids def __call__(self, input_ids, scores, **kwargs): for stop_id_seq in self.stop_ids: if len(stop_id_seq) == 1: if input_ids[0][-1] == stop_id_seq[0]: return True else: if len(input_ids[0]) >= len(stop_id_seq): if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: return True return False # ============================= # Chat History (Global) # ============================= conversation_history = [] # ============================= # Get Response Function # ============================= def get_response(user_input, history_context): """Generate response from ChatDoctor model""" human_invitation = "Patient: " doctor_invitation = "ChatDoctor: " # Build conversation from history history_text = [] for human, assistant in history_context: if human: history_text.append(human_invitation + human) if assistant: history_text.append(doctor_invitation + assistant) # Add current user input history_text.append(human_invitation + user_input) # Build conversation prompt prompt = "\n".join(history_text) + "\n" + doctor_invitation input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Define stop words and their token IDs stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) # Generate model response with torch.no_grad(): output_ids = generator( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode and clean response full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) response = full_output[len(prompt):].strip() # Remove any "Patient:" that might have slipped through for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]: if stop_word in response: response = response.split(stop_word)[0].strip() break response = response.strip() # Free memory del input_ids, output_ids gc.collect() torch.cuda.empty_cache() return response # ============================= # Gradio Chat Function # ============================= def chat_function(message, history): """Gradio chat interface function""" if not message.strip(): return "" try: response = get_response(message, history) return response except Exception as e: return f"Error: {str(e)}" # ============================= # Custom CSS # ============================= custom_css = """ #header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px; } #header h1 { margin: 0; font-size: 2.5em; } #header p { margin: 10px 0 0 0; font-size: 1.1em; opacity: 0.9; } .disclaimer { background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 8px; padding: 15px; margin: 20px 0; color: #856404; } .disclaimer h3 { margin-top: 0; color: #856404; } footer { text-align: center; margin-top: 30px; color: #666; font-size: 0.9em; } """ # ============================= # Gradio Interface # ============================= with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: # Header gr.HTML(""" """) # Disclaimer gr.HTML("""

āš ļø Medical Disclaimer

Important: This AI assistant is for informational and educational purposes only. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition. Never disregard professional medical advice or delay in seeking it because of something you have read here.

""") # Chatbot Interface chatbot = gr.Chatbot( height=500, placeholder="

šŸ‘‹ Welcome to ChatDoctor!

I'm here to discuss your health concerns. How can I assist you today?

", show_label=False, avatar_images=(None, "šŸ¤–"), ) with gr.Row(): msg = gr.Textbox( placeholder="Type your message here... (e.g., 'I have a headache')", show_label=False, scale=9, container=False ) submit_btn = gr.Button("Send šŸ“¤", scale=1, variant="primary") with gr.Row(): clear_btn = gr.Button("šŸ—‘ļø Clear Chat", scale=1) retry_btn = gr.Button("šŸ”„ Retry", scale=1) # Examples gr.Examples( examples=[ "I have a persistent headache for 3 days. What should I do?", "What are the symptoms of diabetes?", "How can I improve my sleep quality?", "I have a fever and sore throat. Should I be concerned?", "What are some natural ways to reduce stress?", ], inputs=msg, label="šŸ’” Example Questions" ) # Settings (collapsed by default) with gr.Accordion("āš™ļø Advanced Settings", open=False): temperature_slider = gr.Slider( minimum=0.1, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature (Creativity)", info="Higher values make responses more creative but less focused" ) max_tokens_slider = gr.Slider( minimum=50, maximum=500, value=MAX_NEW_TOKENS, step=50, label="Max Response Length", info="Maximum number of tokens in response" ) top_k_slider = gr.Slider( minimum=1, maximum=100, value=TOP_K, step=1, label="Top K", info="Limits vocabulary selection" ) # Footer gr.HTML(""" """) # Event handlers def user_message(user_msg, history): return "", history + [[user_msg, None]] def bot_response(history, temp, max_tok, top_k_val): global TEMPERATURE, MAX_NEW_TOKENS, TOP_K TEMPERATURE = temp MAX_NEW_TOKENS = int(max_tok) TOP_K = int(top_k_val) user_msg = history[-1][0] bot_msg = chat_function(user_msg, history[:-1]) history[-1][1] = bot_msg return history # Connect events msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot ) submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot ) clear_btn.click(lambda: None, None, chatbot, queue=False) def retry_last(): return None retry_btn.click(retry_last, None, chatbot, queue=False) # ============================= # Launch Interface # ============================= if __name__ == "__main__": print("\nšŸš€ Launching ChatDoctor Gradio Interface...") demo.queue() demo.launch( server_name="0.0.0.0", # Accessible from network server_port=7860, share=False, # Set to True to create public link show_error=True )