Update modelling_magi.py
Browse files- modelling_magi.py +2 -2
modelling_magi.py
CHANGED
|
@@ -181,7 +181,7 @@ class MagiModel(PreTrainedModel):
|
|
| 181 |
|
| 182 |
return crop_embeddings_for_batch
|
| 183 |
|
| 184 |
-
def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32):
|
| 185 |
assert not self.config.disable_ocr
|
| 186 |
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 187 |
|
|
@@ -207,7 +207,7 @@ class MagiModel(PreTrainedModel):
|
|
| 207 |
pbar = range(0, len(crops_per_image), batch_size)
|
| 208 |
for i in pbar:
|
| 209 |
crops = crops_per_image[i:i+batch_size]
|
| 210 |
-
generated_ids = self.ocr_model.generate(crops)
|
| 211 |
generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
|
| 212 |
all_generated_texts.extend(generated_texts)
|
| 213 |
|
|
|
|
| 181 |
|
| 182 |
return crop_embeddings_for_batch
|
| 183 |
|
| 184 |
+
def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32, max_new_tokens=64):
|
| 185 |
assert not self.config.disable_ocr
|
| 186 |
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 187 |
|
|
|
|
| 207 |
pbar = range(0, len(crops_per_image), batch_size)
|
| 208 |
for i in pbar:
|
| 209 |
crops = crops_per_image[i:i+batch_size]
|
| 210 |
+
generated_ids = self.ocr_model.generate(crops, max_new_tokens=max_new_tokens)
|
| 211 |
generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
|
| 212 |
all_generated_texts.extend(generated_texts)
|
| 213 |
|