|
|
|
|
|
import os
|
|
|
import gc
|
|
|
import torch
|
|
|
from transformers import LlamaTokenizer, LlamaForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
|
|
|
MAX_NEW_TOKENS = 200
|
|
|
TEMPERATURE = 0.5
|
|
|
TOP_K = 50
|
|
|
REPETITION_PENALTY = 1.1
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
print(f"Loading model from {MODEL_PATH} on {device}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
|
|
|
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
MODEL_PATH,
|
|
|
device_map="auto",
|
|
|
torch_dtype=torch.float16,
|
|
|
low_cpu_mem_usage=True
|
|
|
)
|
|
|
|
|
|
|
|
|
generator = model.generate
|
|
|
print("✅ Model loaded successfully!\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history = ["ChatDoctor: I am ChatDoctor, what medical questions do you have?"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_response(user_input):
|
|
|
global history
|
|
|
human_invitation = "Patient: "
|
|
|
doctor_invitation = "ChatDoctor: "
|
|
|
|
|
|
|
|
|
history.append(human_invitation + user_input)
|
|
|
|
|
|
|
|
|
prompt = "\n".join(history) + "\n" + doctor_invitation
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output_ids = generator(
|
|
|
input_ids,
|
|
|
max_new_tokens=MAX_NEW_TOKENS,
|
|
|
do_sample=True,
|
|
|
temperature=TEMPERATURE,
|
|
|
top_k=TOP_K,
|
|
|
repetition_penalty=REPETITION_PENALTY
|
|
|
)
|
|
|
|
|
|
|
|
|
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
response = full_output[len(prompt):].strip()
|
|
|
|
|
|
|
|
|
if response.startswith("Patient:"):
|
|
|
response = response[len("Patient:"):].strip()
|
|
|
|
|
|
|
|
|
history.append(doctor_invitation + response)
|
|
|
|
|
|
|
|
|
del input_ids, output_ids
|
|
|
gc.collect()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("\n=== ChatDoctor is ready! Type your questions. ===\n")
|
|
|
while True:
|
|
|
try:
|
|
|
user_input = input("Patient: ").strip()
|
|
|
if user_input.lower() in ["exit", "quit"]:
|
|
|
print("Exiting ChatDoctor. Goodbye!")
|
|
|
break
|
|
|
response = get_response(user_input)
|
|
|
print("ChatDoctor: " + response + "\n")
|
|
|
except KeyboardInterrupt:
|
|
|
print("\nExiting ChatDoctor. Goodbye!")
|
|
|
break
|
|
|
|