import dataclasses import json import math import os from typing import Optional, Literal, Union import safetensors.torch import torch from torch import nn from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache try: from liger_kernel.transformers.rms_norm import LigerRMSNorm as RMSNorm from liger_kernel.ops.geglu import LigerGELUMulFunction as GELUMulFunction except ModuleNotFoundError: from torch.nn import RMSNorm class GELUMulFunction: @staticmethod def apply(hidden_gelu, hidden_linear): return torch.nn.functional.gelu(hidden_gelu, approximate='tanh') * hidden_linear @dataclasses.dataclass class Pix2StructConfig: vision_patch_size: int = 768 vision_hidden_size: int = 768 vision_mlp_ff_size: int = 2048 vision_layers: int = 12 vision_attention_kv_size: int = 64 vision_attention_heads: int = 12 vision_max_rows: int = 4096 vision_max_columns: int = 4096 vision_page_mode: Literal['concat', 'index'] = 'index' vision_max_pages: int = 4096 text_vocab_size: int = 50265 text_layers: int = 12 text_hidden_size: int = 768 text_dense_act_ff_size: int = 2048 text_attention_kv_size: int = 64 text_attention_heads: int = 12 rms_norm_eps: float = 1e-6 dropout_rate: float = 0.1 @staticmethod def from_transformers(config): return Pix2StructConfig( vision_patch_size=config.vision_config.patch_embed_hidden_size, vision_hidden_size=config.vision_config.hidden_size, vision_mlp_ff_size=config.vision_config.d_ff, vision_layers=config.vision_config.num_hidden_layers, vision_attention_kv_size=config.vision_config.d_kv, vision_attention_heads=config.vision_config.num_attention_heads, vision_max_rows=config.vision_config.seq_len, vision_max_columns=config.vision_config.seq_len, vision_page_mode='index', vision_max_pages=4096, text_vocab_size=config.text_config.vocab_size, text_layers=config.text_config.num_layers, text_hidden_size=config.text_config.hidden_size, text_dense_act_ff_size=config.text_config.d_ff, text_attention_kv_size=config.text_config.d_kv, text_attention_heads=config.text_config.num_heads, rms_norm_eps=config.vision_config.layer_norm_eps, dropout_rate=config.vision_config.dropout_rate, ) class Pix2StructModel(nn.Module): def __init__(self, config: Pix2StructConfig): super().__init__() self.config = config self.encoder = Pix2StructVisionEncoder(config) self.decoder = Pix2StructTextDecoder(config) def forward( self, flattened_patches: torch.Tensor, flattened_patches_cu_seq_lens: torch.LongTensor, flattened_patches_max_seq_len: int, decoder_input_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, decoder_cu_seq_lens: Optional[torch.LongTensor] = None, decoder_max_seq_len: Optional[int] = None, decoder_cross_cu_seq_lens: Optional[torch.LongTensor] = None, decoder_cross_max_seq_len: Optional[int] = None, ): encoder_hidden_states = self.encoder( flattened_patches=flattened_patches, cu_seq_lens=flattened_patches_cu_seq_lens, max_seq_len=flattened_patches_max_seq_len, ) decoder_outputs = self.decoder( input_ids=decoder_input_ids, labels=labels, cu_seq_lens=decoder_cu_seq_lens, max_seq_len=decoder_max_seq_len, encoder_hidden_states=encoder_hidden_states, encoder_cu_seq_lens=flattened_patches_cu_seq_lens, encoder_max_seq_len=flattened_patches_max_seq_len, cross_cu_seq_lens=decoder_cross_cu_seq_lens, cross_max_seq_len=decoder_cross_max_seq_len, ) return decoder_outputs def get_encoder_kv_cache( self, flattened_patches: torch.Tensor, flattened_patches_cu_seq_lens: torch.LongTensor, flattened_patches_max_seq_len: int, ): encoder_hidden_states = self.encoder( flattened_patches=flattened_patches, cu_seq_lens=flattened_patches_cu_seq_lens, max_seq_len=flattened_patches_max_seq_len, ) encoder_hidden_states_batch = [] # print('flattened_patches', flattened_patches.shape) # print('flattened_patches_cu_seq_lens', flattened_patches_cu_seq_lens) for i in range(1, flattened_patches_cu_seq_lens.shape[0]): sub_seq_len = flattened_patches_cu_seq_lens[i] - flattened_patches_cu_seq_lens[i - 1] encoder_hidden_states_batch.append( torch.nn.functional.pad( encoder_hidden_states[flattened_patches_cu_seq_lens[i - 1]:flattened_patches_cu_seq_lens[i]], (0, 0, 0, flattened_patches_max_seq_len - sub_seq_len), ) ) # print('encoder_hidden_states_batch', i, encoder_hidden_states_batch[-1].shape, sub_seq_len, flattened_patches_max_seq_len) encoder_hidden_states_batch = torch.stack(encoder_hidden_states_batch) encoder_k_cache, encoder_v_cache = self.decoder.get_encoder_kv_cache(encoder_hidden_states_batch) return { 'encoder_k_cache': encoder_k_cache, 'encoder_v_cache': encoder_v_cache, 'encoder_cache_seqlens': flattened_patches_cu_seq_lens[1:] - flattened_patches_cu_seq_lens[:-1], } def save(self, model_folder): os.makedirs(model_folder, exist_ok=True) safetensors.torch.save_file(self.state_dict(), model_folder + '/model.safetensors') with open(model_folder + '/config.json', 'w') as f: json.dump(dataclasses.asdict(self.config), f) @staticmethod def load(model_folder): if model_folder.startswith('transformers/'): return Pix2StructModel.from_transformers(model_folder.replace('transformers/', '')) with open(model_folder + '/config.json', 'r') as f: config = Pix2StructConfig(**json.load(f)) model = Pix2StructModel(config) model.load_state_dict(safetensors.torch.load_file(model_folder + '/model.safetensors')) return model @staticmethod def from_transformers(model_id): import transformers donor_config = transformers.Pix2StructConfig.from_pretrained(model_id) donor_model = transformers.Pix2StructForConditionalGeneration.from_pretrained(model_id) config = Pix2StructConfig.from_transformers(donor_config) model = Pix2StructModel(config) weights = donor_model.state_dict() mapping = { 'encoder.embeddings.patch_projection.weight': 'encoder.embeddings.patch_projection.weight', 'encoder.embeddings.patch_projection.bias': 'encoder.embeddings.patch_projection.bias', 'encoder.embeddings.row_embedder.weight': 'encoder.embeddings.row_embedder.weight', 'encoder.embeddings.column_embedder.weight': 'encoder.embeddings.column_embedder.weight', 'encoder.layer_norm.weight': 'encoder.layernorm.weight', 'decoder.embed_tokens.weight': 'decoder.embed_tokens.weight', 'decoder.layer_norm.weight': 'decoder.final_layer_norm.weight', 'decoder.lm_head.weight': 'decoder.lm_head.weight' } for vision_layer_idx in range(config.vision_layers): prefix = f'encoder.layers.{vision_layer_idx}' s_prefix = f'encoder.encoder.layer.{vision_layer_idx}' mapping[f'{prefix}.attention.pre_layer_norm.weight'] = f'{s_prefix}.pre_attention_layer_norm.weight' mapping[f'{prefix}.attention.query.weight'] = f'{s_prefix}.attention.query.weight' mapping[f'{prefix}.attention.key.weight'] = f'{s_prefix}.attention.key.weight' mapping[f'{prefix}.attention.value.weight'] = f'{s_prefix}.attention.value.weight' mapping[f'{prefix}.attention.output.weight'] = f'{s_prefix}.attention.output.weight' mapping[f'{prefix}.mlp.pre_layer_norm.weight'] = f'{s_prefix}.pre_mlp_layer_norm.weight' mapping[f'{prefix}.mlp.wi_0.weight'] = f'{s_prefix}.mlp.wi_0.weight' mapping[f'{prefix}.mlp.wi_1.weight'] = f'{s_prefix}.mlp.wi_1.weight' mapping[f'{prefix}.mlp.wo.weight'] = f'{s_prefix}.mlp.wo.weight' for vision_layer_idx in range(config.text_layers): prefix = f'decoder.layers.{vision_layer_idx}' s_prefix = f'decoder.layer.{vision_layer_idx}' mapping[f'{prefix}.encoder_decoder_attention.pre_layer_norm.weight'] = f'{s_prefix}.encoder_decoder_attention.layer_norm.weight' mapping[f'{prefix}.encoder_decoder_attention.query.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.query.weight' mapping[f'{prefix}.encoder_decoder_attention.key.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.key.weight' mapping[f'{prefix}.encoder_decoder_attention.value.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.value.weight' mapping[f'{prefix}.encoder_decoder_attention.output.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.output.weight' mapping[f'{prefix}.self_attention.pre_layer_norm.weight'] = f'{s_prefix}.self_attention.layer_norm.weight' mapping[f'{prefix}.self_attention.query.weight'] = f'{s_prefix}.self_attention.attention.query.weight' mapping[f'{prefix}.self_attention.key.weight'] = f'{s_prefix}.self_attention.attention.key.weight' mapping[f'{prefix}.self_attention.value.weight'] = f'{s_prefix}.self_attention.attention.value.weight' mapping[f'{prefix}.self_attention.output.weight'] = f'{s_prefix}.self_attention.attention.output.weight' mapping[f'{prefix}.mlp.pre_layer_norm.weight'] = f'{s_prefix}.mlp.layer_norm.weight' mapping[f'{prefix}.mlp.wi_0.weight'] = f'{s_prefix}.mlp.DenseReluDense.wi_0.weight' mapping[f'{prefix}.mlp.wi_1.weight'] = f'{s_prefix}.mlp.DenseReluDense.wi_1.weight' mapping[f'{prefix}.mlp.wo.weight'] = f'{s_prefix}.mlp.DenseReluDense.wo.weight' state_dict = {} for target_key, source_key in mapping.items(): state_dict[target_key] = weights[source_key] del weights[source_key] load_info = model.load_state_dict(state_dict, strict=False) if len(load_info.missing_keys) > 0: print('The following keys are missing', load_info.missing_keys) if 'encoder.embeddings.page_embedder.weight' in load_info.missing_keys: model.encoder.embeddings.page_embedder.weight.data.normal_(mean=0.0, std=0.02) return model # VISION class Pix2StructVisionEncoder(nn.Module): def __init__(self, config: Pix2StructConfig): super().__init__() self.config = config self.embeddings = Pix2StructVisionEmbeddings(config) self.layers = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.vision_layers)]) self.layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps) def forward( self, flattened_patches: torch.Tensor, # (sum_seq_lens, vision_patch_size) cu_seq_lens: torch.Tensor, # (seq_count) max_seq_len: int, ) -> torch.Tensor: hidden_states = self.embeddings(flattened_patches) for i, layer_module in enumerate(self.layers): hidden_states = layer_module(hidden_states, cu_seq_lens, max_seq_len) hidden_states = self.layer_norm(hidden_states) return hidden_states class Pix2StructVisionLayer(nn.Module): def __init__(self, config: Pix2StructConfig) -> None: super().__init__() self.attention = Pix2StructVisionAttention(config) self.mlp = Pix2StructVisionMLP(config) def forward( self, hidden_states: torch.Tensor, # (sum_seq_lens, vision_hidden_size) cu_seq_lens: torch.Tensor, # (seq_count) max_seq_len: int, ) -> torch.Tensor: # (sum_seq_lens, vision_hidden_size) hidden_states = hidden_states + self.attention( hidden_states=hidden_states, cu_seq_lens=cu_seq_lens, max_seq_len=max_seq_len, ) hidden_states = hidden_states + self.mlp(hidden_states) return hidden_states class Pix2StructVisionAttention(nn.Module): def __init__(self, config: Pix2StructConfig): super().__init__() self.config = config self.inner_dim = config.vision_attention_kv_size * config.vision_attention_heads self.pre_layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps) self.query = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) self.key = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) self.value = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) self.output = nn.Linear(self.inner_dim, config.vision_hidden_size, bias=False) def forward(self, hidden_states, cu_seq_lens, max_seq_len): sum_seq_lens = hidden_states.size(0) def to_projection_shape(states: torch.Tensor) -> torch.Tensor: # (sum_seq_lens, n_heads, dim_per_head) return states.contiguous().view( sum_seq_lens, self.config.vision_attention_heads, self.config.vision_attention_kv_size ) hidden_states = self.pre_layer_norm(hidden_states) query_states = to_projection_shape(self.query(hidden_states)) key_states = to_projection_shape(self.key(hidden_states)) value_states = to_projection_shape(self.value(hidden_states)) attn_output = flash_attn_varlen_func( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seq_lens, cu_seqlens_k=cu_seq_lens, max_seqlen_q=max_seq_len, max_seqlen_k=max_seq_len, dropout_p=self.config.dropout_rate if self.training else 0.0, causal=False, ) attn_output = attn_output.contiguous().view(sum_seq_lens, self.inner_dim) attn_output = self.output(attn_output) return attn_output class Pix2StructVisionMLP(nn.Module): def __init__(self, config: Pix2StructConfig): super().__init__() self.pre_layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps) self.wi_0 = nn.Linear(config.vision_hidden_size, config.vision_mlp_ff_size, bias=False) self.wi_1 = nn.Linear(config.vision_hidden_size, config.vision_mlp_ff_size, bias=False) self.wo = nn.Linear(config.vision_mlp_ff_size, config.vision_hidden_size, bias=False) self.dropout = nn.Dropout(config.dropout_rate) self.act = nn.GELU(approximate='tanh') def forward(self, hidden_states): hidden_states = self.pre_layer_norm(hidden_states) hidden_gelu = self.wi_0(hidden_states) hidden_linear = self.wi_1(hidden_states) hidden_states = GELUMulFunction.apply(hidden_gelu, hidden_linear) # hidden_gelu = self.act(self.wi_0(hidden_states)) # hidden_linear = self.wi_1(hidden_states) # hidden_states = hidden_gelu * hidden_linear if self.training: hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states class Pix2StructVisionEmbeddings(nn.Module): def __init__(self, config: Pix2StructConfig) -> None: super().__init__() self.config = config self.patch_projection = nn.Linear(config.vision_patch_size, config.vision_hidden_size) if config.vision_page_mode == 'index': self.page_embedder = nn.Embedding(config.vision_max_pages, config.vision_hidden_size) else: self.page_embedder = None self.row_embedder = nn.Embedding(config.vision_max_rows, config.vision_hidden_size) self.column_embedder = nn.Embedding(config.vision_max_columns, config.vision_hidden_size) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, flattened_patches: torch.Tensor, # (sum_seq_lens, vision_patch_size) ) -> torch.Tensor: # (sum_seq_lens, vision_hidden_size) if self.config.vision_page_mode == 'index': page_indices = flattened_patches[:, 0].long() row_indices = flattened_patches[:, 1].long() col_indices = flattened_patches[:, 2].long() flattened_patches = flattened_patches[:, 3:] else: page_indices = None row_indices = flattened_patches[:, 0].long() col_indices = flattened_patches[:, 1].long() flattened_patches = flattened_patches[:, 2:] embeddings = self.patch_projection(flattened_patches) row_embeddings = self.row_embedder(row_indices) col_embeddings = self.column_embedder(col_indices) if self.config.vision_page_mode == 'index': page_embeddings = self.page_embedder(page_indices) embeddings = embeddings + page_embeddings + row_embeddings + col_embeddings else: embeddings = embeddings + row_embeddings + col_embeddings if self.training: embeddings = self.dropout(embeddings) return embeddings # TEXT class Pix2StructTextDecoder(nn.Module): def __init__(self, config: Pix2StructConfig): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.text_vocab_size, config.text_hidden_size) self.layers = nn.ModuleList( [Pix2StructTextLayer(config, alibi=bool(i == 0)) for i in range(config.text_layers)] ) self.layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.text_hidden_size, config.text_vocab_size, bias=False) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, input_ids: torch.LongTensor, cu_seq_lens: torch.LongTensor, max_seq_len: int, encoder_hidden_states: torch.Tensor, encoder_cu_seq_lens: torch.LongTensor, encoder_max_seq_len: int, cross_cu_seq_lens: torch.LongTensor, cross_max_seq_len: int, labels: Optional[torch.LongTensor] = None, ): hidden_states = self.embed_tokens(input_ids) hidden_states = self.dropout(hidden_states) if self.training else hidden_states for i, layer_module in enumerate(self.layers): hidden_states = layer_module( hidden_states, cu_seq_lens, max_seq_len, encoder_hidden_states, encoder_cu_seq_lens, encoder_max_seq_len, cross_cu_seq_lens, cross_max_seq_len, ) hidden_states = self.layer_norm(hidden_states) if self.training: hidden_states = self.dropout(hidden_states) if labels is not None: logits = self.lm_head(hidden_states) loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") loss = loss_fct(logits, labels) return loss else: logits = self.lm_head(hidden_states) return logits def predict( self, input_ids: torch.LongTensor, # (batch_size, max_seq_len), decoder_k_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)] decoder_v_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)] decoder_cache_seqlens: torch.IntTensor, # (batch_size_cache) encoder_k_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)] encoder_v_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)] encoder_cache_seqlens: torch.IntTensor, # (batch_size_cache) encoder_cache_batch_idx: torch.IntTensor, # (batch_size_cache) ): hidden_states = self.embed_tokens(input_ids) for i, layer_module in enumerate(self.layers): hidden_states = layer_module.predict( hidden_states, decoder_k_cache[i], decoder_v_cache[i], decoder_cache_seqlens, encoder_k_cache[i], encoder_v_cache[i], encoder_cache_seqlens, encoder_cache_batch_idx, ) hidden_states = self.layer_norm(hidden_states) logits = self.lm_head(hidden_states) return logits def get_decoder_kv_cache(self, device, batch_size, seq_len, dtype=torch.float32): k_cache_layers = [] v_cache_layers = [] for _ in self.layers: k_cache_layers.append( torch.empty( (batch_size, seq_len, self.config.text_attention_heads, self.config.text_attention_kv_size), device=device, dtype=dtype ) ) v_cache_layers.append( torch.empty( (batch_size, seq_len, self.config.text_attention_heads, self.config.text_attention_kv_size), device=device, dtype=dtype ) ) return k_cache_layers, v_cache_layers def get_encoder_kv_cache(self, encoder_hidden_states: torch.Tensor): k_cache_layers = [] v_cache_layers = [] for layer in self.layers: k, v = layer.get_encoder_kv_cache(encoder_hidden_states) k_cache_layers.append(k) v_cache_layers.append(v) return k_cache_layers, v_cache_layers class Pix2StructTextLayer(nn.Module): def __init__(self, config, alibi=False): super().__init__() self.self_attention = Pix2StructTextSelfAttention(config, alibi=alibi) self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) self.mlp = Pix2StructTextMLP(config) def forward( self, hidden_states, cu_seq_lens, max_seq_len, encoder_hidden_states, encoder_cu_seq_lens, encoder_max_seq_len, cross_cu_seq_lens, cross_max_seq_len, ): hidden_states = self.self_attention( hidden_states=hidden_states, cu_seq_lens=cu_seq_lens, max_seq_len=max_seq_len, ) do_cross_attention = encoder_hidden_states is not None if do_cross_attention: hidden_states = self.encoder_decoder_attention( hidden_states=hidden_states, cu_seq_lens=cross_cu_seq_lens, max_seq_len=cross_max_seq_len, kv_hidden_states=encoder_hidden_states, kv_cu_seq_lens=encoder_cu_seq_lens, kv_max_seq_len=encoder_max_seq_len, ) hidden_states = self.mlp(hidden_states) return hidden_states def predict( self, hidden_states, decoder_k_cache, decoder_v_cache, decoder_cache_seqlens, encoder_k_cache, encoder_v_cache, encoder_cache_seqlens, encoder_cache_batch_idx, ): hidden_states = self.self_attention.predict( hidden_states, decoder_k_cache, decoder_v_cache, decoder_cache_seqlens, ) do_cross_attention = encoder_k_cache is not None if do_cross_attention: hidden_states = self.encoder_decoder_attention.predict( hidden_states, encoder_k_cache, encoder_v_cache, encoder_cache_seqlens, encoder_cache_batch_idx, ) hidden_states = self.mlp(hidden_states) return hidden_states def get_encoder_kv_cache(self, encoder_hidden_states): return self.encoder_decoder_attention.get_kv_cache(encoder_hidden_states) class Pix2StructTextSelfAttention(nn.Module): def __init__(self, config: Pix2StructConfig, alibi=False): super().__init__() self.config = config self.alibi = alibi if alibi: self.alibi_slopes = generate_alibi_slopes(config.text_attention_heads) self.inner_dim = config.text_attention_kv_size * config.text_attention_heads self.dropout = config.dropout_rate self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps) self.query = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False) self.key = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False) self.value = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False) self.output = nn.Linear(self.inner_dim, config.text_hidden_size, bias=False) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, hidden_states, cu_seq_lens, max_seq_len, ): normed_hidden_states = self.pre_layer_norm(hidden_states) query_states = self.to_projection_shape(self.query(normed_hidden_states)) key_states = self.to_projection_shape(self.key(normed_hidden_states)) value_states = self.to_projection_shape(self.value(normed_hidden_states)) attention_output = flash_attn_varlen_func( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seq_lens, cu_seqlens_k=cu_seq_lens, max_seqlen_q=max_seq_len, max_seqlen_k=max_seq_len, dropout_p=self.config.dropout_rate if self.training else 0.0, causal=True, alibi_slopes=self.alibi_slopes.to(query_states.device) if self.alibi else None, ) attention_output = attention_output.contiguous().view(-1, self.inner_dim) attention_output = self.output(attention_output) return hidden_states + (self.dropout(attention_output) if self.training else attention_output) def predict( self, hidden_states, decoder_k_cache, decoder_v_cache, decoder_cache_seqlens, ): normed_hidden_states = self.pre_layer_norm(hidden_states) query_states = self.to_projection_shape(self.query(normed_hidden_states)) key_states = self.to_projection_shape(self.key(normed_hidden_states)) value_states = self.to_projection_shape(self.value(normed_hidden_states)) attention_output = flash_attn_with_kvcache( q=query_states, k_cache=decoder_k_cache, v_cache=decoder_v_cache, k=key_states, v=value_states, cache_seqlens=decoder_cache_seqlens, causal=True, alibi_slopes=self.alibi_slopes.to(query_states.device) if self.alibi else None, ) attention_output = attention_output.contiguous().view(*hidden_states.shape) attention_output = self.output(attention_output) return hidden_states + attention_output def to_projection_shape(self, states): return states.contiguous().view( *states.shape[:-1], self.config.text_attention_heads, self.config.text_attention_kv_size ) class Pix2StructTextLayerCrossAttention(nn.Module): def __init__(self, config: Pix2StructConfig): super().__init__() self.config = config self.inner_dim = config.text_attention_kv_size * config.text_attention_heads self.dropout = config.dropout_rate self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps) self.query = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False) self.key = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) self.value = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) self.output = nn.Linear(self.inner_dim, config.text_hidden_size, bias=False) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, hidden_states, cu_seq_lens, max_seq_len, kv_hidden_states, kv_cu_seq_lens, kv_max_seq_len, ): normed_hidden_states = self.pre_layer_norm(hidden_states) query_states = self.to_projection_shape(self.query(normed_hidden_states)) key_states = self.to_projection_shape(self.key(kv_hidden_states)) value_states = self.to_projection_shape(self.value(kv_hidden_states)) attention_output = flash_attn_varlen_func( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seq_lens, cu_seqlens_k=kv_cu_seq_lens, max_seqlen_q=max_seq_len, max_seqlen_k=kv_max_seq_len, dropout_p=self.config.dropout_rate if self.training else 0.0, causal=False, ) attention_output = attention_output.contiguous().view(hidden_states.shape[0], self.inner_dim) attention_output = self.output(attention_output) hidden_states = hidden_states + (self.dropout(attention_output) if self.training else attention_output) return hidden_states def predict( self, hidden_states, encoder_k_cache, encoder_v_cache, encoder_cache_seqlens, encoder_cache_batch_idx, ): normed_hidden_states = self.pre_layer_norm(hidden_states) query_states = self.to_projection_shape(self.query(normed_hidden_states)) attention_output = flash_attn_with_kvcache( q=query_states, k_cache=encoder_k_cache, v_cache=encoder_v_cache, cache_seqlens=encoder_cache_seqlens, cache_batch_idx=encoder_cache_batch_idx, causal=False, ) attention_output = attention_output.contiguous().view(*hidden_states.shape) attention_output = self.output(attention_output) return hidden_states + attention_output def get_kv_cache(self, states): return self.to_projection_shape(self.key(states)), self.to_projection_shape(self.value(states)) def to_projection_shape(self, states: torch.Tensor) -> torch.Tensor: return states.contiguous().view( *states.shape[:-1], self.config.text_attention_heads, self.config.text_attention_kv_size ) class Pix2StructTextMLP(nn.Module): def __init__(self, config: Pix2StructConfig): super().__init__() self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps) self.wi_0 = nn.Linear(config.text_hidden_size, config.text_dense_act_ff_size, bias=False) self.wi_1 = nn.Linear(config.text_hidden_size, config.text_dense_act_ff_size, bias=False) self.wo = nn.Linear(config.text_dense_act_ff_size, config.text_hidden_size, bias=False) self.dropout = nn.Dropout(config.dropout_rate) self.act = nn.GELU(approximate='tanh') def forward(self, hidden_states): residual = hidden_states hidden_states = self.pre_layer_norm(hidden_states) hidden_gelu = self.wi_0(hidden_states) hidden_linear = self.wi_1(hidden_states) hidden_states = GELUMulFunction.apply(hidden_gelu, hidden_linear) # hidden_gelu = self.act(self.wi_0(hidden_states)) # hidden_linear = self.wi_1(hidden_states) # hidden_states = hidden_gelu * hidden_linear if self.training: hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) hidden_states = residual + (self.dropout(hidden_states) if self.training else hidden_states) return hidden_states def generate_alibi_slopes(num_heads): def get_slopes_power_of_two(n): start = (2 ** (-2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio ** i for i in range(n)] if num_heads <= 8: slopes = get_slopes_power_of_two(num_heads) else: slopes = get_slopes_power_of_two(8) for i in range(8, num_heads): slopes.append(slopes[-1] * slopes[0]) return torch.tensor(slopes, dtype=torch.float32)