Spaces:
Running on Zero
Running on Zero
File size: 6,620 Bytes
ddb382a 8031e67 ddb382a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import argparse
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from tqdm import tqdm # 导入 tqdm
import logging # 导入 logging
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from data_utils.v2a_utils.audio_text_dataset import Audio_Text
from data_utils.v2a_utils.feature_utils_224_audio import FeaturesUtils
import torchaudio
from einops import rearrange
from torch.utils.data.dataloader import default_collate
# 设置日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def error_avoidance_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
if len(batch) == 0:
return None # 或 return {}
return default_collate(batch)
def main(args):
rank = int(os.environ["RANK"]) # 从环境变量中获取 rank
world_size = int(os.environ["WORLD_SIZE"]) # 从环境变量中获取 world size
setup(rank, world_size)
tsv_path = args.tsv_path
save_dir = args.save_dir
root = args.root
dataset = Audio_Text(
root=root,
tsv_path=tsv_path,
sample_rate=args.sample_rate,
duration_sec=args.duration_sec,
audio_samples=args.audio_samples,
start_row=args.start_row,
end_row=args.end_row,
save_dir=save_dir
)
os.makedirs(save_dir, exist_ok=True)
# 使用 DataLoader 加载数据集,增加 batch_size 和 num_workers
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=1, sampler=train_sampler, num_workers=16, drop_last=False,collate_fn=error_avoidance_collate)
feature_extractor = FeaturesUtils(
vae_ckpt=args.vae_ckpt,
vae_config=args.vae_config,
enable_conditions=True,
synchformer_ckpt=args.synchformer_ckpt
).eval().cuda(rank)
# 使用 DistributedDataParallel 支持多显卡
feature_extractor = torch.nn.parallel.DistributedDataParallel(feature_extractor, device_ids=[rank])
# 使用 tqdm 显示进度条
for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")):
# 使用 torch.no_grad() 来加快推理速度
if data is None:
continue
ids = data['id'] # 获取当前批次的所有 ID
with torch.no_grad():
audio = data['audio'].cuda(rank, non_blocking=True)
if audio.size(0) == 0:
continue
output = {
'caption': data['caption'],
'caption_cot': data['caption_cot']
}
# logging.info(f'Processing batch {i} with IDs: {ids}') # 添加日志记录
# latent = feature_extractor.module.encode_audio(audio)
# output['latent'] = latent.detach().cpu()
caption = data['caption']
# print(caption,'debug!!!!!!!!!')
# metaclip_global_text_features, metaclip_text_features = feature_extractor.module.encode_text(caption)
# output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu()
# output['metaclip_text_features'] = metaclip_text_features.detach().cpu()
caption_cot = data['caption_cot']
t5_features = feature_extractor.module.encode_t5_text(caption_cot)
output['t5_features'] = t5_features.detach().cpu()
# 保存每个样本的输出
for j in range(audio.size(0)): # 遍历当前批次的每个样本
sample_output = {
'id': ids[j],
'caption': output['caption'][j],
'caption_cot': output['caption_cot'][j],
# 'latent': output['latent'][j],
# 'metaclip_global_text_features': output['metaclip_global_text_features'][j],
# 'metaclip_text_features': output['metaclip_text_features'][j],
't5_features': output['t5_features'][j]
}
torch.save(sample_output, f'{save_dir}/{ids[j]}.pth')
## test the sync between videos and audios
# torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100)
# recon_audio = feature_extractor.decode_audio(latent)
# recon_audio = rearrange(recon_audio, "b d n -> d (b n)")
# id = data['id']
# torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100)
# os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4')
# os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4')
cleanup()
if __name__ == '__main__':
# print('i am rank',os.environ["RANK"])
parser = argparse.ArgumentParser(description='Extract Audio Training Latents')
parser.add_argument('--root', type=str, default='dataset/vggsound/raw_audios/test', help='Root directory of the audio dataset')
parser.add_argument('--tsv_path', type=str, default='cot_vgg_test_mix_coarse.csv', help='Path to the TSV file')
parser.add_argument('--save-dir', type=str, default='vgg_cot_extra/cot_coarse', help='Save Directory')
parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio')
parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
parser.add_argument('--audio_samples', type=int, default=397312, help='Number of audio samples')
parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
parser.add_argument('--start-row', type=int, default=0, help='start row')
parser.add_argument('--end-row', type=int, default=None, help='end row')
args = parser.parse_args()
# 直接使用 torch.distributed.launch 启动
main(args=args) # 这里的 rank 需要在命令行中指定
|