|
|
|
|
|
import os |
|
|
import gc |
|
|
import torch |
|
|
from transformers import LlamaTokenizer, LlamaForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained" |
|
|
MAX_NEW_TOKENS = 200 |
|
|
TEMPERATURE = 0.5 |
|
|
TOP_K = 50 |
|
|
REPETITION_PENALTY = 1.1 |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Loading model from {MODEL_PATH} on {device}...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) |
|
|
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
|
MODEL_PATH, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
|
|
|
generator = model.generate |
|
|
print("✅ Model loaded successfully!\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
systemprompt = ("""You are ChatDoctor — an intelligent, empathetic medical AI assistant. |
|
|
Your role is to carefully gather medical information, reason clinically, |
|
|
and provide safe, evidence-based guidance. |
|
|
|
|
|
Follow these instructions strictly: |
|
|
1. When a patient describes their illness, DO NOT diagnose immediately. |
|
|
2. Ask relevant, targeted questions to collect all necessary details |
|
|
such as symptoms, duration, severity, lifestyle habits, medical history, |
|
|
medications, and any recent tests or changes. |
|
|
3. Once you have enough information for a preliminary diagnosis, clearly |
|
|
explain your reasoning and possible causes in simple medical language. |
|
|
4. Then, provide a clear and structured response that includes: |
|
|
- **Diagnosis:** probable or confirmed condition(s) |
|
|
- **Dietary Advice:** foods to include and avoid |
|
|
- **Lifestyle Advice:** exercise, sleep, stress, and other habits |
|
|
5. Be concise, empathetic, and professional at all times. |
|
|
6. Never switch roles or generate “Patient:” responses. Always remain as ChatDoctor. |
|
|
7. If symptoms suggest a serious or emergency condition, advise the patient |
|
|
to seek immediate medical attention.""") |
|
|
|
|
|
history = [systemprompt, "ChatDoctor: I am ChatDoctor, what medical questions do you have?"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_response(user_input): |
|
|
global history |
|
|
human_invitation = "Patient: " |
|
|
doctor_invitation = "ChatDoctor: " |
|
|
|
|
|
|
|
|
history.append(human_invitation + user_input) |
|
|
|
|
|
|
|
|
prompt = "\n".join(history) + "\n" + doctor_invitation |
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = generator( |
|
|
input_ids, |
|
|
max_new_tokens=MAX_NEW_TOKENS, |
|
|
do_sample=True, |
|
|
temperature=TEMPERATURE, |
|
|
top_k=TOP_K, |
|
|
repetition_penalty=REPETITION_PENALTY |
|
|
) |
|
|
|
|
|
|
|
|
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
response = full_output[len(prompt):].strip() |
|
|
|
|
|
|
|
|
if response.startswith("Patient:"): |
|
|
response = response[len("Patient:"):].strip() |
|
|
|
|
|
|
|
|
history.append(doctor_invitation + response) |
|
|
|
|
|
|
|
|
del input_ids, output_ids |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n=== ChatDoctor is ready! Type your questions. ===\n") |
|
|
while True: |
|
|
try: |
|
|
user_input = input("Patient: ").strip() |
|
|
if user_input.lower() in ["exit", "quit"]: |
|
|
print("Exiting ChatDoctor. Goodbye!") |
|
|
break |
|
|
response = get_response(user_input) |
|
|
print("ChatDoctor: " + response + "\n") |
|
|
except KeyboardInterrupt: |
|
|
print("\nExiting ChatDoctor. Goodbye!") |
|
|
break |
|
|
|