Muhammadidrees commited on
Commit
ab6fe89
Β·
verified Β·
1 Parent(s): a0a8be5

Rename frontend.py to app.py

Browse files
Files changed (1) hide show
  1. frontend.py β†’ app.py +312 -312
frontend.py β†’ app.py RENAMED
@@ -1,313 +1,313 @@
1
- import os
2
- import gc
3
- import torch
4
- import gradio as gr
5
- from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
-
7
- # =============================
8
- # Configuration
9
- # =============================
10
- MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
11
- MAX_NEW_TOKENS = 200
12
- TEMPERATURE = 0.5
13
- TOP_K = 50
14
- REPETITION_PENALTY = 1.1
15
-
16
- # Detect device
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"Loading model from {MODEL_PATH} on {device}...")
19
-
20
- # =============================
21
- # Load Tokenizer and Model
22
- # =============================
23
- tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
- model = LlamaForCausalLM.from_pretrained(
25
- MODEL_PATH,
26
- device_map="auto",
27
- torch_dtype=torch.float16,
28
- low_cpu_mem_usage=True
29
- )
30
-
31
- generator = model.generate
32
- print("βœ… ChatDoctor model loaded successfully!\n")
33
-
34
- # =============================
35
- # Stopping Criteria
36
- # =============================
37
- class StopOnTokens(StoppingCriteria):
38
- def __init__(self, stop_ids):
39
- self.stop_ids = stop_ids
40
-
41
- def __call__(self, input_ids, scores, **kwargs):
42
- for stop_id_seq in self.stop_ids:
43
- if len(stop_id_seq) == 1:
44
- if input_ids[0][-1] == stop_id_seq[0]:
45
- return True
46
- else:
47
- if len(input_ids[0]) >= len(stop_id_seq):
48
- if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
49
- return True
50
- return False
51
-
52
- # =============================
53
- # Chat History (Global)
54
- # =============================
55
- conversation_history = []
56
-
57
- # =============================
58
- # Get Response Function
59
- # =============================
60
- def get_response(user_input, history_context):
61
- """Generate response from ChatDoctor model"""
62
- human_invitation = "Patient: "
63
- doctor_invitation = "ChatDoctor: "
64
-
65
- # Build conversation from history
66
- history_text = []
67
- for human, assistant in history_context:
68
- if human:
69
- history_text.append(human_invitation + human)
70
- if assistant:
71
- history_text.append(doctor_invitation + assistant)
72
-
73
- # Add current user input
74
- history_text.append(human_invitation + user_input)
75
-
76
- # Build conversation prompt
77
- prompt = "\n".join(history_text) + "\n" + doctor_invitation
78
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
79
-
80
- # Define stop words and their token IDs
81
- stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
82
- stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
83
- stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
84
-
85
- # Generate model response
86
- with torch.no_grad():
87
- output_ids = generator(
88
- input_ids,
89
- max_new_tokens=MAX_NEW_TOKENS,
90
- do_sample=True,
91
- temperature=TEMPERATURE,
92
- top_k=TOP_K,
93
- repetition_penalty=REPETITION_PENALTY,
94
- stopping_criteria=stopping_criteria,
95
- pad_token_id=tokenizer.eos_token_id,
96
- eos_token_id=tokenizer.eos_token_id
97
- )
98
-
99
- # Decode and clean response
100
- full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
101
- response = full_output[len(prompt):].strip()
102
-
103
- # Remove any "Patient:" that might have slipped through
104
- for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
105
- if stop_word in response:
106
- response = response.split(stop_word)[0].strip()
107
- break
108
-
109
- response = response.strip()
110
-
111
- # Free memory
112
- del input_ids, output_ids
113
- gc.collect()
114
- torch.cuda.empty_cache()
115
-
116
- return response
117
-
118
- # =============================
119
- # Gradio Chat Function
120
- # =============================
121
- def chat_function(message, history):
122
- """Gradio chat interface function"""
123
- if not message.strip():
124
- return ""
125
-
126
- try:
127
- response = get_response(message, history)
128
- return response
129
- except Exception as e:
130
- return f"Error: {str(e)}"
131
-
132
- # =============================
133
- # Custom CSS
134
- # =============================
135
- custom_css = """
136
- #header {
137
- text-align: center;
138
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
139
- color: white;
140
- padding: 20px;
141
- border-radius: 10px;
142
- margin-bottom: 20px;
143
- }
144
-
145
- #header h1 {
146
- margin: 0;
147
- font-size: 2.5em;
148
- }
149
-
150
- #header p {
151
- margin: 10px 0 0 0;
152
- font-size: 1.1em;
153
- opacity: 0.9;
154
- }
155
-
156
- .disclaimer {
157
- background-color: #fff3cd;
158
- border: 1px solid #ffc107;
159
- border-radius: 8px;
160
- padding: 15px;
161
- margin: 20px 0;
162
- color: #856404;
163
- }
164
-
165
- .disclaimer h3 {
166
- margin-top: 0;
167
- color: #856404;
168
- }
169
-
170
- footer {
171
- text-align: center;
172
- margin-top: 30px;
173
- color: #666;
174
- font-size: 0.9em;
175
- }
176
- """
177
-
178
- # =============================
179
- # Gradio Interface
180
- # =============================
181
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
182
- # Header
183
- gr.HTML("""
184
- <div id="header">
185
- <h1>🩺 ChatDoctor AI Assistant</h1>
186
- <p>Your AI-powered medical conversation partner</p>
187
- </div>
188
- """)
189
-
190
- # Disclaimer
191
- gr.HTML("""
192
- <div class="disclaimer">
193
- <h3>⚠️ Medical Disclaimer</h3>
194
- <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
195
- It is NOT a substitute for professional medical advice, diagnosis, or treatment.
196
- Always seek the advice of your physician or other qualified health provider with any questions
197
- you may have regarding a medical condition. Never disregard professional medical advice or
198
- delay in seeking it because of something you have read here.</p>
199
- </div>
200
- """)
201
-
202
- # Chatbot Interface
203
- chatbot = gr.Chatbot(
204
- height=500,
205
- placeholder="<div style='text-align: center; padding: 40px;'><h3>πŸ‘‹ Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. How can I assist you today?</p></div>",
206
- show_label=False,
207
- avatar_images=(None, "πŸ€–"),
208
- )
209
-
210
- with gr.Row():
211
- msg = gr.Textbox(
212
- placeholder="Type your message here... (e.g., 'I have a headache')",
213
- show_label=False,
214
- scale=9,
215
- container=False
216
- )
217
- submit_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
218
-
219
- with gr.Row():
220
- clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
221
- retry_btn = gr.Button("πŸ”„ Retry", scale=1)
222
-
223
- # Examples
224
- gr.Examples(
225
- examples=[
226
- "I have a persistent headache for 3 days. What should I do?",
227
- "What are the symptoms of diabetes?",
228
- "How can I improve my sleep quality?",
229
- "I have a fever and sore throat. Should I be concerned?",
230
- "What are some natural ways to reduce stress?",
231
- ],
232
- inputs=msg,
233
- label="πŸ’‘ Example Questions"
234
- )
235
-
236
- # Settings (collapsed by default)
237
- with gr.Accordion("βš™οΈ Advanced Settings", open=False):
238
- temperature_slider = gr.Slider(
239
- minimum=0.1,
240
- maximum=1.0,
241
- value=TEMPERATURE,
242
- step=0.1,
243
- label="Temperature (Creativity)",
244
- info="Higher values make responses more creative but less focused"
245
- )
246
- max_tokens_slider = gr.Slider(
247
- minimum=50,
248
- maximum=500,
249
- value=MAX_NEW_TOKENS,
250
- step=50,
251
- label="Max Response Length",
252
- info="Maximum number of tokens in response"
253
- )
254
- top_k_slider = gr.Slider(
255
- minimum=1,
256
- maximum=100,
257
- value=TOP_K,
258
- step=1,
259
- label="Top K",
260
- info="Limits vocabulary selection"
261
- )
262
-
263
- # Footer
264
- gr.HTML("""
265
- <footer>
266
- <p>Powered by ChatDoctor Model | Built with Gradio</p>
267
- <p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
268
- </footer>
269
- """)
270
-
271
- # Event handlers
272
- def user_message(user_msg, history):
273
- return "", history + [[user_msg, None]]
274
-
275
- def bot_response(history, temp, max_tok, top_k_val):
276
- global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
277
- TEMPERATURE = temp
278
- MAX_NEW_TOKENS = int(max_tok)
279
- TOP_K = int(top_k_val)
280
-
281
- user_msg = history[-1][0]
282
- bot_msg = chat_function(user_msg, history[:-1])
283
- history[-1][1] = bot_msg
284
- return history
285
-
286
- # Connect events
287
- msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
288
- bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
289
- )
290
-
291
- submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
292
- bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
293
- )
294
-
295
- clear_btn.click(lambda: None, None, chatbot, queue=False)
296
-
297
- def retry_last():
298
- return None
299
-
300
- retry_btn.click(retry_last, None, chatbot, queue=False)
301
-
302
- # =============================
303
- # Launch Interface
304
- # =============================
305
- if __name__ == "__main__":
306
- print("\nπŸš€ Launching ChatDoctor Gradio Interface...")
307
- demo.queue()
308
- demo.launch(
309
- server_name="0.0.0.0", # Accessible from network
310
- server_port=7860,
311
- share=False, # Set to True to create public link
312
- show_error=True
313
  )
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
+
7
+ # =============================
8
+ # Configuration
9
+ # =============================
10
+ MODEL_PATH = r"Muhammadidrees/JayConverstionalModel"
11
+ MAX_NEW_TOKENS = 200
12
+ TEMPERATURE = 0.5
13
+ TOP_K = 50
14
+ REPETITION_PENALTY = 1.1
15
+
16
+ # Detect device
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Loading model from {MODEL_PATH} on {device}...")
19
+
20
+ # =============================
21
+ # Load Tokenizer and Model
22
+ # =============================
23
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
+ model = LlamaForCausalLM.from_pretrained(
25
+ MODEL_PATH,
26
+ device_map="auto",
27
+ torch_dtype=torch.float16,
28
+ low_cpu_mem_usage=True
29
+ )
30
+
31
+ generator = model.generate
32
+ print("βœ… ChatDoctor model loaded successfully!\n")
33
+
34
+ # =============================
35
+ # Stopping Criteria
36
+ # =============================
37
+ class StopOnTokens(StoppingCriteria):
38
+ def __init__(self, stop_ids):
39
+ self.stop_ids = stop_ids
40
+
41
+ def __call__(self, input_ids, scores, **kwargs):
42
+ for stop_id_seq in self.stop_ids:
43
+ if len(stop_id_seq) == 1:
44
+ if input_ids[0][-1] == stop_id_seq[0]:
45
+ return True
46
+ else:
47
+ if len(input_ids[0]) >= len(stop_id_seq):
48
+ if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
49
+ return True
50
+ return False
51
+
52
+ # =============================
53
+ # Chat History (Global)
54
+ # =============================
55
+ conversation_history = []
56
+
57
+ # =============================
58
+ # Get Response Function
59
+ # =============================
60
+ def get_response(user_input, history_context):
61
+ """Generate response from ChatDoctor model"""
62
+ human_invitation = "Patient: "
63
+ doctor_invitation = "ChatDoctor: "
64
+
65
+ # Build conversation from history
66
+ history_text = []
67
+ for human, assistant in history_context:
68
+ if human:
69
+ history_text.append(human_invitation + human)
70
+ if assistant:
71
+ history_text.append(doctor_invitation + assistant)
72
+
73
+ # Add current user input
74
+ history_text.append(human_invitation + user_input)
75
+
76
+ # Build conversation prompt
77
+ prompt = "\n".join(history_text) + "\n" + doctor_invitation
78
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
79
+
80
+ # Define stop words and their token IDs
81
+ stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
82
+ stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
83
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
84
+
85
+ # Generate model response
86
+ with torch.no_grad():
87
+ output_ids = generator(
88
+ input_ids,
89
+ max_new_tokens=MAX_NEW_TOKENS,
90
+ do_sample=True,
91
+ temperature=TEMPERATURE,
92
+ top_k=TOP_K,
93
+ repetition_penalty=REPETITION_PENALTY,
94
+ stopping_criteria=stopping_criteria,
95
+ pad_token_id=tokenizer.eos_token_id,
96
+ eos_token_id=tokenizer.eos_token_id
97
+ )
98
+
99
+ # Decode and clean response
100
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
101
+ response = full_output[len(prompt):].strip()
102
+
103
+ # Remove any "Patient:" that might have slipped through
104
+ for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
105
+ if stop_word in response:
106
+ response = response.split(stop_word)[0].strip()
107
+ break
108
+
109
+ response = response.strip()
110
+
111
+ # Free memory
112
+ del input_ids, output_ids
113
+ gc.collect()
114
+ torch.cuda.empty_cache()
115
+
116
+ return response
117
+
118
+ # =============================
119
+ # Gradio Chat Function
120
+ # =============================
121
+ def chat_function(message, history):
122
+ """Gradio chat interface function"""
123
+ if not message.strip():
124
+ return ""
125
+
126
+ try:
127
+ response = get_response(message, history)
128
+ return response
129
+ except Exception as e:
130
+ return f"Error: {str(e)}"
131
+
132
+ # =============================
133
+ # Custom CSS
134
+ # =============================
135
+ custom_css = """
136
+ #header {
137
+ text-align: center;
138
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
139
+ color: white;
140
+ padding: 20px;
141
+ border-radius: 10px;
142
+ margin-bottom: 20px;
143
+ }
144
+
145
+ #header h1 {
146
+ margin: 0;
147
+ font-size: 2.5em;
148
+ }
149
+
150
+ #header p {
151
+ margin: 10px 0 0 0;
152
+ font-size: 1.1em;
153
+ opacity: 0.9;
154
+ }
155
+
156
+ .disclaimer {
157
+ background-color: #fff3cd;
158
+ border: 1px solid #ffc107;
159
+ border-radius: 8px;
160
+ padding: 15px;
161
+ margin: 20px 0;
162
+ color: #856404;
163
+ }
164
+
165
+ .disclaimer h3 {
166
+ margin-top: 0;
167
+ color: #856404;
168
+ }
169
+
170
+ footer {
171
+ text-align: center;
172
+ margin-top: 30px;
173
+ color: #666;
174
+ font-size: 0.9em;
175
+ }
176
+ """
177
+
178
+ # =============================
179
+ # Gradio Interface
180
+ # =============================
181
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
182
+ # Header
183
+ gr.HTML("""
184
+ <div id="header">
185
+ <h1>🩺 ChatDoctor AI Assistant</h1>
186
+ <p>Your AI-powered medical conversation partner</p>
187
+ </div>
188
+ """)
189
+
190
+ # Disclaimer
191
+ gr.HTML("""
192
+ <div class="disclaimer">
193
+ <h3>⚠️ Medical Disclaimer</h3>
194
+ <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
195
+ It is NOT a substitute for professional medical advice, diagnosis, or treatment.
196
+ Always seek the advice of your physician or other qualified health provider with any questions
197
+ you may have regarding a medical condition. Never disregard professional medical advice or
198
+ delay in seeking it because of something you have read here.</p>
199
+ </div>
200
+ """)
201
+
202
+ # Chatbot Interface
203
+ chatbot = gr.Chatbot(
204
+ height=500,
205
+ placeholder="<div style='text-align: center; padding: 40px;'><h3>πŸ‘‹ Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. How can I assist you today?</p></div>",
206
+ show_label=False,
207
+ avatar_images=(None, "πŸ€–"),
208
+ )
209
+
210
+ with gr.Row():
211
+ msg = gr.Textbox(
212
+ placeholder="Type your message here... (e.g., 'I have a headache')",
213
+ show_label=False,
214
+ scale=9,
215
+ container=False
216
+ )
217
+ submit_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
218
+
219
+ with gr.Row():
220
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
221
+ retry_btn = gr.Button("πŸ”„ Retry", scale=1)
222
+
223
+ # Examples
224
+ gr.Examples(
225
+ examples=[
226
+ "I have a persistent headache for 3 days. What should I do?",
227
+ "What are the symptoms of diabetes?",
228
+ "How can I improve my sleep quality?",
229
+ "I have a fever and sore throat. Should I be concerned?",
230
+ "What are some natural ways to reduce stress?",
231
+ ],
232
+ inputs=msg,
233
+ label="πŸ’‘ Example Questions"
234
+ )
235
+
236
+ # Settings (collapsed by default)
237
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
238
+ temperature_slider = gr.Slider(
239
+ minimum=0.1,
240
+ maximum=1.0,
241
+ value=TEMPERATURE,
242
+ step=0.1,
243
+ label="Temperature (Creativity)",
244
+ info="Higher values make responses more creative but less focused"
245
+ )
246
+ max_tokens_slider = gr.Slider(
247
+ minimum=50,
248
+ maximum=500,
249
+ value=MAX_NEW_TOKENS,
250
+ step=50,
251
+ label="Max Response Length",
252
+ info="Maximum number of tokens in response"
253
+ )
254
+ top_k_slider = gr.Slider(
255
+ minimum=1,
256
+ maximum=100,
257
+ value=TOP_K,
258
+ step=1,
259
+ label="Top K",
260
+ info="Limits vocabulary selection"
261
+ )
262
+
263
+ # Footer
264
+ gr.HTML("""
265
+ <footer>
266
+ <p>Powered by ChatDoctor Model | Built with Gradio</p>
267
+ <p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
268
+ </footer>
269
+ """)
270
+
271
+ # Event handlers
272
+ def user_message(user_msg, history):
273
+ return "", history + [[user_msg, None]]
274
+
275
+ def bot_response(history, temp, max_tok, top_k_val):
276
+ global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
277
+ TEMPERATURE = temp
278
+ MAX_NEW_TOKENS = int(max_tok)
279
+ TOP_K = int(top_k_val)
280
+
281
+ user_msg = history[-1][0]
282
+ bot_msg = chat_function(user_msg, history[:-1])
283
+ history[-1][1] = bot_msg
284
+ return history
285
+
286
+ # Connect events
287
+ msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
288
+ bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
289
+ )
290
+
291
+ submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
292
+ bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
293
+ )
294
+
295
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
296
+
297
+ def retry_last():
298
+ return None
299
+
300
+ retry_btn.click(retry_last, None, chatbot, queue=False)
301
+
302
+ # =============================
303
+ # Launch Interface
304
+ # =============================
305
+ if __name__ == "__main__":
306
+ print("\nπŸš€ Launching ChatDoctor Gradio Interface...")
307
+ demo.queue()
308
+ demo.launch(
309
+ server_name="0.0.0.0", # Accessible from network
310
+ server_port=7860,
311
+ share=False, # Set to True to create public link
312
+ show_error=True
313
  )