Allow for batch sizes larger than 1.
#16
by
sebbyjp
- opened
- processing_phi3_v.py +22 -9
processing_phi3_v.py
CHANGED
|
@@ -20,14 +20,19 @@ import re
|
|
| 20 |
from typing import List, Optional, Union
|
| 21 |
|
| 22 |
import torch
|
| 23 |
-
|
| 24 |
import transformers
|
| 25 |
from transformers.feature_extraction_utils import BatchFeature
|
| 26 |
from transformers.image_utils import ImageInput
|
| 27 |
from transformers.processing_utils import ProcessorMixin
|
| 28 |
-
from transformers.tokenization_utils_base import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
from transformers.utils import TensorType
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
transformers.Phi3VImageProcessor = Phi3VImageProcessor
|
| 32 |
|
| 33 |
class Phi3VProcessor(ProcessorMixin):
|
|
@@ -144,13 +149,23 @@ class Phi3VProcessor(ProcessorMixin):
|
|
| 144 |
return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
|
| 145 |
|
| 146 |
def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None):
|
| 147 |
-
|
| 148 |
if not len(images):
|
| 149 |
model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)
|
| 150 |
return BatchFeature(data={**model_inputs})
|
| 151 |
|
| 152 |
pattern = r"<\|image_\d+\|>"
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
if 'num_img_tokens' in images:
|
| 156 |
num_img_tokens = images['num_img_tokens']
|
|
@@ -161,10 +176,8 @@ class Phi3VProcessor(ProcessorMixin):
|
|
| 161 |
|
| 162 |
images, image_sizes = images['pixel_values'], images['image_sizes']
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
|
| 167 |
-
# image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
|
| 168 |
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
| 169 |
unique_image_ids = sorted(list(set(image_ids)))
|
| 170 |
# image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
|
|
|
|
| 20 |
from typing import List, Optional, Union
|
| 21 |
|
| 22 |
import torch
|
|
|
|
| 23 |
import transformers
|
| 24 |
from transformers.feature_extraction_utils import BatchFeature
|
| 25 |
from transformers.image_utils import ImageInput
|
| 26 |
from transformers.processing_utils import ProcessorMixin
|
| 27 |
+
from transformers.tokenization_utils_base import (
|
| 28 |
+
PaddingStrategy,
|
| 29 |
+
TextInput,
|
| 30 |
+
TruncationStrategy,
|
| 31 |
+
)
|
| 32 |
from transformers.utils import TensorType
|
| 33 |
+
|
| 34 |
+
from .image_processing_phi3_v import Phi3VImageProcessor
|
| 35 |
+
|
| 36 |
transformers.Phi3VImageProcessor = Phi3VImageProcessor
|
| 37 |
|
| 38 |
class Phi3VProcessor(ProcessorMixin):
|
|
|
|
| 149 |
return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
|
| 150 |
|
| 151 |
def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None):
|
|
|
|
| 152 |
if not len(images):
|
| 153 |
model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)
|
| 154 |
return BatchFeature(data={**model_inputs})
|
| 155 |
|
| 156 |
pattern = r"<\|image_\d+\|>"
|
| 157 |
+
|
| 158 |
+
# Don't over list-comprehend this, it's already hard to read.
|
| 159 |
+
prompt_chunks = []
|
| 160 |
+
image_tags = []
|
| 161 |
+
for text in texts:
|
| 162 |
+
chunks = re.split(pattern, text)
|
| 163 |
+
chunk_image_tags = re.findall(pattern, text)
|
| 164 |
+
for chunk, chunk_image_tag in zip(chunks, chunk_image_tags):
|
| 165 |
+
tokenized_chunk = self.tokenizer(chunk).input_ids
|
| 166 |
+
prompt_chunks.append(tokenized_chunk)
|
| 167 |
+
image_tags.append(chunk_image_tag)
|
| 168 |
+
|
| 169 |
|
| 170 |
if 'num_img_tokens' in images:
|
| 171 |
num_img_tokens = images['num_img_tokens']
|
|
|
|
| 176 |
|
| 177 |
images, image_sizes = images['pixel_values'], images['image_sizes']
|
| 178 |
|
| 179 |
+
|
| 180 |
+
# image_tags needs to start from 1 to num_images
|
|
|
|
|
|
|
| 181 |
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
| 182 |
unique_image_ids = sorted(list(set(image_ids)))
|
| 183 |
# image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
|