| import torch |
| from torch.utils import data |
| import numpy as np |
| from torchvision import transforms |
| import cv2 |
| import glob |
| import json |
| import os |
| from PIL import Image |
| import random |
| import decord |
| decord.bridge.set_bridge('torch') |
|
|
| class t2i_Dataset(data.Dataset): |
| def __init__(self,vids_dir, json_path, vis_processor): |
| super(t2i_Dataset).__init__() |
| with open(json_path, "r") as f: |
| self.dt=json.load(f) |
| self.vids_dir=vids_dir |
| self.vis_processor = vis_processor |
| |
| def __len__(self): |
| return len(self.dt) |
|
|
| def __getitem__(self, idx): |
| item = self.dt[idx] |
| prompt = item["prompt"] |
| filename = os.path.join(self.vids_dir,item["video"]) |
| vid_nm = filename.split('/')[-1][:-4] |
| cap = cv2.VideoCapture(filename) |
| video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| video_frame_rate = int(round(cap.get(cv2.CAP_PROP_FPS))) |
|
|
| video_length_round = video_length if video_length % 8 == 0 else ( |
| video_length//8 + 1)*8 |
| if video_frame_rate == 0: |
| raise Exception('no frame detect') |
| |
| transformed_frame_all = [] |
|
|
| for i in range(video_length): |
| has_frames, frame = cap.read() |
| if not has_frames: |
| raise Exception('no frame detect') |
| read_frame = Image.fromarray( |
| cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| image = self.vis_processor["eval"](read_frame) |
| transformed_frame_all.append(image) |
| cap.release() |
| transformed_frame_all=torch.stack(transformed_frame_all).half() |
| |
| |
| if video_length_round != video_length: |
| transformed_frame_all[video_length:, |
| :] = transformed_frame_all[video_length-1, :] |
|
|
| |
| |
|
|
| return transformed_frame_all, vid_nm, prompt |
|
|
|
|
| class tem_Dataset(data.Dataset): |
| """Read data from the original dataset for feature extraction""" |
|
|
| def __init__(self, |
| data_dir, transform, json_path=None, resize=224): |
| super(tem_Dataset, self).__init__() |
|
|
| if json_path is not None: |
| with open(json_path, "r") as f: |
| self.dt = json.load(f) |
| self.video_names = [os.path.join( |
| data_dir, item["video"]) for item in self.dt] |
| else: |
| self.video_names = glob.glob(f'{data_dir}/*.mp4') |
| self.json_path = json_path |
| self.videos_dir = data_dir |
| self.transform = transform |
| self.resize = resize |
|
|
| def __len__(self): |
| return len(self.video_names) |
|
|
| def __getitem__(self, idx): |
| filename = self.video_names[idx] |
| if self.json_path is None: |
| vid_nm = filename.split('/')[-1][:-4] |
| else: |
| vid_nm = self.dt[idx]["video"][:-4] |
| cap = cv2.VideoCapture(filename) |
| video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| video_frame_rate = int(round(cap.get(cv2.CAP_PROP_FPS))) |
|
|
| video_length_round = video_length if video_length % 8 == 0 else ( |
| video_length//8 + 1)*8 |
| if video_frame_rate == 0: |
| raise Exception('no frame detect') |
| video_channel = 3 |
| transformed_frame_all = torch.zeros( |
| [video_length_round, video_channel, self.resize, self.resize]) |
|
|
| for i in range(video_length): |
| has_frames, frame = cap.read() |
| if not has_frames: |
| raise Exception('no frame detect') |
| read_frame = Image.fromarray( |
| cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| read_frame = self.transform(read_frame) |
| |
| transformed_frame_all[i] = read_frame |
| cap.release() |
|
|
| |
| |
| if video_length_round != video_length: |
| transformed_frame_all[video_length:, |
| :] = transformed_frame_all[video_length-1, :] |
|
|
| |
| |
|
|
| return transformed_frame_all, vid_nm |
|
|
|
|
| def pyramidsGL(image, num_levels, dim=224): |
| ''' Creates Gaussian (G) and Laplacian (L) pyramids of level "num_levels" from image im. |
| G and L are list where G[i], L[i] stores the i-th level of Gaussian and Laplacian pyramid, respectively. ''' |
| o_height, o_width = image.shape[:2] |
|
|
| |
| if o_width > o_height: |
| f_height = dim |
| f_width = int((o_width * f_height) / o_height) |
| elif o_height > o_width: |
| f_width = dim |
| f_height = int((o_height * f_width) / o_width) |
| else: |
| f_width = f_height = dim |
|
|
| |
| if o_width > (dim + num_levels) and o_height > (dim + num_levels): |
| height_step = int((o_height - f_height) / (num_levels - 1)) * (-1) |
| width_step = int((o_width - f_width) / (num_levels - 1)) * (-1) |
| height_list = list(range(o_height, f_height - 1, height_step)) |
| width_list = list(range(o_width, f_width - 1, width_step)) |
| else: |
| |
| image = cv2.resize(image, (f_width, f_height), |
| interpolation=cv2.INTER_CUBIC) |
| height_list = [f_height] * num_levels |
| width_list = [f_width] * num_levels |
|
|
| |
| gaussian_pyramid = [image.copy()] |
| laplacian_pyramid = [] |
|
|
| |
| for i in range(num_levels - 1): |
| |
| blur = cv2.GaussianBlur(gaussian_pyramid[i], (5, 5), 5) |
|
|
| |
| next_level = cv2.resize(blur, (width_list[i + 1], height_list[i + 1]), |
| interpolation=cv2.INTER_CUBIC) |
| gaussian_pyramid.append(next_level) |
|
|
| |
| upsampled = cv2.resize(blur, (width_list[i], height_list[i]), |
| interpolation=cv2.INTER_CUBIC) |
| laplacian = cv2.subtract(gaussian_pyramid[i], upsampled) |
| laplacian_pyramid.append(laplacian) |
|
|
| return laplacian_pyramid |
|
|
|
|
| def resizedpyramids(laplacian_pyramid, num_levels, width, height): |
| '''Resize all levels of the Laplacian pyramid to the specified dimensions''' |
| return [cv2.resize(level, (width, height), interpolation=cv2.INTER_CUBIC) |
| for level in laplacian_pyramid] |
|
|
|
|
| class spa_Dataset(data.Dataset): |
| """Read data from the original dataset for feature extraction""" |
|
|
| def __init__(self, vids_dir, num_levels=6, json_path=None): |
| super(spa_Dataset, self).__init__() |
| self.json_path = json_path |
| if self.json_path is not None: |
| with open(self.json_path, "r") as f: |
| dt = json.load(f) |
| |
| |
| |
| |
| self.vids_dir = [os.path.join( |
| vids_dir, item["video"]) for item in dt] |
| self.dt = dt |
| else: |
| self.vids_dir = glob.glob(f'{vids_dir}/*.mp4') |
|
|
| self.num_levels = num_levels |
|
|
| def __len__(self): |
| return len((self.vids_dir)) |
|
|
| def __getitem__(self, idx): |
|
|
| vid_path = self.vids_dir[idx] |
| if self.json_path is not None: |
| vid_name = self.dt[idx]["video"][:-4] |
| else: |
| vid_name = vid_path.split('/')[-1][:-4] |
|
|
| frames = decord.VideoReader(vid_path) |
| final_vid = [] |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) |
| print(frames) |
| for img in frames: |
| img = img.numpy() |
| video_height = img.shape[0] |
| video_width = img.shape[1] |
| |
| laplacian_pyramid = pyramidsGL(img, self.num_levels) |
|
|
| laplacian_pyramid_resized = resizedpyramids( |
| laplacian_pyramid, self.num_levels, video_height, video_width) |
| laplacian_pyramid_resized = np.array(laplacian_pyramid_resized) |
| for frame in laplacian_pyramid_resized: |
| lp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| |
| |
| lp = transform(lp) |
| final_vid.append(lp) |
| final_vid = torch.stack(final_vid) |
| return final_vid, vid_name |
|
|
|
|
| class Dataset1(data.Dataset): |
| def __init__(self, datatype, name, cfg, seed=42, frame_num=8): |
| super(Dataset1, self).__init__() |
|
|
| with open(cfg["mos_file"], "r") as f: |
| dt = json.load(f) |
| |
| length = cfg["prompt_num"] |
| random.seed(seed) |
| np.random.seed(seed) |
| random_idx = np.random.permutation(length) |
| if datatype == "val": |
| random_idx = random_idx[int(length*0.7):int(length*0.8)] |
| if datatype == "test": |
| random_idx = random_idx[int(length*0.8):] |
| else: |
| random_idx = random_idx[:int(length*0.7)] |
| |
| loop = 4 if name == "GenAI" else 10 |
| self.dt = [] |
| for idx in random_idx: |
| idx_copy = idx*loop |
| for i in range(loop): |
| self.dt.append(dt[idx_copy]) |
| idx_copy += 1 |
| |
|
|
| self.tem_prefix = cfg["tem_feat_dir"] |
| self.spa_prefix = cfg["spa_feat_dir"] |
| self.t2i_prefix = cfg["t2i_feat_dir"] |
| self.vids_dir = cfg["vids_dir"] |
| self.frame_num = frame_num |
| self.datatype = datatype |
|
|
| def __len__(self): |
| return len(self.dt) |
|
|
| def __getitem__(self, idx): |
|
|
| item = self.dt[idx] |
| prmt = item["prompt"] |
| vid_path = item["video"] |
| vid = decord.VideoReader(os.path.join(self.vids_dir, vid_path)) |
| frame_len = len(vid) |
|
|
| score = torch.tensor([float(item["mos"])], dtype=torch.float32) |
|
|
| if frame_len < self.frame_num: |
| raise Exception('the vid is too short.') |
| spa_feature = torch.from_numpy( |
| np.load(os.path.join(self.spa_prefix, vid_path[:-4] + '.npy'))) |
| tem_feature = torch.from_numpy( |
| np.load(os.path.join(self.tem_prefix, vid_path[:-4] + '.npy'))) |
| t2i_feature = torch.from_numpy( |
| np.load(os.path.join(self.t2i_prefix, vid_path[:-4] + '.npy'))) |
|
|
| interval = frame_len // self.frame_num |
| select_idx = list(x for x in range(0, frame_len, interval))[ |
| :self.frame_num] |
| final_imgs = vid.get_batch(select_idx).numpy() |
| final_imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| for img in final_imgs] |
|
|
| |
| final_spa = torch.stack( |
| [spa_feature[i:i+5, :].view(-1) for i in select_idx]) |
| final_t2i = t2i_feature[select_idx, :, :].squeeze() |
|
|
| if interval < 32: |
| sample_len = interval |
| else: |
| sample_len = 32 |
|
|
| final_tem = [] |
| for i in range(self.frame_num): |
| start_idx = select_idx[i] |
| end_idx = min(start_idx+sample_len, frame_len) |
| |
| |
| sampled_tem = tem_feature[:, :, start_idx:end_idx, :, :] |
| sampled_tem = sampled_tem.mean(dim=2).squeeze() |
| final_tem.append(sampled_tem) |
| final_tem = torch.stack(final_tem) |
|
|
| return final_imgs, final_tem, final_spa, final_t2i, score, prmt |
|
|
|
|
| class Dataset3(data.Dataset): |
| def __init__(self, datatype, name, cfg, seed=42, frame_num=8): |
| super(Dataset3, self).__init__() |
|
|
| with open(cfg["mos_file"], "r") as f: |
| dt = json.load(f) |
|
|
| length = cfg["prompt_num"] |
| random.seed(seed) |
| np.random.seed(seed) |
| random_idx = np.random.permutation(length) |
| if datatype == "val": |
| random_idx = random_idx[int(length*0.7):int(length*0.8)] |
| if datatype == "test": |
| random_idx = random_idx[int(length*0.8):] |
| else: |
| random_idx = random_idx[:int(length*0.7)] |
|
|
| loop = 4 if name == "FETV" else 6 |
|
|
| self.dt = [] |
| for idx in random_idx: |
| idx_copy = idx*loop |
| for i in range(loop): |
| self.dt.append(dt[idx_copy]) |
| idx_copy += 1 |
|
|
| self.tem_prefix = cfg["tem_feat_dir"] |
| self.spa_prefix = cfg["spa_feat_dir"] |
| self.t2i_prefix = cfg["t2i_feat_dir"] |
| self.vids_dir = cfg["vids_dir"] |
| self.frame_num = frame_num |
|
|
| def __len__(self): |
| return len(self.dt) |
|
|
| def __getitem__(self, idx): |
|
|
| item = self.dt[idx] |
| prmt = item["prompt"] |
| vid_path = item["video"] |
| vid = decord.VideoReader(os.path.join(self.vids_dir, vid_path)) |
| frame_len = len(vid) |
|
|
| score = torch.tensor([float(item["tem"]), |
| float(item["spa"]), |
| float(item["ali"])], dtype=torch.float32) |
|
|
| if frame_len < self.frame_num: |
| raise Exception('the vid is too short.') |
| spa_feature = torch.from_numpy( |
| np.load(os.path.join(self.spa_prefix, vid_path[:-4] + '.npy'))) |
| tem_feature = torch.from_numpy( |
| np.load(os.path.join(self.tem_prefix, vid_path[:-4] + '.npy'))) |
| t2i_feature = torch.from_numpy( |
| np.load(os.path.join(self.t2i_prefix, vid_path[:-4] + '.npy'))) |
|
|
| interval = frame_len // self.frame_num |
| select_idx = list(x for x in range(0, frame_len, interval))[ |
| :self.frame_num] |
| final_imgs = vid.get_batch(select_idx).numpy() |
| final_imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| for img in final_imgs] |
|
|
| |
| final_spa = torch.stack( |
| [spa_feature[i:i+5, :].view(-1) for i in select_idx]) |
| final_t2i = t2i_feature[select_idx, :, :].squeeze() |
|
|
| if interval < 32: |
| sample_len = interval |
| else: |
| sample_len = 32 |
|
|
| final_tem = [] |
| for i in range(self.frame_num): |
| start_idx = select_idx[i] |
| end_idx = min(start_idx+sample_len, frame_len) |
| sampled_tem = tem_feature[:, :, start_idx:end_idx, :, :] |
| sampled_tem = sampled_tem.mean(dim=2).squeeze() |
| final_tem.append(sampled_tem) |
| final_tem = torch.stack(final_tem) |
| |
|
|
| return final_imgs, final_tem, final_spa, final_t2i, score, prmt |
| |
| |
| class Dataset5(data.Dataset): |
| def __init__(self, datatype, name, cfg, seed=42, frame_num=8): |
| super(Dataset5, self).__init__() |
|
|
| with open(cfg["mos_file"], "r") as f: |
| dt = json.load(f) |
|
|
| length = cfg["prompt_num"] |
| random.seed(seed) |
| np.random.seed(seed) |
| |
| random_idx = np.random.permutation(length) |
| if datatype == "val": |
| random_idx = random_idx[int(length*0.8):] |
| if datatype == "test": |
| random_idx = random_idx[int(length*0.8):] |
| else: |
| |
| random_idx = random_idx[:int(length*1.0)] |
|
|
| self.dt = [] |
| loop = 5 |
| for idx in random_idx: |
| idx_copy = idx*loop |
| for i in range(loop): |
| self.dt.append(dt[idx_copy]) |
| idx_copy += 1 |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| self.tem_prefix = cfg["tem_feat_dir"] |
| self.spa_prefix = cfg["spa_feat_dir"] |
| self.t2i_prefix = cfg["t2i_feat_dir"] |
| self.vids_dir = cfg["vids_dir"] |
| self.frame_num = frame_num |
|
|
| def __len__(self): |
| return len(self.dt) |
|
|
| def __getitem__(self, idx): |
| item = self.dt[idx] |
| prmt = item["prompt"] |
| vid_path = item["video"] |
| vid = decord.VideoReader(os.path.join(self.vids_dir, vid_path)) |
| frame_len = len(vid) |
|
|
| score = torch.tensor([float(item["overall_mos"]), |
| float(item["spatial_mos"]), |
| float(item["alignment_mos"]), |
| float(item["aesthetic_mos"]), |
| float(item["temporal_mos"])], dtype=torch.float32) |
|
|
| if frame_len < self.frame_num: |
| raise Exception('the vid is too short.') |
| |
| |
| tem_feature = torch.from_numpy( |
| np.load(os.path.join(self.tem_prefix, vid_path[:-4] + '.npy'))) |
| t2i_feature = torch.from_numpy( |
| np.load(os.path.join(self.t2i_prefix, vid_path[:-4] + '.npy'))) |
| |
|
|
| interval = frame_len // self.frame_num |
| select_idx = list(x for x in range(0, frame_len, interval))[ |
| :self.frame_num] |
| |
| |
| |
| final_imgs = [] |
| final_spa = [] |
| for i in select_idx: |
| |
| cur_frame = vid[i].numpy() |
| |
| cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB) |
| cur_frame = Image.fromarray(cur_frame) |
| final_imgs.append(cur_frame) |
| |
|
|
| |
| spatial_feat_name = os.path.join(self.spa_prefix, vid_path[:-4]) |
| cur_spa = torch.from_numpy(np.load(os.path.join(spatial_feat_name, f'{i}.npy'))).view(-1) |
| final_spa.append(cur_spa) |
|
|
| |
| |
| |
| final_spa = torch.stack(final_spa) |
| |
| |
| |
| |
| final_t2i = t2i_feature[select_idx, :, :].squeeze() |
|
|
| if interval < 32: |
| sample_len = interval |
| else: |
| sample_len = 32 |
|
|
| final_tem = [] |
| for i in range(self.frame_num): |
| start_idx = select_idx[i] |
| end_idx = min(start_idx+sample_len, frame_len) |
| sampled_tem = tem_feature[:, :, start_idx:end_idx, :, :] |
| sampled_tem = sampled_tem.mean(dim=2).squeeze() |
| final_tem.append(sampled_tem) |
| final_tem = torch.stack(final_tem) |
| |
|
|
| return final_imgs, final_tem, final_spa, final_t2i, score, prmt |
| |
|
|
|
|
| def custom_collate_fn(batch): |
|
|
| ds_nums = len(batch[0]) |
| new_batch = [[] for _ in range(ds_nums)] |
|
|
| for item in batch: |
| for i in range(ds_nums): |
| new_batch[i].append(item[i]) |
| batch = [] |
| for item in new_batch: |
| vids, t, s, i, score, prmt = zip(*item) |
| |
| t = torch.stack(t) |
| s = torch.stack(s) |
| i = torch.stack(i) |
| |
| score = torch.stack(score) |
| batch.append((vids, t, s, i, score, prmt)) |
| |
| return batch |
| |
| |
|
|
| def get_dataset(opt, seed): |
| print("current seed is: ", seed) |
| |
| cfg = opt["dataset"] |
| if cfg.get("GenAI") or cfg.get("T2VQA"): |
| dataname = "GenAI" if cfg.get("GenAI") else "T2VQA" |
| cfg = cfg[dataname] |
| trainset = Dataset1('train', dataname, cfg, seed=seed) |
| train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt["train_batch_size"], |
| shuffle=True, num_workers=opt["num_workers"], drop_last=True, collate_fn=custom_collate_fn) |
|
|
| valset = Dataset1('val', dataname, cfg, seed=seed) |
| val_loader = torch.utils.data.DataLoader(valset, |
| batch_size=1, shuffle=False, num_workers=0) |
|
|
| testset = Dataset1('test', dataname, cfg, seed=seed) |
| test_loader = torch.utils.data.DataLoader(testset, |
| batch_size=1, shuffle=False, num_workers=0) |
|
|
| elif cfg.get("FETV") or cfg.get("LGVQ"): |
| dataname = "FETV" if cfg.get("FETV") else "LGVQ" |
| cfg = cfg[dataname] |
| trainset = Dataset3('train', dataname, cfg, seed=seed) |
| train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt["train_batch_size"], |
| shuffle=True, num_workers=opt["num_workers"], drop_last=True, collate_fn=custom_collate_fn) |
|
|
| valset = Dataset3('val', dataname, cfg, seed=seed) |
| val_loader = torch.utils.data.DataLoader(valset, |
| batch_size=1, shuffle=False, num_workers=0) |
|
|
| testset = Dataset3('test', dataname, cfg, seed=seed) |
| test_loader = torch.utils.data.DataLoader(testset, |
| batch_size=1, shuffle=False, num_workers=0) |
| |
| elif cfg.get("ICCV"): |
| dataname = "ICCV" |
| cfg = cfg[dataname] |
| trainset = Dataset5('train', dataname, cfg, seed=seed) |
| train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt["train_batch_size"], |
| shuffle=True, num_workers=opt["num_workers"], drop_last=True, collate_fn=custom_collate_fn) |
|
|
| valset = Dataset5('val', dataname, cfg, seed=seed) |
| val_loader = torch.utils.data.DataLoader(valset, |
| batch_size=1, shuffle=False, num_workers=0) |
|
|
| testset = Dataset5('test', dataname, cfg, seed=seed) |
| test_loader = torch.utils.data.DataLoader(testset, |
| batch_size=1, shuffle=False, num_workers=0) |
|
|
| return train_loader, val_loader, test_loader |
|
|
|
|
| def get_dataset_v2(opt, seed): |
| print("current seed is: ", seed) |
| cfg = opt["dataset"] |
| trainset = VQAdataset(None, cfg, "train", seed) |
|
|
| train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt["train_batch_size"], |
| shuffle=True, num_workers=opt["num_workers"], drop_last=True, collate_fn=custom_collate_fn) |
|
|
| val_loader = dict() |
| test_loader = dict() |
|
|
| for key in cfg: |
| curDs = VQAdataset(key, cfg, "val", seed) |
| val_loader[key] = torch.utils.data.DataLoader( |
| curDs, batch_size=1, shuffle=False, num_workers=opt["num_workers"], collate_fn=custom_collate_fn) |
|
|
| for key in cfg: |
| curDs = VQAdataset(key, cfg, "test", seed) |
| test_loader[key] = torch.utils.data.DataLoader( |
| curDs, batch_size=1, shuffle=False, num_workers=opt["num_workers"], collate_fn=custom_collate_fn) |
|
|
| return train_loader, val_loader, test_loader |
|
|
|
|
|
|
| class VQAdataset(data.Dataset): |
| def __init__(self, ds_name, args, strategy, seed): |
| super(VQAdataset, self).__init__() |
|
|
| dataset = dict() |
| if ds_name == None: |
| for key in args: |
| if args[key]["mos_num"] == 3: |
| curDs = Dataset3(strategy, key, args[key], seed) |
| elif args[key]["mos_num"] == 1: |
| curDs = Dataset1(strategy, key, args[key], seed) |
| else: |
| curDs = Dataset5(strategy, key, args[key], seed) |
| dataset[key] = curDs |
| else: |
| if args[ds_name]["mos_num"] == 3: |
| curDs = Dataset3(strategy, ds_name, args[ds_name], seed) |
| elif args[ds_name]["mos_num"] == 1: |
| curDs = Dataset1(strategy, ds_name, args[ds_name], seed) |
| else: |
| curDs = Dataset5(strategy, ds_name, args[ds_name], seed) |
| dataset[ds_name] = curDs |
| self.dataset = dataset |
|
|
| def __len__(self): |
| return max(len(v) for k, v in self.dataset.items()) |
|
|
| def __getitem__(self, idx): |
| return_list = [] |
| |
| for key, dataset in self.dataset.items(): |
| length = len(dataset) |
| return_list.append(dataset[idx % length]) |
| return return_list |
|
|
| |
|
|
| def custom_collate_fn2(batch): |
|
|
| ds_nums = len(batch[0]) |
| new_batch = [[] for _ in range(ds_nums)] |
|
|
| for item in batch: |
| for i in range(ds_nums): |
| new_batch[i].append(item[i]) |
| batch = [] |
| for item in new_batch: |
| vids, t, s, s2, i, score, prmt, vid_path = zip(*item) |
| |
| t = torch.stack(t) |
| s = torch.stack(s) |
| s2 = torch.stack(s2) |
| i = torch.stack(i) |
| |
| score = torch.stack(score) |
| |
| |
| batch.append((vids, t, s, s2, i, score, prmt, vid_path)) |
| |
| return batch |
|
|
| |
| class Dataset5_submission(data.Dataset): |
| def __init__(self, datatype, name, cfg, seed=42, frame_num=8): |
| super(Dataset5_submission, self).__init__() |
|
|
| with open(cfg["mos_file"], "r") as f: |
| dt = json.load(f) |
|
|
| length = cfg["prompt_num"] |
| index_rd = np.arange(0, length) |
| test_index = index_rd |
| |
| self.dt = [] |
| loop = 1 |
| for idx in test_index: |
| idx_copy = idx*loop |
| for i in range(loop): |
| self.dt.append(dt[idx_copy]) |
| idx_copy += 1 |
| |
| self.tem_prefix = cfg["tem_feat_dir"] |
| self.spa_prefix = cfg["spa_feat_dir"] |
| self.spa2_prefix = cfg["spa_feat_dir2"] |
| self.t2i_prefix = cfg["t2i_feat_dir"] |
| self.vids_dir = cfg["vids_dir"] |
| self.frame_num = frame_num |
|
|
| def __len__(self): |
| return len(self.dt) |
|
|
| def __getitem__(self, idx): |
| item = self.dt[idx] |
| prmt = item["prompt"] |
| vid_path = item["video"] |
| vid = decord.VideoReader(os.path.join(self.vids_dir, vid_path)) |
| frame_len = len(vid) |
|
|
| score = torch.tensor([float(item["overall_mos"]), |
| float(item["spatial_mos"]), |
| float(item["alignment_mos"]), |
| float(item["aesthetic_mos"]), |
| float(item["temporal_mos"])], dtype=torch.float32) |
|
|
| if frame_len < self.frame_num: |
| raise Exception('the vid is too short.') |
| |
| |
| tem_feature = torch.from_numpy( |
| np.load(os.path.join(self.tem_prefix, vid_path[:-4] + '.npy'))) |
| t2i_feature = torch.from_numpy( |
| np.load(os.path.join(self.t2i_prefix, vid_path[:-4] + '.npy'))) |
| |
|
|
| interval = frame_len // self.frame_num |
| select_idx = list(x for x in range(0, frame_len, interval))[ |
| :self.frame_num] |
| |
| |
| |
| final_imgs = [] |
| final_spa = [] |
| final_spa2 = [] |
| for i in select_idx: |
| |
| cur_frame = vid[i].numpy() |
| |
| cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB) |
| cur_frame = Image.fromarray(cur_frame) |
| final_imgs.append(cur_frame) |
| |
|
|
| |
| spatial_feat_name = os.path.join(self.spa_prefix, vid_path[:-4]) |
| cur_spa = torch.from_numpy(np.load(os.path.join(spatial_feat_name, f'{i}.npy'))).view(-1) |
| final_spa.append(cur_spa) |
| |
| |
| spatial_feat_name2 = os.path.join(self.spa2_prefix, vid_path[:-4]) |
| cur_spa2 = torch.from_numpy(np.load(os.path.join(spatial_feat_name2, f'{i}.npy'))).view(-1) |
| final_spa2.append(cur_spa2) |
| |
|
|
| |
| |
| |
| final_spa = torch.stack(final_spa) |
| final_spa2 = torch.stack(final_spa2) |
| |
| |
| |
| |
| final_t2i = t2i_feature[select_idx, :, :].squeeze() |
|
|
| if interval < 32: |
| sample_len = interval |
| else: |
| sample_len = 32 |
|
|
| final_tem = [] |
| for i in range(self.frame_num): |
| start_idx = select_idx[i] |
| end_idx = min(start_idx+sample_len, frame_len) |
| sampled_tem = tem_feature[:, :, start_idx:end_idx, :, :] |
| sampled_tem = sampled_tem.mean(dim=2).squeeze() |
| final_tem.append(sampled_tem) |
| final_tem = torch.stack(final_tem) |
| |
|
|
| return final_imgs, final_tem, final_spa, final_spa2, final_t2i, score, prmt, vid_path |
| |
| |
| def get_dataset_v3(opt, seed): |
| print("current seed is: ", seed) |
| cfg = opt["dataset"] |
| val_loader = dict() |
| for key in cfg: |
| curDs = VQAdataset_submission(key, cfg, "val", seed) |
| val_loader[key] = torch.utils.data.DataLoader( |
| curDs, batch_size=1, shuffle=False, num_workers=opt["num_workers"], collate_fn=custom_collate_fn2) |
|
|
| return val_loader |
| |
| class VQAdataset_submission(data.Dataset): |
| def __init__(self, ds_name, args, strategy, seed): |
| super(VQAdataset_submission, self).__init__() |
|
|
| dataset = dict() |
| curDs = Dataset5_submission(strategy, ds_name, args[ds_name], seed) |
| dataset[ds_name] = curDs |
| self.dataset = dataset |
|
|
| def __len__(self): |
| return max(len(v) for k, v in self.dataset.items()) |
|
|
| def __getitem__(self, idx): |
| return_list = [] |
| |
| for key, dataset in self.dataset.items(): |
| length = len(dataset) |
| return_list.append(dataset[idx % length]) |
| return return_list |
|
|