| import os |
| import datetime |
| import time |
| import json |
| import hashlib |
| import contextlib |
| import random |
| from collections import defaultdict |
| from concurrent import futures |
| from functools import partial |
| from resonate.model.flow_matching import FlowMatching |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.distributed as distributed |
| from torch.utils.data import Dataset, DataLoader, Sampler |
| import numpy as np |
| import wandb |
| import soundfile as sf |
| import hydra |
| from hydra.core.hydra_config import HydraConfig |
| from omegaconf import DictConfig |
| from accelerate import Accelerator |
| from accelerate.utils import set_seed, ProjectConfiguration |
| from accelerate.logging import get_logger |
| from peft import LoraConfig, get_peft_model, PeftModel |
|
|
| |
| from torch.distributed.elastic.multiprocessing.errors import record |
| from resonate.model.sequence_config import CONFIG_16K, CONFIG_44K |
| from resonate.model.utils.features_utils import FeaturesUtils |
| from resonate.model.networks import get_model |
|
|
| |
| from flow_grpo.stat_tracking import PerPromptStatTracker |
| from flow_grpo.fluxaudio_pipeline_with_logprob import pipeline_with_logprob |
| from flow_grpo.ema import EMAModuleWrapper |
| from flow_grpo.rewards import multi_score |
|
|
| def pbar(iterable, desc=None, position=0, leave=True, disable=True, **kwargs): |
| if disable: |
| return iterable |
| else: |
| from tqdm.auto import tqdm |
| return tqdm(iterable, desc=desc, position=position, leave=leave, dynamic_ncols=True, **kwargs) |
|
|
| logger = get_logger(__name__) |
|
|
| |
|
|
| class AudioPromptDataset(Dataset): |
| def __init__(self, dataset, split='train'): |
| self.file_path = os.path.join(dataset, f'{split}_metadata.jsonl') |
| with open(self.file_path, 'r', encoding='utf-8') as f: |
| self.metadatas = [json.loads(line) for line in f] |
| self.prompts = [item['prompt'] for item in self.metadatas] |
|
|
| def __len__(self): |
| return len(self.prompts) |
|
|
| def __getitem__(self, idx): |
| return {"prompt": self.prompts[idx], "metadata": self.metadatas[idx]} |
|
|
| @staticmethod |
| def collate_fn(examples): |
| prompts = [example["prompt"] for example in examples] |
| metadatas = [example["metadata"] for example in examples] |
| return prompts, metadatas |
| |
| class AudioTemporalDataset(Dataset): |
| def __init__(self, dataset, split='train'): |
| self.file_path = os.path.join(dataset, f'{split}_metadata.jsonl') |
| with open(self.file_path, 'r', encoding='utf-8') as f: |
| self.data = [json.loads(line) for line in f] |
| self.metadatas = [item['phrases'] for item in self.data] |
| from resonate.data.online_audio import format_variant1 |
| self.format_fn = [format_variant1] |
| |
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| phrases = self.metadatas[idx] |
| format_fn = random.choice(self.format_fn) |
| prompt = format_fn(phrases) |
| return {"prompt": prompt, "metadata": phrases} |
|
|
| @staticmethod |
| def collate_fn(examples): |
| prompts = [example["prompt"] for example in examples] |
| metadatas = [example["metadata"] for example in examples] |
| return prompts, metadatas |
|
|
| class DistributedKRepeatSampler(Sampler): |
| def __init__(self, dataset, batch_size, k, num_replicas, rank, seed=0): |
| self.dataset = dataset |
| self.batch_size = batch_size |
| self.k = k |
| self.num_replicas = num_replicas |
| self.rank = rank |
| self.seed = seed |
| self.total_samples = self.num_replicas * self.batch_size |
| assert self.total_samples % self.k == 0 |
| self.m = self.total_samples // self.k |
| self.epoch = 0 |
|
|
| def __iter__(self): |
| while True: |
| g = torch.Generator() |
| g.manual_seed(self.seed + self.epoch) |
| indices = torch.randperm(len(self.dataset), generator=g)[:self.m].tolist() |
| repeated_indices = [idx for idx in indices for _ in range(self.k)] |
| |
| |
|
|
| per_card_samples = [] |
| for i in range(self.num_replicas): |
| start = i * self.batch_size |
| end = start + self.batch_size |
| |
| per_card_samples.append(repeated_indices[start:end]) |
| yield per_card_samples[self.rank] |
|
|
| def set_epoch(self, epoch): |
| self.epoch = epoch |
|
|
| |
|
|
| def create_generator(prompts, base_seed): |
| generators = [] |
| for prompt in prompts: |
| hash_digest = hashlib.sha256(prompt.encode()).digest() |
| prompt_hash_int = int.from_bytes(hash_digest[:4], 'big') |
| seed = (base_seed + prompt_hash_int) % (2**31) |
| gen = torch.Generator().manual_seed(seed) |
| generators.append(gen) |
| return generators |
|
|
| def unwrap_model(model, accelerator): |
| model = accelerator.unwrap_model(model) |
| return model |
|
|
| def save_ckpt(save_dir, transformer, global_step, accelerator, ema, transformer_trainable_parameters, config): |
| save_path = os.path.join(save_dir) |
| os.makedirs(save_path, exist_ok=True) |
| model_path = os.path.join(save_path, f'model_{global_step}.pth') |
| |
| if accelerator.is_main_process: |
| if config.train.ema: |
| ema.copy_ema_to(transformer_trainable_parameters, store_temp=True) |
| unwrapped = unwrap_model(transformer, accelerator) |
| if isinstance(unwrapped, PeftModel): |
| unwrapped.save_pretrained(os.path.join(save_path, f"lora_{global_step}")) |
| else: |
| torch.save(unwrapped.state_dict(), model_path) |
| logger.info(f'Network weights saved to {model_path}.') |
| if config.train.ema: |
| ema.copy_temp_to(transformer_trainable_parameters) |
|
|
| |
| def eval(pipeline, test_dataloader, features, config, accelerator, global_step, reward_fn, executor, autocast, ema, transformer_trainable_parameters): |
| if config.train.ema: |
| ema.copy_ema_to(transformer_trainable_parameters, store_temp=True) |
| |
| with torch.no_grad(): |
| neg_prompt_embed, neg_pooled_prompt_embed = features.encode_text(['']) |
| neg_prompt_embed = neg_prompt_embed[0].unsqueeze(0).to(accelerator.device) |
| if neg_pooled_prompt_embed is not None: |
| neg_pooled_prompt_embed = neg_pooled_prompt_embed[0].unsqueeze(0).to(accelerator.device) |
| else: |
| neg_pooled_prompt_embed = neg_prompt_embed.mean(dim=1) |
|
|
| all_rewards = defaultdict(list) |
| |
| for test_batch in pbar(test_dataloader, desc="Eval: ", disable=not accelerator.is_main_process): |
| prompts, prompt_metadata = test_batch |
| prompt_embeds, pooled_prompt_embeds = features.encode_text(prompts) |
| prompt_embeds = prompt_embeds.to(accelerator.device) |
| pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) |
| |
| with autocast(): |
| with torch.no_grad(): |
| latents, _, _ = pipeline_with_logprob( |
| pipeline, |
| prompt_embeds=prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| num_inference_steps=config.sample.eval_num_steps, |
| guidance_scale=config.sample.guidance_scale, |
| noise_level=0, |
| ) |
| |
| last_latent = latents[-1] |
| mel = features.decode(last_latent) |
| audios = features.vocode(mel).squeeze(1) |
|
|
| rewards_future = executor.submit(reward_fn, audios, prompts, prompt_metadata, vae_sr=16000 if config.audio_sample_rate==16000 else 44100, only_strict=False) |
| rewards, _ = rewards_future.result() |
|
|
| for key, value in rewards.items(): |
| rewards_gather = accelerator.gather(torch.as_tensor(value, device=accelerator.device)).cpu().numpy() |
| all_rewards[key].append(rewards_gather) |
| |
| if accelerator.is_main_process and config.use_wandb: |
| wandb.log({f"eval_reward_{k}": np.concatenate(v).mean() for k, v in all_rewards.items()}, step=global_step) |
|
|
| if config.train.ema: |
| ema.copy_temp_to(transformer_trainable_parameters) |
|
|
|
|
| |
| @record |
| @hydra.main(version_base='1.3.2', config_path='config', config_name='train_config.yaml') |
| def train(cfg: DictConfig): |
| |
| if cfg.get("debug", False): |
| import debugpy |
| if "RANK" not in os.environ or int(os.environ["RANK"]) == 0: |
| debugpy.listen(6665) |
| print(f'Waiting for debugger attach (rank {os.environ["RANK"]})...') |
| debugpy.wait_for_client() |
| |
| config = cfg |
| if config.sample.num_audio_per_prompt % 2 != 0: |
| raise ValueError("For DPO training, num_audio_per_prompt (k) must be even.") |
|
|
| save_dir = os.path.join(HydraConfig.get().run.dir) |
| accelerator_config = ProjectConfiguration(project_dir=save_dir, automatic_checkpoint_naming=True, total_limit=config.num_checkpoint_limit) |
| accelerator = Accelerator( |
| mixed_precision=config.mixed_precision, |
| project_config=accelerator_config, |
| gradient_accumulation_steps=config.train.gradient_accumulation_steps, |
| ) |
|
|
| if accelerator.is_main_process and config.use_wandb: |
| wandb.init(project="flow_dpo", name=cfg.exp_id, config=dict(config)) |
| |
| logger.info(f"\n{config}") |
| set_seed(config.seed, device_specific=True) |
|
|
| if cfg.audio_sample_rate == 16000: |
| mode, seq_cfg, vae_sr = '16k', CONFIG_16K, 16000 |
| elif config.audio_sample_rate == 44100: |
| mode, seq_cfg, vae_sr = '44k', CONFIG_44K, 44100 |
| else: |
| raise ValueError(f'Invalid sample rate: {cfg.audio_sample_rate}') |
|
|
| inference_dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": inference_dtype = torch.float16 |
| elif accelerator.mixed_precision == "bf16": inference_dtype = torch.bfloat16 |
|
|
| if mode == '16k': |
| features = FeaturesUtils(tod_vae_ckpt=cfg['vae_16k_ckpt'], bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'], encoder_name=cfg['text_encoder_name'], enable_conditions=True, mode=mode, need_vae_encoder=cfg.get('online_feature_extraction', False)) |
| else: |
| features = FeaturesUtils(tod_vae_ckpt=cfg['vae_44k_ckpt'], encoder_name=cfg['text_encoder_name'], enable_conditions=True, mode=mode, need_vae_encoder=cfg.get('online_feature_extraction', False)) |
| |
| features = features.eval().to(accelerator.device, dtype=torch.float32) |
| if cfg.compile: features.compile() |
|
|
| |
| with torch.no_grad(): |
| neg_prompt_embed, neg_pooled_prompt_embed = features.encode_text(['']) |
| neg_prompt_embed = neg_prompt_embed[0].to(accelerator.device) |
| if neg_pooled_prompt_embed is not None: |
| neg_pooled_prompt_embed = neg_pooled_prompt_embed[0].to(accelerator.device) |
| else: |
| neg_pooled_prompt_embed = neg_prompt_embed.mean(dim=0) |
| |
| latent_mean, latent_std = torch.load(cfg.latent_mean), torch.load(cfg.latent_std) |
| transformer = get_model(cfg.model, text_dim=cfg.text_dim, text_c_dim=cfg.text_c_dim, latent_mean=latent_mean, latent_std=latent_std, empty_string_feat=neg_prompt_embed, empty_string_feat_c=neg_pooled_prompt_embed, use_rope=cfg.use_rope) |
| transformer.load_weights(torch.load(cfg.weight, map_location=accelerator.device, weights_only=True)) |
| transformer.to(accelerator.device, dtype=inference_dtype) |
|
|
| if config.use_lora: |
| transformer_lora_config = LoraConfig( |
| r=32, lora_alpha=64, init_lora_weights="gaussian", |
| target_modules=["attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", "attn.to_k", "attn.to_out.0", "attn.to_q", "attn.to_v"] |
| ) |
| if config.train.lora_path: |
| transformer = PeftModel.from_pretrained(transformer, config.train.lora_path) |
| transformer.set_adapter("default") |
| else: |
| transformer = get_peft_model(transformer, transformer_lora_config) |
| ref_transformer = None |
| else: |
| import copy |
| ref_transformer = copy.deepcopy(transformer) |
| for p in ref_transformer.parameters(): p.requires_grad = False |
| ref_transformer.eval() |
|
|
| transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) |
| ema = EMAModuleWrapper(transformer_trainable_parameters, decay=0.9, update_step_interval=8, device=accelerator.device) |
|
|
| if config.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True |
| if config.train.use_8bit_adam: |
| import bitsandbytes as bnb |
| optimizer = bnb.optim.AdamW8bit(transformer_trainable_parameters, lr=config.train.learning_rate, betas=(config.train.adam_beta1, config.train.adam_beta2), weight_decay=config.train.adam_weight_decay, eps=config.train.adam_epsilon) |
| else: |
| optimizer = torch.optim.AdamW(transformer_trainable_parameters, lr=config.train.learning_rate, betas=(config.train.adam_beta1, config.train.adam_beta2), weight_decay=config.train.adam_weight_decay, eps=config.train.adam_epsilon) |
|
|
| if config.prompt_fn == "audioprompt": |
| train_dataset = AudioPromptDataset(config.dataset, 'train') |
| test_dataset = AudioPromptDataset(config.dataset, 'test') |
| collate_fn = AudioPromptDataset.collate_fn |
| elif config.prompt_fn == "audio_temporal_prompt": |
| train_dataset = AudioTemporalDataset(config.dataset, 'train') |
| test_dataset = AudioTemporalDataset(config.dataset, 'test') |
| collate_fn = AudioTemporalDataset.collate_fn |
| else: |
| raise NotImplementedError |
| |
| train_sampler = DistributedKRepeatSampler(train_dataset, batch_size=config.sample.train_batch_size, k=config.sample.num_audio_per_prompt, num_replicas=accelerator.num_processes, rank=accelerator.process_index, seed=42) |
| train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=1, collate_fn=collate_fn) |
| test_dataloader = DataLoader(test_dataset, batch_size=config.sample.test_batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4) |
|
|
| reward_fn = multi_score(accelerator.device, config.reward_fn) |
| eval_reward_fn = multi_score(accelerator.device, config.reward_fn) |
| |
| autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast |
| transformer, optimizer, train_dataloader, test_dataloader = accelerator.prepare(transformer, optimizer, train_dataloader, test_dataloader) |
| executor = futures.ThreadPoolExecutor(max_workers=8) |
|
|
| epoch = 0 |
| global_step = 0 |
| train_iter = iter(train_dataloader) |
| online_dpo = config.get('online_dpo', True) |
| logger.info(f"Starting DPO Training: Batch Size {config.train.batch_size}, K={config.sample.num_audio_per_prompt}, Online: {online_dpo}") |
|
|
| while epoch < config.total_epoch: |
| if epoch % config.eval_freq == 0 and config.do_eval and epoch >= 0: |
| transformer.eval() |
| eval(transformer, test_dataloader, features, config, accelerator, global_step, eval_reward_fn, executor, autocast, ema, transformer_trainable_parameters) |
| if epoch % config.save_freq == 0 and epoch > 0: |
| save_ckpt(save_dir, transformer, global_step, accelerator, ema, transformer_trainable_parameters, config) |
|
|
| transformer.eval() |
| samples = [] |
| |
| for i in pbar(range(config.sample.num_batches_per_epoch), desc=f"Epoch {epoch}: sampling", disable=not accelerator.is_main_process): |
| train_sampler.set_epoch(epoch * config.sample.num_batches_per_epoch + i) |
| try: |
| prompts, prompt_metadata = next(train_iter) |
| except StopIteration: |
| train_iter = iter(train_dataloader) |
| prompts, prompt_metadata = next(train_iter) |
|
|
| prompt_embeds, pooled_prompt_embeds = features.encode_text(prompts) |
| generator = create_generator(prompts, base_seed=epoch*10000+i) if config.sample.same_latent else None |
| |
| with autocast(), torch.no_grad(): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def generate(transformer, prompts): |
| if hasattr(transformer, 'module'): |
| transformer = transformer.module |
| |
| bs = len(prompts) |
| x0 = torch.randn(bs, transformer.latent_seq_len, transformer.latent_dim, device=transformer.device, generator=generator) |
| preprocessed_conditions = transformer.preprocess_conditions(prompt_embeds, pooled_prompt_embeds) |
|
|
| empty_conditions = transformer.get_empty_conditions(bs) |
| cfg_ode_wrapper = lambda t, x: transformer.ode_wrapper(t, |
| x, |
| preprocessed_conditions, |
| empty_conditions, |
| config.sample.guidance_scale) |
| fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=config.sample.num_steps) |
| x1 = fm.to_data(cfg_ode_wrapper, x0) |
| latents = [x1] |
| return latents |
| |
| if online_dpo: |
| latents = generate(transformer, prompts) |
| else: |
| latents = generate(ref_transformer, prompts) |
|
|
| last_latent = latents[-1] |
| mel = features.decode(last_latent) |
| audios = features.vocode(mel).squeeze(1) |
|
|
| rewards_future = executor.submit(reward_fn, audios, prompts, prompt_metadata, vae_sr=vae_sr, only_strict=True) |
| time.sleep(0) |
| |
| samples.append({ |
| "prompt_embeds": prompt_embeds, |
| "pooled_prompt_embeds": pooled_prompt_embeds, |
| "latents": last_latent.detach().cpu(), |
| "rewards_future": rewards_future, |
| "audios": audios.detach().cpu(), |
| "prompts": prompts |
| }) |
|
|
| paired_samples = [] |
| for batch_idx, sample in enumerate(pbar(samples, desc="Pairing", disable=not accelerator.is_main_process)): |
| rewards_dict, _ = sample["rewards_future"].result() |
| |
| |
| if epoch % 100 == 0 and accelerator.is_main_process: |
| rollout_root = os.path.join(save_dir, "dpo_rollout") |
| os.makedirs(rollout_root, exist_ok=True) |
| |
| epoch_dir = os.path.join(rollout_root, f"epoch_{epoch:04d}") |
| audio_dir = os.path.join(epoch_dir, "audio") |
| os.makedirs(audio_dir, exist_ok=True) |
|
|
| jsonl_path = os.path.join(epoch_dir, "metadata.jsonl") |
| |
| current_audios = sample["audios"] |
| current_prompts = sample["prompts"] |
|
|
| with open(jsonl_path, "w", encoding="utf-8") as f: |
| for b in range(len(current_prompts)): |
| audio_name = f"epoch{epoch:04d}_batch{batch_idx:03d}_sample{b:02d}.wav" |
| audio_path = os.path.join(audio_dir, audio_name) |
|
|
| sf.write( |
| audio_path, |
| current_audios[b].numpy(), |
| samplerate=vae_sr, |
| ) |
| |
| record = { |
| "epoch": epoch, |
| "batch_idx": batch_idx, |
| "sample_idx": b, |
| "audio_path": audio_path, |
| "prompt": current_prompts[b] |
| } |
| |
| for r_key, r_val in rewards_dict.items(): |
| if hasattr(r_val, '__getitem__'): |
| record[r_key] = float(r_val[b]) |
| else: |
| record[r_key] = float(r_val) |
| |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") |
| |
| if batch_idx == 0: |
| logger.info(f"[DEBUG] Saving rollout samples to {epoch_dir}") |
|
|
| scores = torch.tensor(rewards_dict['avg'], device=accelerator.device) |
| |
| if accelerator.is_main_process and batch_idx == 0: |
| logger.info(f"Step {global_step} | reward mean: {scores.mean().item()}") |
| if config.use_wandb: |
| wandb.log({"train_reward_mean": scores.mean().item()}, step=global_step) |
| |
| k = config.sample.num_audio_per_prompt |
| bsz_total = scores.shape[0] |
| num_prompts = bsz_total // k |
| |
| scores = scores.reshape(num_prompts, k) |
| latents = sample["latents"].to(accelerator.device).reshape(num_prompts, k, *sample["latents"].shape[1:]) |
| p_embeds = sample["prompt_embeds"].reshape(num_prompts, k, *sample["prompt_embeds"].shape[1:]) |
| pp_embeds = sample["pooled_prompt_embeds"].reshape(num_prompts, k, *sample["pooled_prompt_embeds"].shape[1:]) |
|
|
| swap_mask = scores[:, 1] > scores[:, 0] |
| |
| latents_w = torch.where(swap_mask.view(-1, 1, 1), latents[:, 1], latents[:, 0]) |
| latents_l = torch.where(swap_mask.view(-1, 1, 1), latents[:, 0], latents[:, 1]) |
| |
| pe_w = p_embeds[:, 0] |
| pe_l = p_embeds[:, 0] |
| ppe_w = pp_embeds[:, 0] |
| ppe_l = pp_embeds[:, 0] |
|
|
| batch_latents = torch.cat([latents_w, latents_l], dim=0) |
| batch_pe = torch.cat([pe_w, pe_l], dim=0) |
| batch_ppe = torch.cat([ppe_w, ppe_l], dim=0) |
|
|
| paired_samples.append({ |
| "latents": batch_latents, |
| "prompt_embeds": batch_pe, |
| "pooled_prompt_embeds": batch_ppe |
| }) |
|
|
| transformer.train() |
| for inner_epoch in range(config.train.num_inner_epochs): |
| random.shuffle(paired_samples) |
| for step, batch in enumerate(pbar(paired_samples, desc=f"Epoch {epoch}.{inner_epoch}: Train", disable=not accelerator.is_main_process)): |
| |
| model_input = batch["latents"].to(accelerator.device) |
| embeds = batch["prompt_embeds"].to(accelerator.device) |
| pooled_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) |
| |
| bsz = model_input.shape[0] |
| N = bsz // 2 |
| |
| noise = torch.randn_like(model_input) |
| |
| noise_w = noise[:N] |
| noise = torch.cat([noise_w, noise_w], dim=0) |
| |
| t = torch.rand((N,), device=accelerator.device) |
| t = torch.cat([t, t], dim=0) |
| t_expand = t.view(-1, 1, 1) |
| |
| |
| |
| noisy_latents = t_expand * noise + (1 - t_expand) * model_input |
| target_v = noise - model_input |
|
|
| |
| |
| if config.train.cfg: |
| cfg_drop_rate = config.train.get("cfg_drop_rate", 0.1) |
| |
| drop_mask = torch.rand(bsz, device=accelerator.device) < cfg_drop_rate |
| |
| if drop_mask.any(): |
| |
| neg_pe_batch = neg_prompt_embed.unsqueeze(0).expand(bsz, -1, -1) |
| neg_ppe_batch = neg_pooled_prompt_embed.unsqueeze(0).expand(bsz, -1) |
| |
| embeds[drop_mask] = neg_pe_batch[drop_mask] |
| pooled_embeds[drop_mask] = neg_ppe_batch[drop_mask] |
|
|
| with accelerator.accumulate(transformer): |
| with autocast(): |
| |
| model_pred = transformer( |
| latent=noisy_latents, |
| text_f=embeds, |
| text_f_c=pooled_embeds, |
| t=t |
| ) |
| |
| |
| with torch.no_grad(): |
| with autocast(): |
| if config.use_lora: |
| with transformer.module.disable_adapter(): |
| ref_pred = transformer( |
| latent=noisy_latents, |
| text_f=embeds, |
| text_f_c=pooled_embeds, |
| t=t |
| ) |
| else: |
| ref_pred = ref_transformer( |
| latent=noisy_latents, |
| text_f=embeds, |
| text_f_c=pooled_embeds, |
| t=t |
| ) |
| |
| target_v = target_v.float() |
| model_pred = model_pred.float() |
| ref_pred = ref_pred.float() |
| |
| raw_model_loss = (model_pred - target_v).pow(2).mean(dim=[1, 2]) |
| raw_ref_loss = (ref_pred - target_v).pow(2).mean(dim=[1, 2]) |
| |
| model_losses_w, model_losses_l = raw_model_loss[:N], raw_model_loss[N:] |
| ref_losses_w, ref_losses_l = raw_ref_loss[:N], raw_ref_loss[N:] |
| |
| w_diff = model_losses_w - ref_losses_w |
| l_diff = model_losses_l - ref_losses_l |
| |
| inside_term = -0.5 * config.train.beta * (w_diff - l_diff) |
| loss = -F.logsigmoid(inside_term).mean() |
| |
| reward_accuracies = (inside_term > 0).float().mean() |
| |
| accelerator.backward(loss) |
| |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(transformer.parameters(), config.train.max_grad_norm) |
| |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| if accelerator.sync_gradients: |
| if accelerator.is_main_process and config.use_wandb: |
| wandb.log({ |
| "loss": loss.item(), |
| "accuracy": reward_accuracies.item(), |
| "margin": inside_term.mean().item(), |
| "lr": optimizer.param_groups[0]["lr"], |
| "epoch": epoch |
| }, step=global_step) |
| |
| if config.train.ema: |
| ema.step(transformer_trainable_parameters, global_step) |
| |
| global_step += 1 |
|
|
| epoch += 1 |
|
|
| if __name__ == "__main__": |
| train() |