| import math
|
| from typing import Optional, Tuple, Union
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from transformers.models.auto import AutoModel
|
| from transformers.modeling_utils import PreTrainedModel
|
|
|
| from transformers.activations import ACT2FN
|
| from transformers.utils import logging
|
|
|
| from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
|
|
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
|
|
| class RMSNorm(nn.Module):
|
| def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
|
| super().__init__()
|
| self.dim = dim
|
| self.eps = eps
|
| self.elementwise_affine = elementwise_affine
|
| if self.elementwise_affine:
|
| self.weight = nn.Parameter(torch.ones(dim))
|
| else:
|
| self.register_parameter('weight', None)
|
|
|
| def _norm(self, x):
|
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
| def forward(self, x):
|
| output = self._norm(x.float()).type_as(x)
|
| if self.weight is not None:
|
| output = output * self.weight
|
| return output
|
|
|
| def extra_repr(self) -> str:
|
| return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
|
|
| def modulate(x, shift, scale):
|
| """Apply modulation to input tensor."""
|
| return x * (1 + scale) + shift
|
|
|
|
|
| class TimestepEmbedder(nn.Module):
|
| """
|
| Embeds scalar timesteps into vector representations.
|
|
|
| Args:
|
| hidden_size (`int`): Size of the output embedding
|
| frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
|
| """
|
| def __init__(self, hidden_size, frequency_embedding_size=256):
|
| super().__init__()
|
| self.mlp = nn.Sequential(
|
| nn.Linear(frequency_embedding_size, hidden_size, bias=False),
|
|
|
| ACT2FN['silu'],
|
| nn.Linear(hidden_size, hidden_size, bias=False),
|
| )
|
| self.frequency_embedding_size = frequency_embedding_size
|
|
|
| @staticmethod
|
| def timestep_embedding(t, dim, max_period=10000):
|
| """
|
| Create sinusoidal timestep embeddings.
|
|
|
| Args:
|
| t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
|
| These may be fractional.
|
| dim (`int`): The dimension of the output.
|
| max_period (`int`, optional): Controls the minimum frequency of the embeddings.
|
|
|
| Returns:
|
| `torch.Tensor`: An [N, D] Tensor of positional embeddings.
|
| """
|
| half = dim // 2
|
| freqs = torch.exp(
|
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| ).to(t.device)
|
| args = t[:, None].float() * freqs[None]
|
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| if dim % 2:
|
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| return embedding.to(t.dtype)
|
|
|
| def forward(self, t):
|
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| t_emb = self.mlp(t_freq)
|
| return t_emb
|
|
|
|
|
| class FeedForwardNetwork(nn.Module):
|
| """
|
| Standard feed-forward network with SwiGLU activation.
|
|
|
| Args:
|
| embed_dim (`int`): Input dimension
|
| ffn_dim (`int`): Hidden dimension
|
| """
|
| def __init__(
|
| self,
|
| embed_dim,
|
| ffn_dim,
|
| ):
|
| super().__init__()
|
| self.embed_dim = embed_dim
|
| self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
| self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
| self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
| self.act_fn = ACT2FN['silu']
|
|
|
| def forward(self, x):
|
| gate = self.gate_proj(x)
|
| up = self.up_proj(x)
|
|
|
|
|
|
|
| gate = self.act_fn(gate)
|
| return self.down_proj(gate * up)
|
|
|
|
|
| class HeadLayer(nn.Module):
|
| """
|
| A layer in the diffusion head.
|
|
|
| Args:
|
| embed_dim (`int`): Input dimension
|
| ffn_dim (`int`): Hidden dimension
|
| cond_dim (`int`): Condition embedding dimension
|
| norm_eps (`float`, optional): Epsilon for normalization
|
| """
|
| def __init__(
|
| self,
|
| embed_dim,
|
| ffn_dim,
|
| cond_dim,
|
| norm_eps=1e-5,
|
| ):
|
| super().__init__()
|
| self.embed_dim = embed_dim
|
| self.cond_dim = cond_dim
|
| self.ffn_dim = ffn_dim
|
| self.ffn = FeedForwardNetwork(
|
| self.embed_dim,
|
| self.ffn_dim,
|
| )
|
| self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
|
| self.adaLN_modulation = nn.Sequential(
|
|
|
| ACT2FN['silu'],
|
| nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
|
| )
|
|
|
| def forward(self, x, c):
|
| shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
|
| x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
|
| return x
|
|
|
|
|
| class FinalLayer(nn.Module):
|
| """
|
| Final layer in the diffusion head.
|
|
|
| Args:
|
| hidden_size (`int`): Input dimension
|
| output_size (`int`): Output dimension
|
| cond_size (`int`): Condition embedding dimension
|
| norm_eps (`float`, optional): Epsilon for normalization
|
| """
|
| def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
|
| super().__init__()
|
| self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
|
| self.linear = nn.Linear(hidden_size, output_size, bias=False)
|
| self.adaLN_modulation = nn.Sequential(
|
|
|
| ACT2FN['silu'],
|
| nn.Linear(cond_size, 2 * hidden_size, bias=False)
|
| )
|
|
|
| def forward(self, x, c):
|
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| x = modulate(self.norm_final(x), shift, scale)
|
| x = self.linear(x)
|
| return x
|
|
|
|
|
| class VibeVoiceDiffusionHead(PreTrainedModel):
|
| """
|
| Diffusion head model for vibevoice.
|
|
|
| Args:
|
| config (`VibeVoiceDiffusionHeadConfig`): Model configuration
|
| latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
|
| """
|
| config_class = VibeVoiceDiffusionHeadConfig
|
| supports_gradient_checkpointing = True
|
| _supports_flash_attn_2 = True
|
| _supports_sdpa = True
|
|
|
| def __init__(
|
| self,
|
| config,
|
| ):
|
| super().__init__(config)
|
| self.config = config
|
| self.cond_dim = config.hidden_size
|
| latent_size = config.latent_size
|
|
|
| self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
|
| self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
|
| self.t_embedder = TimestepEmbedder(self.cond_dim)
|
|
|
| ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
|
|
|
|
|
| self.layers = nn.ModuleList([
|
| HeadLayer(
|
| embed_dim=config.hidden_size,
|
| ffn_dim=ffn_dim,
|
| cond_dim=self.cond_dim,
|
| norm_eps=config.rms_norm_eps
|
| )
|
| for _ in range(config.head_layers)
|
| ])
|
|
|
|
|
| self.final_layer = FinalLayer(
|
| hidden_size=config.hidden_size,
|
| output_size=latent_size,
|
| cond_size=self.cond_dim,
|
| norm_eps=config.rms_norm_eps
|
| )
|
|
|
| self.initialize_weights()
|
|
|
| def initialize_weights(self):
|
| """Initialize the weights of the model."""
|
|
|
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
|
|
|
|
| for layer in self.layers:
|
| nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
|
|
|
|
|
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| nn.init.constant_(self.final_layer.linear.weight, 0)
|
|
|
| def forward(
|
| self,
|
| noisy_images,
|
| timesteps,
|
| condition,
|
| ):
|
| """
|
| Forward pass of the prediction head.
|
|
|
| Args:
|
| noisy_images (`torch.Tensor`): Noisy images/latents to denoise
|
| timesteps (`torch.Tensor`): Timesteps for diffusion
|
| condition (`torch.Tensor`): Conditioning information
|
|
|
| Returns:
|
| `torch.Tensor`: The predicted noise/velocity
|
| """
|
| x = self.noisy_images_proj(noisy_images)
|
| t = self.t_embedder(timesteps)
|
| condition = self.cond_proj(condition)
|
| c = condition + t
|
|
|
| for layer in self.layers:
|
| x = layer(x, c)
|
|
|
| x = self.final_layer(x, c)
|
| return x
|
|
|
|
|
| AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
|
|
|
| __all__ = [
|
| "VibeVoiceDiffusionHead",
|
| ] |