Resonate / train_grpo.py
AndreasXi's picture
Update app without the binary image
e471447
from collections import defaultdict
import contextlib
import os
import datetime
from concurrent import futures
import time
import json
import hashlib
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate.logging import get_logger
from diffusers.utils.torch_utils import is_compiled_module
import numpy as np
# import flow_grpo.prompts
# import flow_grpo.rewards
from flow_grpo.stat_tracking import PerPromptStatTracker
from flow_grpo.fluxaudio_pipeline_with_logprob import pipeline_with_logprob
from flow_grpo.fluxaudio_sde_with_logprob import sde_step_with_logprob
import torch
import wandb
import tempfile
from PIL import Image
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, PeftModel
import random
from torch.utils.data import Dataset, DataLoader, Sampler
from flow_grpo.ema import EMAModuleWrapper
from flow_grpo.rewards import multi_score
import soundfile as sf
from resonate.data.online_audio import format_variant1, format_variant2, format_variant3
import hydra
import numpy as np
import torch
import torch.distributed as distributed
from hydra import compose
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, open_dict
from torch.distributed.elastic.multiprocessing.errors import record
from resonate.model.sequence_config import CONFIG_16K, CONFIG_44K
from resonate.model.utils.features_utils import FeaturesUtils
from resonate.model.flow_matching import FlowMatching
from resonate.model.networks import get_model
logger = get_logger(__name__)
def pbar(iterable, desc=None, position=0, leave=True, disable=True, **kwargs):
if disable:
return iterable
else:
from tqdm.auto import tqdm
return tqdm(iterable, desc=desc, position=position, leave=leave, dynamic_ncols=True, **kwargs)
class AudioPromptDataset(Dataset):
def __init__(self, dataset, split='train'):
self.file_path = os.path.join(dataset, f'{split}_metadata.jsonl')
with open(self.file_path, 'r', encoding='utf-8') as f:
self.metadatas = [json.loads(line) for line in f]
self.prompts = [item['prompt'] for item in self.metadatas]
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return {"prompt": self.prompts[idx], "metadata": self.metadatas[idx]}
@staticmethod
def collate_fn(examples):
prompts = [example["prompt"] for example in examples]
metadatas = [example["metadata"] for example in examples]
return prompts, metadatas
class AudioTemporalDataset(Dataset):
def __init__(self, dataset, split='train'):
self.file_path = os.path.join(dataset, f'{split}_metadata.jsonl')
with open(self.file_path, 'r', encoding='utf-8') as f:
self.data = [json.loads(line) for line in f]
self.metadatas = [item['phrases'] for item in self.data]
self.format_fn = [format_variant1]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
phrases = self.metadatas[idx]
format_fn = random.choice(self.format_fn)
prompt = format_fn(phrases)
return {"prompt": prompt, "metadata": phrases}
@staticmethod
def collate_fn(examples):
prompts = [example["prompt"] for example in examples]
metadatas = [example["metadata"] for example in examples]
return prompts, metadatas
class DistributedKRepeatSampler(Sampler):
def __init__(self, dataset, batch_size, k, num_replicas, rank, seed=0):
self.dataset = dataset
self.batch_size = batch_size # Batch size per replica
self.k = k # Number of repetitions per sample
self.num_replicas = num_replicas # Total number of replicas
self.rank = rank # Current replica rank
self.seed = seed # Random seed for synchronization
# Compute the number of unique samples needed per iteration
self.total_samples = self.num_replicas * self.batch_size
assert self.total_samples % self.k == 0, f"k can not divide n*b, k{k}-num_replicas{num_replicas}-batch_size{batch_size}"
self.m = self.total_samples // self.k # Number of unique samples
self.epoch = 0
def __iter__(self):
while True:
# Generate a deterministic random sequence to ensure all replicas are synchronized
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# Randomly select m unique samples
indices = torch.randperm(len(self.dataset), generator=g)[:self.m].tolist()
# Repeat each sample k times to generate n*b total samples
repeated_indices = [idx for idx in indices for _ in range(self.k)]
# Shuffle to ensure uniform distribution
shuffled_indices = torch.randperm(len(repeated_indices), generator=g).tolist()
shuffled_samples = [repeated_indices[i] for i in shuffled_indices]
# Split samples to each replica
per_card_samples = []
for i in range(self.num_replicas):
start = i * self.batch_size
end = start + self.batch_size
per_card_samples.append(shuffled_samples[start:end])
# Return current replica's sample indices
yield per_card_samples[self.rank]
def set_epoch(self, epoch):
self.epoch = epoch # Used to synchronize random state across epochs
def calculate_zero_std_ratio(prompts, gathered_rewards):
"""
Calculate the proportion of unique prompts whose reward standard deviation is zero.
Args:
prompts: List of prompts.
gathered_rewards: Dictionary containing rewards, must include the key 'ori_avg'.
Returns:
zero_std_ratio: Proportion of prompts with zero standard deviation.
prompt_std_devs: Mean standard deviation across all unique prompts.
"""
# Convert prompt list to NumPy array
prompt_array = np.array(prompts)
# Get unique prompts and their group information
unique_prompts, inverse_indices, counts = np.unique(
prompt_array,
return_inverse=True,
return_counts=True
)
# Group rewards for each prompt
grouped_rewards = gathered_rewards['ori_avg'][np.argsort(inverse_indices)]
split_indices = np.cumsum(counts)[:-1]
reward_groups = np.split(grouped_rewards, split_indices)
# Calculate standard deviation for each group
prompt_std_devs = np.array([np.std(group) for group in reward_groups])
# Calculate the ratio of zero standard deviation
zero_std_count = np.count_nonzero(prompt_std_devs == 0)
zero_std_ratio = zero_std_count / len(prompt_std_devs)
return zero_std_ratio, prompt_std_devs.mean()
def create_generator(prompts, base_seed):
generators = []
for prompt in prompts:
# Use a stable hash (SHA256), then convert it to an integer seed
hash_digest = hashlib.sha256(prompt.encode()).digest()
prompt_hash_int = int.from_bytes(hash_digest[:4], 'big') # Take the first 4 bytes as part of the seed
seed = (base_seed + prompt_hash_int) % (2**31) # Ensure the number is within a valid range
gen = torch.Generator().manual_seed(seed)
generators.append(gen)
return generators
def compute_log_prob(transformer, timesteps, sample, j, embeds, pooled_embeds, config):
if config.train.cfg:
noise_pred = transformer(
latent=torch.cat([sample["latents"][:, j]] * 2),
text_f=embeds,
text_f_c=pooled_embeds,
t=torch.cat([sample["timesteps"][:, j]] * 2),
)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (
noise_pred_uncond
+ config.sample.guidance_scale
* (noise_pred_text - noise_pred_uncond)
)
else:
noise_pred = transformer(
latent=torch.cat([sample["latents"][:, j]] * 2),
text_f=embeds,
text_f_c=pooled_embeds,
t=torch.cat([sample["timesteps"][:, j]] * 2),
)
# compute the log prob of next_latents given latents under the current model
prev_sample, log_prob, prev_sample_mean, std_dev_t = sde_step_with_logprob(
timesteps[0],
noise_pred.float(),
sample["timesteps"][:, j],
sample["latents"][:, j].float(),
noise_level=config.sample.noise_level,
prev_sample=sample["next_latents"][:, j].float(),
)
return prev_sample, log_prob, prev_sample_mean, std_dev_t
def eval(pipeline, test_dataloader, text_encoders, tokenizers, config, accelerator, global_step, reward_fn, executor, autocast, num_train_timesteps, ema, transformer_trainable_parameters):
if config.train.ema:
ema.copy_ema_to(transformer_trainable_parameters, store_temp=True)
neg_prompt_embed, neg_pooled_prompt_embed = compute_text_embeddings([""], text_encoders, tokenizers, max_sequence_length=128, device=accelerator.device)
sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.test_batch_size, 1, 1)
sample_neg_pooled_prompt_embeds = neg_pooled_prompt_embed.repeat(config.sample.test_batch_size, 1)
# test_dataloader = itertools.islice(test_dataloader, 2)
all_rewards = defaultdict(list)
for test_batch in pbar(
test_dataloader,
desc="Eval: ",
disable=not accelerator.is_main_process,
position=0,
):
prompts, prompt_metadata = test_batch
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
prompts,
text_encoders,
tokenizers,
max_sequence_length=128,
device=accelerator.device
)
# The last batch may not be full batch_size
if len(prompt_embeds)<len(sample_neg_prompt_embeds):
sample_neg_prompt_embeds = sample_neg_prompt_embeds[:len(prompt_embeds)]
sample_neg_pooled_prompt_embeds = sample_neg_pooled_prompt_embeds[:len(prompt_embeds)]
with autocast():
with torch.no_grad():
images, _, _ = pipeline_with_logprob(
pipeline,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
negative_pooled_prompt_embeds=sample_neg_pooled_prompt_embeds,
num_inference_steps=config.sample.eval_num_steps,
guidance_scale=config.sample.guidance_scale,
output_type="pt",
height=config.resolution,
width=config.resolution,
noise_level=0,
)
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata, only_strict=False)
# yield to to make sure reward computation starts
time.sleep(0)
rewards, reward_metadata = rewards.result()
for key, value in rewards.items():
rewards_gather = accelerator.gather(torch.as_tensor(value, device=accelerator.device)).cpu().numpy()
all_rewards[key].append(rewards_gather)
last_batch_images_gather = accelerator.gather(torch.as_tensor(images, device=accelerator.device)).cpu().numpy()
last_batch_prompt_ids = tokenizers[0](
prompts,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt",
).input_ids.to(accelerator.device)
last_batch_prompt_ids_gather = accelerator.gather(last_batch_prompt_ids).cpu().numpy()
last_batch_prompts_gather = pipeline.tokenizer.batch_decode(
last_batch_prompt_ids_gather, skip_special_tokens=True
)
last_batch_rewards_gather = {}
for key, value in rewards.items():
last_batch_rewards_gather[key] = accelerator.gather(torch.as_tensor(value, device=accelerator.device)).cpu().numpy()
all_rewards = {key: np.concatenate(value) for key, value in all_rewards.items()}
if accelerator.is_main_process:
with tempfile.TemporaryDirectory() as tmpdir:
num_samples = min(15, len(last_batch_images_gather))
# sample_indices = random.sample(range(len(images)), num_samples)
sample_indices = range(num_samples)
for idx, index in enumerate(sample_indices):
image = last_batch_images_gather[index]
pil = Image.fromarray(
(image.transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((config.resolution, config.resolution))
pil.save(os.path.join(tmpdir, f"{idx}.jpg"))
sampled_prompts = [last_batch_prompts_gather[index] for index in sample_indices]
sampled_rewards = [{k: last_batch_rewards_gather[k][index] for k in last_batch_rewards_gather} for index in sample_indices]
for key, value in all_rewards.items():
print(key, value.shape)
wandb.log(
{
"eval_images": [
wandb.Image(
os.path.join(tmpdir, f"{idx}.jpg"),
caption=f"{prompt:.1000} | " + " | ".join(f"{k}: {v:.2f}" for k, v in reward.items() if v != -10),
)
for idx, (prompt, reward) in enumerate(zip(sampled_prompts, sampled_rewards))
],
**{f"eval_reward_{key}": np.mean(value[value != -10]) for key, value in all_rewards.items()},
},
step=global_step,
)
if config.train.ema:
ema.copy_temp_to(transformer_trainable_parameters)
def unwrap_model(model, accelerator):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def save_ckpt(save_dir, transformer, global_step, accelerator, ema, transformer_trainable_parameters, config):
# save_root = os.path.join(save_dir, "checkpoints", f"checkpoint-{global_step}")
# save_root_lora = os.path.join(save_root, "lora")
# os.makedirs(save_root_lora, exist_ok=True)
# if accelerator.is_main_process:
# if config.train.ema:
# ema.copy_ema_to(transformer_trainable_parameters, store_temp=True)
# unwrap_model(transformer, accelerator).save_pretrained(save_root_lora)
# if config.train.ema:
# ema.copy_temp_to(transformer_trainable_parameters)
model_path = os.path.join(save_dir, f'model_{global_step}.pth')
torch.save(transformer.module.state_dict(), model_path)
logger.info(f'Network weights saved to {model_path}.')
@record
@hydra.main(version_base='1.3.2', config_path='config', config_name='train_config.yaml')
def train(cfg: DictConfig):
if cfg.get("debug", False):
import debugpy
if "RANK" not in os.environ or int(os.environ["RANK"]) == 0:
debugpy.listen(6665)
print(f'Waiting for debugger attach (rank {os.environ["RANK"]})...')
debugpy.wait_for_client()
#### 1. Basic setup
config = cfg
unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
# if not config.exp_id:
# config.exp_id = unique_id
# else:
# config.exp_id += "_" + unique_id
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) # number of timesteps within each trajectory to train on
save_dir = os.path.join(HydraConfig.get().run.dir)
accelerator_config = ProjectConfiguration(
project_dir=save_dir,
automatic_checkpoint_naming=True,
total_limit=config.num_checkpoint_limit,
)
accelerator = Accelerator(
# log_with="wandb",
mixed_precision=config.mixed_precision,
project_config=accelerator_config,
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
# the total number of optimizer steps to accumulate across.
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
)
if accelerator.is_main_process and config.use_wandb:
wandb.init(
project="flow_grpo",
name = cfg.exp_id,
)
# accelerator.init_trackers(
# project_name="flow-grpo",
# config=config.to_dict(),
# init_kwargs={"wandb": {"name": config.run_name}},
# )
logger.info(f"\n{config}")
set_seed(config.seed, device_specific=True) # set seed (device_specific is very important to get different prompts on different devices)
##### 2. load scheduler, tokenizer and models.
if cfg.audio_sample_rate == 16000:
mode = '16k'
seq_cfg = CONFIG_16K # for 10s audio
sample_rate = seq_cfg.sampling_rate
duration_sec = seq_cfg.duration
logger.info(f'Using 16k mode for sequence config')
elif config.audio_sample_rate == 44100:
mode = '44k'
seq_cfg = CONFIG_44K # for 10s audio
sample_rate = seq_cfg.sampling_rate
duration_sec = seq_cfg.duration
logger.info(f'Using 44k mode for sequence config')
else:
raise ValueError(f'Invalid mode: {mode}')
inference_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
inference_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
inference_dtype = torch.bfloat16
text_encoder_name = cfg['text_encoder_name']
need_vae_encoder = cfg.get('online_feature_extraction', False)
if mode == '16k':
features = FeaturesUtils(
tod_vae_ckpt=cfg['vae_16k_ckpt'],
bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'],
encoder_name=text_encoder_name,
enable_conditions=True,
mode=mode,
need_vae_encoder=need_vae_encoder,
)
vae_sr = 16000
elif mode == '44k':
features = FeaturesUtils(
tod_vae_ckpt=cfg['vae_44k_ckpt'],
encoder_name=text_encoder_name,
enable_conditions=True,
mode=mode,
need_vae_encoder=need_vae_encoder,
)
vae_sr = 44000
features = features.eval()
features.to(accelerator.device, dtype=torch.float32)
with torch.no_grad():
neg_prompt_embed, neg_pooled_prompt_embed = features.encode_text([''])
neg_prompt_embed = neg_prompt_embed[0]
if neg_pooled_prompt_embed is not None:
neg_pooled_prompt_embed = neg_pooled_prompt_embed[0]
else:
neg_pooled_prompt_embed = neg_prompt_embed.mean(dim=0)
train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)
train_neg_pooled_prompt_embeds = neg_pooled_prompt_embed.repeat(config.train.batch_size, 1)
if cfg.compile:
features.compile()
logger.info(f'Computed empty string feature for {text_encoder_name}')
latent_mean, latent_std = torch.load(cfg.latent_mean), torch.load(cfg.latent_std)
transformer = get_model(cfg.model,
text_dim=cfg.text_dim,
text_c_dim=cfg.text_c_dim,
latent_mean=latent_mean,
latent_std=latent_std,
empty_string_feat=neg_prompt_embed,
empty_string_feat_c=neg_pooled_prompt_embed,
use_rope=cfg.use_rope)
transformer.load_weights(torch.load(cfg.weight, map_location=accelerator.device, weights_only=True))
transformer.to(accelerator.device, dtype=inference_dtype)
logger.info(f'Loaded weights from {cfg.weight}')
fm = FlowMatching(cfg.sampling.min_sigma,
inference_mode=cfg.sampling.method,
num_steps=cfg.sampling.num_steps)
if config.use_lora:
# Set correct lora layers
target_modules = [
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"attn.to_k",
"attn.to_out.0",
"attn.to_q",
"attn.to_v",
]
transformer_lora_config = LoraConfig(
r=32,
lora_alpha=64,
init_lora_weights="gaussian",
target_modules=target_modules,
)
if config.train.lora_path:
transformer = PeftModel.from_pretrained(transformer, config.train.lora_path)
# After loading with PeftModel.from_pretrained, all parameters have requires_grad set to False. You need to call set_adapter to enable gradients for the adapter parameters.
transformer.set_adapter("default")
else:
transformer = get_peft_model(transformer, transformer_lora_config)
else:
import copy
ref_transformer = copy.deepcopy(transformer)
for p in ref_transformer.parameters():
p.requires_grad = False
ref_transformer.eval()
transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
# This ema setting affects the previous 20 × 8 = 160 steps on average.
ema = EMAModuleWrapper(transformer_trainable_parameters, decay=0.9, update_step_interval=8, device=accelerator.device)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# Initialize the optimizer
if config.train.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
transformer_trainable_parameters,
lr=config.train.learning_rate,
betas=(config.train.adam_beta1, config.train.adam_beta2),
weight_decay=config.train.adam_weight_decay,
eps=config.train.adam_epsilon,
)
# prepare prompt and reward fn
if config.reward_fn == "qwen3_omni_thinking_semantic_align_score":
reward_fn = multi_score("auto", config.reward_fn)
else:
reward_fn = multi_score(accelerator.device, config.reward_fn)
eval_reward_fn = multi_score(accelerator.device, config.reward_fn)
if config.prompt_fn == "audioprompt":
train_dataset = AudioPromptDataset(config.dataset, 'train')
test_dataset = AudioPromptDataset(config.dataset, 'test')
train_sampler = DistributedKRepeatSampler(
dataset=train_dataset,
batch_size=config.sample.train_batch_size,
k=config.sample.num_audio_per_prompt,
num_replicas=accelerator.num_processes,
rank=accelerator.process_index,
seed=42
)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=1,
collate_fn=AudioPromptDataset.collate_fn,
# persistent_workers=True
)
test_dataloader = DataLoader(
test_dataset,
batch_size=config.sample.test_batch_size,
collate_fn=AudioPromptDataset.collate_fn,
shuffle=False,
num_workers=8,
)
elif config.prompt_fn == "audio_temporal_prompt":
train_dataset = AudioTemporalDataset(config.dataset, 'train')
test_dataset = AudioTemporalDataset(config.dataset, 'test')
train_sampler = DistributedKRepeatSampler(
dataset=train_dataset,
batch_size=config.sample.train_batch_size,
k=config.sample.num_audio_per_prompt,
num_replicas=accelerator.num_processes,
rank=accelerator.process_index,
seed=42
)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=1,
collate_fn=AudioTemporalDataset.collate_fn,
# persistent_workers=True
)
test_dataloader = DataLoader(
test_dataset,
batch_size=config.sample.test_batch_size,
collate_fn=AudioTemporalDataset.collate_fn,
shuffle=False,
num_workers=8,
)
else:
raise NotImplementedError(f"Unrecognized prompt_fn: {config.prompt_fn}")
if config.sample.num_audio_per_prompt == 1:
config.per_prompt_stat_tracking = False
# initialize stat tracker
if config.per_prompt_stat_tracking:
stat_tracker = PerPromptStatTracker(config.sample.global_std)
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
# more memory
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
# autocast = accelerator.autocast
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, test_dataloader = accelerator.prepare(transformer, optimizer, train_dataloader, test_dataloader)
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
# remote server running llava inference.
executor = futures.ThreadPoolExecutor(max_workers=8)
# Train!
samples_per_epoch = (
config.sample.train_batch_size
* accelerator.num_processes
* config.sample.num_batches_per_epoch
)
total_train_batch_size = (
config.train.batch_size
* accelerator.num_processes
* config.train.gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Sample batch size per device = {config.sample.train_batch_size}")
logger.info(f" Train batch size per device = {config.train.batch_size}")
logger.info(
f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}"
)
logger.info("")
logger.info(f" Total number of samples per epoch = {samples_per_epoch}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}"
)
logger.info(
f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}"
)
logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}")
# assert config.sample.train_batch_size >= config.train.batch_size
# assert config.sample.train_batch_size % config.train.batch_size == 0
# assert samples_per_epoch % total_train_batch_size == 0
epoch = 0
global_step = 0
train_iter = iter(train_dataloader)
while epoch < config.total_epoch:
#################### EVAL ####################
transformer.eval()
if epoch % config.eval_freq == 0 and config.do_eval:
eval(pipeline, test_dataloader, text_encoders, tokenizers, config, accelerator, global_step, eval_reward_fn, executor, autocast, num_train_timesteps, ema, transformer_trainable_parameters)
if epoch % config.save_freq == 0 and epoch > 0 and accelerator.is_main_process:
save_ckpt(save_dir, transformer, global_step, accelerator, ema, transformer_trainable_parameters, config)
#################### SAMPLING ####################
transformer.eval()
samples = []
prompts = []
for i in pbar(
range(config.sample.num_batches_per_epoch),
desc=f"Epoch {epoch}: sampling",
disable=not accelerator.is_main_process,
position=0,
):
train_sampler.set_epoch(epoch * config.sample.num_batches_per_epoch + i)
prompts, prompt_metadata = next(train_iter)
if epoch < 3 and accelerator.is_main_process:
logger.info(f"[DEBUG] Sampled prompts: {prompts}")
prompt_embeds, pooled_prompt_embeds = features.encode_text(prompts)
prompt_ids = features.tokenizer(
prompts,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
).input_ids.to(accelerator.device)
# sample
if config.sample.same_latent:
generator = create_generator(prompts, base_seed=epoch*10000+i)
else:
generator = None
with autocast():
with torch.no_grad():
latents, log_probs, timesteps = pipeline_with_logprob(
transformer,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=config.sample.num_steps,
guidance_scale=config.sample.guidance_scale,
noise_level=config.sample.noise_level,
generator=generator
)
last_latent = latents[-1]
mel = features.decode(last_latent)
audios = features.vocode(mel)
audios = audios.squeeze(1)
latents = torch.stack(
latents, dim=1
) # (batch_size, num_steps + 1, 16, 96, 96)
log_probs = torch.stack(log_probs, dim=1) # shape after stack (batch_size, num_steps)
timesteps = timesteps.repeat(config.sample.train_batch_size, 1) # (batch_size, num_steps)
# compute rewards asynchronously
rewards = executor.submit(reward_fn, audios, prompts, prompt_metadata, vae_sr=vae_sr, only_strict=True)
# yield to to make sure reward computation starts
time.sleep(0)
samples.append(
{
"prompt_ids": prompt_ids,
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"timesteps": timesteps,
"latents": latents[
:, :-1
], # each entry is the latent before timestep t
"next_latents": latents[
:, 1:
], # each entry is the latent after timestep t
"log_probs": log_probs,
"rewards": rewards,
}
)
# wait for all rewards to be computed
for sample in pbar(
samples,
desc="Waiting for rewards",
disable=not accelerator.is_main_process,
position=0,
):
rewards, reward_metadata = sample["rewards"].result()
# accelerator.print(reward_metadata)
sample["rewards"] = {
key: torch.as_tensor(value, device=accelerator.device).float()
for key, value in rewards.items()
}
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {
k: torch.cat([s[k] for s in samples], dim=0)
if not isinstance(samples[0][k], dict)
else {
sub_key: torch.cat([s[k][sub_key] for s in samples], dim=0)
for sub_key in samples[0][k]
}
for k in samples[0].keys()
}
samples["rewards"]["ori_avg"] = samples["rewards"]["avg"]
# The purpose of repeating `adv` along the timestep dimension here is to make it easier to introduce timestep-dependent advantages later, such as adding a KL reward.
samples["rewards"]["avg"] = samples["rewards"]["avg"].unsqueeze(1).repeat(1, num_train_timesteps)
# gather rewards across processes
gathered_rewards = {key: accelerator.gather(value) for key, value in samples["rewards"].items()}
gathered_rewards = {key: value.cpu().numpy() for key, value in gathered_rewards.items()}
if epoch < 3 and accelerator.is_main_process:
rollout_root = os.path.join(save_dir, "flowgrpo_rollout")
os.makedirs(rollout_root, exist_ok=True)
epoch_dir = os.path.join(rollout_root, f"epoch_{epoch:04d}")
audio_dir = os.path.join(epoch_dir, "audio")
os.makedirs(audio_dir, exist_ok=True)
jsonl_path = os.path.join(epoch_dir, "metadata.jsonl")
with open(jsonl_path, "a", encoding="utf-8") as f:
for b in range(audios.shape[0]):
audio_path = os.path.join(
audio_dir,
f"epoch{epoch:04d}_batch{i:03d}_sample{b:02d}.wav"
)
sf.write(
audio_path,
audios[b].detach().cpu().numpy(),
samplerate=vae_sr,
)
record = {
"epoch": epoch,
"batch_idx": i,
"sample_idx": b,
"audio_path": audio_path,
"prompt": prompts[b]
}
for reward_key in config.reward_fn.keys():
record[reward_key] = float(samples['rewards'][reward_key][b].item())
f.write(json.dumps(record, ensure_ascii=False) + "\n")
logger.info(f"[DEBUG] Samples saved for epoch {epoch}")
# log rewards and images
if accelerator.is_main_process and config.use_wandb:
wandb.log(
{
"epoch": epoch,
**{f"reward_{key}": value.mean() for key, value in gathered_rewards.items() if '_strict_accuracy' not in key and '_accuracy' not in key},
},
step=global_step,
)
if global_step % config.log_text_interval == 0 and accelerator.is_main_process:
for reward_key in config.reward_fn.keys():
logger.info(f"Global step: {global_step}, gathered rewards: {gathered_rewards[reward_key].mean()}")
# per-prompt mean/std tracking
if config.per_prompt_stat_tracking:
# gather the prompts across processes
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
prompts = features.tokenizer.batch_decode(
prompt_ids, skip_special_tokens=True
)
advantages = stat_tracker.update(prompts, gathered_rewards['avg'])
if accelerator.is_main_process:
print("len(prompts)", len(prompts))
print("len unique prompts", len(set(prompts)))
group_size, trained_prompt_num = stat_tracker.get_stats()
zero_std_ratio, reward_std_mean = calculate_zero_std_ratio(prompts, gathered_rewards)
if accelerator.is_main_process and config.use_wandb:
wandb.log(
{
"group_size": group_size,
"trained_prompt_num": trained_prompt_num,
"zero_std_ratio": zero_std_ratio,
"reward_std_mean": reward_std_mean,
},
step=global_step,
)
stat_tracker.clear()
else:
advantages = (gathered_rewards['avg'] - gathered_rewards['avg'].mean()) / (gathered_rewards['avg'].std() + 1e-4)
# ungather advantages; we only need to keep the entries corresponding to the samples on this process
advantages = torch.as_tensor(advantages)
samples["advantages"] = (
advantages.reshape(accelerator.num_processes, -1, advantages.shape[-1])[accelerator.process_index]
.to(accelerator.device)
)
if accelerator.is_main_process:
print("advantages: ", samples["advantages"].abs().mean())
del samples["rewards"]
del samples["prompt_ids"]
# Get the mask for samples where all advantages are zero across the time dimension
mask = (samples["advantages"].abs().sum(dim=1) != 0)
# If the number of True values in mask is not divisible by config.sample.num_batches_per_epoch,
# randomly change some False values to True to make it divisible
num_batches = config.sample.num_batches_per_epoch
true_count = mask.sum()
if true_count % num_batches != 0:
false_indices = torch.where(~mask)[0]
num_to_change = num_batches - (true_count % num_batches)
if len(false_indices) >= num_to_change:
random_indices = torch.randperm(len(false_indices))[:num_to_change]
mask[false_indices[random_indices]] = True
if accelerator.is_main_process and config.use_wandb:
wandb.log(
{
"actual_batch_size": mask.sum().item()//config.sample.num_batches_per_epoch,
},
step=global_step,
)
# Filter out samples where the entire time dimension of advantages is zero
samples = {k: v[mask] for k, v in samples.items()}
total_batch_size, num_timesteps = samples["timesteps"].shape
# assert (
# total_batch_size
# == config.sample.train_batch_size * config.sample.num_batches_per_epoch
# )
assert num_timesteps == config.sample.num_steps+1
#################### TRAINING ####################
for inner_epoch in range(config.train.num_inner_epochs):
# shuffle samples along batch dimension
perm = torch.randperm(total_batch_size, device=accelerator.device)
samples = {k: v[perm] for k, v in samples.items()}
# rebatch for training
samples_batched = {
k: v.reshape(-1, total_batch_size//config.sample.num_batches_per_epoch, *v.shape[1:])
for k, v in samples.items()
}
# dict of lists -> list of dicts for easier iteration
samples_batched = [
dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
]
# train
transformer.train()
info = defaultdict(list)
for i, sample in pbar(
list(enumerate(samples_batched)),
desc=f"Epoch {epoch}.{inner_epoch}: training",
position=0,
disable=not accelerator.is_main_process,
):
if config.train.cfg:
# concat negative prompts to sample prompts to avoid two forward passes
embeds = torch.cat(
[train_neg_prompt_embeds[:len(sample["prompt_embeds"])], sample["prompt_embeds"]]
)
pooled_embeds = torch.cat(
[train_neg_pooled_prompt_embeds[:len(sample["pooled_prompt_embeds"])], sample["pooled_prompt_embeds"]]
)
else:
embeds = sample["prompt_embeds"]
pooled_embeds = sample["pooled_prompt_embeds"]
train_timesteps = [step_index for step_index in range(num_train_timesteps)]
# for j in tqdm(
# train_timesteps,
# desc="Timestep",
# position=1,
# leave=False,
# disable=not accelerator.is_main_process,
# ):
for j in train_timesteps:
with accelerator.accumulate(transformer):
with autocast():
prev_sample, log_prob, prev_sample_mean, std_dev_t = compute_log_prob(transformer, timesteps, sample, j, embeds, pooled_embeds, config)
if config.train.beta > 0:
with torch.no_grad():
if config.use_lora:
with transformer.module.disable_adapter():
_, _, prev_sample_mean_ref, _ = compute_log_prob(transformer, timesteps, sample, j, embeds, pooled_embeds, config)
else:
_, _, prev_sample_mean_ref, _ = compute_log_prob(ref_transformer, timesteps, sample, j, embeds, pooled_embeds, config)
# grpo logic
advantages = torch.clamp(
sample["advantages"][:, j],
-config.train.adv_clip_max,
config.train.adv_clip_max,
)
ratio = torch.exp(log_prob - sample["log_probs"][:, j])
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(
ratio,
1.0 - config.train.clip_range,
1.0 + config.train.clip_range,
)
policy_loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
if config.train.beta > 0:
kl_loss = ((prev_sample_mean - prev_sample_mean_ref) ** 2).mean(dim=(1,2), keepdim=True) / (2 * std_dev_t ** 2)
kl_loss = torch.mean(kl_loss)
loss = policy_loss + config.train.beta * kl_loss
else:
loss = policy_loss
info["approx_kl"].append(
0.5
* torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
)
info["clipfrac"].append(
torch.mean(
(
torch.abs(ratio - 1.0) > config.train.clip_range
).float()
)
)
info["clipfrac_gt_one"].append(
torch.mean(
(
ratio - 1.0 > config.train.clip_range
).float()
)
)
info["clipfrac_lt_one"].append(
torch.mean(
(
1.0 - ratio > config.train.clip_range
).float()
)
)
info["policy_loss"].append(policy_loss)
if config.train.beta > 0:
info["kl_loss"].append(kl_loss)
info["loss"].append(loss)
# backward pass
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
transformer.parameters(), config.train.max_grad_norm
)
optimizer.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
# assert (j == train_timesteps[-1]) and (
# i + 1
# ) % config.train.gradient_accumulation_steps == 0
# log training-related stuff
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
info = accelerator.reduce(info, reduction="mean")
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
if accelerator.is_main_process and config.use_wandb:
wandb.log(info, step=global_step)
global_step += 1
info = defaultdict(list)
if config.train.ema:
ema.step(transformer_trainable_parameters, global_step)
# make sure we did an optimization step at the end of the inner epoch
# assert accelerator.sync_gradients
epoch+=1
if __name__ == "__main__":
train()