| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| import pandas as pd |
| import pdb |
|
|
| import sys |
| import os |
| import random |
|
|
| from torchvision import transforms |
| from collections import defaultdict |
| import cv2 |
|
|
| |
|
|
| import transformers |
| from transformers import AutoImageProcessor, AutoTokenizer, AutoModel |
| from PIL import Image |
|
|
| |
| from transformers import pipeline |
| from transformers import CLIPImageProcessor |
| from models.segment_anything.utils.transforms import ResizeLongestSide |
|
|
| |
|
|
| import pickle as pkl |
| from models.llava import conversation as conversation_lib |
| from collections import defaultdict |
|
|
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = -200 |
| DEFAULT_IMAGE_TOKEN = "<image>" |
| DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
| DEFAULT_IM_START_TOKEN = "<im_start>" |
| DEFAULT_IM_END_TOKEN = "<im_end>" |
| DEFAULT_VIDEO_TOKEN = "<video>" |
|
|
| AUDIO_TOKEN_INDEX = -300 |
| DEFAULT_AUDIO_TOKEN = "<audio>" |
|
|
| class REFAVS(Dataset): |
| def __init__(self, split='train', cfg=None, tokenizer=None, input_type='refer'): |
| self.input_type = input_type |
| self.data_dir = cfg.data_dir |
|
|
| meta_path = f'{self.data_dir}/metadata.csv' |
|
|
|
|
| metadata = pd.read_csv(meta_path, header=0) |
| self.split = split |
| self.metadata = metadata[metadata['split'] == split] |
|
|
| |
| self.video_to_samples = defaultdict(list) |
|
|
|
|
| for i in range(len(self.metadata)): |
| row = self.metadata.iloc[i] |
| vid = row['uid'].rsplit('_', 2)[0] |
| self.video_to_samples[vid].append(i) |
|
|
|
|
| self.all_vids = list(self.video_to_samples.keys()) |
| |
|
|
|
|
|
|
| self.media_path = f'{self.data_dir}/media' |
| self.label_path = f'{self.data_dir}/gt_mask' |
| self.frame_num = cfg.frame_n |
| self.text_max_len = cfg.text_max_len |
|
|
| self.tokenizer = tokenizer |
|
|
| |
| if cfg.conv_template == 0: |
| self.system = "\nReference Video: <video> \nTarget Image: <image> \n" |
| elif cfg.conv_template == 1: |
| self.system = "\nReference Video: <video> \nReference Audio: <audio> \nTarget Image: <image> \n" |
|
|
|
|
| self.question = "What is {sent} in the Reference Video? Please respond with segmentation mask in the Target Image." |
|
|
|
|
|
|
| self.clip_image_processor = CLIPImageProcessor.from_pretrained(cfg.vision_tower) |
|
|
|
|
| self.transform = ResizeLongestSide(1024) |
|
|
| self.pixel_mean = torch.Tensor([113.263, 99.370, 92.492]).view(-1, 1, 1) |
| self.pixel_std = torch.Tensor([64.274, 61.068, 58.626]).view(-1, 1, 1) |
| self.img_size = 1024 |
| |
|
|
|
|
| def preprocess(self, x: torch.Tensor) -> torch.Tensor: |
| """Normalize pixel values and pad to a square input.""" |
| |
| x = (x - self.pixel_mean) / self.pixel_std |
| |
| h, w = x.shape[-2:] |
| padh = self.img_size - h |
| padw = self.img_size - w |
| x = F.pad(x, (0, padw, 0, padh)) |
| return x |
|
|
|
|
|
|
| def __len__(self): |
| if self.input_type == 'refer': |
| return len(self.metadata) |
| elif self.input_type == 'video': |
| return len(self.all_vids) |
|
|
| def __getitem__(self, idx): |
|
|
| if self.input_type == 'refer' : |
| vid = self.metadata.iloc[idx]['uid'].rsplit('_', 2)[0] |
| indices = [idx] |
| elif self.input_type == 'video': |
|
|
| vid = self.all_vids[idx] |
|
|
| indices = self.video_to_samples[vid] |
|
|
|
|
| feat_aud = torch.load(f'{self.data_dir}/audio_embed/{vid}.pt') |
| image_feat = torch.load(f'{self.data_dir}/image_embed/{vid}.pt') |
|
|
| img_clips = [] |
| masks = [] |
| images = [] |
| rec_texts = [] |
| target_ids = [] |
|
|
| conversations = [] |
|
|
| conv = conversation_lib.default_conversation.copy() |
| conv.messages = [] |
| |
|
|
| conv.system += self.system.format() |
|
|
|
|
| for i, meta_idx in enumerate(indices): |
| row = self.metadata.iloc[meta_idx] |
| refer = row['exp'].lower().rstrip('.') |
| fid = row['fid'] |
|
|
| conv.append_message(conv.roles[0], self.question.format(sent=refer)) |
| conv.append_message(conv.roles[1], "Sure, it is [SEG]") |
|
|
| rec_texts.append(refer) |
| target_ids.append(fid) |
|
|
| temp_mask = [] |
| for frame_idx in range(self.frame_num): |
| path_mask = f'{self.label_path}/{vid}/fid_{fid}/0000{frame_idx}.png' |
| mask_cv2 = cv2.imread(path_mask) |
| mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY) |
| gt_binary_mask = torch.as_tensor(mask_cv2 > 0, dtype=torch.float32) |
| temp_mask.append(gt_binary_mask) |
| masks.append(torch.stack(temp_mask, dim=0)) |
|
|
| orgsize = masks[0].shape[-2:] |
|
|
| conversation = conv.get_prompt() |
|
|
| for _idx in range(self.frame_num): |
| path_frame = f'{self.media_path}/{vid}/frames/{_idx}.jpg' |
| image = cv2.imread(path_frame) |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
| image_clip = self.clip_image_processor(image, return_tensors="pt")["pixel_values"][0] |
| image = self.transform.apply_image(image) |
| resize = image.shape[:2] |
| image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) |
|
|
| images.append(image) |
| img_clips.append(image_clip) |
|
|
|
|
|
|
|
|
| images = torch.stack(images, dim=0) |
| img_clips = torch.stack(img_clips, dim=0) |
| masks = torch.stack(masks, dim=0) |
|
|
| return { |
| 'vid': vid, |
| 'image': images, |
| 'img_clip': img_clips, |
| 'mask': masks, |
| 'conversation': conversation, |
| 'feat_aud': feat_aud, |
| 'resize': resize, |
| 'orgsize': orgsize, |
| 'feat_sam': image_feat, |
| 'ref': rec_texts, |
| 'fids': target_ids, |
| } |
|
|
|
|
|
|