|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
from apps.plm.dataset_conf import DatasetConf |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class VisionPreprocessor: |
|
|
def __init__( |
|
|
self, |
|
|
transform: Optional[Callable], |
|
|
tokenizer: Callable, |
|
|
max_video_frames: Optional[int], |
|
|
dataset_config: DatasetConf, |
|
|
): |
|
|
self.mllm_tokenizer = tokenizer |
|
|
self.transform = transform |
|
|
self.root_dir = "" |
|
|
if dataset_config.root_dir: |
|
|
self.root_dir = dataset_config.root_dir |
|
|
self.max_video_frames = max_video_frames |
|
|
|
|
|
def __call__(self, row: Dict[str, Any], rng: np.random.RandomState): |
|
|
try: |
|
|
return self.process(row, rng) |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing row: {e}") |
|
|
return None |
|
|
|
|
|
def process(self, row: Dict[str, Any], rng: np.random.RandomState): |
|
|
del rng |
|
|
if "conversations" in row: |
|
|
conversations = row["conversations"] |
|
|
else: |
|
|
conversations = self.get_conversation(caption=row["text"], prompt="") |
|
|
|
|
|
if "bbox" in row: |
|
|
assert ( |
|
|
"width" in row and "height" in row |
|
|
), f"bbox is present in the annotation, however image width or height is not specified, which is not expected." |
|
|
w, h = row["width"], row["height"] |
|
|
bboxes = row["bbox"] |
|
|
conversations = self.transform["region"](conversations, bboxes, w, h) |
|
|
|
|
|
media = None |
|
|
media_type = "" |
|
|
if "image" in row: |
|
|
processed_images = [] |
|
|
image_files = row["image"] |
|
|
if isinstance(image_files, str): |
|
|
image_files = [image_files] |
|
|
pil_images = [] |
|
|
for image_file in image_files: |
|
|
if self.root_dir: |
|
|
image_file = os.path.join(self.root_dir, image_file) |
|
|
try: |
|
|
image = Image.open(image_file).convert("RGB") |
|
|
pil_images.append(image) |
|
|
except Exception as e: |
|
|
logger.info( |
|
|
f"loading image failed because of the following error:\n {e}" |
|
|
) |
|
|
return None |
|
|
if self.transform: |
|
|
if len(pil_images) == 1: |
|
|
transform = self.transform["image"] |
|
|
processed_images, _ = transform(pil_images[0]) |
|
|
else: |
|
|
transform = self.transform["video"] |
|
|
processed_images, _ = transform._process_multiple_images_pil( |
|
|
pil_images |
|
|
) |
|
|
if len(processed_images.shape) == 3: |
|
|
processed_images = processed_images.unsqueeze(0) |
|
|
media = processed_images |
|
|
media_type = "multi_image" if len(image_files) > 1 else "image" |
|
|
elif "video" in row: |
|
|
video_file = row["video"] |
|
|
start_time = row.get("start_time", None) |
|
|
bbox_map = row.get("bbox_map", None) |
|
|
end_time = row.get("end_time", None) |
|
|
if self.root_dir: |
|
|
video_file = os.path.join(self.root_dir, video_file) |
|
|
video_info = ( |
|
|
video_file, |
|
|
self.max_video_frames, |
|
|
start_time, |
|
|
end_time, |
|
|
bbox_map, |
|
|
) |
|
|
video, _ = self.transform["video"](video_info) |
|
|
media = video |
|
|
media_type = "video" |
|
|
else: |
|
|
|
|
|
|
|
|
media = torch.ones( |
|
|
1, 3, self.transform["image"].size, self.transform["image"].size |
|
|
) |
|
|
media_type = "text" |
|
|
|
|
|
tokenized_sample = self.mllm_tokenizer( |
|
|
conversations=conversations, media=media, media_type=media_type |
|
|
) |
|
|
out = ( |
|
|
{ |
|
|
"media": media, |
|
|
"text_ids": tokenized_sample.text_ids, |
|
|
"response_pos": tokenized_sample.response_pos, |
|
|
"image_pos": tokenized_sample.image_pos, |
|
|
"num_image_chunks": tokenized_sample.num_media_chunks, |
|
|
"media_type": media_type, |
|
|
} |
|
|
if tokenized_sample.is_valid |
|
|
else None |
|
|
) |
|
|
return out |
|
|
|
|
|
def get_conversation( |
|
|
self, caption: str, prompt: str = None |
|
|
) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Converts plain caption to conversation. |
|
|
|
|
|
Args: |
|
|
caption (str): plain caption |
|
|
|
|
|
Returns: |
|
|
List[Dict[str, str]]: conversation |
|
|
""" |
|
|
conversations = [ |
|
|
{"from": "human", "value": prompt if prompt is not None else ""}, |
|
|
{"from": "assistant", "value": caption}, |
|
|
] |
|
|
return conversations |
|
|
|