Spaces:
Running
on
Zero
Running
on
Zero
| # Reference: https://github.com/LTH14/JiT/blob/main/model_jit.py | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint as checkpoint | |
| import torch.nn.functional as F | |
| from .config import DenoiserConfig | |
| # https://github.com/huggingface/diffusers/blob/66bf7ea5be7099c8a47b9cba135f276d55247447/src/diffusers/models/embeddings.py#L27 | |
| def get_timestep_embedding( | |
| timesteps: torch.Tensor, | |
| embedding_dim: int, | |
| flip_sin_to_cos: bool = False, | |
| downscale_freq_shift: float = 1, | |
| scale: float = 1, | |
| max_period: int = 10000, | |
| ): | |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
| half_dim = embedding_dim // 2 | |
| exponent = -math.log(max_period) * torch.arange( | |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | |
| ) | |
| exponent = exponent / (half_dim - downscale_freq_shift) | |
| emb = torch.exp(exponent) | |
| emb = timesteps[:, None].float() * emb[None, :] | |
| # scale embeddings | |
| emb = scale * emb | |
| # concat sine and cosine embeddings | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
| # flip sine and cosine embeddings | |
| if flip_sin_to_cos: | |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
| # zero pad | |
| if embedding_dim % 2 == 1: | |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
| return emb | |
| class FP32RMSNorm(nn.RMSNorm): | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| return F.rms_norm( | |
| hidden_states.to(torch.float32), | |
| self.normalized_shape, | |
| weight=self.weight, | |
| eps=self.eps, | |
| ).to(hidden_states.dtype) | |
| class BottleneckPatchEmbed(nn.Module): | |
| """Image to Patch Embedding""" | |
| def __init__( | |
| self, | |
| patch_size: int = 16, | |
| in_channels: int = 3, | |
| bottleneck_dim: int = 128, | |
| hidden_dim: int = 768, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.in_channels = in_channels | |
| self.bottleneck_dim = bottleneck_dim | |
| self.hidden_dim = hidden_dim | |
| self.bias = bias | |
| self.proj_1 = nn.Conv2d( | |
| in_channels, | |
| bottleneck_dim, | |
| kernel_size=patch_size, | |
| stride=patch_size, | |
| bias=False, | |
| ) | |
| self.proj_2 = nn.Conv2d( | |
| bottleneck_dim, | |
| hidden_dim, | |
| kernel_size=1, | |
| stride=1, | |
| bias=bias, | |
| ) | |
| def forward(self, image: torch.Tensor) -> torch.Tensor: | |
| # B, C, H, W = image.shape | |
| # [B, C, H, W] | |
| # -> [B, bottleneck_dim, H/patch_size, W/patch_size] (proj_1) | |
| # -> [B, hidden_dim, H/patch_size, W/patch_size] (proj_2) | |
| # -> [B, hidden_dim, num_patches] (flatten) | |
| # -> [B, num_patches, hidden_dim] (transpose) | |
| patches = ( | |
| self.proj_2( | |
| self.proj_1(image), | |
| ) | |
| .flatten(2) | |
| .transpose(1, 2) | |
| ) | |
| return patches | |
| class TimestepEmbedder(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| freq_embedding_size: int = 256, | |
| ): | |
| super().__init__() | |
| self.freq_embedding_size = freq_embedding_size | |
| self.mlp = nn.Sequential( | |
| nn.Linear(freq_embedding_size, hidden_dim, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, hidden_dim, bias=True), | |
| ) | |
| def forward(self, timestep: torch.Tensor) -> torch.Tensor: | |
| freq_emb = get_timestep_embedding( | |
| timestep, | |
| embedding_dim=self.freq_embedding_size, | |
| flip_sin_to_cos=True, | |
| downscale_freq_shift=0, | |
| ) | |
| time_embed = self.mlp(freq_emb.to(dtype=self.mlp[0].weight.dtype)) | |
| return time_embed | |
| def apply_rope( | |
| inputs: torch.Tensor, # (batch_size, num_heads, seq_len, dim) | |
| freqs_cis: torch.Tensor, # (batch_size, seq_len, dim//2) complex64 | |
| ) -> torch.Tensor: | |
| batch_size, num_heads, seq_len, dim = inputs.shape | |
| with torch.autocast(device_type="cuda", enabled=False): | |
| inputs_cis = torch.view_as_complex( | |
| inputs.float().view(batch_size, num_heads, seq_len, dim // 2, 2) | |
| ) | |
| freqs_cis = freqs_cis.unsqueeze(1) # (batch_size, 1, seq_len, dim//2) | |
| output = torch.view_as_real(inputs_cis * freqs_cis).flatten(3) | |
| return output.type_as(inputs) | |
| class RopeEmbedder: | |
| def __init__( | |
| self, | |
| rope_theta: float = 256.0, # ref: Z-Image | |
| axes_dims: list[int] = [32, 64, 64], # text, height, width | |
| axes_lens: list[int] = [256, 128, 128], # text, height, width | |
| zero_centered: list[bool] = [False, True, True], | |
| ): | |
| self.rope_theta = rope_theta | |
| self.axes_dims = axes_dims | |
| self.axes_lens = axes_lens | |
| self.zero_centered = zero_centered | |
| # text starts with 0, image axes are zero-centered | |
| self.freqs_cis = self.precompute_freqs_cis( | |
| theta=self.rope_theta, | |
| dims=self.axes_dims, | |
| lens=self.axes_lens, | |
| zero_centered=self.zero_centered, | |
| ) | |
| def get_rope_freqs( | |
| dim: int, | |
| min_position: int = 0, | |
| max_position: int = 128, | |
| theta: float = 10000.0, | |
| ) -> torch.Tensor: | |
| freqs = 1.0 / ( | |
| theta | |
| ** ( | |
| torch.arange(0, dim, 2, dtype=torch.float64, device=torch.device("cpu")) | |
| / dim | |
| ) | |
| ) | |
| positions = torch.arange( | |
| start=min_position, | |
| end=max_position, | |
| dtype=torch.float64, | |
| device=torch.device("cpu"), | |
| ) | |
| freqs = torch.outer(positions, freqs).float() # (max_position, dim//2) | |
| # ↓pos, → dim//2 | |
| # [ min_position * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)] | |
| # ... | |
| # 0 * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)] | |
| # 1 * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)] | |
| # 2 * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)] | |
| # ... | |
| # max_position * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)] ] | |
| freqs_cis = torch.polar( | |
| abs=torch.ones_like(freqs), | |
| angle=freqs, | |
| ).to(torch.complex64) # (min_position~max_position, dim//2) complex64 | |
| # 大きさは変えずに回転を表す複素数 | |
| return freqs_cis | |
| def precompute_freqs_cis( | |
| theta: float, | |
| dims: list[int], | |
| lens: list[int], | |
| zero_centered: list[bool], | |
| ): | |
| freqs_cis = [] | |
| for i, (dim, len_) in enumerate(zip(dims, lens)): | |
| freq_cis = RopeEmbedder.get_rope_freqs( | |
| dim=dim, | |
| min_position=(len_ // 2) - len_ if zero_centered[i] else 0, | |
| max_position=len_ // 2 if zero_centered[i] else len_, | |
| theta=theta, | |
| ) # (len_, dim//2) complex64 | |
| freqs_cis.append(freq_cis) | |
| return freqs_cis | |
| # get frequencies for given position ids | |
| def __call__(self, position_ids: torch.Tensor): | |
| # move to device | |
| freqs_cis = [fc.to(position_ids.device) for fc in self.freqs_cis] | |
| result = [] | |
| for i in range(len(self.axes_dims)): | |
| index = ( | |
| position_ids[..., i : i + 1] | |
| .repeat( | |
| # match dimensions for each axis | |
| 1, # batch size? | |
| 1, # sequence length? | |
| freqs_cis[i].shape[-1], | |
| ) | |
| .to(torch.int64) | |
| ) | |
| result.append( | |
| torch.gather( | |
| freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), | |
| dim=1, | |
| index=index, | |
| ) | |
| ) | |
| return torch.cat(result, dim=-1) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = True, | |
| qk_norm: bool = True, | |
| attn_dropout: float = 0.0, | |
| proj_dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q_norm = FP32RMSNorm(self.head_dim) if qk_norm else nn.Identity() | |
| self.k_norm = FP32RMSNorm(self.head_dim) if qk_norm else nn.Identity() | |
| self.to_q = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.to_k = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.to_v = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.attn_dropout = nn.Dropout(attn_dropout) | |
| self.to_o = nn.Linear(dim, dim) | |
| self.proj_dropout = nn.Dropout(proj_dropout) | |
| def _pre_attn_reshape(self, x: torch.Tensor): | |
| batch_size, seq_len, dim = x.shape | |
| # [B, N, D] -> [B, N, num_heads, D/num_heads] -> [B, num_heads, N, D/num_heads] | |
| x = x.view( | |
| batch_size, | |
| seq_len, | |
| self.num_heads, | |
| self.head_dim, | |
| ).permute(0, 2, 1, 3) # [B, num_heads, N, head_dim] | |
| return x | |
| def _post_attn_reshape(self, x: torch.Tensor): | |
| batch_size, num_heads, seq_len, head_dim = x.shape | |
| # [B, num_heads, N, head_dim] -> [B, N, num_heads, head_dim] -> [B, N, D] | |
| x = ( | |
| x.permute(0, 2, 1, 3) | |
| .contiguous() | |
| .view(batch_size, seq_len, num_heads * head_dim) | |
| ) | |
| return x | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| rope_freqs: torch.Tensor, | |
| mask: torch.Tensor | None = None, # 1: attend, 0: ignore | |
| ) -> torch.Tensor: | |
| batch_size, seq_len, _dim = hidden_states.shape | |
| # QKV | |
| q = self.to_q(hidden_states) | |
| k = self.to_k(hidden_states) | |
| v = self.to_v(hidden_states) | |
| q = self._pre_attn_reshape(q) # [B, num_heads, N, head_dim] | |
| k = self._pre_attn_reshape(k) | |
| v = self._pre_attn_reshape(v) | |
| # QKNorm | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q = apply_rope(q, rope_freqs) | |
| k = apply_rope(k, rope_freqs) | |
| if mask is not None: | |
| # mask: (batch_size, seq_len) -> (batch_size, num_heads, seq_len, seq_len) | |
| mask = ( | |
| mask.bool() | |
| .view(batch_size, 1, 1, seq_len) | |
| .expand(-1, self.num_heads, seq_len, -1) | |
| ) | |
| attn = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| dropout_p=self.attn_dropout.p if self.training else 0.0, | |
| attn_mask=mask, | |
| is_causal=False, | |
| ).to(hidden_states.dtype) | |
| attn = self._post_attn_reshape(attn) | |
| # output | |
| out = self.to_o(attn) | |
| out = self.proj_dropout(out) | |
| return out | |
| class SwiGLU(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| dropout: float = 0.0, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| hidden_dim = int(hidden_dim * 2 / 3) | |
| self.w_1 = nn.Linear(dim, hidden_dim, bias=bias) | |
| self.w_2 = nn.Linear(dim, hidden_dim, bias=bias) | |
| self.w_3 = nn.Linear(hidden_dim, dim, bias=bias) | |
| self.ffn_dropout = nn.Dropout(dropout) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| x_1 = self.w_1(hidden_states) | |
| x_2 = self.w_2(hidden_states) | |
| x = F.silu(x_1) * x_2 | |
| x = self.w_3(self.ffn_dropout(x)) | |
| return x | |
| class FinalLayer(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| mlp_ratio: float, | |
| patch_size: int, | |
| out_channels: int, | |
| ): | |
| super().__init__() | |
| self.norm_final = FP32RMSNorm(hidden_dim) | |
| self.mlp = SwiGLU( | |
| dim=hidden_dim, | |
| hidden_dim=int(hidden_dim * mlp_ratio), | |
| dropout=0.0, | |
| bias=True, | |
| ) | |
| self.linear = nn.Linear( | |
| hidden_dim, | |
| patch_size * patch_size * out_channels, | |
| bias=True, | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| ) -> torch.Tensor: | |
| x = self.norm_final(hidden_states) | |
| x = self.mlp(x) | |
| x = self.linear(x) | |
| return x | |
| class JiTBlock(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| attn_dropout: float = 0.0, | |
| proj_dropout: float = 0.0, | |
| ffn_dropout: float = 0.0, | |
| qkv_bias: bool = True, | |
| qk_norm: bool = True, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.norm1 = FP32RMSNorm(hidden_dim, eps=1e-6) | |
| self.attn = Attention( | |
| dim=hidden_dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| attn_dropout=attn_dropout, | |
| proj_dropout=proj_dropout, | |
| ) | |
| self.norm2 = FP32RMSNorm(hidden_dim) | |
| self.mlp = SwiGLU( | |
| dim=hidden_dim, | |
| hidden_dim=int(hidden_dim * mlp_ratio), | |
| dropout=ffn_dropout, | |
| bias=bias, | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| rope_freqs: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| ): | |
| # attn | |
| hidden_states = hidden_states + self.attn( | |
| self.norm1(hidden_states), | |
| rope_freqs, | |
| mask=mask, | |
| ) | |
| # mlp | |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) | |
| return hidden_states | |
| class JiT(nn.Module): | |
| def __init__(self, config: DenoiserConfig): | |
| super().__init__() | |
| self.config = config | |
| assert (config.hidden_size // config.num_heads) == sum(config.rope_axes_dims), ( | |
| "The sum of rope_axes_dims must equal to hidden_size / num_heads = head_dim." | |
| ) | |
| self.num_axes = len( | |
| config.rope_axes_dims | |
| ) # 0: image_index, 1: height, 2: width | |
| # image patch embedder | |
| self.patch_embedder = BottleneckPatchEmbed( | |
| patch_size=config.patch_size, | |
| in_channels=config.in_channels, | |
| bottleneck_dim=config.bottleneck_dim, | |
| hidden_dim=config.hidden_size, | |
| bias=True, | |
| ) | |
| # timestep embedder | |
| self.time_embedder = TimestepEmbedder( | |
| hidden_dim=config.hidden_size, | |
| freq_embedding_size=256, | |
| ) | |
| self.time_position_embeds = nn.Parameter( | |
| torch.randn( | |
| config.num_time_tokens, | |
| config.hidden_size, | |
| ), | |
| requires_grad=True, | |
| ) | |
| # RoPE embedder | |
| self.rope_embedder = RopeEmbedder( | |
| rope_theta=config.rope_theta, | |
| axes_dims=config.rope_axes_dims, | |
| axes_lens=config.rope_axes_lens, | |
| zero_centered=config.rope_zero_centered, | |
| ) | |
| # class condition or text embedding | |
| self.context_embedder = nn.Linear( | |
| config.context_dim, | |
| config.hidden_size, | |
| bias=True, | |
| ) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| JiTBlock( | |
| hidden_dim=config.hidden_size, | |
| num_heads=config.num_heads, | |
| mlp_ratio=config.mlp_ratio, | |
| attn_dropout=config.attn_dropout, | |
| proj_dropout=config.proj_dropout, | |
| ffn_dropout=0.0, | |
| qkv_bias=True, | |
| qk_norm=True, | |
| bias=True, | |
| ) | |
| for _ in range(config.depth) | |
| ] | |
| ) | |
| self.final_layer = FinalLayer( | |
| hidden_dim=config.hidden_size, | |
| mlp_ratio=config.mlp_ratio, | |
| patch_size=config.patch_size, | |
| out_channels=config.in_channels, | |
| ) | |
| self.gradient_checkpointing = False | |
| def initialize_weights(self): | |
| # Initialize weights | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.RMSNorm): | |
| nn.init.ones_(m.weight) | |
| # patch embed | |
| w_1 = self.patch_embedder.proj_1.weight | |
| nn.init.xavier_uniform_(w_1.view([w_1.shape[0], -1])) | |
| w_2 = self.patch_embedder.proj_2.weight | |
| nn.init.xavier_uniform_(w_2.view([w_2.shape[0], -1])) | |
| if self.patch_embedder.proj_2.bias is not None: | |
| nn.init.zeros_(self.patch_embedder.proj_2.bias) | |
| # time position embeds | |
| nn.init.normal_( | |
| self.time_position_embeds, | |
| std=0.02, | |
| ) | |
| # time embedder | |
| nn.init.normal_( | |
| self.time_embedder.mlp[0].weight, # type: ignore | |
| std=0.02, | |
| ) | |
| nn.init.normal_( | |
| self.time_embedder.mlp[2].weight, # type: ignore | |
| std=0.02, | |
| ) | |
| def set_gradient_checkpointing(self, enable: bool = True): | |
| self.gradient_checkpointing = enable | |
| def prepare_image_position_ids( | |
| self, | |
| height: int, | |
| width: int, | |
| image_index: int, | |
| ) -> torch.Tensor: | |
| # [H/patch_size, W/patch_size] | |
| patch_size = self.config.patch_size | |
| h_patches = height // patch_size | |
| w_patches = width // patch_size | |
| position_ids = torch.zeros( | |
| h_patches, | |
| w_patches, | |
| self.num_axes, | |
| ) | |
| # image_index | |
| position_ids[:, :, 0] = image_index # image | |
| # height (y-index) | |
| position_ids[:, :, 1] = ( | |
| torch.arange( | |
| h_patches, | |
| ) | |
| .unsqueeze(1) | |
| .repeat(1, w_patches) | |
| ) | |
| # width (x-index) | |
| position_ids[:, :, 2] = ( | |
| torch.arange( | |
| w_patches, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(h_patches, 1) | |
| ) | |
| return position_ids.view(-1, self.num_axes) # (num_patches, n_axes) | |
| def prepare_context_position_ids( | |
| self, | |
| seq_len: int, | |
| context_start_index: int = 0, | |
| xy_position: int = 0, | |
| ) -> torch.Tensor: | |
| position_ids = torch.zeros( | |
| seq_len, | |
| self.num_axes, | |
| ) | |
| # context_index (0, ..., seq_len-1) | |
| position_ids[:, 0] = torch.arange( | |
| context_start_index, | |
| context_start_index + seq_len, | |
| ) # text | |
| # token indices are (0, 0)...(0, 0) | |
| position_ids[:, 1] = xy_position | |
| position_ids[:, 2] = xy_position | |
| return position_ids | |
| def prepare_time_position_ids( | |
| self, | |
| seq_len: int, | |
| time_start_index: int, | |
| xy_position: int = 0, | |
| ) -> torch.Tensor: | |
| position_ids = torch.zeros( | |
| seq_len, | |
| self.num_axes, | |
| ) | |
| # time_index | |
| position_ids[:, 0] = torch.arange( | |
| time_start_index, time_start_index + seq_len | |
| ) # time | |
| # token indices are (0, 0)...(0, 0) | |
| position_ids[:, 1] = xy_position | |
| position_ids[:, 2] = xy_position | |
| return position_ids | |
| def unpatchify( | |
| self, | |
| patches: torch.Tensor, | |
| height: int, | |
| width: int, | |
| ) -> torch.Tensor: | |
| batch_size, num_patches, _patch_dim = patches.shape | |
| patch_size = self.config.patch_size | |
| out_channels = self.config.out_channels | |
| h_patches = height // patch_size | |
| w_patches = width // patch_size | |
| assert num_patches == h_patches * w_patches, "Mismatch in number of patches" | |
| # [B, N, patch_size*patch_size*C] -> [B, H_patch, W_patch, patch_size, patch_size, C] | |
| patches = patches.view( | |
| batch_size, | |
| h_patches, | |
| w_patches, | |
| patch_size, | |
| patch_size, | |
| out_channels, | |
| ) | |
| # [B, H_patch, W_patch, patch_size, patch_size, C] | |
| # -> [B, C, H_patch, patch_size, W_patch, patch_size] | |
| patches = patches.permute(0, 5, 1, 3, 2, 4) | |
| # -> [B, C, H_img, W_img] | |
| images = patches.reshape(batch_size, out_channels, height, width) | |
| return images | |
| def forward( | |
| self, | |
| image: torch.Tensor, # [B, C, H, W] | |
| timestep: torch.Tensor, # [B] | |
| context: torch.Tensor, # [B, context_len, context_dim] | |
| context_mask: torch.Tensor | None = None, # [B, context_len] | |
| ): | |
| batch_size, _in_channels, height, width = image.shape | |
| time_embed: torch.Tensor = self.time_embedder(timestep) # [B, hidden_dim] | |
| time_tokens = time_embed.unsqueeze(1).repeat( # add seq_len dim | |
| 1, | |
| self.time_position_embeds.shape[0], # num_time_tokens | |
| 1, | |
| ) + self.time_position_embeds.unsqueeze(0).repeat( # add batch dim | |
| batch_size, | |
| 1, | |
| 1, | |
| ) # [B, num_time_tokens, hidden_dim] | |
| num_time_tokens = time_tokens.shape[1] | |
| context_embed = self.context_embedder(context) | |
| context_len = context_embed.shape[1] | |
| patches = self.patch_embedder(image) # [B, N, hidden_dim]] | |
| patches_len = patches.shape[1] | |
| # context -> time -> patches | |
| context_position_ids = self.prepare_context_position_ids( | |
| seq_len=context_len, | |
| context_start_index=0, | |
| ) | |
| time_position_ids = self.prepare_time_position_ids( | |
| seq_len=num_time_tokens, | |
| time_start_index=context_len, | |
| ) | |
| patches_position_ids = self.prepare_image_position_ids( | |
| height=height, | |
| width=width, | |
| image_index=context_len + num_time_tokens, # after context and time tokens | |
| ) | |
| # actually: patches -> time -> context | |
| position_ids = torch.cat( | |
| [ | |
| patches_position_ids, | |
| time_position_ids, | |
| context_position_ids, | |
| ], | |
| dim=0, | |
| ).view(1, -1, self.num_axes) # (1, total_seq_len, n_axes) | |
| # prepare RoPE | |
| freqs_cis = ( | |
| self.rope_embedder(position_ids=position_ids) | |
| .repeat( | |
| batch_size, | |
| 1, | |
| 1, | |
| ) | |
| .to(device=image.device) | |
| ) | |
| # attention mask | |
| if context_mask is not None: | |
| patches_mask = torch.ones(batch_size, patches_len, device=image.device) | |
| time_mask = torch.ones(batch_size, num_time_tokens, device=image.device) | |
| mask = torch.cat( | |
| [ | |
| patches_mask, | |
| time_mask, | |
| context_mask.to(image.device), | |
| ], | |
| dim=1, | |
| ) | |
| else: | |
| # attend all | |
| mask = torch.ones( | |
| batch_size, | |
| patches_len + num_time_tokens + context_len, | |
| device=image.device, | |
| ) | |
| for _i, block in enumerate(self.blocks): | |
| tokens = torch.cat( | |
| [ | |
| patches, # 16x16 | |
| time_tokens, # 4 | |
| context_embed, # 64 | |
| ], | |
| dim=1, # cat in seq_len dimension | |
| ) | |
| if self.gradient_checkpointing and self.training: | |
| patches = checkpoint.checkpoint( # type: ignore | |
| block, | |
| tokens, | |
| freqs_cis, | |
| mask, | |
| )[:, :patches_len, :] | |
| else: | |
| patches = block( | |
| tokens, | |
| rope_freqs=freqs_cis, | |
| mask=mask, | |
| )[:, :patches_len, :] # only keep patch tokens | |
| patches = self.final_layer(patches) | |
| pred_image = self.unpatchify( | |
| patches, | |
| height=height, | |
| width=width, | |
| ) | |
| return pred_image | |