| from .base_pipeline import BasePipeline |
| import torch |
|
|
|
|
| def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): |
| max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) |
| min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) |
|
|
| timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) |
| timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) |
| |
| noise = torch.randn_like(inputs["input_latents"]) |
| inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) |
| training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) |
| |
| if "first_frame_latents" in inputs: |
| inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"] |
| |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} |
| noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) |
| |
| if "first_frame_latents" in inputs: |
| noise_pred = noise_pred[:, :, 1:] |
| training_target = training_target[:, :, 1:] |
| |
| loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) |
| loss = loss * pipe.scheduler.training_weight(timestep) |
| return loss |
|
|
|
|
| def DirectDistillLoss(pipe: BasePipeline, **inputs): |
| pipe.scheduler.set_timesteps(inputs["num_inference_steps"]) |
| pipe.scheduler.training = True |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} |
| for progress_id, timestep in enumerate(pipe.scheduler.timesteps): |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) |
| noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) |
| inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) |
| loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) |
| return loss |
|
|
|
|
| def DPOLoss(pipe: BasePipeline, **inputs): |
| max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) |
| min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) |
|
|
| timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) |
| timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) |
| |
| noise = torch.randn_like(inputs["input_latents_pos"]) |
| inputs_pos = {**inputs, "input_latents": inputs["input_latents_pos"]} |
| inputs_neg = {**inputs, "input_latents": inputs["input_latents_neg"]} |
| |
| inputs_pos["latents"] = pipe.scheduler.add_noise(inputs_pos["input_latents"], noise, timestep) |
| inputs_neg["latents"] = pipe.scheduler.add_noise(inputs_neg["input_latents"], noise, timestep) |
| |
| training_target_pos = pipe.scheduler.training_target(inputs_pos["input_latents"], noise, timestep) |
| training_target_neg = pipe.scheduler.training_target(inputs_neg["input_latents"], noise, timestep) |
| |
| if "first_frame_latents" in inputs_pos: |
| inputs_pos["latents"][:, :, 0:1] = inputs_pos["first_frame_latents"] |
| inputs_neg["latents"][:, :, 0:1] = inputs_neg["first_frame_latents"] |
| |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} |
| noise_pred_pos = pipe.model_fn(**models, **inputs_pos, timestep=timestep) |
| noise_pred_neg = pipe.model_fn(**models, **inputs_neg, timestep=timestep) |
| |
| |
| pipe_ref = inputs.get("pipe_ref", None) |
| if pipe_ref is not None: |
| ref_models = {name: getattr(pipe_ref, name) for name in pipe_ref.in_iteration_models} |
| with torch.no_grad(): |
| ref_noise_pred_pos = pipe_ref.model_fn(**ref_models, **inputs_pos, timestep=timestep) |
| ref_noise_pred_neg = pipe_ref.model_fn(**ref_models, **inputs_neg, timestep=timestep) |
| else: |
| |
| ref_noise_pred_pos = noise_pred_pos.detach() |
| ref_noise_pred_neg = noise_pred_neg.detach() |
| |
| if "first_frame_latents" in inputs_pos: |
| noise_pred_pos = noise_pred_pos[:, :, 1:] |
| noise_pred_neg = noise_pred_neg[:, :, 1:] |
| ref_noise_pred_pos = ref_noise_pred_pos[:, :, 1:] |
| ref_noise_pred_neg = ref_noise_pred_neg[:, :, 1:] |
| training_target_pos = training_target_pos[:, :, 1:] |
| training_target_neg = training_target_neg[:, :, 1:] |
| |
| loss_pos = torch.nn.functional.mse_loss(noise_pred_pos.float(), training_target_pos.float(), reduction='none') |
| loss_neg = torch.nn.functional.mse_loss(noise_pred_neg.float(), training_target_neg.float(), reduction='none') |
| |
| ref_loss_pos = torch.nn.functional.mse_loss(ref_noise_pred_pos.float(), training_target_pos.float(), reduction='none') |
| ref_loss_neg = torch.nn.functional.mse_loss(ref_noise_pred_neg.float(), training_target_neg.float(), reduction='none') |
| |
| |
| dims = list(range(1, loss_pos.ndim)) |
| loss_pos = loss_pos.mean(dim=dims) |
| loss_neg = loss_neg.mean(dim=dims) |
| ref_loss_pos = ref_loss_pos.mean(dim=dims) |
| ref_loss_neg = ref_loss_neg.mean(dim=dims) |
| |
| |
| |
| |
| |
| diff_pos = ref_loss_pos - loss_pos |
| diff_neg = ref_loss_neg - loss_neg |
| |
| |
| beta = inputs.get("dpo_beta", 10.0) |
|
|
| diff_total_raw = diff_pos - diff_neg |
| |
| diff_total = torch.clamp(diff_total_raw, min=-0.05, max=0.05) |
|
|
| loss_dpo = -torch.nn.functional.logsigmoid(beta * diff_total).mean() |
|
|
| |
| |
| loss_sft = loss_pos.mean() |
| loss = loss_dpo + loss_sft |
|
|
| if getattr(pipe, "device", None) is None or str(pipe.device) == "cuda:0" or str(pipe.device) == "cuda": |
| print(f"\n[DPO] t: {timestep.item():.0f} | " |
| f"L_pos: {loss_pos.mean().item():.5f} | L_neg: {loss_neg.mean().item():.5f} | " |
| f"d_tot: {diff_total_raw.mean().item():.5f} | L_dpo: {loss_dpo.item():.4f} | L_sft: {loss_sft.item():.4f} | loss: {loss.item():.4f}") |
|
|
| return loss |
|
|
|
|
| class TrajectoryImitationLoss(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.initialized = False |
| |
| def initialize(self, device): |
| import lpips |
| self.loss_fn = lpips.LPIPS(net='alex').to(device) |
| self.initialized = True |
|
|
| def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): |
| trajectory = [inputs_shared["latents"].clone()] |
|
|
| pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student) |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} |
| for progress_id, timestep in enumerate(pipe.scheduler.timesteps): |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) |
| noise_pred = pipe.cfg_guided_model_fn( |
| pipe.model_fn, cfg_scale, |
| inputs_shared, inputs_posi, inputs_nega, |
| **models, timestep=timestep, progress_id=progress_id |
| ) |
| inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) |
|
|
| trajectory.append(inputs_shared["latents"].clone()) |
| return pipe.scheduler.timesteps, trajectory |
| |
| def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): |
| loss = 0 |
| pipe.scheduler.set_timesteps(num_inference_steps, training=True) |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} |
| for progress_id, timestep in enumerate(pipe.scheduler.timesteps): |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) |
|
|
| progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs()) |
| inputs_shared["latents"] = trajectory_teacher[progress_id_teacher] |
|
|
| noise_pred = pipe.cfg_guided_model_fn( |
| pipe.model_fn, cfg_scale, |
| inputs_shared, inputs_posi, inputs_nega, |
| **models, timestep=timestep, progress_id=progress_id |
| ) |
|
|
| sigma = pipe.scheduler.sigmas[progress_id] |
| sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1] |
| if progress_id + 1 >= len(pipe.scheduler.timesteps): |
| latents_ = trajectory_teacher[-1] |
| else: |
| progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) |
| latents_ = trajectory_teacher[progress_id_teacher] |
| |
| target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma) |
| loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) |
| return loss |
| |
| def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): |
| inputs_shared["latents"] = trajectory_teacher[0] |
| pipe.scheduler.set_timesteps(num_inference_steps) |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} |
| for progress_id, timestep in enumerate(pipe.scheduler.timesteps): |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) |
| noise_pred = pipe.cfg_guided_model_fn( |
| pipe.model_fn, cfg_scale, |
| inputs_shared, inputs_posi, inputs_nega, |
| **models, timestep=timestep, progress_id=progress_id |
| ) |
| inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) |
|
|
| image_pred = pipe.vae_decoder(inputs_shared["latents"]) |
| image_real = pipe.vae_decoder(trajectory_teacher[-1]) |
| loss = self.loss_fn(image_pred.float(), image_real.float()) |
| return loss |
|
|
| def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega): |
| if not self.initialized: |
| self.initialize(pipe.device) |
| with torch.no_grad(): |
| pipe.scheduler.set_timesteps(8) |
| timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2) |
| timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device) |
| loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) |
| loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) |
| loss = loss_1 + loss_2 |
| return loss |
|
|