nikshep01 commited on
Commit
a59d8d3
·
verified ·
1 Parent(s): d1ecdea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -61
app.py CHANGED
@@ -1,3 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
 
3
  # import gradio as gr
@@ -162,66 +233,5 @@
162
 
163
 
164
 
165
- import gradio as gr
166
- import torch
167
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
168
- from threading import Thread
169
-
170
- # Load the tokenizer and model
171
- tokenizer = AutoTokenizer.from_pretrained("thrishala/mental_health_chatbot")
172
-
173
- # Check if CUDA (GPU) is available, otherwise use CPU
174
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
175
- model = AutoModelForCausalLM.from_pretrained("thrishala/mental_health_chatbot", torch_dtype=torch.float16)
176
- model = model.to(device)
177
-
178
- # Custom stopping criteria to stop generation on specific tokens
179
- class StopOnTokens(StoppingCriteria):
180
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
181
- stop_ids = [29, 0] # EOS token or any other token you want to stop on
182
- for stop_id in stop_ids:
183
- if input_ids[0][-1] == stop_id:
184
- return True
185
- return False
186
-
187
- def predict(message, history):
188
- # Prepare the message history for the model
189
- history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
190
- stop = StopOnTokens()
191
-
192
- # Format the conversation for the model
193
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) for item in history_transformer_format])
194
-
195
- # Tokenize input and move to the correct device (GPU or CPU)
196
- model_inputs = tokenizer([messages], return_tensors="pt").to(device)
197
-
198
- # Create a streamer to handle model outputs
199
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
200
-
201
- # Generation parameters
202
- generate_kwargs = dict(
203
- model_inputs,
204
- streamer=streamer,
205
- max_new_tokens=1024,
206
- do_sample=True,
207
- top_p=0.95,
208
- top_k=1000,
209
- temperature=1.0,
210
- num_beams=1,
211
- stopping_criteria=StoppingCriteriaList([stop])
212
- )
213
-
214
- # Run the generation in a separate thread
215
- t = Thread(target=model.generate, kwargs=generate_kwargs)
216
- t.start()
217
-
218
- # Collect the generated tokens
219
- partial_message = ""
220
- for new_token in streamer:
221
- if new_token != '<': # Avoid issues with special tokens
222
- partial_message += new_token
223
- yield partial_message
224
 
225
- # Launch the Gradio interface
226
- gr.ChatInterface(predict).launch()
227
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+ from threading import Thread
5
+ from queue import Empty
6
+
7
+ # Load the tokenizer and model
8
+ tokenizer = AutoTokenizer.from_pretrained("thrishala/mental_health_chatbot")
9
+ model = AutoModelForCausalLM.from_pretrained("thrishala/mental_health_chatbot", torch_dtype=torch.float16)
10
+
11
+ # Move model to GPU if available
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ model = model.to(device)
14
+
15
+ class StopOnTokens(StoppingCriteria):
16
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
+ stop_ids = [29, 0] # Token IDs for stopping criteria
18
+ for stop_id in stop_ids:
19
+ if input_ids[0][-1] == stop_id:
20
+ return True
21
+ return False
22
+
23
+ def predict(message, history):
24
+ # Prepare the input history in the expected format for the model
25
+ history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
26
+ stop = StopOnTokens()
27
+
28
+ # Concatenate conversation history
29
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) for item in history_transformer_format])
30
+
31
+ # Tokenize and prepare model inputs
32
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
33
+
34
+ # Create streamer with longer timeout
35
+ streamer = TextIteratorStreamer(tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
36
+
37
+ # Define generation parameters
38
+ generate_kwargs = dict(
39
+ model_inputs,
40
+ streamer=streamer,
41
+ max_new_tokens=512, # Reduced to avoid memory issues
42
+ do_sample=True,
43
+ top_p=0.85, # Adjusted for faster generation
44
+ top_k=500, # Adjusted for faster generation
45
+ temperature=1.0,
46
+ num_beams=1,
47
+ stopping_criteria=StoppingCriteriaList([stop])
48
+ )
49
+
50
+ # Run the generation in a separate thread
51
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
52
+ t.start()
53
+
54
+ # Yield generated tokens
55
+ partial_message = ""
56
+ try:
57
+ for new_token in streamer:
58
+ print(f"Received token: {new_token}") # Debugging output
59
+ if new_token != '<':
60
+ partial_message += new_token
61
+ yield partial_message
62
+ except Empty:
63
+ print("No tokens were generated within the timeout period.")
64
+
65
+ # Gradio interface to run the chatbot
66
+ gr.ChatInterface(predict).launch()
67
+
68
+
69
+
70
+
71
+
72
 
73
 
74
  # import gradio as gr
 
233
 
234
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
 
 
237