File size: 5,177 Bytes
3cf4fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import torch
from PIL import Image
from apps.plm.dataset_conf import DatasetConf
logger = logging.getLogger(__name__)
class VisionPreprocessor:
def __init__(
self,
transform: Optional[Callable],
tokenizer: Callable,
max_video_frames: Optional[int],
dataset_config: DatasetConf,
):
self.mllm_tokenizer = tokenizer
self.transform = transform
self.root_dir = ""
if dataset_config.root_dir:
self.root_dir = dataset_config.root_dir
self.max_video_frames = max_video_frames
def __call__(self, row: Dict[str, Any], rng: np.random.RandomState):
try:
return self.process(row, rng)
except Exception as e:
logger.error(f"Error processing row: {e}")
return None # None will be skipped in training
def process(self, row: Dict[str, Any], rng: np.random.RandomState):
del rng
if "conversations" in row:
conversations = row["conversations"]
else:
conversations = self.get_conversation(caption=row["text"], prompt="")
if "bbox" in row:
assert (
"width" in row and "height" in row
), f"bbox is present in the annotation, however image width or height is not specified, which is not expected."
w, h = row["width"], row["height"]
bboxes = row["bbox"]
conversations = self.transform["region"](conversations, bboxes, w, h)
media = None
media_type = ""
if "image" in row:
processed_images = []
image_files = row["image"]
if isinstance(image_files, str):
image_files = [image_files]
pil_images = []
for image_file in image_files:
if self.root_dir:
image_file = os.path.join(self.root_dir, image_file)
try:
image = Image.open(image_file).convert("RGB")
pil_images.append(image)
except Exception as e:
logger.info(
f"loading image failed because of the following error:\n {e}"
)
return None
if self.transform:
if len(pil_images) == 1:
transform = self.transform["image"]
processed_images, _ = transform(pil_images[0])
else:
transform = self.transform["video"]
processed_images, _ = transform._process_multiple_images_pil(
pil_images
)
if len(processed_images.shape) == 3:
processed_images = processed_images.unsqueeze(0)
media = processed_images
media_type = "multi_image" if len(image_files) > 1 else "image"
elif "video" in row:
video_file = row["video"]
start_time = row.get("start_time", None)
bbox_map = row.get("bbox_map", None)
end_time = row.get("end_time", None)
if self.root_dir:
video_file = os.path.join(self.root_dir, video_file)
video_info = (
video_file,
self.max_video_frames,
start_time,
end_time,
bbox_map,
)
video, _ = self.transform["video"](video_info)
media = video
media_type = "video"
else:
# This is a text-only sample. We create a dummy white image to facilitate batch processing.
# Note that this image serves solely as a placeholder and is never used as input to the VLM.
media = torch.ones(
1, 3, self.transform["image"].size, self.transform["image"].size
)
media_type = "text"
tokenized_sample = self.mllm_tokenizer(
conversations=conversations, media=media, media_type=media_type
)
out = (
{
"media": media,
"text_ids": tokenized_sample.text_ids,
"response_pos": tokenized_sample.response_pos,
"image_pos": tokenized_sample.image_pos,
"num_image_chunks": tokenized_sample.num_media_chunks,
"media_type": media_type,
}
if tokenized_sample.is_valid
else None
) # None will be skipped in training
return out
def get_conversation(
self, caption: str, prompt: str = None
) -> List[Dict[str, str]]:
"""
Converts plain caption to conversation.
Args:
caption (str): plain caption
Returns:
List[Dict[str, str]]: conversation
"""
conversations = [
{"from": "human", "value": prompt if prompt is not None else ""},
{"from": "assistant", "value": caption},
]
return conversations
|