File size: 4,157 Bytes
429fe72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# chat.py
import os
import gc
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM

# =============================
# Configuration
# =============================
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.5
TOP_K = 50
REPETITION_PENALTY = 1.1

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {MODEL_PATH} on {device}...")

# =============================
# Load Tokenizer and Model
# =============================
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)

model = LlamaForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map="auto",          # automatically dispatch weights to GPU
    torch_dtype=torch.float16,  # half precision for faster inference
    low_cpu_mem_usage=True      # optimize CPU memory
)

# DO NOT call model.to(device) when using device_map="auto"
generator = model.generate
print("✅ Model loaded successfully!\n")

# =============================
# Chat History
# =============================
systemprompt = ("""You are ChatDoctor — an intelligent, empathetic medical AI assistant. 
Your role is to carefully gather medical information, reason clinically, 
and provide safe, evidence-based guidance.

Follow these instructions strictly:
1. When a patient describes their illness, DO NOT diagnose immediately.
2. Ask relevant, targeted questions to collect all necessary details 
   such as symptoms, duration, severity, lifestyle habits, medical history, 
   medications, and any recent tests or changes.
3. Once you have enough information for a preliminary diagnosis, clearly 
   explain your reasoning and possible causes in simple medical language.
4. Then, provide a clear and structured response that includes:
   - **Diagnosis:** probable or confirmed condition(s)
   - **Dietary Advice:** foods to include and avoid
   - **Lifestyle Advice:** exercise, sleep, stress, and other habits
5. Be concise, empathetic, and professional at all times.
6. Never switch roles or generate “Patient:” responses. Always remain as ChatDoctor.
7. If symptoms suggest a serious or emergency condition, advise the patient 
   to seek immediate medical attention.""")

history = [systemprompt, "ChatDoctor: I am ChatDoctor, what medical questions do you have?"]

# =============================
# Response Function
# =============================
def get_response(user_input):
    global history
    human_invitation = "Patient: "
    doctor_invitation = "ChatDoctor: "

    # Append user input
    history.append(human_invitation + user_input)

    # Build prompt
    prompt = "\n".join(history) + "\n" + doctor_invitation
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    # Generate response
    with torch.no_grad():
        output_ids = generator(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=TEMPERATURE,
            top_k=TOP_K,
            repetition_penalty=REPETITION_PENALTY
        )

    # Decode response
    full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    response = full_output[len(prompt):].strip()

    # Clean if the model repeats the patient prompt
    if response.startswith("Patient:"):
        response = response[len("Patient:"):].strip()

    # Append model response to history
    history.append(doctor_invitation + response)

    # Free memory
    del input_ids, output_ids
    gc.collect()
    torch.cuda.empty_cache()

    return response

# =============================
# CLI Chat
# =============================
if __name__ == "__main__":
    print("\n=== ChatDoctor is ready! Type your questions. ===\n")
    while True:
        try:
            user_input = input("Patient: ").strip()
            if user_input.lower() in ["exit", "quit"]:
                print("Exiting ChatDoctor. Goodbye!")
                break
            response = get_response(user_input)
            print("ChatDoctor: " + response + "\n")
        except KeyboardInterrupt:
            print("\nExiting ChatDoctor. Goodbye!")
            break