Muhammadidrees commited on
Commit
0054032
·
verified ·
1 Parent(s): 8b28e5d

Update Coreectcodewithoutfronted.py

Browse files
Files changed (1) hide show
  1. Coreectcodewithoutfronted.py +140 -140
Coreectcodewithoutfronted.py CHANGED
@@ -1,141 +1,141 @@
1
- import os
2
- import gc
3
- import torch
4
- from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
5
-
6
- # =============================
7
- # Configuration
8
- # =============================
9
- MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
10
- MAX_NEW_TOKENS = 200
11
- TEMPERATURE = 0.5
12
- TOP_K = 50
13
- REPETITION_PENALTY = 1.1
14
-
15
- # Detect device
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
- print(f"Loading model from {MODEL_PATH} on {device}...")
18
-
19
- # =============================
20
- # Load Tokenizer and Model
21
- # =============================
22
- tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
23
- model = LlamaForCausalLM.from_pretrained(
24
- MODEL_PATH,
25
- device_map="auto",
26
- torch_dtype=torch.float16,
27
- low_cpu_mem_usage=True
28
- )
29
-
30
- generator = model.generate
31
- print("✅ ChatDoctor model loaded successfully!\n")
32
-
33
- # =============================
34
- # Stopping Criteria
35
- # =============================
36
- class StopOnTokens(StoppingCriteria):
37
- def __init__(self, stop_ids):
38
- self.stop_ids = stop_ids
39
-
40
- def __call__(self, input_ids, scores, **kwargs):
41
- for stop_id_seq in self.stop_ids:
42
- if len(stop_id_seq) == 1:
43
- if input_ids[0][-1] == stop_id_seq[0]:
44
- return True
45
- else:
46
- if len(input_ids[0]) >= len(stop_id_seq):
47
- if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
48
- return True
49
- return False
50
-
51
- # =============================
52
- # Chat History
53
- # =============================
54
- history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"]
55
-
56
- # =============================
57
- # Get Response Function
58
- # =============================
59
- def get_response(user_input):
60
- global history
61
- human_invitation = "Patient: "
62
- doctor_invitation = "ChatDoctor: "
63
-
64
- # Add user input to history
65
- history.append(human_invitation + user_input)
66
-
67
- # Build conversation prompt
68
- prompt = "\n".join(history) + "\n" + doctor_invitation
69
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
70
-
71
- # Define stop words and their token IDs
72
- stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
73
- stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
74
- stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
75
-
76
- # Generate model response
77
- with torch.no_grad():
78
- output_ids = generator(
79
- input_ids,
80
- max_new_tokens=MAX_NEW_TOKENS,
81
- do_sample=True,
82
- temperature=TEMPERATURE,
83
- top_k=TOP_K,
84
- repetition_penalty=REPETITION_PENALTY,
85
- stopping_criteria=stopping_criteria,
86
- pad_token_id=tokenizer.eos_token_id,
87
- eos_token_id=tokenizer.eos_token_id
88
- )
89
-
90
- # Decode and clean response
91
- full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
92
- response = full_output[len(prompt):].strip()
93
-
94
- # Remove any "Patient:" that might have slipped through
95
- for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
96
- if stop_word in response:
97
- response = response.split(stop_word)[0].strip()
98
- break
99
-
100
- # Remove any leading/trailing punctuation artifacts
101
- response = response.strip()
102
-
103
- history.append(doctor_invitation + response)
104
-
105
- # Free memory
106
- del input_ids, output_ids
107
- gc.collect()
108
- torch.cuda.empty_cache()
109
-
110
- return response
111
-
112
- # =============================
113
- # Chat Loop
114
- # =============================
115
- if __name__ == "__main__":
116
- print("\n=== ChatDoctor is ready! ===")
117
- print("You (the human) = Patient ")
118
- print("AI = ChatDoctor")
119
- print("Type 'exit' or 'quit' to end the chat.\n")
120
-
121
- print("ChatDoctor: Hi there! How can I help you today?\n")
122
-
123
- while True:
124
- try:
125
- user_input = input("Patient: ").strip()
126
- if user_input.lower() in ["exit", "quit"]:
127
- print("ChatDoctor: Take care! Goodbye ")
128
- break
129
-
130
- if not user_input:
131
- continue
132
-
133
- response = get_response(user_input)
134
- print("ChatDoctor:", response, "\n")
135
-
136
- except KeyboardInterrupt:
137
- print("\nChatDoctor: Take care! Goodbye")
138
- break
139
- except Exception as e:
140
- print(f"Error: {e}")
141
  print("Please try again.\n")
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
5
+
6
+ # =============================
7
+ # Configuration
8
+ # =============================
9
+ MODEL_PATH = r"zl111/ChatDoctor"
10
+ MAX_NEW_TOKENS = 200
11
+ TEMPERATURE = 0.5
12
+ TOP_K = 50
13
+ REPETITION_PENALTY = 1.1
14
+
15
+ # Detect device
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"Loading model from {MODEL_PATH} on {device}...")
18
+
19
+ # =============================
20
+ # Load Tokenizer and Model
21
+ # =============================
22
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
23
+ model = LlamaForCausalLM.from_pretrained(
24
+ MODEL_PATH,
25
+ device_map="auto",
26
+ torch_dtype=torch.float16,
27
+ low_cpu_mem_usage=True
28
+ )
29
+
30
+ generator = model.generate
31
+ print("✅ ChatDoctor model loaded successfully!\n")
32
+
33
+ # =============================
34
+ # Stopping Criteria
35
+ # =============================
36
+ class StopOnTokens(StoppingCriteria):
37
+ def __init__(self, stop_ids):
38
+ self.stop_ids = stop_ids
39
+
40
+ def __call__(self, input_ids, scores, **kwargs):
41
+ for stop_id_seq in self.stop_ids:
42
+ if len(stop_id_seq) == 1:
43
+ if input_ids[0][-1] == stop_id_seq[0]:
44
+ return True
45
+ else:
46
+ if len(input_ids[0]) >= len(stop_id_seq):
47
+ if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
48
+ return True
49
+ return False
50
+
51
+ # =============================
52
+ # Chat History
53
+ # =============================
54
+ history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"]
55
+
56
+ # =============================
57
+ # Get Response Function
58
+ # =============================
59
+ def get_response(user_input):
60
+ global history
61
+ human_invitation = "Patient: "
62
+ doctor_invitation = "ChatDoctor: "
63
+
64
+ # Add user input to history
65
+ history.append(human_invitation + user_input)
66
+
67
+ # Build conversation prompt
68
+ prompt = "\n".join(history) + "\n" + doctor_invitation
69
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
70
+
71
+ # Define stop words and their token IDs
72
+ stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
73
+ stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
74
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
75
+
76
+ # Generate model response
77
+ with torch.no_grad():
78
+ output_ids = generator(
79
+ input_ids,
80
+ max_new_tokens=MAX_NEW_TOKENS,
81
+ do_sample=True,
82
+ temperature=TEMPERATURE,
83
+ top_k=TOP_K,
84
+ repetition_penalty=REPETITION_PENALTY,
85
+ stopping_criteria=stopping_criteria,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ eos_token_id=tokenizer.eos_token_id
88
+ )
89
+
90
+ # Decode and clean response
91
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
92
+ response = full_output[len(prompt):].strip()
93
+
94
+ # Remove any "Patient:" that might have slipped through
95
+ for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
96
+ if stop_word in response:
97
+ response = response.split(stop_word)[0].strip()
98
+ break
99
+
100
+ # Remove any leading/trailing punctuation artifacts
101
+ response = response.strip()
102
+
103
+ history.append(doctor_invitation + response)
104
+
105
+ # Free memory
106
+ del input_ids, output_ids
107
+ gc.collect()
108
+ torch.cuda.empty_cache()
109
+
110
+ return response
111
+
112
+ # =============================
113
+ # Chat Loop
114
+ # =============================
115
+ if __name__ == "__main__":
116
+ print("\n=== ChatDoctor is ready! ===")
117
+ print("You (the human) = Patient ")
118
+ print("AI = ChatDoctor")
119
+ print("Type 'exit' or 'quit' to end the chat.\n")
120
+
121
+ print("ChatDoctor: Hi there! How can I help you today?\n")
122
+
123
+ while True:
124
+ try:
125
+ user_input = input("Patient: ").strip()
126
+ if user_input.lower() in ["exit", "quit"]:
127
+ print("ChatDoctor: Take care! Goodbye ")
128
+ break
129
+
130
+ if not user_input:
131
+ continue
132
+
133
+ response = get_response(user_input)
134
+ print("ChatDoctor:", response, "\n")
135
+
136
+ except KeyboardInterrupt:
137
+ print("\nChatDoctor: Take care! Goodbye")
138
+ break
139
+ except Exception as e:
140
+ print(f"Error: {e}")
141
  print("Please try again.\n")