from prefigure.prefigure import get_all_args, push_wandb_config import json import os import re import torch import torchaudio from lightning.pytorch import seed_everything import random from datetime import datetime import numpy as np from PrismAudio.models import create_model_from_config from PrismAudio.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model from PrismAudio.inference.sampling import sample, sample_discrete_euler from pathlib import Path def predict_step(diffusion, batch, diffusion_objective, device='cuda:0'): diffusion = diffusion.to(device) reals, metadata = batch ids = [item['id'] for item in metadata] batch_size, length = reals.shape[0], reals.shape[2] print(f"Predicting {batch_size} samples with length {length} for ids: {ids}") with torch.amp.autocast('cuda'): conditioning = diffusion.conditioner(metadata, device) video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) if 'metaclip_features' in conditioning: conditioning['metaclip_features'][~video_exist] = diffusion.model.model.empty_clip_feat if 'sync_features' in conditioning: conditioning['sync_features'][~video_exist] = diffusion.model.model.empty_sync_feat cond_inputs = diffusion.get_conditioning_inputs(conditioning) if batch_size > 1: noise_list = [] for _ in range(batch_size): noise_1 = torch.randn([1, diffusion.io_channels, length]).to(device) # 每次生成推进RNG状态 noise_list.append(noise_1) noise = torch.cat(noise_list, dim=0) else: noise = torch.randn([batch_size, diffusion.io_channels, length]).to(device) with torch.amp.autocast('cuda'): model = diffusion.model if diffusion_objective == "v": fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) elif diffusion_objective == "rectified_flow": import time start_time = time.time() fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) end_time = time.time() execution_time = end_time - start_time print(f"执行时间: {execution_time:.2f} 秒") if diffusion.pretransform is not None: fakes = diffusion.pretransform.decode(fakes) audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() return audios def load_file(filename, info, latent_length): # try: npz_file = filename if os.path.exists(npz_file): # print(filename) npz_data = np.load(npz_file,allow_pickle=True) data = {key: npz_data[key] for key in npz_data.files} # print("data.keys()",data.keys()) for key in data.keys(): if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): data[key] = torch.from_numpy(data[key]) else: raise ValueError(f'error load file: {filename}') info.update(data) audio = torch.zeros((1, 64, latent_length), dtype=torch.float32) info['video_exist'] = torch.tensor(True) # except: # print(f'error load file: {filename}') return audio, info def load(filename,duration): assert os.path.exists(filename) info = {} audio, info = load_file(filename, info, round(44100/64/32*duration)) info["path"] = filename info['id'] = Path(filename).stem info["relpath"] = 'demo.npz' return (audio, info) def main(): args = get_all_args() if (args.save_dir == ''): args.save_dir=args.results_dir seed = args.seed # Set a different seed for each process if using SLURM if os.environ.get("SLURM_PROCID") is not None: seed += int(os.environ.get("SLURM_PROCID")) # random.seed(seed) # torch.manual_seed(seed) seed_everything(seed, workers=True) #Get JSON config from args.model_config if args.model_config == '': args.model_config = "PrismAudio/configs/model_configs/thinksound.json" with open(args.model_config) as f: model_config = json.load(f) duration=(float)(args.duration_sec) model_config["sample_size"] = duration * model_config["sample_rate"] if "sync_seq_len" in model_config["model"]["diffusion"]["config"]: model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration) if "clip_seq_len" in model_config["model"]["diffusion"]["config"]: model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration) if "latent_seq_len" in model_config["model"]["diffusion"]["config"]: model_config["model"]["diffusion"]["config"]["latent_seq_len"] = round(44100 / 64 / 32 * duration) model = create_model_from_config(model_config) ## speed by torch.compile if args.compile: model = torch.compile(model) model.load_state_dict(torch.load(args.ckpt_dir)) load_vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.') model.pretransform.load_state_dict(load_vae_state) audio,meta=load(os.path.join(args.results_dir, "demo.npz") , duration) for k, v in meta.items(): if isinstance(v, torch.Tensor): meta[k] = v.to('cuda:0') audio=predict_step(model, batch=[audio,(meta,)], diffusion_objective=model_config["model"]["diffusion"]["diffusion_objective"], device='cuda:0' ) current_date = datetime.now() formatted_date = current_date.strftime('%m%d') audio_dir = os.path.join(args.save_dir,f'{formatted_date}_batch_size'+str(args.test_batch_size)) os.makedirs(audio_dir,exist_ok=True) torchaudio.save(os.path.join(audio_dir,"demo.wav"), audio[0], 44100) #trainer.predict(training_wrapper, dm, return_predictions=False) if __name__ == '__main__': main()