ShanglinHG's picture
Upload folder using huggingface_hub
ad79891 verified
Raw
History Blame Contribute Delete
11.7 kB
from .base_pipeline import BasePipeline
import torch
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
if "first_frame_latents" in inputs:
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
if "first_frame_latents" in inputs:
noise_pred = noise_pred[:, :, 1:]
training_target = training_target[:, :, 1:]
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
return loss
def DPOLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
noise = torch.randn_like(inputs["input_latents_pos"])
inputs_pos = {**inputs, "input_latents": inputs["input_latents_pos"]}
inputs_neg = {**inputs, "input_latents": inputs["input_latents_neg"]}
inputs_pos["latents"] = pipe.scheduler.add_noise(inputs_pos["input_latents"], noise, timestep)
inputs_neg["latents"] = pipe.scheduler.add_noise(inputs_neg["input_latents"], noise, timestep)
training_target_pos = pipe.scheduler.training_target(inputs_pos["input_latents"], noise, timestep)
training_target_neg = pipe.scheduler.training_target(inputs_neg["input_latents"], noise, timestep)
if "first_frame_latents" in inputs_pos:
inputs_pos["latents"][:, :, 0:1] = inputs_pos["first_frame_latents"]
inputs_neg["latents"][:, :, 0:1] = inputs_neg["first_frame_latents"]
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred_pos = pipe.model_fn(**models, **inputs_pos, timestep=timestep)
noise_pred_neg = pipe.model_fn(**models, **inputs_neg, timestep=timestep)
# Get reference model predictions
pipe_ref = inputs.get("pipe_ref", None)
if pipe_ref is not None:
ref_models = {name: getattr(pipe_ref, name) for name in pipe_ref.in_iteration_models}
with torch.no_grad():
ref_noise_pred_pos = pipe_ref.model_fn(**ref_models, **inputs_pos, timestep=timestep)
ref_noise_pred_neg = pipe_ref.model_fn(**ref_models, **inputs_neg, timestep=timestep)
else:
# Fallback to standard prediction with detached gradients to prevent total divergence
ref_noise_pred_pos = noise_pred_pos.detach()
ref_noise_pred_neg = noise_pred_neg.detach()
if "first_frame_latents" in inputs_pos:
noise_pred_pos = noise_pred_pos[:, :, 1:]
noise_pred_neg = noise_pred_neg[:, :, 1:]
ref_noise_pred_pos = ref_noise_pred_pos[:, :, 1:]
ref_noise_pred_neg = ref_noise_pred_neg[:, :, 1:]
training_target_pos = training_target_pos[:, :, 1:]
training_target_neg = training_target_neg[:, :, 1:]
loss_pos = torch.nn.functional.mse_loss(noise_pred_pos.float(), training_target_pos.float(), reduction='none')
loss_neg = torch.nn.functional.mse_loss(noise_pred_neg.float(), training_target_neg.float(), reduction='none')
ref_loss_pos = torch.nn.functional.mse_loss(ref_noise_pred_pos.float(), training_target_pos.float(), reduction='none')
ref_loss_neg = torch.nn.functional.mse_loss(ref_noise_pred_neg.float(), training_target_neg.float(), reduction='none')
# Mean over all dimensions except batch (dim 0)
dims = list(range(1, loss_pos.ndim))
loss_pos = loss_pos.mean(dim=dims)
loss_neg = loss_neg.mean(dim=dims)
ref_loss_pos = ref_loss_pos.mean(dim=dims)
ref_loss_neg = ref_loss_neg.mean(dim=dims)
# Standard DPO loss for Diffusion:
# L_DPO = -logsigmoid( beta * [ (L_ref_neg - L_model_neg) - (L_ref_pos - L_model_pos) ] )
# This encourages model to be better than ref on positive (L_model_pos < L_ref_pos)
# and worse than ref on negative (L_model_neg > L_ref_neg)
diff_pos = ref_loss_pos - loss_pos
diff_neg = ref_loss_neg - loss_neg
# 缩小 beta,防止模型受到过大惩罚而走捷径(偏色)
beta = inputs.get("dpo_beta", 10.0)
diff_total_raw = diff_pos - diff_neg
# 加入梯度裁剪(Clamp),防止极端 Timestep 产生的离谱差值引发梯度爆炸
diff_total = torch.clamp(diff_total_raw, min=-0.05, max=0.05)
loss_dpo = -torch.nn.functional.logsigmoid(beta * diff_total).mean()
# 核心修复:加入正则化项 (SFT Loss),防止模型偏离过远导致偏色或结构崩坏
# 权重设为 1.0,强制模型在拉开差距的同时,必须保证正样本的生成质量
loss_sft = loss_pos.mean()
loss = loss_dpo + loss_sft
if getattr(pipe, "device", None) is None or str(pipe.device) == "cuda:0" or str(pipe.device) == "cuda":
print(f"\n[DPO] t: {timestep.item():.0f} | "
f"L_pos: {loss_pos.mean().item():.5f} | L_neg: {loss_neg.mean().item():.5f} | "
f"d_tot: {diff_total_raw.mean().item():.5f} | L_dpo: {loss_dpo.item():.4f} | L_sft: {loss_sft.item():.4f} | loss: {loss.item():.4f}")
return loss
class TrajectoryImitationLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.initialized = False
def initialize(self, device):
import lpips # TODO: remove it
self.loss_fn = lpips.LPIPS(net='alex').to(device)
self.initialized = True
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
trajectory = [inputs_shared["latents"].clone()]
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
trajectory.append(inputs_shared["latents"].clone())
return pipe.scheduler.timesteps, trajectory
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
loss = 0
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
sigma = pipe.scheduler.sigmas[progress_id]
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
if progress_id + 1 >= len(pipe.scheduler.timesteps):
latents_ = trajectory_teacher[-1]
else:
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
latents_ = trajectory_teacher[progress_id_teacher]
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
return loss
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
inputs_shared["latents"] = trajectory_teacher[0]
pipe.scheduler.set_timesteps(num_inference_steps)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
image_pred = pipe.vae_decoder(inputs_shared["latents"])
image_real = pipe.vae_decoder(trajectory_teacher[-1])
loss = self.loss_fn(image_pred.float(), image_real.float())
return loss
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
if not self.initialized:
self.initialize(pipe.device)
with torch.no_grad():
pipe.scheduler.set_timesteps(8)
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
loss = loss_1 + loss_2
return loss