MoTIF / utils /core /data /preprocessor.py
P4ddyki's picture
Upload folder using huggingface_hub
3cf4fff verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import torch
from PIL import Image
from apps.plm.dataset_conf import DatasetConf
logger = logging.getLogger(__name__)
class VisionPreprocessor:
def __init__(
self,
transform: Optional[Callable],
tokenizer: Callable,
max_video_frames: Optional[int],
dataset_config: DatasetConf,
):
self.mllm_tokenizer = tokenizer
self.transform = transform
self.root_dir = ""
if dataset_config.root_dir:
self.root_dir = dataset_config.root_dir
self.max_video_frames = max_video_frames
def __call__(self, row: Dict[str, Any], rng: np.random.RandomState):
try:
return self.process(row, rng)
except Exception as e:
logger.error(f"Error processing row: {e}")
return None # None will be skipped in training
def process(self, row: Dict[str, Any], rng: np.random.RandomState):
del rng
if "conversations" in row:
conversations = row["conversations"]
else:
conversations = self.get_conversation(caption=row["text"], prompt="")
if "bbox" in row:
assert (
"width" in row and "height" in row
), f"bbox is present in the annotation, however image width or height is not specified, which is not expected."
w, h = row["width"], row["height"]
bboxes = row["bbox"]
conversations = self.transform["region"](conversations, bboxes, w, h)
media = None
media_type = ""
if "image" in row:
processed_images = []
image_files = row["image"]
if isinstance(image_files, str):
image_files = [image_files]
pil_images = []
for image_file in image_files:
if self.root_dir:
image_file = os.path.join(self.root_dir, image_file)
try:
image = Image.open(image_file).convert("RGB")
pil_images.append(image)
except Exception as e:
logger.info(
f"loading image failed because of the following error:\n {e}"
)
return None
if self.transform:
if len(pil_images) == 1:
transform = self.transform["image"]
processed_images, _ = transform(pil_images[0])
else:
transform = self.transform["video"]
processed_images, _ = transform._process_multiple_images_pil(
pil_images
)
if len(processed_images.shape) == 3:
processed_images = processed_images.unsqueeze(0)
media = processed_images
media_type = "multi_image" if len(image_files) > 1 else "image"
elif "video" in row:
video_file = row["video"]
start_time = row.get("start_time", None)
bbox_map = row.get("bbox_map", None)
end_time = row.get("end_time", None)
if self.root_dir:
video_file = os.path.join(self.root_dir, video_file)
video_info = (
video_file,
self.max_video_frames,
start_time,
end_time,
bbox_map,
)
video, _ = self.transform["video"](video_info)
media = video
media_type = "video"
else:
# This is a text-only sample. We create a dummy white image to facilitate batch processing.
# Note that this image serves solely as a placeholder and is never used as input to the VLM.
media = torch.ones(
1, 3, self.transform["image"].size, self.transform["image"].size
)
media_type = "text"
tokenized_sample = self.mllm_tokenizer(
conversations=conversations, media=media, media_type=media_type
)
out = (
{
"media": media,
"text_ids": tokenized_sample.text_ids,
"response_pos": tokenized_sample.response_pos,
"image_pos": tokenized_sample.image_pos,
"num_image_chunks": tokenized_sample.num_media_chunks,
"media_type": media_type,
}
if tokenized_sample.is_valid
else None
) # None will be skipped in training
return out
def get_conversation(
self, caption: str, prompt: str = None
) -> List[Dict[str, str]]:
"""
Converts plain caption to conversation.
Args:
caption (str): plain caption
Returns:
List[Dict[str, str]]: conversation
"""
conversations = [
{"from": "human", "value": prompt if prompt is not None else ""},
{"from": "assistant", "value": caption},
]
return conversations