Commit
·
8e269ab
1
Parent(s):
cff733f
change inference.py class StoppingCriteriaSub(StoppingCriteria) to fix the error
Browse files- inference.py +6 -1
inference.py
CHANGED
|
@@ -176,11 +176,16 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
| 176 |
|
| 177 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
|
| 178 |
for stop in self.stops:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
| 180 |
return True
|
| 181 |
-
|
| 182 |
return False
|
| 183 |
|
|
|
|
|
|
|
| 184 |
@torch.inference_mode()
|
| 185 |
def generate_stream(
|
| 186 |
model, tokenizer, image_processor, params, device
|
|
|
|
| 176 |
|
| 177 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
|
| 178 |
for stop in self.stops:
|
| 179 |
+
if input_ids.shape[-1] < len(stop):
|
| 180 |
+
continue
|
| 181 |
+
# 把 stop 移动到 input_ids 的设备上
|
| 182 |
+
stop = stop.to(input_ids.device)
|
| 183 |
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
| 184 |
return True
|
|
|
|
| 185 |
return False
|
| 186 |
|
| 187 |
+
|
| 188 |
+
|
| 189 |
@torch.inference_mode()
|
| 190 |
def generate_stream(
|
| 191 |
model, tokenizer, image_processor, params, device
|