File size: 4,405 Bytes
0054032 373f237 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import gc
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
# =============================
# Configuration
# =============================
MODEL_PATH = r"zl111/ChatDoctor"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.5
TOP_K = 50
REPETITION_PENALTY = 1.1
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {MODEL_PATH} on {device}...")
# =============================
# Load Tokenizer and Model
# =============================
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
generator = model.generate
print("✅ ChatDoctor model loaded successfully!\n")
# =============================
# Stopping Criteria
# =============================
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_ids):
self.stop_ids = stop_ids
def __call__(self, input_ids, scores, **kwargs):
for stop_id_seq in self.stop_ids:
if len(stop_id_seq) == 1:
if input_ids[0][-1] == stop_id_seq[0]:
return True
else:
if len(input_ids[0]) >= len(stop_id_seq):
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
return True
return False
# =============================
# Chat History
# =============================
history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"]
# =============================
# Get Response Function
# =============================
def get_response(user_input):
global history
human_invitation = "Patient: "
doctor_invitation = "ChatDoctor: "
# Add user input to history
history.append(human_invitation + user_input)
# Build conversation prompt
prompt = "\n".join(history) + "\n" + doctor_invitation
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Define stop words and their token IDs
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
# Generate model response
with torch.no_grad():
output_ids = generator(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_k=TOP_K,
repetition_penalty=REPETITION_PENALTY,
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode and clean response
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
response = full_output[len(prompt):].strip()
# Remove any "Patient:" that might have slipped through
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
if stop_word in response:
response = response.split(stop_word)[0].strip()
break
# Remove any leading/trailing punctuation artifacts
response = response.strip()
history.append(doctor_invitation + response)
# Free memory
del input_ids, output_ids
gc.collect()
torch.cuda.empty_cache()
return response
# =============================
# Chat Loop
# =============================
if __name__ == "__main__":
print("\n=== ChatDoctor is ready! ===")
print("You (the human) = Patient ")
print("AI = ChatDoctor")
print("Type 'exit' or 'quit' to end the chat.\n")
print("ChatDoctor: Hi there! How can I help you today?\n")
while True:
try:
user_input = input("Patient: ").strip()
if user_input.lower() in ["exit", "quit"]:
print("ChatDoctor: Take care! Goodbye ")
break
if not user_input:
continue
response = get_response(user_input)
print("ChatDoctor:", response, "\n")
except KeyboardInterrupt:
print("\nChatDoctor: Take care! Goodbye")
break
except Exception as e:
print(f"Error: {e}")
print("Please try again.\n") |