File size: 2,626 Bytes
c6dfc69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """Ref-AVS dataset: frames, masks, log-mel audio, and referring expressions."""
import os
import numpy
import torch
import pandas
from dataloader.visual.visual_dataset import Visual
from dataloader.audio.audio_and_text_dataset import AudioAndText
class AV(torch.utils.data.Dataset):
"""Pairs ``Visual`` with ``AudioAndText`` via REFAVS ``metadata.csv``."""
def __init__(self, split, augmentation, param, root_path=''):
self.visual_dataset = Visual(
augmentation['visual'], root_path, split,
param.image_size, param.image_embedding_size,
)
self.audio_and_text_dataset = AudioAndText(augmentation['audio'], root_path, split)
self.split = split
self.file_path = self.organise_files(self.split, root_path, csv_name_='metadata.csv')
def __getitem__(self, index):
vid, fid, exp, _ = self.file_path[index]
frame, label, prompts = self.visual_dataset.load_data(vid, fid)
audio_mel, text_feature = self.audio_and_text_dataset.load_audio_wave(vid, exp)
return {
'frame': frame,
'label': label,
'spectrogram': audio_mel,
'text': text_feature,
'id': self.file_path[index],
'prompts': prompts,
}
def __len__(self):
return len(self.file_path)
@staticmethod
def organise_files(split_, root_path_, csv_name_):
total_files = pandas.read_csv(os.path.join(root_path_, csv_name_))
if split_ == 'test_n':
rows = zip(
total_files[total_files['split'] == split_]['uid'],
total_files[total_files['split'] == split_]['fid'],
total_files[total_files['split'] == split_]['exp'],
)
return [
[name.rsplit('_', 2)[0], object_id, expression, 0]
for name, object_id, expression in rows
]
rows = zip(
total_files[total_files['split'] == split_]['vid'],
total_files[total_files['split'] == split_]['fid'],
total_files[total_files['split'] == split_]['exp'],
)
file_path = [[vid, fid, expression, 0] for vid, fid, expression in rows]
if split_ == 'train':
null_uids = list(total_files[total_files['split'] == split_]['uid'])
assert len(null_uids) == len(file_path)
for idx, row in enumerate(file_path):
if 'null_' in null_uids[idx]:
row[0] = null_uids[idx].rsplit('_', 2)[0]
row[-1] = null_uids[idx].rsplit('_', 2)[1]
return file_path
|