Resonate / eval.py
AndreasXi's picture
Update app without the binary image
e471447
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import logging
import math
from argparse import ArgumentParser
from pathlib import Path
import torch
import torchaudio
from hydra import compose, initialize
import pandas as pd
from tqdm import tqdm
from resonate.eval_utils import generate_fm, setup_eval_logging
from resonate.model.flow_matching import FlowMatching
from resonate.model.networks import FluxAudio, get_model
from resonate.model.utils.features_utils import FeaturesUtils
from resonate.model.sequence_config import CONFIG_16K, CONFIG_44K
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from tqdm import tqdm
log = logging.getLogger()
@torch.inference_mode()
def main():
setup_eval_logging()
parser = ArgumentParser()
parser.add_argument('--config_name', type=str, required=True, help='config file name under config/ (e.g., train_config_online_feature_umt5.yaml)')
parser.add_argument('--eval_dataset', type=str, required=True, help='eval dataset name, e.g. librispeech-pc')
parser.add_argument('--cfg_strength', type=float, default=4.5)
parser.add_argument('--num_steps', type=int, default=25)
parser.add_argument('--output', type=Path, help='Output directory', default='./output')
parser.add_argument('--seed', type=int, help='Random seed', default=42)
parser.add_argument('--full_precision', action='store_true')
parser.add_argument('--model_path', type=str, help='Path of trained model')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--enable_speech_prompt', action='store_true', help='Whether or not explicitly instruct the speech')
args = parser.parse_args()
if args.debug:
import debugpy
debugpy.listen(6666)
print("Waiting for debugger attach (rank 0)...")
debugpy.wait_for_client()
with initialize(version_base="1.3.2", config_path="config"):
cfg = compose(config_name=args.config_name)
if cfg.audio_sample_rate == 16000:
seq_cfg = CONFIG_16K
elif cfg.audio_sample_rate == 44100:
seq_cfg = CONFIG_44K
else:
raise ValueError(f'Invalid audio sample rate: {cfg.audio_sample_rate}')
output_dir: str = args.output.expanduser()
seed: int = args.seed
num_steps: int = args.num_steps
cfg_strength: float = args.cfg_strength
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
log.warning('CUDA/MPS are not available, running on CPU')
dtype = torch.float32 if args.full_precision else torch.bfloat16
output_dir.mkdir(parents=True, exist_ok=True)
# load a pretrained model with train-style params (only overwrite model-loading related fields)
use_rope = cfg.get('use_rope', True)
text_dim = cfg.get('text_dim', None)
text_c_dim = cfg.get('text_c_dim', None)
net: FluxAudio = get_model(cfg.model,
use_rope=use_rope,
text_dim=text_dim,
text_c_dim=text_c_dim).to(device, dtype).eval()
net.load_weights(torch.load(args.model_path, map_location=device, weights_only=True))
log.info(f'Loaded weights from {args.model_path}')
net.update_seq_lengths(seq_cfg.latent_seq_len)
# misc setup
rng = torch.Generator(device=device)
rng.manual_seed(seed)
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
encoder_name = cfg.get('text_encoder_name', 'flan-t5')
if cfg.audio_sample_rate == 16000:
feature_utils = FeaturesUtils(tod_vae_ckpt=cfg.get('vae_16k_ckpt'),
enable_conditions=True,
encoder_name=encoder_name,
mode='16k',
bigvgan_vocoder_ckpt=cfg.get('bigvgan_vocoder_ckpt'),
need_vae_encoder=False)
elif cfg.audio_sample_rate == 44100:
feature_utils = FeaturesUtils(tod_vae_ckpt=cfg.get('vae_44k_ckpt'),
enable_conditions=True,
encoder_name=encoder_name,
mode='44k',
need_vae_encoder=False)
feature_utils = feature_utils.to(device, dtype).eval()
metadata_file = 'sets/acc_prompt.json'
import json
data = json.load(open(metadata_file, 'r'))
bsz = 16
for i in tqdm(range(0, len(data), bsz)):
batch = data[i:i + bsz]
audio_ids = [d['id'] for d in batch]
prompts = [d['prompt_text'] for d in batch]
for audio_id, prompt in zip(audio_ids, prompts):
log.info(f'Audio id: {audio_id} Prompt: {prompt}')
# batch generate
audios = generate_fm(
prompts,
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg_strength
)
audios = audios.float().cpu()
for audio_id, audio in zip(audio_ids, audios):
save_path = output_dir / f'{audio_id}.wav'
audio = audio.detach().cpu()
if audio.ndim == 1:
audio = audio.unsqueeze(0) # [1, T]
elif audio.ndim == 2:
pass # already [C, T]
elif audio.ndim == 3:
audio = audio.squeeze(0) # [1, T] or [C, T]
else:
raise RuntimeError(f"Unexpected audio shape: {audio.shape}")
torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
if __name__ == '__main__':
main()