| """Ref-AVS dataset: frames, masks, log-mel audio, and referring expressions.""" |
| import os |
| import numpy |
| import torch |
| import pandas |
|
|
| from dataloader.visual.visual_dataset import Visual |
| from dataloader.audio.audio_and_text_dataset import AudioAndText |
|
|
|
|
| class AV(torch.utils.data.Dataset): |
| """Pairs ``Visual`` with ``AudioAndText`` via REFAVS ``metadata.csv``.""" |
|
|
| def __init__(self, split, augmentation, param, root_path=''): |
| self.visual_dataset = Visual( |
| augmentation['visual'], root_path, split, |
| param.image_size, param.image_embedding_size, |
| ) |
| self.audio_and_text_dataset = AudioAndText(augmentation['audio'], root_path, split) |
| self.split = split |
| self.file_path = self.organise_files(self.split, root_path, csv_name_='metadata.csv') |
|
|
| def __getitem__(self, index): |
| vid, fid, exp, _ = self.file_path[index] |
| frame, label, prompts = self.visual_dataset.load_data(vid, fid) |
| audio_mel, text_feature = self.audio_and_text_dataset.load_audio_wave(vid, exp) |
| return { |
| 'frame': frame, |
| 'label': label, |
| 'spectrogram': audio_mel, |
| 'text': text_feature, |
| 'id': self.file_path[index], |
| 'prompts': prompts, |
| } |
|
|
| def __len__(self): |
| return len(self.file_path) |
|
|
| @staticmethod |
| def organise_files(split_, root_path_, csv_name_): |
| total_files = pandas.read_csv(os.path.join(root_path_, csv_name_)) |
| if split_ == 'test_n': |
| rows = zip( |
| total_files[total_files['split'] == split_]['uid'], |
| total_files[total_files['split'] == split_]['fid'], |
| total_files[total_files['split'] == split_]['exp'], |
| ) |
| return [ |
| [name.rsplit('_', 2)[0], object_id, expression, 0] |
| for name, object_id, expression in rows |
| ] |
|
|
| rows = zip( |
| total_files[total_files['split'] == split_]['vid'], |
| total_files[total_files['split'] == split_]['fid'], |
| total_files[total_files['split'] == split_]['exp'], |
| ) |
| file_path = [[vid, fid, expression, 0] for vid, fid, expression in rows] |
|
|
| if split_ == 'train': |
| null_uids = list(total_files[total_files['split'] == split_]['uid']) |
| assert len(null_uids) == len(file_path) |
| for idx, row in enumerate(file_path): |
| if 'null_' in null_uids[idx]: |
| row[0] = null_uids[idx].rsplit('_', 2)[0] |
| row[-1] = null_uids[idx].rsplit('_', 2)[1] |
|
|
| return file_path |
|
|