nikshep01 commited on
Commit
d1ecdea
·
verified ·
1 Parent(s): 227fc7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -167,27 +167,38 @@ import torch
167
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
168
  from threading import Thread
169
 
 
170
  tokenizer = AutoTokenizer.from_pretrained("thrishala/mental_health_chatbot")
 
 
 
171
  model = AutoModelForCausalLM.from_pretrained("thrishala/mental_health_chatbot", torch_dtype=torch.float16)
172
- model = model.to('cuda:0')
173
 
 
174
  class StopOnTokens(StoppingCriteria):
175
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
176
- stop_ids = [29, 0]
177
  for stop_id in stop_ids:
178
  if input_ids[0][-1] == stop_id:
179
  return True
180
  return False
181
 
182
  def predict(message, history):
 
183
  history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
184
  stop = StopOnTokens()
185
 
186
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
187
- for item in history_transformer_format])
188
 
189
- model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
 
 
 
190
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
 
191
  generate_kwargs = dict(
192
  model_inputs,
193
  streamer=streamer,
@@ -198,15 +209,19 @@ def predict(message, history):
198
  temperature=1.0,
199
  num_beams=1,
200
  stopping_criteria=StoppingCriteriaList([stop])
201
- )
 
 
202
  t = Thread(target=model.generate, kwargs=generate_kwargs)
203
  t.start()
204
 
 
205
  partial_message = ""
206
  for new_token in streamer:
207
- if new_token != '<':
208
  partial_message += new_token
209
  yield partial_message
210
 
 
211
  gr.ChatInterface(predict).launch()
212
 
 
167
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
168
  from threading import Thread
169
 
170
+ # Load the tokenizer and model
171
  tokenizer = AutoTokenizer.from_pretrained("thrishala/mental_health_chatbot")
172
+
173
+ # Check if CUDA (GPU) is available, otherwise use CPU
174
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
175
  model = AutoModelForCausalLM.from_pretrained("thrishala/mental_health_chatbot", torch_dtype=torch.float16)
176
+ model = model.to(device)
177
 
178
+ # Custom stopping criteria to stop generation on specific tokens
179
  class StopOnTokens(StoppingCriteria):
180
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
181
+ stop_ids = [29, 0] # EOS token or any other token you want to stop on
182
  for stop_id in stop_ids:
183
  if input_ids[0][-1] == stop_id:
184
  return True
185
  return False
186
 
187
  def predict(message, history):
188
+ # Prepare the message history for the model
189
  history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
190
  stop = StopOnTokens()
191
 
192
+ # Format the conversation for the model
193
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) for item in history_transformer_format])
194
 
195
+ # Tokenize input and move to the correct device (GPU or CPU)
196
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
197
+
198
+ # Create a streamer to handle model outputs
199
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
200
+
201
+ # Generation parameters
202
  generate_kwargs = dict(
203
  model_inputs,
204
  streamer=streamer,
 
209
  temperature=1.0,
210
  num_beams=1,
211
  stopping_criteria=StoppingCriteriaList([stop])
212
+ )
213
+
214
+ # Run the generation in a separate thread
215
  t = Thread(target=model.generate, kwargs=generate_kwargs)
216
  t.start()
217
 
218
+ # Collect the generated tokens
219
  partial_message = ""
220
  for new_token in streamer:
221
+ if new_token != '<': # Avoid issues with special tokens
222
  partial_message += new_token
223
  yield partial_message
224
 
225
+ # Launch the Gradio interface
226
  gr.ChatInterface(predict).launch()
227