|
|
import os |
|
|
import gc |
|
|
import torch |
|
|
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList |
|
|
from huggingface_hub import login |
|
|
import os |
|
|
|
|
|
|
|
|
login(token=os.getenv("HUGGINGFACE_TOKEN")) |
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = r"zl111/ChatDoctor" |
|
|
MAX_NEW_TOKENS = 200 |
|
|
TEMPERATURE = 0.5 |
|
|
TOP_K = 50 |
|
|
REPETITION_PENALTY = 1.1 |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Loading model from {MODEL_PATH} on {device}...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) |
|
|
model = LlamaForCausalLM.from_pretrained( |
|
|
MODEL_PATH, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
generator = model.generate |
|
|
print("✅ ChatDoctor model loaded successfully!\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StopOnTokens(StoppingCriteria): |
|
|
def __init__(self, stop_ids): |
|
|
self.stop_ids = stop_ids |
|
|
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
for stop_id_seq in self.stop_ids: |
|
|
if len(stop_id_seq) == 1: |
|
|
if input_ids[0][-1] == stop_id_seq[0]: |
|
|
return True |
|
|
else: |
|
|
if len(input_ids[0]) >= len(stop_id_seq): |
|
|
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_response(user_input): |
|
|
global history |
|
|
human_invitation = "Patient: " |
|
|
doctor_invitation = "ChatDoctor: " |
|
|
|
|
|
|
|
|
history.append(human_invitation + user_input) |
|
|
|
|
|
|
|
|
prompt = "\n".join(history) + "\n" + doctor_invitation |
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
|
|
|
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] |
|
|
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] |
|
|
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = generator( |
|
|
input_ids, |
|
|
max_new_tokens=MAX_NEW_TOKENS, |
|
|
do_sample=True, |
|
|
temperature=TEMPERATURE, |
|
|
top_k=TOP_K, |
|
|
repetition_penalty=REPETITION_PENALTY, |
|
|
stopping_criteria=stopping_criteria, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
response = full_output[len(prompt):].strip() |
|
|
|
|
|
|
|
|
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]: |
|
|
if stop_word in response: |
|
|
response = response.split(stop_word)[0].strip() |
|
|
break |
|
|
|
|
|
|
|
|
response = response.strip() |
|
|
|
|
|
history.append(doctor_invitation + response) |
|
|
|
|
|
|
|
|
del input_ids, output_ids |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n=== ChatDoctor is ready! ===") |
|
|
print("You (the human) = Patient ") |
|
|
print("AI = ChatDoctor") |
|
|
print("Type 'exit' or 'quit' to end the chat.\n") |
|
|
|
|
|
print("ChatDoctor: Hi there! How can I help you today?\n") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("Patient: ").strip() |
|
|
if user_input.lower() in ["exit", "quit"]: |
|
|
print("ChatDoctor: Take care! Goodbye ") |
|
|
break |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
response = get_response(user_input) |
|
|
print("ChatDoctor:", response, "\n") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nChatDoctor: Take care! Goodbye") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
print("Please try again.\n") |