File size: 4,405 Bytes
0054032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373f237
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import gc
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList

# =============================
# Configuration
# =============================
MODEL_PATH = r"zl111/ChatDoctor"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.5
TOP_K = 50
REPETITION_PENALTY = 1.1

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {MODEL_PATH} on {device}...")

# =============================
# Load Tokenizer and Model
# =============================
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

generator = model.generate
print("✅ ChatDoctor model loaded successfully!\n")

# =============================
# Stopping Criteria
# =============================
class StopOnTokens(StoppingCriteria):
    def __init__(self, stop_ids):
        self.stop_ids = stop_ids

    def __call__(self, input_ids, scores, **kwargs):
        for stop_id_seq in self.stop_ids:
            if len(stop_id_seq) == 1:
                if input_ids[0][-1] == stop_id_seq[0]:
                    return True
            else:
                if len(input_ids[0]) >= len(stop_id_seq):
                    if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
                        return True
        return False

# =============================
# Chat History
# =============================
history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"]

# =============================
# Get Response Function
# =============================
def get_response(user_input):
    global history
    human_invitation = "Patient: "
    doctor_invitation = "ChatDoctor: "

    # Add user input to history
    history.append(human_invitation + user_input)

    # Build conversation prompt
    prompt = "\n".join(history) + "\n" + doctor_invitation
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    # Define stop words and their token IDs
    stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
    stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
    stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])

    # Generate model response
    with torch.no_grad():
        output_ids = generator(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=TEMPERATURE,
            top_k=TOP_K,
            repetition_penalty=REPETITION_PENALTY,
            stopping_criteria=stopping_criteria,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    # Decode and clean response
    full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    response = full_output[len(prompt):].strip()
    
    # Remove any "Patient:" that might have slipped through
    for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
        if stop_word in response:
            response = response.split(stop_word)[0].strip()
            break

    # Remove any leading/trailing punctuation artifacts
    response = response.strip()

    history.append(doctor_invitation + response)

    # Free memory
    del input_ids, output_ids
    gc.collect()
    torch.cuda.empty_cache()

    return response

# =============================
# Chat Loop
# =============================
if __name__ == "__main__":
    print("\n=== ChatDoctor is ready! ===")
    print("You (the human) = Patient ")
    print("AI = ChatDoctor")
    print("Type 'exit' or 'quit' to end the chat.\n")

    print("ChatDoctor: Hi there! How can I help you today?\n")

    while True:
        try:
            user_input = input("Patient: ").strip()
            if user_input.lower() in ["exit", "quit"]:
                print("ChatDoctor: Take care! Goodbye ")
                break

            if not user_input:
                continue

            response = get_response(user_input)
            print("ChatDoctor:", response, "\n")

        except KeyboardInterrupt:
            print("\nChatDoctor: Take care! Goodbye")
            break
        except Exception as e:
            print(f"Error: {e}")
            print("Please try again.\n")