VQualA_GenAI_track2 / feat_extract /all_dataloader.py
zwx8981's picture
Upload 493 files
a6bc892 verified
import torch
from torch.utils import data
import numpy as np
from torchvision import transforms
import cv2
import glob
import json
import os
from PIL import Image
import random
import decord
decord.bridge.set_bridge('torch')
class t2i_Dataset(data.Dataset):
def __init__(self,vids_dir, json_path, vis_processor):
super(t2i_Dataset).__init__()
with open(json_path, "r") as f:
self.dt=json.load(f)
self.vids_dir=vids_dir
self.vis_processor = vis_processor
def __len__(self):
return len(self.dt)
def __getitem__(self, idx):
item = self.dt[idx]
prompt = item["prompt"]
filename = os.path.join(self.vids_dir,item["video"])
vid_nm = filename.split('/')[-1][:-4]
cap = cv2.VideoCapture(filename)
video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_frame_rate = int(round(cap.get(cv2.CAP_PROP_FPS)))
video_length_round = video_length if video_length % 8 == 0 else (
video_length//8 + 1)*8
if video_frame_rate == 0:
raise Exception('no frame detect')
transformed_frame_all = []
for i in range(video_length):
has_frames, frame = cap.read()
if not has_frames:
raise Exception('no frame detect')
read_frame = Image.fromarray(
cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
image = self.vis_processor["eval"](read_frame)
transformed_frame_all.append(image)
cap.release()
transformed_frame_all=torch.stack(transformed_frame_all).half()
# 0-39 40
# 0-32 33 || 33: == 32
if video_length_round != video_length:
transformed_frame_all[video_length:,
:] = transformed_frame_all[video_length-1, :]
# print('len: ', transformed_frame_all.shape[0],
# ' fr: ', video_frame_rate, ' name: ', vid_nm)
return transformed_frame_all, vid_nm, prompt
class tem_Dataset(data.Dataset):
"""Read data from the original dataset for feature extraction"""
def __init__(self,
data_dir, transform, json_path=None, resize=224):
super(tem_Dataset, self).__init__()
if json_path is not None:
with open(json_path, "r") as f:
self.dt = json.load(f)
self.video_names = [os.path.join(
data_dir, item["video"]) for item in self.dt]
else:
self.video_names = glob.glob(f'{data_dir}/*.mp4')
self.json_path = json_path
self.videos_dir = data_dir
self.transform = transform
self.resize = resize
def __len__(self):
return len(self.video_names)
def __getitem__(self, idx):
filename = self.video_names[idx]
if self.json_path is None:
vid_nm = filename.split('/')[-1][:-4]
else:
vid_nm = self.dt[idx]["video"][:-4]
cap = cv2.VideoCapture(filename)
video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_frame_rate = int(round(cap.get(cv2.CAP_PROP_FPS)))
video_length_round = video_length if video_length % 8 == 0 else (
video_length//8 + 1)*8
if video_frame_rate == 0:
raise Exception('no frame detect')
video_channel = 3
transformed_frame_all = torch.zeros(
[video_length_round, video_channel, self.resize, self.resize])
for i in range(video_length):
has_frames, frame = cap.read()
if not has_frames:
raise Exception('no frame detect')
read_frame = Image.fromarray(
cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
read_frame = self.transform(read_frame)
# print(read_frame.shape)
transformed_frame_all[i] = read_frame
cap.release()
# 0-39 40
# 0-32 33 || 33: == 32
if video_length_round != video_length:
transformed_frame_all[video_length:,
:] = transformed_frame_all[video_length-1, :]
# print('len: ', transformed_frame_all.shape[0],
# ' fr: ', video_frame_rate, ' name: ', vid_nm)
return transformed_frame_all, vid_nm
def pyramidsGL(image, num_levels, dim=224):
''' Creates Gaussian (G) and Laplacian (L) pyramids of level "num_levels" from image im.
G and L are list where G[i], L[i] stores the i-th level of Gaussian and Laplacian pyramid, respectively. '''
o_height, o_width = image.shape[:2]
# Calculate target dimensions while maintaining aspect ratio
if o_width > o_height:
f_height = dim
f_width = int((o_width * f_height) / o_height)
elif o_height > o_width:
f_width = dim
f_height = int((o_height * f_width) / o_width)
else:
f_width = f_height = dim
# Generate dimension lists for each level
if o_width > (dim + num_levels) and o_height > (dim + num_levels):
height_step = int((o_height - f_height) / (num_levels - 1)) * (-1)
width_step = int((o_width - f_width) / (num_levels - 1)) * (-1)
height_list = list(range(o_height, f_height - 1, height_step))
width_list = list(range(o_width, f_width - 1, width_step))
else:
# If dimensions are already close to target, just resize once
image = cv2.resize(image, (f_width, f_height),
interpolation=cv2.INTER_CUBIC)
height_list = [f_height] * num_levels
width_list = [f_width] * num_levels
# Initialize pyramids
gaussian_pyramid = [image.copy()]
laplacian_pyramid = []
# Create pyramids
for i in range(num_levels - 1):
# Apply Gaussian blur
blur = cv2.GaussianBlur(gaussian_pyramid[i], (5, 5), 5)
# Create next level
next_level = cv2.resize(blur, (width_list[i + 1], height_list[i + 1]),
interpolation=cv2.INTER_CUBIC)
gaussian_pyramid.append(next_level)
# Create Laplacian level
upsampled = cv2.resize(blur, (width_list[i], height_list[i]),
interpolation=cv2.INTER_CUBIC)
laplacian = cv2.subtract(gaussian_pyramid[i], upsampled)
laplacian_pyramid.append(laplacian)
return laplacian_pyramid
def resizedpyramids(laplacian_pyramid, num_levels, width, height):
'''Resize all levels of the Laplacian pyramid to the specified dimensions'''
return [cv2.resize(level, (width, height), interpolation=cv2.INTER_CUBIC)
for level in laplacian_pyramid]
class spa_Dataset(data.Dataset):
"""Read data from the original dataset for feature extraction"""
def __init__(self, vids_dir, num_levels=6, json_path=None):
super(spa_Dataset, self).__init__()
self.json_path = json_path
if self.json_path is not None:
with open(self.json_path, "r") as f:
dt = json.load(f)
# split dt into 3 parts
# dt = dt[:len(dt)//3] +
# dt = dt[len(dt)//3:2*len(dt)//3]
# dt=dt[2*len(dt)//3:]
self.vids_dir = [os.path.join(
vids_dir, item["video"]) for item in dt]
self.dt = dt
else:
self.vids_dir = glob.glob(f'{vids_dir}/*.mp4')
self.num_levels = num_levels
def __len__(self):
return len((self.vids_dir))
def __getitem__(self, idx):
vid_path = self.vids_dir[idx]
if self.json_path is not None:
vid_name = self.dt[idx]["video"][:-4]
else:
vid_name = vid_path.split('/')[-1][:-4]
frames = decord.VideoReader(vid_path)
final_vid = []
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Resize((224, 224)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])
print(frames)
for img in frames:
img = img.numpy()
video_height = img.shape[0]
video_width = img.shape[1]
#img = cv2.resize(img, (512, 512))
laplacian_pyramid = pyramidsGL(img, self.num_levels)
laplacian_pyramid_resized = resizedpyramids(
laplacian_pyramid, self.num_levels, video_height, video_width)
laplacian_pyramid_resized = np.array(laplacian_pyramid_resized)
for frame in laplacian_pyramid_resized:
lp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# change lp to tensor
#lp = torch.from_numpy(lp)
lp = transform(lp)
final_vid.append(lp)
final_vid = torch.stack(final_vid)
return final_vid, vid_name
class Dataset1(data.Dataset):
def __init__(self, datatype, name, cfg, seed=42, frame_num=8):
super(Dataset1, self).__init__()
with open(cfg["mos_file"], "r") as f:
dt = json.load(f)
# self.dt = dt
length = cfg["prompt_num"]
random.seed(seed)
np.random.seed(seed)
random_idx = np.random.permutation(length)
if datatype == "val":
random_idx = random_idx[int(length*0.7):int(length*0.8)]
if datatype == "test":
random_idx = random_idx[int(length*0.8):]
else:
random_idx = random_idx[:int(length*0.7)]
# self.dt = [dt[i] for i in random_idx]
loop = 4 if name == "GenAI" else 10
self.dt = []
for idx in random_idx:
idx_copy = idx*loop
for i in range(loop):
self.dt.append(dt[idx_copy])
idx_copy += 1
self.tem_prefix = cfg["tem_feat_dir"]
self.spa_prefix = cfg["spa_feat_dir"]
self.t2i_prefix = cfg["t2i_feat_dir"]
self.vids_dir = cfg["vids_dir"]
self.frame_num = frame_num
self.datatype = datatype
def __len__(self):
return len(self.dt)
def __getitem__(self, idx):
item = self.dt[idx]
prmt = item["prompt"]
vid_path = item["video"]
vid = decord.VideoReader(os.path.join(self.vids_dir, vid_path))
frame_len = len(vid)
score = torch.tensor([float(item["mos"])], dtype=torch.float32)
if frame_len < self.frame_num:
raise Exception('the vid is too short.')
spa_feature = torch.from_numpy(
np.load(os.path.join(self.spa_prefix, vid_path[:-4] + '.npy')))
tem_feature = torch.from_numpy(
np.load(os.path.join(self.tem_prefix, vid_path[:-4] + '.npy')))
t2i_feature = torch.from_numpy(
np.load(os.path.join(self.t2i_prefix, vid_path[:-4] + '.npy')))
interval = frame_len // self.frame_num
select_idx = list(x for x in range(0, frame_len, interval))[
:self.frame_num]
final_imgs = vid.get_batch(select_idx).numpy()
final_imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for img in final_imgs]
# as we sampled spatial feat using a laplacian pyramid, and the level is 5
final_spa = torch.stack(
[spa_feature[i:i+5, :].view(-1) for i in select_idx])
final_t2i = t2i_feature[select_idx, :, :].squeeze()
if interval < 32:
sample_len = interval
else:
sample_len = 32
final_tem = []
for i in range(self.frame_num):
start_idx = select_idx[i]
end_idx = min(start_idx+sample_len, frame_len)
# sampled_tem = tem_feature[:, start_idx:end_idx, :]
# sampled_tem = sampled_tem.mean(dim=1).squeeze()
sampled_tem = tem_feature[:, :, start_idx:end_idx, :, :]
sampled_tem = sampled_tem.mean(dim=2).squeeze()
final_tem.append(sampled_tem)
final_tem = torch.stack(final_tem)
return final_imgs, final_tem, final_spa, final_t2i, score, prmt
class Dataset3(data.Dataset):
def __init__(self, datatype, name, cfg, seed=42, frame_num=8):
super(Dataset3, self).__init__()
with open(cfg["mos_file"], "r") as f:
dt = json.load(f)
length = cfg["prompt_num"]
random.seed(seed)
np.random.seed(seed)
random_idx = np.random.permutation(length)
if datatype == "val":
random_idx = random_idx[int(length*0.7):int(length*0.8)]
if datatype == "test":
random_idx = random_idx[int(length*0.8):]
else:
random_idx = random_idx[:int(length*0.7)]
loop = 4 if name == "FETV" else 6
self.dt = []
for idx in random_idx:
idx_copy = idx*loop
for i in range(loop):
self.dt.append(dt[idx_copy])
idx_copy += 1
self.tem_prefix = cfg["tem_feat_dir"]
self.spa_prefix = cfg["spa_feat_dir"]
self.t2i_prefix = cfg["t2i_feat_dir"]
self.vids_dir = cfg["vids_dir"]
self.frame_num = frame_num
def __len__(self):
return len(self.dt)
def __getitem__(self, idx):
item = self.dt[idx]
prmt = item["prompt"]
vid_path = item["video"]
vid = decord.VideoReader(os.path.join(self.vids_dir, vid_path))
frame_len = len(vid)
score = torch.tensor([float(item["tem"]),
float(item["spa"]),
float(item["ali"])], dtype=torch.float32)
if frame_len < self.frame_num:
raise Exception('the vid is too short.')
spa_feature = torch.from_numpy(
np.load(os.path.join(self.spa_prefix, vid_path[:-4] + '.npy')))
tem_feature = torch.from_numpy(
np.load(os.path.join(self.tem_prefix, vid_path[:-4] + '.npy')))
t2i_feature = torch.from_numpy(
np.load(os.path.join(self.t2i_prefix, vid_path[:-4] + '.npy')))
interval = frame_len // self.frame_num
select_idx = list(x for x in range(0, frame_len, interval))[
:self.frame_num]
final_imgs = vid.get_batch(select_idx).numpy()
final_imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for img in final_imgs]
# as we sampled spatial feat using a laplacian pyramid, and the level is 5
final_spa = torch.stack(
[spa_feature[i:i+5, :].view(-1) for i in select_idx])
final_t2i = t2i_feature[select_idx, :, :].squeeze()
if interval < 32:
sample_len = interval
else:
sample_len = 32
final_tem = []
for i in range(self.frame_num):
start_idx = select_idx[i]
end_idx = min(start_idx+sample_len, frame_len)
sampled_tem = tem_feature[:, :, start_idx:end_idx, :, :]
sampled_tem = sampled_tem.mean(dim=2).squeeze()
final_tem.append(sampled_tem)
final_tem = torch.stack(final_tem)
# breakpoint()
return final_imgs, final_tem, final_spa, final_t2i, score, prmt
class Dataset5(data.Dataset):
def __init__(self, datatype, name, cfg, seed=42, frame_num=8):
super(Dataset5, self).__init__()
with open(cfg["mos_file"], "r") as f:
dt = json.load(f)
length = cfg["prompt_num"]
random.seed(seed)
np.random.seed(seed)
random_idx = np.random.permutation(length)
if datatype == "val":
random_idx = random_idx[int(length*0.8):]
if datatype == "test":
random_idx = random_idx[int(length*0.8):]
else:
#random_idx = random_idx[:int(length*0.8)]
random_idx = random_idx[:int(length*1.0)]
self.dt = []
loop = 5
for idx in random_idx:
idx_copy = idx*loop
for i in range(loop):
self.dt.append(dt[idx_copy])
idx_copy += 1
# all_indices = np.arange(length)
# if datatype == "val":
# filtered = all_indices[int(length*0.7):int(length*0.8)]
# elif datatype == "test":
# filtered = all_indices[int(length*0.8):]
# else:
# filtered = all_indices[:int(length*0.7)]
# random_idx = int(np.random.choice(filtered))
# self.dt = dt[random_idx]
self.tem_prefix = cfg["tem_feat_dir"]
self.spa_prefix = cfg["spa_feat_dir"]
self.t2i_prefix = cfg["t2i_feat_dir"]
self.vids_dir = cfg["vids_dir"]
self.frame_num = frame_num
def __len__(self):
return len(self.dt)
def __getitem__(self, idx):
item = self.dt[idx]
prmt = item["prompt"]
vid_path = item["video"]
vid = decord.VideoReader(os.path.join(self.vids_dir, vid_path))
frame_len = len(vid)
score = torch.tensor([float(item["overall_mos"]),
float(item["spatial_mos"]),
float(item["alignment_mos"]),
float(item["aesthetic_mos"]),
float(item["temporal_mos"])], dtype=torch.float32)
if frame_len < self.frame_num:
raise Exception('the vid is too short.')
# spa_feature = torch.from_numpy(
# np.load(os.path.join(self.spa_prefix, vid_path[:-4] + '.npy')))
tem_feature = torch.from_numpy(
np.load(os.path.join(self.tem_prefix, vid_path[:-4] + '.npy')))
t2i_feature = torch.from_numpy(
np.load(os.path.join(self.t2i_prefix, vid_path[:-4] + '.npy')))
interval = frame_len // self.frame_num
select_idx = list(x for x in range(0, frame_len, interval))[
:self.frame_num]
# final_imgs = vid.get_batch(select_idx).numpy()
# final_imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# for img in final_imgs]
final_imgs = []
final_spa = []
for i in select_idx:
#cur_frame = vid[i].asnumpy()
cur_frame = vid[i].numpy()
# ======
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
cur_frame = Image.fromarray(cur_frame)
final_imgs.append(cur_frame)
# ======
# final_imgs.append(torch.from_numpy(cur_frame))
spatial_feat_name = os.path.join(self.spa_prefix, vid_path[:-4])
cur_spa = torch.from_numpy(np.load(os.path.join(spatial_feat_name, f'{i}.npy'))).view(-1)
final_spa.append(cur_spa)
# self.frame_num * 3 * 224 * 224
#final_imgs = torch.stack(final_imgs)
# self.frame_num * (256*5)
final_spa = torch.stack(final_spa)
# # as we sampled spatial feat using a laplacian pyramid, and the level is 5
# final_spa = torch.stack(
# [spa_feature[i:i+5, :].view(-1) for i in select_idx])
final_t2i = t2i_feature[select_idx, :, :].squeeze()
if interval < 32:
sample_len = interval
else:
sample_len = 32
final_tem = []
for i in range(self.frame_num):
start_idx = select_idx[i]
end_idx = min(start_idx+sample_len, frame_len)
sampled_tem = tem_feature[:, :, start_idx:end_idx, :, :]
sampled_tem = sampled_tem.mean(dim=2).squeeze()
final_tem.append(sampled_tem)
final_tem = torch.stack(final_tem)
# breakpoint()
return final_imgs, final_tem, final_spa, final_t2i, score, prmt
def custom_collate_fn(batch):
ds_nums = len(batch[0])
new_batch = [[] for _ in range(ds_nums)]
for item in batch:
for i in range(ds_nums):
new_batch[i].append(item[i])
batch = []
for item in new_batch:
vids, t, s, i, score, prmt = zip(*item)
# breakpoint()
t = torch.stack(t)
s = torch.stack(s)
i = torch.stack(i)
# # # breakpoint()
score = torch.stack(score)
batch.append((vids, t, s, i, score, prmt))
# breakpoint()
return batch
# return vids, t, s, i, score, prmt
def get_dataset(opt, seed):
print("current seed is: ", seed)
# breakpoint()
cfg = opt["dataset"]
if cfg.get("GenAI") or cfg.get("T2VQA"):
dataname = "GenAI" if cfg.get("GenAI") else "T2VQA"
cfg = cfg[dataname]
trainset = Dataset1('train', dataname, cfg, seed=seed)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt["train_batch_size"],
shuffle=True, num_workers=opt["num_workers"], drop_last=True, collate_fn=custom_collate_fn)
valset = Dataset1('val', dataname, cfg, seed=seed)
val_loader = torch.utils.data.DataLoader(valset,
batch_size=1, shuffle=False, num_workers=0)
testset = Dataset1('test', dataname, cfg, seed=seed)
test_loader = torch.utils.data.DataLoader(testset,
batch_size=1, shuffle=False, num_workers=0)
elif cfg.get("FETV") or cfg.get("LGVQ"):
dataname = "FETV" if cfg.get("FETV") else "LGVQ"
cfg = cfg[dataname]
trainset = Dataset3('train', dataname, cfg, seed=seed)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt["train_batch_size"],
shuffle=True, num_workers=opt["num_workers"], drop_last=True, collate_fn=custom_collate_fn)
valset = Dataset3('val', dataname, cfg, seed=seed)
val_loader = torch.utils.data.DataLoader(valset,
batch_size=1, shuffle=False, num_workers=0)
testset = Dataset3('test', dataname, cfg, seed=seed)
test_loader = torch.utils.data.DataLoader(testset,
batch_size=1, shuffle=False, num_workers=0)
elif cfg.get("ICCV"):
dataname = "ICCV"
cfg = cfg[dataname]
trainset = Dataset5('train', dataname, cfg, seed=seed)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt["train_batch_size"],
shuffle=True, num_workers=opt["num_workers"], drop_last=True, collate_fn=custom_collate_fn)
valset = Dataset5('val', dataname, cfg, seed=seed)
val_loader = torch.utils.data.DataLoader(valset,
batch_size=1, shuffle=False, num_workers=0)
testset = Dataset5('test', dataname, cfg, seed=seed)
test_loader = torch.utils.data.DataLoader(testset,
batch_size=1, shuffle=False, num_workers=0)
return train_loader, val_loader, test_loader
def get_dataset_v2(opt, seed):
print("current seed is: ", seed)
cfg = opt["dataset"]
trainset = VQAdataset(None, cfg, "train", seed)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt["train_batch_size"],
shuffle=True, num_workers=opt["num_workers"], drop_last=True, collate_fn=custom_collate_fn)
val_loader = dict()
test_loader = dict()
for key in cfg:
curDs = VQAdataset(key, cfg, "val", seed)
val_loader[key] = torch.utils.data.DataLoader(
curDs, batch_size=1, shuffle=False, num_workers=opt["num_workers"], collate_fn=custom_collate_fn)
for key in cfg:
curDs = VQAdataset(key, cfg, "test", seed)
test_loader[key] = torch.utils.data.DataLoader(
curDs, batch_size=1, shuffle=False, num_workers=opt["num_workers"], collate_fn=custom_collate_fn)
return train_loader, val_loader, test_loader
class VQAdataset(data.Dataset):
def __init__(self, ds_name, args, strategy, seed):
super(VQAdataset, self).__init__()
dataset = dict()
if ds_name == None:
for key in args:
if args[key]["mos_num"] == 3:
curDs = Dataset3(strategy, key, args[key], seed)
elif args[key]["mos_num"] == 1:
curDs = Dataset1(strategy, key, args[key], seed)
else:
curDs = Dataset5(strategy, key, args[key], seed)
dataset[key] = curDs
else:
if args[ds_name]["mos_num"] == 3:
curDs = Dataset3(strategy, ds_name, args[ds_name], seed)
elif args[ds_name]["mos_num"] == 1:
curDs = Dataset1(strategy, ds_name, args[ds_name], seed)
else:
curDs = Dataset5(strategy, ds_name, args[ds_name], seed)
dataset[ds_name] = curDs
self.dataset = dataset
def __len__(self):
return max(len(v) for k, v in self.dataset.items())
def __getitem__(self, idx):
return_list = []
# breakpoint()
for key, dataset in self.dataset.items():
length = len(dataset)
return_list.append(dataset[idx % length])
return return_list
#############################
def custom_collate_fn2(batch):
ds_nums = len(batch[0])
new_batch = [[] for _ in range(ds_nums)]
for item in batch:
for i in range(ds_nums):
new_batch[i].append(item[i])
batch = []
for item in new_batch:
vids, t, s, s2, i, score, prmt, vid_path = zip(*item)
# breakpoint()
t = torch.stack(t)
s = torch.stack(s)
s2 = torch.stack(s2)
i = torch.stack(i)
# # # breakpoint()
score = torch.stack(score)
batch.append((vids, t, s, s2, i, score, prmt, vid_path))
# breakpoint()
return batch
class Dataset5_submission(data.Dataset):
def __init__(self, datatype, name, cfg, seed=42, frame_num=8):
super(Dataset5_submission, self).__init__()
with open(cfg["mos_file"], "r") as f:
dt = json.load(f)
length = cfg["prompt_num"]
index_rd = np.arange(0, length)
test_index = index_rd
self.dt = []
loop = 1
for idx in test_index:
idx_copy = idx*loop
for i in range(loop):
self.dt.append(dt[idx_copy])
idx_copy += 1
self.tem_prefix = cfg["tem_feat_dir"]
self.spa_prefix = cfg["spa_feat_dir"]
self.spa2_prefix = cfg["spa_feat_dir2"]
self.t2i_prefix = cfg["t2i_feat_dir"]
self.vids_dir = cfg["vids_dir"]
self.frame_num = frame_num
def __len__(self):
return len(self.dt)
def __getitem__(self, idx):
item = self.dt[idx]
prmt = item["prompt"]
vid_path = item["video"]
vid = decord.VideoReader(os.path.join(self.vids_dir, vid_path))
frame_len = len(vid)
score = torch.tensor([float(item["overall_mos"]),
float(item["spatial_mos"]),
float(item["alignment_mos"]),
float(item["aesthetic_mos"]),
float(item["temporal_mos"])], dtype=torch.float32)
if frame_len < self.frame_num:
raise Exception('the vid is too short.')
# spa_feature = torch.from_numpy(
# np.load(os.path.join(self.spa_prefix, vid_path[:-4] + '.npy')))
tem_feature = torch.from_numpy(
np.load(os.path.join(self.tem_prefix, vid_path[:-4] + '.npy')))
t2i_feature = torch.from_numpy(
np.load(os.path.join(self.t2i_prefix, vid_path[:-4] + '.npy')))
interval = frame_len // self.frame_num
select_idx = list(x for x in range(0, frame_len, interval))[
:self.frame_num]
# final_imgs = vid.get_batch(select_idx).numpy()
# final_imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# for img in final_imgs]
final_imgs = []
final_spa = []
final_spa2 = []
for i in select_idx:
#cur_frame = vid[i].asnumpy()
cur_frame = vid[i].numpy()
# ======
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
cur_frame = Image.fromarray(cur_frame)
final_imgs.append(cur_frame)
# ======
# final_imgs.append(torch.from_numpy(cur_frame))
spatial_feat_name = os.path.join(self.spa_prefix, vid_path[:-4])
cur_spa = torch.from_numpy(np.load(os.path.join(spatial_feat_name, f'{i}.npy'))).view(-1)
final_spa.append(cur_spa)
spatial_feat_name2 = os.path.join(self.spa2_prefix, vid_path[:-4])
cur_spa2 = torch.from_numpy(np.load(os.path.join(spatial_feat_name2, f'{i}.npy'))).view(-1)
final_spa2.append(cur_spa2)
# self.frame_num * 3 * 224 * 224
#final_imgs = torch.stack(final_imgs)
# self.frame_num * (256*5)
final_spa = torch.stack(final_spa)
final_spa2 = torch.stack(final_spa2)
# # as we sampled spatial feat using a laplacian pyramid, and the level is 5
# final_spa = torch.stack(
# [spa_feature[i:i+5, :].view(-1) for i in select_idx])
final_t2i = t2i_feature[select_idx, :, :].squeeze()
if interval < 32:
sample_len = interval
else:
sample_len = 32
final_tem = []
for i in range(self.frame_num):
start_idx = select_idx[i]
end_idx = min(start_idx+sample_len, frame_len)
sampled_tem = tem_feature[:, :, start_idx:end_idx, :, :]
sampled_tem = sampled_tem.mean(dim=2).squeeze()
final_tem.append(sampled_tem)
final_tem = torch.stack(final_tem)
# breakpoint()
return final_imgs, final_tem, final_spa, final_spa2, final_t2i, score, prmt, vid_path
def get_dataset_v3(opt, seed):
print("current seed is: ", seed)
cfg = opt["dataset"]
val_loader = dict()
for key in cfg:
curDs = VQAdataset_submission(key, cfg, "val", seed)
val_loader[key] = torch.utils.data.DataLoader(
curDs, batch_size=1, shuffle=False, num_workers=opt["num_workers"], collate_fn=custom_collate_fn2)
return val_loader
class VQAdataset_submission(data.Dataset):
def __init__(self, ds_name, args, strategy, seed):
super(VQAdataset_submission, self).__init__()
dataset = dict()
curDs = Dataset5_submission(strategy, ds_name, args[ds_name], seed)
dataset[ds_name] = curDs
self.dataset = dataset
def __len__(self):
return max(len(v) for k, v in self.dataset.items())
def __getitem__(self, idx):
return_list = []
# breakpoint()
for key, dataset in self.dataset.items():
length = len(dataset)
return_list.append(dataset[idx % length])
return return_list