Muhammadidrees's picture
Upload 15 files
373f237 verified
raw
history blame
3.07 kB
# chat.py
import os
import gc
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
# =============================
# Configuration
# =============================
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.5
TOP_K = 50
REPETITION_PENALTY = 1.1
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {MODEL_PATH} on {device}...")
# =============================
# Load Tokenizer and Model
# =============================
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto", # automatically dispatch weights to GPU
torch_dtype=torch.float16, # half precision for faster inference
low_cpu_mem_usage=True # optimize CPU memory
)
# DO NOT call model.to(device) when using device_map="auto"
generator = model.generate
print("✅ Model loaded successfully!\n")
# =============================
# Chat History
# =============================
history = ["ChatDoctor: I am ChatDoctor, what medical questions do you have?"]
# =============================
# Response Function
# =============================
def get_response(user_input):
global history
human_invitation = "Patient: "
doctor_invitation = "ChatDoctor: "
# Append user input
history.append(human_invitation + user_input)
# Build prompt
prompt = "\n".join(history) + "\n" + doctor_invitation
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Generate response
with torch.no_grad():
output_ids = generator(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_k=TOP_K,
repetition_penalty=REPETITION_PENALTY
)
# Decode response
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
response = full_output[len(prompt):].strip()
# Clean if the model repeats the patient prompt
if response.startswith("Patient:"):
response = response[len("Patient:"):].strip()
# Append model response to history
history.append(doctor_invitation + response)
# Free memory
del input_ids, output_ids
gc.collect()
torch.cuda.empty_cache()
return response
# =============================
# CLI Chat
# =============================
if __name__ == "__main__":
print("\n=== ChatDoctor is ready! Type your questions. ===\n")
while True:
try:
user_input = input("Patient: ").strip()
if user_input.lower() in ["exit", "quit"]:
print("Exiting ChatDoctor. Goodbye!")
break
response = get_response(user_input)
print("ChatDoctor: " + response + "\n")
except KeyboardInterrupt:
print("\nExiting ChatDoctor. Goodbye!")
break