ccclemenfff commited on
Commit
8e269ab
·
1 Parent(s): cff733f

change inference.py class StoppingCriteriaSub(StoppingCriteria) to fix the error

Browse files
Files changed (1) hide show
  1. inference.py +6 -1
inference.py CHANGED
@@ -176,11 +176,16 @@ class StoppingCriteriaSub(StoppingCriteria):
176
 
177
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
178
  for stop in self.stops:
 
 
 
 
179
  if torch.all((stop == input_ids[0][-len(stop):])).item():
180
  return True
181
-
182
  return False
183
 
 
 
184
  @torch.inference_mode()
185
  def generate_stream(
186
  model, tokenizer, image_processor, params, device
 
176
 
177
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
178
  for stop in self.stops:
179
+ if input_ids.shape[-1] < len(stop):
180
+ continue
181
+ # 把 stop 移动到 input_ids 的设备上
182
+ stop = stop.to(input_ids.device)
183
  if torch.all((stop == input_ids[0][-len(stop):])).item():
184
  return True
 
185
  return False
186
 
187
+
188
+
189
  @torch.inference_mode()
190
  def generate_stream(
191
  model, tokenizer, image_processor, params, device