| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | import os |
| | import json |
| | from tqdm import tqdm |
| | import ipdb |
| | import random |
| | from torch.nn.utils.rnn import pad_sequence |
| | from dataclasses import dataclass, field |
| | from typing import Callable, Dict, Sequence |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | import transformers |
| | import numpy as np |
| | from torch.utils.data import Dataset |
| | from .base_dataset import BaseDataset |
| | from tqdm import tqdm |
| | import pandas as pd |
| | from .utils import process_caption |
| |
|
| |
|
| | class AudioCapDataset(BaseDataset): |
| | """Dataset for supervised fine-tuning.""" |
| |
|
| | def __init__(self, data_path: str, mm_root_path: str, embed_path: str, dataset_type: str): |
| | super(AudioCapDataset, self).__init__(data_path, mm_root_path, embed_path, dataset_type) |
| | self.embed_path = embed_path |
| |
|
| | print('Load Audiocap dataset ...') |
| | self.mm_path_list, self.caption_list = [], [] |
| | with open(data_path, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | for row in tqdm(data, total=len(data)): |
| | audio_id, one_caption = row["audio_name"], row["caption"] |
| | self.mm_path_list.append(os.path.join(mm_root_path, audio_id)) |
| | self.caption_list.append(process_caption(one_caption)) |
| |
|
| | print(f'[!] collect {len(self.mm_path_list)} samples for training') |
| | self.dataset_type_list = [dataset_type for _ in range(len(self.caption_list))] |
| |
|
| |
|
| |
|