Update processing_phi3_v.py
Browse files- processing_phi3_v.py +17 -20
processing_phi3_v.py
CHANGED
|
@@ -160,12 +160,12 @@ class Phi3VImageProcessor(BaseImageProcessor):
|
|
| 160 |
model_input_names = ["pixel_values"]
|
| 161 |
|
| 162 |
def __init__(
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
) -> None:
|
| 170 |
super().__init__(**kwargs)
|
| 171 |
self.num_crops = num_crops
|
|
@@ -174,8 +174,8 @@ class Phi3VImageProcessor(BaseImageProcessor):
|
|
| 174 |
self.do_convert_rgb = do_convert_rgb
|
| 175 |
|
| 176 |
def calc_num_image_tokens(
|
| 177 |
-
|
| 178 |
-
|
| 179 |
):
|
| 180 |
""" Calculate the number of image tokens for each image.
|
| 181 |
Args:
|
|
@@ -210,12 +210,12 @@ class Phi3VImageProcessor(BaseImageProcessor):
|
|
| 210 |
return num_img_tokens
|
| 211 |
|
| 212 |
def preprocess(
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
):
|
| 220 |
"""
|
| 221 |
Args:
|
|
@@ -276,8 +276,7 @@ class Phi3VImageProcessor(BaseImageProcessor):
|
|
| 276 |
# reshape to channel dimension -> (num_images, num_crops, 3, 336, 336)
|
| 277 |
# (1, 3, h//336, 336, w//336, 336) -> (1, h//336, w//336, 3, 336, 336) -> (h//336*w//336, 3, 336, 336)
|
| 278 |
hd_images_reshape = [
|
| 279 |
-
im.reshape(1, 3, h // 336, 336, w // 336, 336).permute(0, 2, 4, 1, 3, 5).reshape(-1, 3, 336,
|
| 280 |
-
336).contiguous() for
|
| 281 |
im, (h, w) in zip(hd_images, shapes)]
|
| 282 |
# concat global image and local image
|
| 283 |
hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in
|
|
@@ -443,11 +442,9 @@ class Phi3VProcessor(ProcessorMixin):
|
|
| 443 |
unique_image_ids = sorted(list(set(image_ids)))
|
| 444 |
# image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
|
| 445 |
# check the condition
|
| 446 |
-
assert unique_image_ids == list(range(1,
|
| 447 |
-
len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
| 448 |
# total images must be the same as the number of image tags
|
| 449 |
-
assert len(unique_image_ids) == len(
|
| 450 |
-
images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images"
|
| 451 |
|
| 452 |
image_ids_pad = [[-iid] * num_img_tokens[iid - 1] for iid in image_ids]
|
| 453 |
|
|
|
|
| 160 |
model_input_names = ["pixel_values"]
|
| 161 |
|
| 162 |
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
num_crops: int = 1,
|
| 165 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 166 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 167 |
+
do_convert_rgb: bool = True,
|
| 168 |
+
**kwargs,
|
| 169 |
) -> None:
|
| 170 |
super().__init__(**kwargs)
|
| 171 |
self.num_crops = num_crops
|
|
|
|
| 174 |
self.do_convert_rgb = do_convert_rgb
|
| 175 |
|
| 176 |
def calc_num_image_tokens(
|
| 177 |
+
self,
|
| 178 |
+
images: ImageInput
|
| 179 |
):
|
| 180 |
""" Calculate the number of image tokens for each image.
|
| 181 |
Args:
|
|
|
|
| 210 |
return num_img_tokens
|
| 211 |
|
| 212 |
def preprocess(
|
| 213 |
+
self,
|
| 214 |
+
images: ImageInput,
|
| 215 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 216 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 217 |
+
do_convert_rgb: bool = None,
|
| 218 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 219 |
):
|
| 220 |
"""
|
| 221 |
Args:
|
|
|
|
| 276 |
# reshape to channel dimension -> (num_images, num_crops, 3, 336, 336)
|
| 277 |
# (1, 3, h//336, 336, w//336, 336) -> (1, h//336, w//336, 3, 336, 336) -> (h//336*w//336, 3, 336, 336)
|
| 278 |
hd_images_reshape = [
|
| 279 |
+
im.reshape(1, 3, h // 336, 336, w // 336, 336).permute(0, 2, 4, 1, 3, 5).reshape(-1, 3, 336, 336).contiguous() for
|
|
|
|
| 280 |
im, (h, w) in zip(hd_images, shapes)]
|
| 281 |
# concat global image and local image
|
| 282 |
hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in
|
|
|
|
| 442 |
unique_image_ids = sorted(list(set(image_ids)))
|
| 443 |
# image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
|
| 444 |
# check the condition
|
| 445 |
+
assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
|
|
|
| 446 |
# total images must be the same as the number of image tags
|
| 447 |
+
assert len(unique_image_ids) == len(images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images"
|
|
|
|
| 448 |
|
| 449 |
image_ids_pad = [[-iid] * num_img_tokens[iid - 1] for iid in image_ids]
|
| 450 |
|