File size: 3,892 Bytes
c6dfc69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | """Fused audio-visual dataset for AVSBench-style indexing."""
import os
import random
import PIL.Image
import numpy
import torch
from dataloader.visual.visual_dataset import Visual
from dataloader.audio.audio_dataset import Audio
import pandas
class AV(torch.utils.data.Dataset):
"""Pairs video frames + labels from `Visual` with log-mel spectrograms from `Audio` via `metadata.csv`."""
def __init__(self, split, augmentation, param, root_path=''):
# v2.code entry: always merge v1s + v1m + v2 from `avss_index/metadata.csv` (artifacts v2 pool).
# Visual/Audio get `root_path/v2` as base path; per-sample `load_data` uses full `file_path` (v1s|v1m|v2/uid).
v2_root = os.path.join(root_path, 'v2')
self.visual_dataset = Visual(
augmentation['visual'],
v2_root,
split,
param.image_size,
param.image_embedding_size,
)
self.audio_dataset = Audio(augmentation['audio'], v2_root, split)
self.augment = augmentation
self.split = split
self.file_path = self.organise_files(self.split, root_path, csv_name_='avss_index/metadata.csv')
def __getitem__(self, index):
mixing_prob = 0. # we omit this option.
other_index = random.randint(1, self.__len__()) - 1 if random.random() < mixing_prob and self.split == 'train' else None
frame, label, prompts = self.visual_dataset.load_data(self.file_path[index])
if other_index is not None:
other_frame, other_label, other_prompts = self.visual_dataset.load_data(self.file_path[other_index])
frame, label, prompts = self.visual_mix(frame, other_frame, label, other_label, prompts, other_prompts)
audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], self.file_path[other_index])
else:
audio_mel = self.audio_dataset.load_audio_wave(self.file_path[index], None)
assert other_index is None if self.split == 'test' else 1, print('no mix in validation.')
return {'frame': frame, 'label': label, 'spectrogram': audio_mel, 'id': self.file_path[index],
'prompts': prompts}
def __len__(self):
return len(self.file_path)
@staticmethod
def organise_files(split_, root_path_, csv_name_):
total_files = pandas.read_csv(os.path.join(root_path_, csv_name_))
files_info_v2 = total_files[(total_files["split"] == split_) & (total_files["label"] == 'v2')]['uid']
files_path_v2 = [os.path.join(root_path_, 'v2', files_name) for files_name in files_info_v2]
files_info_v1s = total_files[(total_files["split"] == split_) & (total_files["label"] == 'v1s')]['uid']
files_path_v1s = [os.path.join(root_path_, 'v1s', files_name) for files_name in files_info_v1s]
files_info_v1m = total_files[(total_files["split"] == split_) & (total_files["label"] == 'v1m')]['uid']
files_path_v1m = [os.path.join(root_path_, 'v1m', files_name) for files_name in files_info_v1m]
files_path = files_path_v1s + files_path_v1m + files_path_v2
del total_files
return files_path
@staticmethod
def visual_mix(frame1, frame2, label1, label2, prompts1, prompts2):
mix_frame = frame1.clone()
mix_label = label1.clone()
bbx1, bby1, bbx2, bby2 = 0, 0, mix_label.shape[1] - 1, mix_label.shape[2] - 1
for i in range(0, mix_frame.shape[0]):
label_canvas_foreground = label2[i, bbx1:bbx2, bby1:bby2] > 0.
mix_frame[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground] = (
frame2[i, :, bbx1:bbx2, bby1:bby2][:, label_canvas_foreground])
mix_label[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground] = (
label2[i, bbx1:bbx2, bby1:bby2][label_canvas_foreground])
return mix_frame, mix_label, prompts1
|