| from collections import defaultdict |
| import contextlib |
| import os |
| import datetime |
| from concurrent import futures |
| import time |
| import json |
| import hashlib |
| from accelerate import Accelerator |
| from accelerate.utils import set_seed, ProjectConfiguration |
| from accelerate.logging import get_logger |
| from diffusers.utils.torch_utils import is_compiled_module |
| import numpy as np |
| |
| |
| from flow_grpo.stat_tracking import PerPromptStatTracker |
| from flow_grpo.fluxaudio_pipeline_with_logprob import pipeline_with_logprob |
| from flow_grpo.fluxaudio_sde_with_logprob import sde_step_with_logprob |
| import torch |
| import wandb |
| import tempfile |
| from PIL import Image |
| from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, PeftModel |
| import random |
| from torch.utils.data import Dataset, DataLoader, Sampler |
| from flow_grpo.ema import EMAModuleWrapper |
| from flow_grpo.rewards import multi_score |
|
|
| import soundfile as sf |
| from resonate.data.online_audio import format_variant1, format_variant2, format_variant3 |
| import hydra |
| import numpy as np |
| import torch |
| import torch.distributed as distributed |
| from hydra import compose |
| from hydra.core.hydra_config import HydraConfig |
| from omegaconf import DictConfig, open_dict |
| from torch.distributed.elastic.multiprocessing.errors import record |
| from resonate.model.sequence_config import CONFIG_16K, CONFIG_44K |
| from resonate.model.utils.features_utils import FeaturesUtils |
| from resonate.model.flow_matching import FlowMatching |
| from resonate.model.networks import get_model |
|
|
| logger = get_logger(__name__) |
|
|
| def pbar(iterable, desc=None, position=0, leave=True, disable=True, **kwargs): |
| if disable: |
| return iterable |
| else: |
| from tqdm.auto import tqdm |
| return tqdm(iterable, desc=desc, position=position, leave=leave, dynamic_ncols=True, **kwargs) |
|
|
| class AudioPromptDataset(Dataset): |
| def __init__(self, dataset, split='train'): |
| self.file_path = os.path.join(dataset, f'{split}_metadata.jsonl') |
| with open(self.file_path, 'r', encoding='utf-8') as f: |
| self.metadatas = [json.loads(line) for line in f] |
| self.prompts = [item['prompt'] for item in self.metadatas] |
|
|
| def __len__(self): |
| return len(self.prompts) |
|
|
| def __getitem__(self, idx): |
| return {"prompt": self.prompts[idx], "metadata": self.metadatas[idx]} |
|
|
| @staticmethod |
| def collate_fn(examples): |
| prompts = [example["prompt"] for example in examples] |
| metadatas = [example["metadata"] for example in examples] |
| return prompts, metadatas |
| |
| class AudioTemporalDataset(Dataset): |
| def __init__(self, dataset, split='train'): |
| self.file_path = os.path.join(dataset, f'{split}_metadata.jsonl') |
| with open(self.file_path, 'r', encoding='utf-8') as f: |
| self.data = [json.loads(line) for line in f] |
| self.metadatas = [item['phrases'] for item in self.data] |
| self.format_fn = [format_variant1] |
| |
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| phrases = self.metadatas[idx] |
| format_fn = random.choice(self.format_fn) |
| prompt = format_fn(phrases) |
| return {"prompt": prompt, "metadata": phrases} |
|
|
| @staticmethod |
| def collate_fn(examples): |
| prompts = [example["prompt"] for example in examples] |
| metadatas = [example["metadata"] for example in examples] |
| return prompts, metadatas |
|
|
| class DistributedKRepeatSampler(Sampler): |
| def __init__(self, dataset, batch_size, k, num_replicas, rank, seed=0): |
| self.dataset = dataset |
| self.batch_size = batch_size |
| self.k = k |
| self.num_replicas = num_replicas |
| self.rank = rank |
| self.seed = seed |
|
|
| |
| self.total_samples = self.num_replicas * self.batch_size |
| assert self.total_samples % self.k == 0, f"k can not divide n*b, k{k}-num_replicas{num_replicas}-batch_size{batch_size}" |
| self.m = self.total_samples // self.k |
| self.epoch = 0 |
|
|
| def __iter__(self): |
| while True: |
| |
| g = torch.Generator() |
| g.manual_seed(self.seed + self.epoch) |
|
|
| |
| indices = torch.randperm(len(self.dataset), generator=g)[:self.m].tolist() |
|
|
| |
| repeated_indices = [idx for idx in indices for _ in range(self.k)] |
|
|
| |
| shuffled_indices = torch.randperm(len(repeated_indices), generator=g).tolist() |
| shuffled_samples = [repeated_indices[i] for i in shuffled_indices] |
|
|
| |
| per_card_samples = [] |
| for i in range(self.num_replicas): |
| start = i * self.batch_size |
| end = start + self.batch_size |
| per_card_samples.append(shuffled_samples[start:end]) |
|
|
| |
| yield per_card_samples[self.rank] |
|
|
| def set_epoch(self, epoch): |
| self.epoch = epoch |
|
|
|
|
| def calculate_zero_std_ratio(prompts, gathered_rewards): |
| """ |
| Calculate the proportion of unique prompts whose reward standard deviation is zero. |
| |
| Args: |
| prompts: List of prompts. |
| gathered_rewards: Dictionary containing rewards, must include the key 'ori_avg'. |
| |
| Returns: |
| zero_std_ratio: Proportion of prompts with zero standard deviation. |
| prompt_std_devs: Mean standard deviation across all unique prompts. |
| """ |
| |
| prompt_array = np.array(prompts) |
|
|
| |
| unique_prompts, inverse_indices, counts = np.unique( |
| prompt_array, |
| return_inverse=True, |
| return_counts=True |
| ) |
|
|
| |
| grouped_rewards = gathered_rewards['ori_avg'][np.argsort(inverse_indices)] |
| split_indices = np.cumsum(counts)[:-1] |
| reward_groups = np.split(grouped_rewards, split_indices) |
|
|
| |
| prompt_std_devs = np.array([np.std(group) for group in reward_groups]) |
|
|
| |
| zero_std_count = np.count_nonzero(prompt_std_devs == 0) |
| zero_std_ratio = zero_std_count / len(prompt_std_devs) |
|
|
| return zero_std_ratio, prompt_std_devs.mean() |
|
|
| def create_generator(prompts, base_seed): |
| generators = [] |
| for prompt in prompts: |
| |
| hash_digest = hashlib.sha256(prompt.encode()).digest() |
| prompt_hash_int = int.from_bytes(hash_digest[:4], 'big') |
| seed = (base_seed + prompt_hash_int) % (2**31) |
| gen = torch.Generator().manual_seed(seed) |
| generators.append(gen) |
| return generators |
|
|
|
|
| def compute_log_prob(transformer, timesteps, sample, j, embeds, pooled_embeds, config): |
| if config.train.cfg: |
| noise_pred = transformer( |
| latent=torch.cat([sample["latents"][:, j]] * 2), |
| text_f=embeds, |
| text_f_c=pooled_embeds, |
| t=torch.cat([sample["timesteps"][:, j]] * 2), |
| ) |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = ( |
| noise_pred_uncond |
| + config.sample.guidance_scale |
| * (noise_pred_text - noise_pred_uncond) |
| ) |
| else: |
| noise_pred = transformer( |
| latent=torch.cat([sample["latents"][:, j]] * 2), |
| text_f=embeds, |
| text_f_c=pooled_embeds, |
| t=torch.cat([sample["timesteps"][:, j]] * 2), |
| ) |
|
|
| |
| prev_sample, log_prob, prev_sample_mean, std_dev_t = sde_step_with_logprob( |
| timesteps[0], |
| noise_pred.float(), |
| sample["timesteps"][:, j], |
| sample["latents"][:, j].float(), |
| noise_level=config.sample.noise_level, |
| prev_sample=sample["next_latents"][:, j].float(), |
| ) |
|
|
| return prev_sample, log_prob, prev_sample_mean, std_dev_t |
|
|
| def eval(pipeline, test_dataloader, text_encoders, tokenizers, config, accelerator, global_step, reward_fn, executor, autocast, num_train_timesteps, ema, transformer_trainable_parameters): |
| if config.train.ema: |
| ema.copy_ema_to(transformer_trainable_parameters, store_temp=True) |
| neg_prompt_embed, neg_pooled_prompt_embed = compute_text_embeddings([""], text_encoders, tokenizers, max_sequence_length=128, device=accelerator.device) |
|
|
| sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.test_batch_size, 1, 1) |
| sample_neg_pooled_prompt_embeds = neg_pooled_prompt_embed.repeat(config.sample.test_batch_size, 1) |
|
|
| |
| all_rewards = defaultdict(list) |
| for test_batch in pbar( |
| test_dataloader, |
| desc="Eval: ", |
| disable=not accelerator.is_main_process, |
| position=0, |
| ): |
| prompts, prompt_metadata = test_batch |
| prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( |
| prompts, |
| text_encoders, |
| tokenizers, |
| max_sequence_length=128, |
| device=accelerator.device |
| ) |
| |
| if len(prompt_embeds)<len(sample_neg_prompt_embeds): |
| sample_neg_prompt_embeds = sample_neg_prompt_embeds[:len(prompt_embeds)] |
| sample_neg_pooled_prompt_embeds = sample_neg_pooled_prompt_embeds[:len(prompt_embeds)] |
| with autocast(): |
| with torch.no_grad(): |
| images, _, _ = pipeline_with_logprob( |
| pipeline, |
| prompt_embeds=prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| negative_prompt_embeds=sample_neg_prompt_embeds, |
| negative_pooled_prompt_embeds=sample_neg_pooled_prompt_embeds, |
| num_inference_steps=config.sample.eval_num_steps, |
| guidance_scale=config.sample.guidance_scale, |
| output_type="pt", |
| height=config.resolution, |
| width=config.resolution, |
| noise_level=0, |
| ) |
| rewards = executor.submit(reward_fn, images, prompts, prompt_metadata, only_strict=False) |
| |
| time.sleep(0) |
| rewards, reward_metadata = rewards.result() |
|
|
| for key, value in rewards.items(): |
| rewards_gather = accelerator.gather(torch.as_tensor(value, device=accelerator.device)).cpu().numpy() |
| all_rewards[key].append(rewards_gather) |
|
|
| last_batch_images_gather = accelerator.gather(torch.as_tensor(images, device=accelerator.device)).cpu().numpy() |
| last_batch_prompt_ids = tokenizers[0]( |
| prompts, |
| padding="max_length", |
| max_length=256, |
| truncation=True, |
| return_tensors="pt", |
| ).input_ids.to(accelerator.device) |
| last_batch_prompt_ids_gather = accelerator.gather(last_batch_prompt_ids).cpu().numpy() |
| last_batch_prompts_gather = pipeline.tokenizer.batch_decode( |
| last_batch_prompt_ids_gather, skip_special_tokens=True |
| ) |
| last_batch_rewards_gather = {} |
| for key, value in rewards.items(): |
| last_batch_rewards_gather[key] = accelerator.gather(torch.as_tensor(value, device=accelerator.device)).cpu().numpy() |
|
|
| all_rewards = {key: np.concatenate(value) for key, value in all_rewards.items()} |
| if accelerator.is_main_process: |
| with tempfile.TemporaryDirectory() as tmpdir: |
| num_samples = min(15, len(last_batch_images_gather)) |
| |
| sample_indices = range(num_samples) |
| for idx, index in enumerate(sample_indices): |
| image = last_batch_images_gather[index] |
| pil = Image.fromarray( |
| (image.transpose(1, 2, 0) * 255).astype(np.uint8) |
| ) |
| pil = pil.resize((config.resolution, config.resolution)) |
| pil.save(os.path.join(tmpdir, f"{idx}.jpg")) |
| sampled_prompts = [last_batch_prompts_gather[index] for index in sample_indices] |
| sampled_rewards = [{k: last_batch_rewards_gather[k][index] for k in last_batch_rewards_gather} for index in sample_indices] |
| for key, value in all_rewards.items(): |
| print(key, value.shape) |
| wandb.log( |
| { |
| "eval_images": [ |
| wandb.Image( |
| os.path.join(tmpdir, f"{idx}.jpg"), |
| caption=f"{prompt:.1000} | " + " | ".join(f"{k}: {v:.2f}" for k, v in reward.items() if v != -10), |
| ) |
| for idx, (prompt, reward) in enumerate(zip(sampled_prompts, sampled_rewards)) |
| ], |
| **{f"eval_reward_{key}": np.mean(value[value != -10]) for key, value in all_rewards.items()}, |
| }, |
| step=global_step, |
| ) |
| if config.train.ema: |
| ema.copy_temp_to(transformer_trainable_parameters) |
|
|
| def unwrap_model(model, accelerator): |
| model = accelerator.unwrap_model(model) |
| model = model._orig_mod if is_compiled_module(model) else model |
| return model |
|
|
| def save_ckpt(save_dir, transformer, global_step, accelerator, ema, transformer_trainable_parameters, config): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| model_path = os.path.join(save_dir, f'model_{global_step}.pth') |
| torch.save(transformer.module.state_dict(), model_path) |
| logger.info(f'Network weights saved to {model_path}.') |
|
|
| @record |
| @hydra.main(version_base='1.3.2', config_path='config', config_name='train_config.yaml') |
| def train(cfg: DictConfig): |
|
|
| if cfg.get("debug", False): |
| import debugpy |
| if "RANK" not in os.environ or int(os.environ["RANK"]) == 0: |
| debugpy.listen(6665) |
| print(f'Waiting for debugger attach (rank {os.environ["RANK"]})...') |
| debugpy.wait_for_client() |
|
|
| |
| config = cfg |
| unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S") |
| |
| |
| |
| |
|
|
| num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) |
|
|
| save_dir = os.path.join(HydraConfig.get().run.dir) |
| accelerator_config = ProjectConfiguration( |
| project_dir=save_dir, |
| automatic_checkpoint_naming=True, |
| total_limit=config.num_checkpoint_limit, |
| ) |
|
|
| accelerator = Accelerator( |
| |
| mixed_precision=config.mixed_precision, |
| project_config=accelerator_config, |
| |
| |
| |
| gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, |
| ) |
| if accelerator.is_main_process and config.use_wandb: |
| wandb.init( |
| project="flow_grpo", |
| name = cfg.exp_id, |
| ) |
| |
| |
| |
| |
| |
| logger.info(f"\n{config}") |
| set_seed(config.seed, device_specific=True) |
|
|
|
|
| |
| if cfg.audio_sample_rate == 16000: |
| mode = '16k' |
| seq_cfg = CONFIG_16K |
| sample_rate = seq_cfg.sampling_rate |
| duration_sec = seq_cfg.duration |
| logger.info(f'Using 16k mode for sequence config') |
| elif config.audio_sample_rate == 44100: |
| mode = '44k' |
| seq_cfg = CONFIG_44K |
| sample_rate = seq_cfg.sampling_rate |
| duration_sec = seq_cfg.duration |
| logger.info(f'Using 44k mode for sequence config') |
| else: |
| raise ValueError(f'Invalid mode: {mode}') |
|
|
| inference_dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": |
| inference_dtype = torch.float16 |
| elif accelerator.mixed_precision == "bf16": |
| inference_dtype = torch.bfloat16 |
|
|
| text_encoder_name = cfg['text_encoder_name'] |
| need_vae_encoder = cfg.get('online_feature_extraction', False) |
|
|
| if mode == '16k': |
| features = FeaturesUtils( |
| tod_vae_ckpt=cfg['vae_16k_ckpt'], |
| bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'], |
| encoder_name=text_encoder_name, |
| enable_conditions=True, |
| mode=mode, |
| need_vae_encoder=need_vae_encoder, |
| ) |
| vae_sr = 16000 |
| elif mode == '44k': |
| features = FeaturesUtils( |
| tod_vae_ckpt=cfg['vae_44k_ckpt'], |
| encoder_name=text_encoder_name, |
| enable_conditions=True, |
| mode=mode, |
| need_vae_encoder=need_vae_encoder, |
| ) |
| vae_sr = 44000 |
| features = features.eval() |
| features.to(accelerator.device, dtype=torch.float32) |
|
|
| with torch.no_grad(): |
| neg_prompt_embed, neg_pooled_prompt_embed = features.encode_text(['']) |
| neg_prompt_embed = neg_prompt_embed[0] |
| if neg_pooled_prompt_embed is not None: |
| neg_pooled_prompt_embed = neg_pooled_prompt_embed[0] |
| else: |
| neg_pooled_prompt_embed = neg_prompt_embed.mean(dim=0) |
|
|
| train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) |
| train_neg_pooled_prompt_embeds = neg_pooled_prompt_embed.repeat(config.train.batch_size, 1) |
|
|
| if cfg.compile: |
| features.compile() |
|
|
| logger.info(f'Computed empty string feature for {text_encoder_name}') |
| latent_mean, latent_std = torch.load(cfg.latent_mean), torch.load(cfg.latent_std) |
| transformer = get_model(cfg.model, |
| text_dim=cfg.text_dim, |
| text_c_dim=cfg.text_c_dim, |
| latent_mean=latent_mean, |
| latent_std=latent_std, |
| empty_string_feat=neg_prompt_embed, |
| empty_string_feat_c=neg_pooled_prompt_embed, |
| use_rope=cfg.use_rope) |
| transformer.load_weights(torch.load(cfg.weight, map_location=accelerator.device, weights_only=True)) |
| transformer.to(accelerator.device, dtype=inference_dtype) |
| logger.info(f'Loaded weights from {cfg.weight}') |
|
|
| fm = FlowMatching(cfg.sampling.min_sigma, |
| inference_mode=cfg.sampling.method, |
| num_steps=cfg.sampling.num_steps) |
|
|
| if config.use_lora: |
| |
| target_modules = [ |
| "attn.add_k_proj", |
| "attn.add_q_proj", |
| "attn.add_v_proj", |
| "attn.to_add_out", |
| "attn.to_k", |
| "attn.to_out.0", |
| "attn.to_q", |
| "attn.to_v", |
| ] |
| transformer_lora_config = LoraConfig( |
| r=32, |
| lora_alpha=64, |
| init_lora_weights="gaussian", |
| target_modules=target_modules, |
| ) |
| if config.train.lora_path: |
| transformer = PeftModel.from_pretrained(transformer, config.train.lora_path) |
| |
| transformer.set_adapter("default") |
| else: |
| transformer = get_peft_model(transformer, transformer_lora_config) |
| else: |
| import copy |
| ref_transformer = copy.deepcopy(transformer) |
| for p in ref_transformer.parameters(): |
| p.requires_grad = False |
| ref_transformer.eval() |
|
|
| transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) |
| |
| ema = EMAModuleWrapper(transformer_trainable_parameters, decay=0.9, update_step_interval=8, device=accelerator.device) |
|
|
| |
| |
| if config.allow_tf32: |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| |
| if config.train.use_8bit_adam: |
| try: |
| import bitsandbytes as bnb |
| except ImportError: |
| raise ImportError( |
| "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" |
| ) |
|
|
| optimizer_cls = bnb.optim.AdamW8bit |
| else: |
| optimizer_cls = torch.optim.AdamW |
|
|
| optimizer = optimizer_cls( |
| transformer_trainable_parameters, |
| lr=config.train.learning_rate, |
| betas=(config.train.adam_beta1, config.train.adam_beta2), |
| weight_decay=config.train.adam_weight_decay, |
| eps=config.train.adam_epsilon, |
| ) |
|
|
| |
| if config.reward_fn == "qwen3_omni_thinking_semantic_align_score": |
| reward_fn = multi_score("auto", config.reward_fn) |
| else: |
| reward_fn = multi_score(accelerator.device, config.reward_fn) |
| eval_reward_fn = multi_score(accelerator.device, config.reward_fn) |
|
|
| if config.prompt_fn == "audioprompt": |
| train_dataset = AudioPromptDataset(config.dataset, 'train') |
| test_dataset = AudioPromptDataset(config.dataset, 'test') |
|
|
| train_sampler = DistributedKRepeatSampler( |
| dataset=train_dataset, |
| batch_size=config.sample.train_batch_size, |
| k=config.sample.num_audio_per_prompt, |
| num_replicas=accelerator.num_processes, |
| rank=accelerator.process_index, |
| seed=42 |
| ) |
|
|
| train_dataloader = DataLoader( |
| train_dataset, |
| batch_sampler=train_sampler, |
| num_workers=1, |
| collate_fn=AudioPromptDataset.collate_fn, |
| |
| ) |
| test_dataloader = DataLoader( |
| test_dataset, |
| batch_size=config.sample.test_batch_size, |
| collate_fn=AudioPromptDataset.collate_fn, |
| shuffle=False, |
| num_workers=8, |
| ) |
| elif config.prompt_fn == "audio_temporal_prompt": |
| train_dataset = AudioTemporalDataset(config.dataset, 'train') |
| test_dataset = AudioTemporalDataset(config.dataset, 'test') |
| |
| train_sampler = DistributedKRepeatSampler( |
| dataset=train_dataset, |
| batch_size=config.sample.train_batch_size, |
| k=config.sample.num_audio_per_prompt, |
| num_replicas=accelerator.num_processes, |
| rank=accelerator.process_index, |
| seed=42 |
| ) |
| train_dataloader = DataLoader( |
| train_dataset, |
| batch_sampler=train_sampler, |
| num_workers=1, |
| collate_fn=AudioTemporalDataset.collate_fn, |
| |
| ) |
| test_dataloader = DataLoader( |
| test_dataset, |
| batch_size=config.sample.test_batch_size, |
| collate_fn=AudioTemporalDataset.collate_fn, |
| shuffle=False, |
| num_workers=8, |
| ) |
| else: |
| raise NotImplementedError(f"Unrecognized prompt_fn: {config.prompt_fn}") |
|
|
| if config.sample.num_audio_per_prompt == 1: |
| config.per_prompt_stat_tracking = False |
| |
| if config.per_prompt_stat_tracking: |
| stat_tracker = PerPromptStatTracker(config.sample.global_std) |
|
|
| |
| |
| autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast |
| |
|
|
| |
| transformer, optimizer, train_dataloader, test_dataloader = accelerator.prepare(transformer, optimizer, train_dataloader, test_dataloader) |
|
|
| |
| |
| executor = futures.ThreadPoolExecutor(max_workers=8) |
|
|
| |
| samples_per_epoch = ( |
| config.sample.train_batch_size |
| * accelerator.num_processes |
| * config.sample.num_batches_per_epoch |
| ) |
| total_train_batch_size = ( |
| config.train.batch_size |
| * accelerator.num_processes |
| * config.train.gradient_accumulation_steps |
| ) |
|
|
| logger.info("***** Running training *****") |
| logger.info(f" Sample batch size per device = {config.sample.train_batch_size}") |
| logger.info(f" Train batch size per device = {config.train.batch_size}") |
| logger.info( |
| f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}" |
| ) |
| logger.info("") |
| logger.info(f" Total number of samples per epoch = {samples_per_epoch}") |
| logger.info( |
| f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}" |
| ) |
| logger.info( |
| f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}" |
| ) |
| logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}") |
| |
| |
| |
|
|
| epoch = 0 |
| global_step = 0 |
| train_iter = iter(train_dataloader) |
|
|
| while epoch < config.total_epoch: |
| |
| transformer.eval() |
| if epoch % config.eval_freq == 0 and config.do_eval: |
| eval(pipeline, test_dataloader, text_encoders, tokenizers, config, accelerator, global_step, eval_reward_fn, executor, autocast, num_train_timesteps, ema, transformer_trainable_parameters) |
| if epoch % config.save_freq == 0 and epoch > 0 and accelerator.is_main_process: |
| save_ckpt(save_dir, transformer, global_step, accelerator, ema, transformer_trainable_parameters, config) |
|
|
| |
| transformer.eval() |
| samples = [] |
| prompts = [] |
|
|
| for i in pbar( |
| range(config.sample.num_batches_per_epoch), |
| desc=f"Epoch {epoch}: sampling", |
| disable=not accelerator.is_main_process, |
| position=0, |
| ): |
| train_sampler.set_epoch(epoch * config.sample.num_batches_per_epoch + i) |
| prompts, prompt_metadata = next(train_iter) |
| if epoch < 3 and accelerator.is_main_process: |
| logger.info(f"[DEBUG] Sampled prompts: {prompts}") |
|
|
| prompt_embeds, pooled_prompt_embeds = features.encode_text(prompts) |
| prompt_ids = features.tokenizer( |
| prompts, |
| padding="max_length", |
| max_length=77, |
| truncation=True, |
| return_tensors="pt", |
| ).input_ids.to(accelerator.device) |
|
|
| |
| if config.sample.same_latent: |
| generator = create_generator(prompts, base_seed=epoch*10000+i) |
| else: |
| generator = None |
| with autocast(): |
| with torch.no_grad(): |
| latents, log_probs, timesteps = pipeline_with_logprob( |
| transformer, |
| prompt_embeds=prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| num_inference_steps=config.sample.num_steps, |
| guidance_scale=config.sample.guidance_scale, |
| noise_level=config.sample.noise_level, |
| generator=generator |
| ) |
| last_latent = latents[-1] |
| mel = features.decode(last_latent) |
| audios = features.vocode(mel) |
| audios = audios.squeeze(1) |
| latents = torch.stack( |
| latents, dim=1 |
| ) |
| log_probs = torch.stack(log_probs, dim=1) |
| timesteps = timesteps.repeat(config.sample.train_batch_size, 1) |
|
|
| |
| rewards = executor.submit(reward_fn, audios, prompts, prompt_metadata, vae_sr=vae_sr, only_strict=True) |
| |
| time.sleep(0) |
|
|
| samples.append( |
| { |
| "prompt_ids": prompt_ids, |
| "prompt_embeds": prompt_embeds, |
| "pooled_prompt_embeds": pooled_prompt_embeds, |
| "timesteps": timesteps, |
| "latents": latents[ |
| :, :-1 |
| ], |
| "next_latents": latents[ |
| :, 1: |
| ], |
| "log_probs": log_probs, |
| "rewards": rewards, |
| } |
| ) |
| |
| |
| for sample in pbar( |
| samples, |
| desc="Waiting for rewards", |
| disable=not accelerator.is_main_process, |
| position=0, |
| ): |
| rewards, reward_metadata = sample["rewards"].result() |
| |
| sample["rewards"] = { |
| key: torch.as_tensor(value, device=accelerator.device).float() |
| for key, value in rewards.items() |
| } |
|
|
| |
| samples = { |
| k: torch.cat([s[k] for s in samples], dim=0) |
| if not isinstance(samples[0][k], dict) |
| else { |
| sub_key: torch.cat([s[k][sub_key] for s in samples], dim=0) |
| for sub_key in samples[0][k] |
| } |
| for k in samples[0].keys() |
| } |
|
|
| samples["rewards"]["ori_avg"] = samples["rewards"]["avg"] |
| |
| samples["rewards"]["avg"] = samples["rewards"]["avg"].unsqueeze(1).repeat(1, num_train_timesteps) |
| |
| gathered_rewards = {key: accelerator.gather(value) for key, value in samples["rewards"].items()} |
| gathered_rewards = {key: value.cpu().numpy() for key, value in gathered_rewards.items()} |
| |
| if epoch < 3 and accelerator.is_main_process: |
| rollout_root = os.path.join(save_dir, "flowgrpo_rollout") |
| os.makedirs(rollout_root, exist_ok=True) |
| |
| epoch_dir = os.path.join(rollout_root, f"epoch_{epoch:04d}") |
| audio_dir = os.path.join(epoch_dir, "audio") |
| os.makedirs(audio_dir, exist_ok=True) |
|
|
| jsonl_path = os.path.join(epoch_dir, "metadata.jsonl") |
|
|
| with open(jsonl_path, "a", encoding="utf-8") as f: |
| for b in range(audios.shape[0]): |
| audio_path = os.path.join( |
| audio_dir, |
| f"epoch{epoch:04d}_batch{i:03d}_sample{b:02d}.wav" |
| ) |
|
|
| sf.write( |
| audio_path, |
| audios[b].detach().cpu().numpy(), |
| samplerate=vae_sr, |
| ) |
| record = { |
| "epoch": epoch, |
| "batch_idx": i, |
| "sample_idx": b, |
| "audio_path": audio_path, |
| "prompt": prompts[b] |
| } |
| for reward_key in config.reward_fn.keys(): |
| record[reward_key] = float(samples['rewards'][reward_key][b].item()) |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") |
| logger.info(f"[DEBUG] Samples saved for epoch {epoch}") |
| |
| |
| if accelerator.is_main_process and config.use_wandb: |
| wandb.log( |
| { |
| "epoch": epoch, |
| **{f"reward_{key}": value.mean() for key, value in gathered_rewards.items() if '_strict_accuracy' not in key and '_accuracy' not in key}, |
| }, |
| step=global_step, |
| ) |
| if global_step % config.log_text_interval == 0 and accelerator.is_main_process: |
| for reward_key in config.reward_fn.keys(): |
| logger.info(f"Global step: {global_step}, gathered rewards: {gathered_rewards[reward_key].mean()}") |
|
|
| |
| if config.per_prompt_stat_tracking: |
| |
| prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy() |
| prompts = features.tokenizer.batch_decode( |
| prompt_ids, skip_special_tokens=True |
| ) |
| advantages = stat_tracker.update(prompts, gathered_rewards['avg']) |
| if accelerator.is_main_process: |
| print("len(prompts)", len(prompts)) |
| print("len unique prompts", len(set(prompts))) |
|
|
| group_size, trained_prompt_num = stat_tracker.get_stats() |
|
|
| zero_std_ratio, reward_std_mean = calculate_zero_std_ratio(prompts, gathered_rewards) |
|
|
| if accelerator.is_main_process and config.use_wandb: |
| wandb.log( |
| { |
| "group_size": group_size, |
| "trained_prompt_num": trained_prompt_num, |
| "zero_std_ratio": zero_std_ratio, |
| "reward_std_mean": reward_std_mean, |
| }, |
| step=global_step, |
| ) |
| stat_tracker.clear() |
| else: |
| advantages = (gathered_rewards['avg'] - gathered_rewards['avg'].mean()) / (gathered_rewards['avg'].std() + 1e-4) |
|
|
| |
| advantages = torch.as_tensor(advantages) |
| samples["advantages"] = ( |
| advantages.reshape(accelerator.num_processes, -1, advantages.shape[-1])[accelerator.process_index] |
| .to(accelerator.device) |
| ) |
| if accelerator.is_main_process: |
| print("advantages: ", samples["advantages"].abs().mean()) |
|
|
| del samples["rewards"] |
| del samples["prompt_ids"] |
|
|
| |
| mask = (samples["advantages"].abs().sum(dim=1) != 0) |
|
|
| |
| |
| num_batches = config.sample.num_batches_per_epoch |
| true_count = mask.sum() |
| if true_count % num_batches != 0: |
| false_indices = torch.where(~mask)[0] |
| num_to_change = num_batches - (true_count % num_batches) |
| if len(false_indices) >= num_to_change: |
| random_indices = torch.randperm(len(false_indices))[:num_to_change] |
| mask[false_indices[random_indices]] = True |
| if accelerator.is_main_process and config.use_wandb: |
| wandb.log( |
| { |
| "actual_batch_size": mask.sum().item()//config.sample.num_batches_per_epoch, |
| }, |
| step=global_step, |
| ) |
| |
| samples = {k: v[mask] for k, v in samples.items()} |
|
|
| total_batch_size, num_timesteps = samples["timesteps"].shape |
| |
| |
| |
| |
| assert num_timesteps == config.sample.num_steps+1 |
|
|
| |
| for inner_epoch in range(config.train.num_inner_epochs): |
| |
| perm = torch.randperm(total_batch_size, device=accelerator.device) |
| samples = {k: v[perm] for k, v in samples.items()} |
|
|
| |
| samples_batched = { |
| k: v.reshape(-1, total_batch_size//config.sample.num_batches_per_epoch, *v.shape[1:]) |
| for k, v in samples.items() |
| } |
|
|
| |
| samples_batched = [ |
| dict(zip(samples_batched, x)) for x in zip(*samples_batched.values()) |
| ] |
|
|
| |
| transformer.train() |
| info = defaultdict(list) |
| for i, sample in pbar( |
| list(enumerate(samples_batched)), |
| desc=f"Epoch {epoch}.{inner_epoch}: training", |
| position=0, |
| disable=not accelerator.is_main_process, |
| ): |
| if config.train.cfg: |
| |
| embeds = torch.cat( |
| [train_neg_prompt_embeds[:len(sample["prompt_embeds"])], sample["prompt_embeds"]] |
| ) |
| pooled_embeds = torch.cat( |
| [train_neg_pooled_prompt_embeds[:len(sample["pooled_prompt_embeds"])], sample["pooled_prompt_embeds"]] |
| ) |
| else: |
| embeds = sample["prompt_embeds"] |
| pooled_embeds = sample["pooled_prompt_embeds"] |
|
|
| train_timesteps = [step_index for step_index in range(num_train_timesteps)] |
| |
| |
| |
| |
| |
| |
| |
| for j in train_timesteps: |
| with accelerator.accumulate(transformer): |
| with autocast(): |
| prev_sample, log_prob, prev_sample_mean, std_dev_t = compute_log_prob(transformer, timesteps, sample, j, embeds, pooled_embeds, config) |
| if config.train.beta > 0: |
| with torch.no_grad(): |
| if config.use_lora: |
| with transformer.module.disable_adapter(): |
| _, _, prev_sample_mean_ref, _ = compute_log_prob(transformer, timesteps, sample, j, embeds, pooled_embeds, config) |
| else: |
| _, _, prev_sample_mean_ref, _ = compute_log_prob(ref_transformer, timesteps, sample, j, embeds, pooled_embeds, config) |
|
|
| |
| advantages = torch.clamp( |
| sample["advantages"][:, j], |
| -config.train.adv_clip_max, |
| config.train.adv_clip_max, |
| ) |
| ratio = torch.exp(log_prob - sample["log_probs"][:, j]) |
| unclipped_loss = -advantages * ratio |
| clipped_loss = -advantages * torch.clamp( |
| ratio, |
| 1.0 - config.train.clip_range, |
| 1.0 + config.train.clip_range, |
| ) |
| policy_loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) |
| if config.train.beta > 0: |
| kl_loss = ((prev_sample_mean - prev_sample_mean_ref) ** 2).mean(dim=(1,2), keepdim=True) / (2 * std_dev_t ** 2) |
| kl_loss = torch.mean(kl_loss) |
| loss = policy_loss + config.train.beta * kl_loss |
| else: |
| loss = policy_loss |
|
|
| info["approx_kl"].append( |
| 0.5 |
| * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) |
| ) |
| info["clipfrac"].append( |
| torch.mean( |
| ( |
| torch.abs(ratio - 1.0) > config.train.clip_range |
| ).float() |
| ) |
| ) |
| info["clipfrac_gt_one"].append( |
| torch.mean( |
| ( |
| ratio - 1.0 > config.train.clip_range |
| ).float() |
| ) |
| ) |
| info["clipfrac_lt_one"].append( |
| torch.mean( |
| ( |
| 1.0 - ratio > config.train.clip_range |
| ).float() |
| ) |
| ) |
| info["policy_loss"].append(policy_loss) |
| if config.train.beta > 0: |
| info["kl_loss"].append(kl_loss) |
|
|
| info["loss"].append(loss) |
|
|
| |
| accelerator.backward(loss) |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_( |
| transformer.parameters(), config.train.max_grad_norm |
| ) |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| |
| if accelerator.sync_gradients: |
| |
| |
| |
| |
| info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} |
| info = accelerator.reduce(info, reduction="mean") |
| info.update({"epoch": epoch, "inner_epoch": inner_epoch}) |
| if accelerator.is_main_process and config.use_wandb: |
| wandb.log(info, step=global_step) |
| global_step += 1 |
| info = defaultdict(list) |
| if config.train.ema: |
| ema.step(transformer_trainable_parameters, global_step) |
| |
| |
|
|
| epoch+=1 |
|
|
| if __name__ == "__main__": |
| train() |