Spaces:
Sleeping
Sleeping
| ''' | |
| Ke Chen | knutchen@ucsd.edu & Nikita Srivatsan | nsrivats@cmu.edu | |
| Load the mp3 format data from audiostock-full dataset | |
| ''' | |
| import json | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| from pathlib import PurePosixPath | |
| import random | |
| import torch | |
| import torchaudio | |
| from torch.utils.data import Dataset | |
| import sys | |
| from lib import * | |
| from utils import * | |
| import torch.utils.data | |
| def int16_to_float32(x): | |
| return (x / 32767.0).type(torch.float) | |
| def float32_to_int16(x): | |
| x = torch.clip(x, min=-1., max=1.) | |
| return (x * 32767.).type(torch.int16) | |
| def my_collate(batch): | |
| batch = [x for x in batch if x is not None] | |
| if len(batch) == 0: | |
| return batch | |
| else: | |
| return torch.utils.data.dataloader.default_collate(batch) | |
| class AudiostockDataset(Dataset): | |
| ''' | |
| Args: | |
| dataset_path (str): the dataset folder path | |
| train (bool): if True, we randomly return a 10-sec chunk from each audio file; if False, we return the middle 10-sec chunk (fixed) | |
| split (str): a txt file to assign the idx in this dataset (for trainng, validation and testing) | |
| factor (float): how many time we need to loop the whole dataset, this is to increase the number of training data batches in each epoch | |
| whole_track (bool): if True, the dataset will return the full length of the audio file. However, this means the batch_size = 1, and it is usually in the test/validation case | |
| ''' | |
| def __init__(self, dataset_path, tweet_prefix=True, prefix_length=10, normalize=False, dupefile='dupes.pkl', train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True, file_list=[]): | |
| super().__init__() | |
| # set up parameters | |
| self.max_seq_len = 150 | |
| self.tweet_prefix = tweet_prefix | |
| if self.tweet_prefix: | |
| self.max_seq_len *= 2 | |
| self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True) | |
| self.prefix_length = prefix_length | |
| self.normalize = normalize | |
| self.id2neighbor = defaultdict(lambda: '') | |
| if dedup: | |
| if dupefile is not None and os.path.exists(dupefile): | |
| with open(dupefile, 'rb') as dupefile: | |
| self.is_rep = pickle.load(dupefile).is_rep | |
| elif dupefile == 'both': | |
| with open('dupes.pkl', 'rb') as dupefile: | |
| dupes1 = pickle.load(dupefile) | |
| with open('dupes_audio.pkl', 'rb') as dupefile: | |
| dupes2 = pickle.load(dupefile) | |
| self.is_rep = defaultdict(lambda: True) | |
| for k,v in dupes1.is_rep.items(): | |
| self.is_rep[k] = v | |
| for k,v in dupes2.is_rep.items(): | |
| self.is_rep[k] = v | |
| else: | |
| sys.exit('Could not find duplicate file') | |
| subfolders = [f'audiostock-part-{i}' for i in range(1,9)] | |
| self.label_path = os.path.join(dataset_path, 'audiostock-full-label') | |
| self.whole_track = whole_track | |
| self.file_list = file_list | |
| # select out the elements for this split | |
| if self.file_list == []: | |
| temp_file_list = [] | |
| for subfolder in subfolders: | |
| temp_file_list += [os.path.join(dataset_path, subfolder, f) for f in os.listdir(os.path.join(dataset_path, subfolder)) if not dedup or self.is_rep[os.path.basename(f).split('.')[0]]] | |
| if split is not None: | |
| split = set(np.loadtxt(split, dtype = str)) | |
| self.file_list = [f for f in temp_file_list if os.path.basename(f).split('.')[0] in split] | |
| else: | |
| self.file_list = temp_file_list | |
| self.train = train | |
| self.total_len = int(len(self.file_list) * factor) | |
| if verbose: | |
| print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}') | |
| def precompute_rand(self, candidate_set=None): | |
| self.id2neighbor = defaultdict(lambda: '') | |
| # if train | |
| if candidate_set is None: | |
| my_ids = [] | |
| candidate_caps = [] | |
| temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
| for batch in temp_loader: | |
| my_ids += batch['id'] | |
| candidate_caps += batch['short_text'] | |
| for idx in my_ids: | |
| self.id2neighbor[idx] = random.choice(candidate_caps) | |
| # if test | |
| else: | |
| temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
| candidate_caps = [] | |
| for batch in temp_loader: | |
| candidate_caps += batch['short_text'] | |
| temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
| my_ids = [] | |
| for batch in temp_loader: | |
| my_ids += batch['id'] | |
| for idx in my_ids: | |
| self.id2neighbor[idx] = random.choice(candidate_caps) | |
| def precompute_gold(self): | |
| self.id2neighbor = defaultdict(lambda: '') | |
| temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
| for batch in temp_loader: | |
| for idx,short_text in zip(batch['id'], batch['short_text']): | |
| self.id2neighbor[idx] = short_text | |
| def precompute_blank(self): | |
| self.id2neighbor = defaultdict(lambda: '\n') | |
| def precompute_neighbors(self, model, candidate_set=None): | |
| print('Precomputing neighbors') | |
| self.id2neighbor = defaultdict(lambda: '') | |
| # if train and model given | |
| if candidate_set is None: | |
| # compute waveform embeddings for each song | |
| cand_features = None | |
| cand_ids = [] | |
| cand_caps = [] | |
| temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
| progress = tqdm(total=len(temp_loader), dynamic_ncols=True) | |
| for batch in temp_loader: | |
| with torch.no_grad(): | |
| batch_features = model.embed_waveform(batch['waveform'].cuda()) | |
| if cand_features is not None: | |
| cand_features = torch.cat([cand_features, batch_features]) | |
| else: | |
| cand_features = batch_features | |
| cand_ids += batch['id'] | |
| cand_caps += batch['short_text'] | |
| progress.update() | |
| progress.close() | |
| my_features = cand_features | |
| my_ids = cand_ids | |
| # if test and model given | |
| else: | |
| # check if we already precomputed the embeddings | |
| pickle_filename = 'nn_features.pkl' | |
| if os.path.isfile(pickle_filename): | |
| with open(pickle_filename, 'rb') as f: | |
| (cand_features, cand_ids, cand_caps) = pickle.load(f) | |
| else: | |
| # build the features from the provided set instead of self | |
| cand_features = None | |
| cand_ids = [] | |
| cand_caps = [] | |
| temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
| progress = tqdm(total=len(temp_loader), dynamic_ncols=True) | |
| for batch in temp_loader: | |
| with torch.no_grad(): | |
| batch_features = model.embed_waveform(batch['waveform'].cuda()) | |
| if cand_features is not None: | |
| cand_features = torch.cat([cand_features, batch_features]) | |
| else: | |
| cand_features = batch_features | |
| cand_ids += batch['id'] | |
| #cand_caps += [' '.join(x.split()[:10]) for x in batch['short_text']] | |
| cand_caps += batch['short_text'] | |
| progress.update() | |
| progress.close() | |
| # dump to pickle so we don't have to redo this each time | |
| with open(pickle_filename, 'wb') as f: | |
| pickle.dump((cand_features, cand_ids, cand_caps), f) | |
| # load up my own ids and features | |
| my_features = None | |
| my_ids = [] | |
| temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
| progress = tqdm(total=len(temp_loader), dynamic_ncols=True) | |
| for batch in temp_loader: | |
| with torch.no_grad(): | |
| batch_features = model.embed_waveform(batch['waveform'].cuda()) | |
| if my_features is not None: | |
| my_features = torch.cat([my_features, batch_features]) | |
| else: | |
| my_features = batch_features | |
| my_ids += batch['id'] | |
| progress.update() | |
| progress.close() | |
| is_self_sim = my_ids == cand_ids | |
| for idx,audio_id in tqdm(enumerate(my_ids), total=len(my_ids), dynamic_ncols=True): | |
| features = my_features[idx] | |
| similarities = features @ cand_features.T | |
| # remove identical matches | |
| if is_self_sim: | |
| similarities[idx] = float('-inf') | |
| best_idx = torch.argmax(similarities) | |
| most_similar_caption = cand_caps[best_idx] | |
| self.id2neighbor[my_ids[idx]] = most_similar_caption | |
| def pad_tokens(self, tokens, tokens_tweet): | |
| tweet_text_len = 0 | |
| if self.tweet_prefix: | |
| tweet_text_len = tokens_tweet[:self.max_seq_len // 2].shape[0] | |
| tokens = torch.cat((tokens_tweet[:tweet_text_len], tokens)) | |
| padding = self.max_seq_len - tokens.shape[0] | |
| if padding > 0: | |
| tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1)) | |
| elif padding < 0: | |
| tokens = tokens[:self.max_seq_len] | |
| mask = tokens.ge(0) # mask is zero where we out of sequence | |
| tokens[~mask] = 0 | |
| mask = mask.float() | |
| mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask | |
| return tokens, mask, tweet_text_len | |
| def read_wav(self, filename): | |
| # pickling functionality removed since it shouldn't be necessary | |
| # chunk | |
| try: | |
| num_frames = torchaudio.info(filename).num_frames | |
| except: | |
| return None | |
| # make sure it wasn't empty, if so die | |
| if num_frames == 0: | |
| return None | |
| sta = 0 | |
| if not self.whole_track: | |
| if self.train: | |
| sta = random.randint(0, num_frames - 441001) | |
| else: | |
| sta = (num_frames - 441001) // 2 | |
| num_frames = 441000 | |
| y, sr = torchaudio.load(filename, frame_offset=sta, num_frames=num_frames) | |
| # resample | |
| y = torchaudio.functional.resample(y, sr, 48000) | |
| y = y[:, :441000] | |
| # mono | |
| y = y.mean(dim=0) | |
| # normalize | |
| y = int16_to_float32(float32_to_int16(y)) | |
| return y | |
| def __getitem__(self, index): | |
| idx = index % len(self.file_list) | |
| data_dict = {} | |
| f = self.file_list[idx] | |
| lf = os.path.join(self.label_path, os.path.basename(f).split('.')[0] + '.json') | |
| data_dict['waveform'] = self.read_wav(f) | |
| if os.path.isfile(lf): | |
| with open(lf,'r') as label_file: | |
| label_data = json.load(label_file) | |
| data_dict['id'] = label_data['id'] | |
| data_dict['short_text'] = label_data['short_text'] | |
| if self.normalize: | |
| data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text'])) | |
| if 'long_text' in label_data and label_data['long_text'] is not None: | |
| data_dict['long_text'] = label_data['long_text'] | |
| else: | |
| data_dict['long_text'] = '' | |
| ''' | |
| data_dict['tag'] = label_data['tag'] | |
| data_dict['impression'] = label_data['impression'] | |
| data_dict['purpose'] = label_data['purpose'] | |
| ''' | |
| else: | |
| data_dict['id'] = os.path.basename(f).split('.')[0] | |
| data_dict['short_text'] = '' | |
| data_dict['long_text'] = '' | |
| # tokenize the caption | |
| caption_proc = preproc(data_dict['short_text'], self.tokenizer) | |
| tokens = torch.tensor(caption_proc, dtype=torch.int64) | |
| tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else '' | |
| tweet_proc = preproc(tweet_text, self.tokenizer, stop=False) | |
| tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64) | |
| tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet) | |
| data_dict['tokens'] = tokens | |
| data_dict['mask'] = mask | |
| data_dict['tweet_text_len'] = tweet_text_len | |
| data_dict['tweet_text'] = tweet_text | |
| if (data_dict['id'] is None or | |
| data_dict['short_text'] is None or | |
| data_dict['long_text'] is None or | |
| data_dict['tokens'] is None or | |
| data_dict['mask'] is None or | |
| data_dict['tweet_text_len'] is None or | |
| data_dict['tweet_text'] is None or | |
| data_dict['waveform'] is None | |
| ): | |
| return None | |
| else: | |
| return data_dict | |
| def __len__(self): | |
| return self.total_len | |
| class MusicCapsDataset(AudiostockDataset): | |
| def __init__(self, dataset_path, args, train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True): | |
| super(AudiostockDataset, self).__init__() | |
| # set up parameters | |
| self.max_seq_len = 150 | |
| self.tweet_prefix = args.tweet_prefix | |
| if self.tweet_prefix: | |
| self.max_seq_len *= 2 | |
| self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True) | |
| self.prefix_length = args.prefix_length | |
| self.normalize = args.normalize | |
| self.whole_track = whole_track | |
| self.label_path = os.path.join(dataset_path, 'audio') | |
| self.file_list = [] | |
| self.label_data = [] | |
| label_reader = pd.read_csv(f'{dataset_path}/musiccaps-resplit.csv') | |
| for idx,row in label_reader.iterrows(): | |
| if (row['is_audioset_eval'] == 1 and split == 'musiccaps_eval') \ | |
| or (row['is_audioset_eval'] == 0 and split == 'musiccaps_train') \ | |
| or (row['is_audioset_eval'] == 2 and split == 'musiccaps_dev'): | |
| data_dict = {} | |
| data_dict['id'] = row['ytid'] | |
| self.file_list.append(f"{dataset_path}/audio/{data_dict['id']}.wav") | |
| data_dict['short_text'] = row['caption'] | |
| if self.normalize: | |
| data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text'])) | |
| data_dict['long_text'] = '' | |
| data_dict['tag'] = row['aspect_list'] | |
| self.label_data.append(data_dict) | |
| self.train = train | |
| self.total_len = int(len(self.file_list) * factor) | |
| if verbose: | |
| print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}') | |
| def __getitem__(self, index): | |
| idx = index % len(self.file_list) | |
| data_dict = {} | |
| f = self.file_list[idx] | |
| data_dict['waveform'] = self.read_wav(f) | |
| for k,v in self.label_data[idx].items(): | |
| data_dict[k] = v | |
| # tokenize the caption | |
| caption_proc = preproc(data_dict['short_text'], self.tokenizer) | |
| tokens = torch.tensor(caption_proc, dtype=torch.int64) | |
| tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else '' | |
| tweet_proc = preproc(tweet_text, self.tokenizer, stop=False) | |
| tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64) | |
| tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet) | |
| data_dict['tokens'] = tokens | |
| data_dict['mask'] = mask | |
| data_dict['tweet_text_len'] = tweet_text_len | |
| data_dict['tweet_text'] = tweet_text | |
| if (data_dict['id'] is None or | |
| data_dict['short_text'] is None or | |
| data_dict['long_text'] is None or | |
| data_dict['tokens'] is None or | |
| data_dict['mask'] is None or | |
| data_dict['tweet_text_len'] is None or | |
| data_dict['tweet_text'] is None or | |
| data_dict['waveform'] is None | |
| ): | |
| return None | |
| else: | |
| return data_dict | |