|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from comfy.ldm.modules.attention import optimized_attention |
|
|
import comfy.model_management |
|
|
|
|
|
class GELU(nn.Module): |
|
|
|
|
|
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype): |
|
|
super().__init__() |
|
|
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype) |
|
|
|
|
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
if gate.device.type == "mps": |
|
|
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype) |
|
|
|
|
|
return F.gelu(gate) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
|
|
|
hidden_states = self.proj(hidden_states) |
|
|
hidden_states = self.gelu(hidden_states) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
|
|
|
def __init__(self, dim: int, dim_out = None, mult: int = 4, |
|
|
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None): |
|
|
|
|
|
super().__init__() |
|
|
if inner_dim is None: |
|
|
inner_dim = int(dim * mult) |
|
|
|
|
|
dim_out = dim_out if dim_out is not None else dim |
|
|
|
|
|
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype) |
|
|
|
|
|
self.net = nn.ModuleList([]) |
|
|
self.net.append(act_fn) |
|
|
|
|
|
self.net.append(nn.Dropout(dropout)) |
|
|
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype)) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
for module in self.net: |
|
|
hidden_states = module(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
class AddAuxLoss(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x, loss): |
|
|
|
|
|
ctx.requires_aux_loss = loss.requires_grad |
|
|
ctx.dtype = loss.dtype |
|
|
|
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
|
|
|
grad_loss = None |
|
|
|
|
|
|
|
|
if ctx.requires_aux_loss: |
|
|
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device) |
|
|
|
|
|
return grad_output, grad_loss |
|
|
|
|
|
class MoEGate(nn.Module): |
|
|
|
|
|
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None): |
|
|
|
|
|
super().__init__() |
|
|
self.top_k = num_experts_per_tok |
|
|
self.n_routed_experts = num_experts |
|
|
|
|
|
self.alpha = aux_loss_alpha |
|
|
|
|
|
self.gating_dim = embed_dim |
|
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype)) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
hidden_states = hidden_states.view(-1, hidden_states.size(-1)) |
|
|
|
|
|
|
|
|
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None) |
|
|
scores = logits.softmax(dim = -1) |
|
|
|
|
|
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False) |
|
|
|
|
|
if self.training and self.alpha > 0.0: |
|
|
scores_for_aux = scores |
|
|
|
|
|
|
|
|
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float() |
|
|
ce = counts / topk_idx.numel() |
|
|
|
|
|
|
|
|
Pi = scores_for_aux.mean(0) |
|
|
|
|
|
|
|
|
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha |
|
|
else: |
|
|
aux_loss = None |
|
|
|
|
|
return topk_idx, topk_weight, aux_loss |
|
|
|
|
|
class MoEBlock(nn.Module): |
|
|
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0, |
|
|
ff_inner_dim: int = None, operations = None, device = None, dtype = None): |
|
|
super().__init__() |
|
|
|
|
|
self.moe_top_k = moe_top_k |
|
|
self.num_experts = num_experts |
|
|
|
|
|
self.experts = nn.ModuleList([ |
|
|
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) |
|
|
for _ in range(num_experts) |
|
|
]) |
|
|
|
|
|
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype) |
|
|
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) |
|
|
|
|
|
def forward(self, hidden_states) -> torch.Tensor: |
|
|
|
|
|
identity = hidden_states |
|
|
orig_shape = hidden_states.shape |
|
|
topk_idx, topk_weight, aux_loss = self.gate(hidden_states) |
|
|
|
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
|
|
flat_topk_idx = topk_idx.view(-1) |
|
|
|
|
|
if self.training: |
|
|
|
|
|
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0) |
|
|
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype) |
|
|
|
|
|
for i, expert in enumerate(self.experts): |
|
|
tmp = expert(hidden_states[flat_topk_idx == i]) |
|
|
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype) |
|
|
|
|
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1) |
|
|
y = y.view(*orig_shape) |
|
|
|
|
|
y = AddAuxLoss.apply(y, aux_loss) |
|
|
else: |
|
|
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape) |
|
|
|
|
|
y = y + self.shared_experts(identity) |
|
|
|
|
|
return y |
|
|
|
|
|
@torch.no_grad() |
|
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights): |
|
|
|
|
|
expert_cache = torch.zeros_like(x) |
|
|
idxs = flat_expert_indices.argsort() |
|
|
|
|
|
|
|
|
tokens_per_expert = flat_expert_indices.bincount().cumsum(0) |
|
|
token_idxs = idxs // self.moe_top_k |
|
|
|
|
|
for i, end_idx in enumerate(tokens_per_expert): |
|
|
|
|
|
start_idx = 0 if i == 0 else tokens_per_expert[i-1] |
|
|
|
|
|
if start_idx == end_idx: |
|
|
continue |
|
|
|
|
|
expert = self.experts[i] |
|
|
exp_token_idx = token_idxs[start_idx:end_idx] |
|
|
|
|
|
expert_tokens = x[exp_token_idx] |
|
|
expert_out = expert(expert_tokens) |
|
|
|
|
|
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) |
|
|
|
|
|
|
|
|
|
|
|
expert_cache.index_add_(0, exp_token_idx, expert_out) |
|
|
|
|
|
return expert_cache |
|
|
|
|
|
class Timesteps(nn.Module): |
|
|
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0, |
|
|
scale: float = 1.0, max_period: int = 10000): |
|
|
super().__init__() |
|
|
|
|
|
self.num_channels = num_channels |
|
|
half_dim = num_channels // 2 |
|
|
|
|
|
|
|
|
exponent = -math.log(max_period) * torch.arange( |
|
|
half_dim, dtype=torch.float32 |
|
|
) / (half_dim - downscale_freq_shift) |
|
|
|
|
|
inv_freq = torch.exp(exponent) |
|
|
|
|
|
|
|
|
if num_channels % 2 == 1: |
|
|
|
|
|
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)]) |
|
|
|
|
|
|
|
|
self.register_buffer("inv_freq", inv_freq, persistent = False) |
|
|
self.scale = scale |
|
|
|
|
|
def forward(self, timesteps: torch.Tensor): |
|
|
|
|
|
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
sin_emb = x.sin() |
|
|
cos_emb = x.cos() |
|
|
|
|
|
emb = torch.cat([sin_emb, cos_emb], dim = 1) |
|
|
|
|
|
|
|
|
if self.scale != 1.0: |
|
|
emb = emb * self.scale |
|
|
|
|
|
|
|
|
if emb.shape[1] > self.num_channels: |
|
|
emb = emb[:, :self.num_channels] |
|
|
|
|
|
return emb |
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
|
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None): |
|
|
super().__init__() |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype), |
|
|
nn.GELU(), |
|
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype), |
|
|
) |
|
|
self.frequency_embedding_size = frequency_embedding_size |
|
|
|
|
|
if cond_proj_dim is not None: |
|
|
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype) |
|
|
|
|
|
self.time_embed = Timesteps(hidden_size) |
|
|
|
|
|
def forward(self, timesteps, condition): |
|
|
|
|
|
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype) |
|
|
|
|
|
if condition is not None: |
|
|
cond_embed = self.cond_proj(condition) |
|
|
timestep_embed = timestep_embed + cond_embed |
|
|
|
|
|
time_conditioned = self.mlp(timestep_embed) |
|
|
|
|
|
|
|
|
return time_conditioned.unsqueeze(1) |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, *, width: int, operations = None, device = None, dtype = None): |
|
|
super().__init__() |
|
|
self.width = width |
|
|
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype) |
|
|
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype) |
|
|
self.gelu = nn.GELU() |
|
|
|
|
|
def forward(self, x): |
|
|
return self.fc2(self.gelu(self.fc1(x))) |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
qdim, |
|
|
kdim, |
|
|
num_heads, |
|
|
qkv_bias=True, |
|
|
qk_norm=False, |
|
|
norm_layer=nn.LayerNorm, |
|
|
use_fp16: bool = False, |
|
|
operations = None, |
|
|
dtype = None, |
|
|
device = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.qdim = qdim |
|
|
self.kdim = kdim |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.head_dim = self.qdim // num_heads |
|
|
|
|
|
self.scale = self.head_dim ** -0.5 |
|
|
|
|
|
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype) |
|
|
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) |
|
|
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) |
|
|
|
|
|
if use_fp16: |
|
|
eps = 1.0 / 65504 |
|
|
else: |
|
|
eps = 1e-6 |
|
|
|
|
|
if norm_layer == nn.LayerNorm: |
|
|
norm_layer = operations.LayerNorm |
|
|
else: |
|
|
norm_layer = operations.RMSNorm |
|
|
|
|
|
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() |
|
|
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() |
|
|
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype) |
|
|
|
|
|
def forward(self, x, y): |
|
|
|
|
|
b, s1, _ = x.shape |
|
|
_, s2, _ = y.shape |
|
|
|
|
|
y = y.to(next(self.to_k.parameters()).dtype) |
|
|
|
|
|
q = self.to_q(x) |
|
|
k = self.to_k(y) |
|
|
v = self.to_v(y) |
|
|
|
|
|
kv = torch.cat((k, v), dim=-1) |
|
|
split_size = kv.shape[-1] // self.num_heads // 2 |
|
|
|
|
|
kv = kv.view(1, -1, self.num_heads, split_size * 2) |
|
|
k, v = torch.split(kv, split_size, dim=-1) |
|
|
|
|
|
q = q.view(b, s1, self.num_heads, self.head_dim) |
|
|
k = k.view(b, s2, self.num_heads, self.head_dim) |
|
|
v = v.reshape(b, s2, self.num_heads * self.head_dim) |
|
|
|
|
|
q = self.q_norm(q) |
|
|
k = self.k_norm(k) |
|
|
|
|
|
x = optimized_attention( |
|
|
q.reshape(b, s1, self.num_heads * self.head_dim), |
|
|
k.reshape(b, s2, self.num_heads * self.head_dim), |
|
|
v, |
|
|
heads=self.num_heads, |
|
|
) |
|
|
|
|
|
out = self.out_proj(x) |
|
|
|
|
|
return out |
|
|
|
|
|
class Attention(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
num_heads, |
|
|
qkv_bias = True, |
|
|
qk_norm = False, |
|
|
norm_layer = nn.LayerNorm, |
|
|
use_fp16: bool = False, |
|
|
operations = None, |
|
|
device = None, |
|
|
dtype = None |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = self.dim // num_heads |
|
|
self.scale = self.head_dim ** -0.5 |
|
|
|
|
|
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) |
|
|
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) |
|
|
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) |
|
|
|
|
|
if use_fp16: |
|
|
eps = 1.0 / 65504 |
|
|
else: |
|
|
eps = 1e-6 |
|
|
|
|
|
if norm_layer == nn.LayerNorm: |
|
|
norm_layer = operations.LayerNorm |
|
|
else: |
|
|
norm_layer = operations.RMSNorm |
|
|
|
|
|
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() |
|
|
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() |
|
|
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype) |
|
|
|
|
|
def forward(self, x): |
|
|
B, N, _ = x.shape |
|
|
|
|
|
query = self.to_q(x) |
|
|
key = self.to_k(x) |
|
|
value = self.to_v(x) |
|
|
|
|
|
qkv_combined = torch.cat((query, key, value), dim=-1) |
|
|
split_size = qkv_combined.shape[-1] // self.num_heads // 3 |
|
|
|
|
|
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) |
|
|
query, key, value = torch.split(qkv, split_size, dim=-1) |
|
|
|
|
|
query = query.reshape(B, N, self.num_heads, self.head_dim) |
|
|
key = key.reshape(B, N, self.num_heads, self.head_dim) |
|
|
value = value.reshape(B, N, self.num_heads * self.head_dim) |
|
|
|
|
|
query = self.q_norm(query) |
|
|
key = self.k_norm(key) |
|
|
|
|
|
x = optimized_attention( |
|
|
query.reshape(B, N, self.num_heads * self.head_dim), |
|
|
key.reshape(B, N, self.num_heads * self.head_dim), |
|
|
value, |
|
|
heads=self.num_heads, |
|
|
) |
|
|
|
|
|
x = self.out_proj(x) |
|
|
return x |
|
|
|
|
|
class HunYuanDiTBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
hidden_size, |
|
|
c_emb_size, |
|
|
num_heads, |
|
|
text_states_dim=1024, |
|
|
qk_norm=False, |
|
|
norm_layer=nn.LayerNorm, |
|
|
qk_norm_layer=True, |
|
|
qkv_bias=True, |
|
|
skip_connection=True, |
|
|
timested_modulate=False, |
|
|
use_moe: bool = False, |
|
|
num_experts: int = 8, |
|
|
moe_top_k: int = 2, |
|
|
use_fp16: bool = False, |
|
|
operations = None, |
|
|
device = None, dtype = None |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if use_fp16: |
|
|
eps = 1.0 / 65504 |
|
|
else: |
|
|
eps = 1e-6 |
|
|
|
|
|
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) |
|
|
|
|
|
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, |
|
|
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations) |
|
|
|
|
|
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) |
|
|
|
|
|
self.timested_modulate = timested_modulate |
|
|
if self.timested_modulate: |
|
|
self.default_modulation = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype) |
|
|
) |
|
|
|
|
|
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias, |
|
|
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16, |
|
|
device = device, dtype = dtype, operations = operations) |
|
|
|
|
|
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) |
|
|
|
|
|
if skip_connection: |
|
|
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) |
|
|
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype) |
|
|
else: |
|
|
self.skip_linear = None |
|
|
|
|
|
self.use_moe = use_moe |
|
|
|
|
|
if self.use_moe: |
|
|
self.moe = MoEBlock( |
|
|
hidden_size, |
|
|
num_experts = num_experts, |
|
|
moe_top_k = moe_top_k, |
|
|
dropout = 0.0, |
|
|
ff_inner_dim = int(hidden_size * 4.0), |
|
|
device = device, dtype = dtype, |
|
|
operations = operations |
|
|
) |
|
|
else: |
|
|
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype) |
|
|
|
|
|
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None): |
|
|
|
|
|
if self.skip_linear is not None: |
|
|
combined = torch.cat([skip_tensor, hidden_states], dim=-1) |
|
|
hidden_states = self.skip_linear(combined) |
|
|
hidden_states = self.skip_norm(hidden_states) |
|
|
|
|
|
|
|
|
if self.timested_modulate: |
|
|
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1) |
|
|
hidden_states = hidden_states + modulation_shift |
|
|
|
|
|
self_attn_out = self.attn1(self.norm1(hidden_states)) |
|
|
hidden_states = hidden_states + self_attn_out |
|
|
|
|
|
|
|
|
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states) |
|
|
|
|
|
|
|
|
mlp_input = self.norm3(hidden_states) |
|
|
|
|
|
if self.use_moe: |
|
|
hidden_states = hidden_states + self.moe(mlp_input) |
|
|
else: |
|
|
hidden_states = hidden_states + self.mlp(mlp_input) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class FinalLayer(nn.Module): |
|
|
|
|
|
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None): |
|
|
super().__init__() |
|
|
|
|
|
if use_fp16: |
|
|
eps = 1.0 / 65504 |
|
|
else: |
|
|
eps = 1e-6 |
|
|
|
|
|
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) |
|
|
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.norm_final(x) |
|
|
x = x[:, 1:] |
|
|
x = self.linear(x) |
|
|
return x |
|
|
|
|
|
class HunYuanDiTPlain(nn.Module): |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 64, |
|
|
hidden_size: int = 2048, |
|
|
context_dim: int = 1024, |
|
|
depth: int = 21, |
|
|
num_heads: int = 16, |
|
|
qk_norm: bool = True, |
|
|
qkv_bias: bool = False, |
|
|
num_moe_layers: int = 6, |
|
|
guidance_cond_proj_dim = 2048, |
|
|
norm_type = 'layer', |
|
|
num_experts: int = 8, |
|
|
moe_top_k: int = 2, |
|
|
use_fp16: bool = False, |
|
|
dtype = None, |
|
|
device = None, |
|
|
operations = None, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
self.dtype = dtype |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.depth = depth |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.out_channels = in_channels |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm |
|
|
qk_norm = operations.RMSNorm |
|
|
|
|
|
self.context_dim = context_dim |
|
|
self.guidance_cond_proj_dim = guidance_cond_proj_dim |
|
|
|
|
|
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype) |
|
|
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations) |
|
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
HunYuanDiTBlock(hidden_size=hidden_size, |
|
|
c_emb_size=hidden_size, |
|
|
num_heads=num_heads, |
|
|
text_states_dim=context_dim, |
|
|
qk_norm=qk_norm, |
|
|
norm_layer = norm, |
|
|
qk_norm_layer = qk_norm, |
|
|
skip_connection=layer > depth // 2, |
|
|
qkv_bias=qkv_bias, |
|
|
use_moe=True if depth - layer <= num_moe_layers else False, |
|
|
num_experts=num_experts, |
|
|
moe_top_k=moe_top_k, |
|
|
use_fp16 = use_fp16, |
|
|
device = device, dtype = dtype, operations = operations) |
|
|
for layer in range(depth) |
|
|
]) |
|
|
|
|
|
self.depth = depth |
|
|
|
|
|
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype) |
|
|
|
|
|
def forward(self, x, t, context, transformer_options = {}, **kwargs): |
|
|
|
|
|
x = x.movedim(-1, -2) |
|
|
uncond_emb, cond_emb = context.chunk(2, dim = 0) |
|
|
|
|
|
context = torch.cat([cond_emb, uncond_emb], dim = 0) |
|
|
main_condition = context |
|
|
|
|
|
t = 1.0 - t |
|
|
|
|
|
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond')) |
|
|
|
|
|
x = x.to(dtype = next(self.x_embedder.parameters()).dtype) |
|
|
x_embedded = self.x_embedder(x) |
|
|
|
|
|
combined = torch.cat([time_embedded, x_embedded], dim=1) |
|
|
|
|
|
def block_wrap(args): |
|
|
return block( |
|
|
args["x"], |
|
|
args["t"], |
|
|
args["cond"], |
|
|
skip_tensor=args.get("skip"),) |
|
|
|
|
|
skip_stack = [] |
|
|
patches_replace = transformer_options.get("patches_replace", {}) |
|
|
blocks_replace = patches_replace.get("dit", {}) |
|
|
for idx, block in enumerate(self.blocks): |
|
|
if idx <= self.depth // 2: |
|
|
skip_input = None |
|
|
else: |
|
|
skip_input = skip_stack.pop() |
|
|
|
|
|
if ("block", idx) in blocks_replace: |
|
|
|
|
|
combined = blocks_replace[("block", idx)]( |
|
|
{ |
|
|
"x": combined, |
|
|
"t": time_embedded, |
|
|
"cond": main_condition, |
|
|
"skip": skip_input, |
|
|
}, |
|
|
{"original_block": block_wrap}, |
|
|
) |
|
|
else: |
|
|
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input) |
|
|
|
|
|
if idx < self.depth // 2: |
|
|
skip_stack.append(combined) |
|
|
|
|
|
output = self.final_layer(combined) |
|
|
output = output.movedim(-2, -1) * (-1.0) |
|
|
|
|
|
cond_emb, uncond_emb = output.chunk(2, dim = 0) |
|
|
return torch.cat([uncond_emb, cond_emb]) |
|
|
|