# chat.py import os import gc import torch from transformers import LlamaTokenizer, LlamaForCausalLM # ============================= # Configuration # ============================= MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained" MAX_NEW_TOKENS = 200 TEMPERATURE = 0.5 TOP_K = 50 REPETITION_PENALTY = 1.1 # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {MODEL_PATH} on {device}...") # ============================= # Load Tokenizer and Model # ============================= tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) model = LlamaForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", # automatically dispatch weights to GPU torch_dtype=torch.float16, # half precision for faster inference low_cpu_mem_usage=True # optimize CPU memory ) # DO NOT call model.to(device) when using device_map="auto" generator = model.generate print("✅ Model loaded successfully!\n") # ============================= # Chat History # ============================= history = ["ChatDoctor: I am ChatDoctor, what medical questions do you have?"] # ============================= # Response Function # ============================= def get_response(user_input): global history human_invitation = "Patient: " doctor_invitation = "ChatDoctor: " # Append user input history.append(human_invitation + user_input) # Build prompt prompt = "\n".join(history) + "\n" + doctor_invitation input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Generate response with torch.no_grad(): output_ids = generator( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY ) # Decode response full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) response = full_output[len(prompt):].strip() # Clean if the model repeats the patient prompt if response.startswith("Patient:"): response = response[len("Patient:"):].strip() # Append model response to history history.append(doctor_invitation + response) # Free memory del input_ids, output_ids gc.collect() torch.cuda.empty_cache() return response # ============================= # CLI Chat # ============================= if __name__ == "__main__": print("\n=== ChatDoctor is ready! Type your questions. ===\n") while True: try: user_input = input("Patient: ").strip() if user_input.lower() in ["exit", "quit"]: print("Exiting ChatDoctor. Goodbye!") break response = get_response(user_input) print("ChatDoctor: " + response + "\n") except KeyboardInterrupt: print("\nExiting ChatDoctor. Goodbye!") break