File size: 6,620 Bytes
ddb382a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8031e67
ddb382a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from tqdm import tqdm  # 导入 tqdm
import logging  # 导入 logging
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from data_utils.v2a_utils.audio_text_dataset import Audio_Text
from data_utils.v2a_utils.feature_utils_224_audio import FeaturesUtils
import torchaudio
from einops import rearrange
from torch.utils.data.dataloader import default_collate


# 设置日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def error_avoidance_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0:
        return None  # 或 return {}
    return default_collate(batch)
def main(args):
    rank = int(os.environ["RANK"])  # 从环境变量中获取 rank
    world_size = int(os.environ["WORLD_SIZE"])  # 从环境变量中获取 world size
    setup(rank, world_size)
    tsv_path = args.tsv_path
    save_dir = args.save_dir
    root = args.root
    dataset = Audio_Text(
        root=root,
        tsv_path=tsv_path,
        sample_rate=args.sample_rate,
        duration_sec=args.duration_sec,
        audio_samples=args.audio_samples,
        start_row=args.start_row,
        end_row=args.end_row,
        save_dir=save_dir
    )
    os.makedirs(save_dir, exist_ok=True)
    # 使用 DataLoader 加载数据集,增加 batch_size 和 num_workers
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=1, sampler=train_sampler, num_workers=16, drop_last=False,collate_fn=error_avoidance_collate)

    feature_extractor = FeaturesUtils(
        vae_ckpt=args.vae_ckpt,
        vae_config=args.vae_config,
        enable_conditions=True,
        synchformer_ckpt=args.synchformer_ckpt
    ).eval().cuda(rank)

    # 使用 DistributedDataParallel 支持多显卡
    feature_extractor = torch.nn.parallel.DistributedDataParallel(feature_extractor, device_ids=[rank])

    # 使用 tqdm 显示进度条
    for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")):
        # 使用 torch.no_grad() 来加快推理速度
        if data is None:
            continue
        ids = data['id']  # 获取当前批次的所有 ID
        with torch.no_grad():
            audio = data['audio'].cuda(rank, non_blocking=True)
            if audio.size(0) == 0:
                continue
            output = {
                'caption': data['caption'],
                'caption_cot': data['caption_cot']
            }
            # logging.info(f'Processing batch {i} with IDs: {ids}')  # 添加日志记录

            # latent = feature_extractor.module.encode_audio(audio)
            # output['latent'] = latent.detach().cpu()

            caption = data['caption']
            # print(caption,'debug!!!!!!!!!')
            # metaclip_global_text_features, metaclip_text_features = feature_extractor.module.encode_text(caption)
            # output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu()
            # output['metaclip_text_features'] = metaclip_text_features.detach().cpu()

            caption_cot = data['caption_cot']
            t5_features = feature_extractor.module.encode_t5_text(caption_cot)
            output['t5_features'] = t5_features.detach().cpu()

            # 保存每个样本的输出
            for j in range(audio.size(0)):  # 遍历当前批次的每个样本
                sample_output = {
                    'id': ids[j],
                    'caption': output['caption'][j],
                    'caption_cot': output['caption_cot'][j],
                    # 'latent': output['latent'][j],
                    # 'metaclip_global_text_features': output['metaclip_global_text_features'][j],
                    # 'metaclip_text_features': output['metaclip_text_features'][j],
                    't5_features': output['t5_features'][j]
                }
                torch.save(sample_output, f'{save_dir}/{ids[j]}.pth')

        ## test the sync between videos and audios
        # torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100)
        # recon_audio = feature_extractor.decode_audio(latent)
        # recon_audio = rearrange(recon_audio, "b d n -> d (b n)")
        # id = data['id']
        # torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100)
        # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4')
        # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4')

    cleanup()

if __name__ == '__main__':
    # print('i am rank',os.environ["RANK"])
    parser = argparse.ArgumentParser(description='Extract Audio Training Latents')
    parser.add_argument('--root', type=str, default='dataset/vggsound/raw_audios/test', help='Root directory of the audio dataset')
    parser.add_argument('--tsv_path', type=str, default='cot_vgg_test_mix_coarse.csv', help='Path to the TSV file')
    parser.add_argument('--save-dir', type=str, default='vgg_cot_extra/cot_coarse', help='Save Directory')
    parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio')
    parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
    parser.add_argument('--audio_samples', type=int, default=397312, help='Number of audio samples')
    parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
    parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
    parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
    parser.add_argument('--start-row', type=int, default=0, help='start row')
    parser.add_argument('--end-row', type=int, default=None, help='end row')

    args = parser.parse_args()

    # 直接使用 torch.distributed.launch 启动
    main(args=args)  # 这里的 rank 需要在命令行中指定