import os import gc import torch from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList from huggingface_hub import login import os # Login using the token stored in repository secrets login(token=os.getenv("HUGGINGFACE_TOKEN")) # ============================= # Configuration # ============================= MODEL_PATH = r"zl111/ChatDoctor" MAX_NEW_TOKENS = 200 TEMPERATURE = 0.5 TOP_K = 50 REPETITION_PENALTY = 1.1 # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {MODEL_PATH} on {device}...") # ============================= # Load Tokenizer and Model # ============================= tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) model = LlamaForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) generator = model.generate print("✅ ChatDoctor model loaded successfully!\n") # ============================= # Stopping Criteria # ============================= class StopOnTokens(StoppingCriteria): def __init__(self, stop_ids): self.stop_ids = stop_ids def __call__(self, input_ids, scores, **kwargs): for stop_id_seq in self.stop_ids: if len(stop_id_seq) == 1: if input_ids[0][-1] == stop_id_seq[0]: return True else: if len(input_ids[0]) >= len(stop_id_seq): if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: return True return False # ============================= # Chat History # ============================= history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"] # ============================= # Get Response Function # ============================= def get_response(user_input): global history human_invitation = "Patient: " doctor_invitation = "ChatDoctor: " # Add user input to history history.append(human_invitation + user_input) # Build conversation prompt prompt = "\n".join(history) + "\n" + doctor_invitation input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Define stop words and their token IDs stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) # Generate model response with torch.no_grad(): output_ids = generator( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode and clean response full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) response = full_output[len(prompt):].strip() # Remove any "Patient:" that might have slipped through for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]: if stop_word in response: response = response.split(stop_word)[0].strip() break # Remove any leading/trailing punctuation artifacts response = response.strip() history.append(doctor_invitation + response) # Free memory del input_ids, output_ids gc.collect() torch.cuda.empty_cache() return response # ============================= # Chat Loop # ============================= if __name__ == "__main__": print("\n=== ChatDoctor is ready! ===") print("You (the human) = Patient ") print("AI = ChatDoctor") print("Type 'exit' or 'quit' to end the chat.\n") print("ChatDoctor: Hi there! How can I help you today?\n") while True: try: user_input = input("Patient: ").strip() if user_input.lower() in ["exit", "quit"]: print("ChatDoctor: Take care! Goodbye ") break if not user_input: continue response = get_response(user_input) print("ChatDoctor:", response, "\n") except KeyboardInterrupt: print("\nChatDoctor: Take care! Goodbye") break except Exception as e: print(f"Error: {e}") print("Please try again.\n")