Spaces:
Runtime error
Runtime error
| import dataclasses | |
| import json | |
| import math | |
| import os | |
| from typing import Optional, Literal, Union | |
| import safetensors.torch | |
| import torch | |
| from torch import nn | |
| from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache | |
| try: | |
| from liger_kernel.transformers.rms_norm import LigerRMSNorm as RMSNorm | |
| from liger_kernel.ops.geglu import LigerGELUMulFunction as GELUMulFunction | |
| except ModuleNotFoundError: | |
| from torch.nn import RMSNorm | |
| class GELUMulFunction: | |
| def apply(hidden_gelu, hidden_linear): | |
| return torch.nn.functional.gelu(hidden_gelu, approximate='tanh') * hidden_linear | |
| class Pix2StructConfig: | |
| vision_patch_size: int = 768 | |
| vision_hidden_size: int = 768 | |
| vision_mlp_ff_size: int = 2048 | |
| vision_layers: int = 12 | |
| vision_attention_kv_size: int = 64 | |
| vision_attention_heads: int = 12 | |
| vision_max_rows: int = 4096 | |
| vision_max_columns: int = 4096 | |
| vision_page_mode: Literal['concat', 'index'] = 'index' | |
| vision_max_pages: int = 4096 | |
| text_vocab_size: int = 50265 | |
| text_layers: int = 12 | |
| text_hidden_size: int = 768 | |
| text_dense_act_ff_size: int = 2048 | |
| text_attention_kv_size: int = 64 | |
| text_attention_heads: int = 12 | |
| rms_norm_eps: float = 1e-6 | |
| dropout_rate: float = 0.1 | |
| def from_transformers(config): | |
| return Pix2StructConfig( | |
| vision_patch_size=config.vision_config.patch_embed_hidden_size, | |
| vision_hidden_size=config.vision_config.hidden_size, | |
| vision_mlp_ff_size=config.vision_config.d_ff, | |
| vision_layers=config.vision_config.num_hidden_layers, | |
| vision_attention_kv_size=config.vision_config.d_kv, | |
| vision_attention_heads=config.vision_config.num_attention_heads, | |
| vision_max_rows=config.vision_config.seq_len, | |
| vision_max_columns=config.vision_config.seq_len, | |
| vision_page_mode='index', | |
| vision_max_pages=4096, | |
| text_vocab_size=config.text_config.vocab_size, | |
| text_layers=config.text_config.num_layers, | |
| text_hidden_size=config.text_config.hidden_size, | |
| text_dense_act_ff_size=config.text_config.d_ff, | |
| text_attention_kv_size=config.text_config.d_kv, | |
| text_attention_heads=config.text_config.num_heads, | |
| rms_norm_eps=config.vision_config.layer_norm_eps, | |
| dropout_rate=config.vision_config.dropout_rate, | |
| ) | |
| class Pix2StructModel(nn.Module): | |
| def __init__(self, config: Pix2StructConfig): | |
| super().__init__() | |
| self.config = config | |
| self.encoder = Pix2StructVisionEncoder(config) | |
| self.decoder = Pix2StructTextDecoder(config) | |
| def forward( | |
| self, | |
| flattened_patches: torch.Tensor, | |
| flattened_patches_cu_seq_lens: torch.LongTensor, | |
| flattened_patches_max_seq_len: int, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| decoder_cu_seq_lens: Optional[torch.LongTensor] = None, | |
| decoder_max_seq_len: Optional[int] = None, | |
| decoder_cross_cu_seq_lens: Optional[torch.LongTensor] = None, | |
| decoder_cross_max_seq_len: Optional[int] = None, | |
| ): | |
| encoder_hidden_states = self.encoder( | |
| flattened_patches=flattened_patches, | |
| cu_seq_lens=flattened_patches_cu_seq_lens, | |
| max_seq_len=flattened_patches_max_seq_len, | |
| ) | |
| decoder_outputs = self.decoder( | |
| input_ids=decoder_input_ids, | |
| labels=labels, | |
| cu_seq_lens=decoder_cu_seq_lens, | |
| max_seq_len=decoder_max_seq_len, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_cu_seq_lens=flattened_patches_cu_seq_lens, | |
| encoder_max_seq_len=flattened_patches_max_seq_len, | |
| cross_cu_seq_lens=decoder_cross_cu_seq_lens, | |
| cross_max_seq_len=decoder_cross_max_seq_len, | |
| ) | |
| return decoder_outputs | |
| def get_encoder_kv_cache( | |
| self, | |
| flattened_patches: torch.Tensor, | |
| flattened_patches_cu_seq_lens: torch.LongTensor, | |
| flattened_patches_max_seq_len: int, | |
| ): | |
| encoder_hidden_states = self.encoder( | |
| flattened_patches=flattened_patches, | |
| cu_seq_lens=flattened_patches_cu_seq_lens, | |
| max_seq_len=flattened_patches_max_seq_len, | |
| ) | |
| encoder_hidden_states_batch = [] | |
| # print('flattened_patches', flattened_patches.shape) | |
| # print('flattened_patches_cu_seq_lens', flattened_patches_cu_seq_lens) | |
| for i in range(1, flattened_patches_cu_seq_lens.shape[0]): | |
| sub_seq_len = flattened_patches_cu_seq_lens[i] - flattened_patches_cu_seq_lens[i - 1] | |
| encoder_hidden_states_batch.append( | |
| torch.nn.functional.pad( | |
| encoder_hidden_states[flattened_patches_cu_seq_lens[i - 1]:flattened_patches_cu_seq_lens[i]], | |
| (0, 0, 0, flattened_patches_max_seq_len - sub_seq_len), | |
| ) | |
| ) | |
| # print('encoder_hidden_states_batch', i, encoder_hidden_states_batch[-1].shape, sub_seq_len, flattened_patches_max_seq_len) | |
| encoder_hidden_states_batch = torch.stack(encoder_hidden_states_batch) | |
| encoder_k_cache, encoder_v_cache = self.decoder.get_encoder_kv_cache(encoder_hidden_states_batch) | |
| return { | |
| 'encoder_k_cache': encoder_k_cache, | |
| 'encoder_v_cache': encoder_v_cache, | |
| 'encoder_cache_seqlens': flattened_patches_cu_seq_lens[1:] - flattened_patches_cu_seq_lens[:-1], | |
| } | |
| def save(self, model_folder): | |
| os.makedirs(model_folder, exist_ok=True) | |
| safetensors.torch.save_file(self.state_dict(), model_folder + '/model.safetensors') | |
| with open(model_folder + '/config.json', 'w') as f: | |
| json.dump(dataclasses.asdict(self.config), f) | |
| def load(model_folder): | |
| if model_folder.startswith('transformers/'): | |
| return Pix2StructModel.from_transformers(model_folder.replace('transformers/', '')) | |
| with open(model_folder + '/config.json', 'r') as f: | |
| config = Pix2StructConfig(**json.load(f)) | |
| model = Pix2StructModel(config) | |
| model.load_state_dict(safetensors.torch.load_file(model_folder + '/model.safetensors')) | |
| return model | |
| def from_transformers(model_id): | |
| import transformers | |
| donor_config = transformers.Pix2StructConfig.from_pretrained(model_id) | |
| donor_model = transformers.Pix2StructForConditionalGeneration.from_pretrained(model_id) | |
| config = Pix2StructConfig.from_transformers(donor_config) | |
| model = Pix2StructModel(config) | |
| weights = donor_model.state_dict() | |
| mapping = { | |
| 'encoder.embeddings.patch_projection.weight': 'encoder.embeddings.patch_projection.weight', | |
| 'encoder.embeddings.patch_projection.bias': 'encoder.embeddings.patch_projection.bias', | |
| 'encoder.embeddings.row_embedder.weight': 'encoder.embeddings.row_embedder.weight', | |
| 'encoder.embeddings.column_embedder.weight': 'encoder.embeddings.column_embedder.weight', | |
| 'encoder.layer_norm.weight': 'encoder.layernorm.weight', | |
| 'decoder.embed_tokens.weight': 'decoder.embed_tokens.weight', | |
| 'decoder.layer_norm.weight': 'decoder.final_layer_norm.weight', | |
| 'decoder.lm_head.weight': 'decoder.lm_head.weight' | |
| } | |
| for vision_layer_idx in range(config.vision_layers): | |
| prefix = f'encoder.layers.{vision_layer_idx}' | |
| s_prefix = f'encoder.encoder.layer.{vision_layer_idx}' | |
| mapping[f'{prefix}.attention.pre_layer_norm.weight'] = f'{s_prefix}.pre_attention_layer_norm.weight' | |
| mapping[f'{prefix}.attention.query.weight'] = f'{s_prefix}.attention.query.weight' | |
| mapping[f'{prefix}.attention.key.weight'] = f'{s_prefix}.attention.key.weight' | |
| mapping[f'{prefix}.attention.value.weight'] = f'{s_prefix}.attention.value.weight' | |
| mapping[f'{prefix}.attention.output.weight'] = f'{s_prefix}.attention.output.weight' | |
| mapping[f'{prefix}.mlp.pre_layer_norm.weight'] = f'{s_prefix}.pre_mlp_layer_norm.weight' | |
| mapping[f'{prefix}.mlp.wi_0.weight'] = f'{s_prefix}.mlp.wi_0.weight' | |
| mapping[f'{prefix}.mlp.wi_1.weight'] = f'{s_prefix}.mlp.wi_1.weight' | |
| mapping[f'{prefix}.mlp.wo.weight'] = f'{s_prefix}.mlp.wo.weight' | |
| for vision_layer_idx in range(config.text_layers): | |
| prefix = f'decoder.layers.{vision_layer_idx}' | |
| s_prefix = f'decoder.layer.{vision_layer_idx}' | |
| mapping[f'{prefix}.encoder_decoder_attention.pre_layer_norm.weight'] = f'{s_prefix}.encoder_decoder_attention.layer_norm.weight' | |
| mapping[f'{prefix}.encoder_decoder_attention.query.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.query.weight' | |
| mapping[f'{prefix}.encoder_decoder_attention.key.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.key.weight' | |
| mapping[f'{prefix}.encoder_decoder_attention.value.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.value.weight' | |
| mapping[f'{prefix}.encoder_decoder_attention.output.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.output.weight' | |
| mapping[f'{prefix}.self_attention.pre_layer_norm.weight'] = f'{s_prefix}.self_attention.layer_norm.weight' | |
| mapping[f'{prefix}.self_attention.query.weight'] = f'{s_prefix}.self_attention.attention.query.weight' | |
| mapping[f'{prefix}.self_attention.key.weight'] = f'{s_prefix}.self_attention.attention.key.weight' | |
| mapping[f'{prefix}.self_attention.value.weight'] = f'{s_prefix}.self_attention.attention.value.weight' | |
| mapping[f'{prefix}.self_attention.output.weight'] = f'{s_prefix}.self_attention.attention.output.weight' | |
| mapping[f'{prefix}.mlp.pre_layer_norm.weight'] = f'{s_prefix}.mlp.layer_norm.weight' | |
| mapping[f'{prefix}.mlp.wi_0.weight'] = f'{s_prefix}.mlp.DenseReluDense.wi_0.weight' | |
| mapping[f'{prefix}.mlp.wi_1.weight'] = f'{s_prefix}.mlp.DenseReluDense.wi_1.weight' | |
| mapping[f'{prefix}.mlp.wo.weight'] = f'{s_prefix}.mlp.DenseReluDense.wo.weight' | |
| state_dict = {} | |
| for target_key, source_key in mapping.items(): | |
| state_dict[target_key] = weights[source_key] | |
| del weights[source_key] | |
| load_info = model.load_state_dict(state_dict, strict=False) | |
| if len(load_info.missing_keys) > 0: | |
| print('The following keys are missing', load_info.missing_keys) | |
| if 'encoder.embeddings.page_embedder.weight' in load_info.missing_keys: | |
| model.encoder.embeddings.page_embedder.weight.data.normal_(mean=0.0, std=0.02) | |
| return model | |
| # VISION | |
| class Pix2StructVisionEncoder(nn.Module): | |
| def __init__(self, config: Pix2StructConfig): | |
| super().__init__() | |
| self.config = config | |
| self.embeddings = Pix2StructVisionEmbeddings(config) | |
| self.layers = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.vision_layers)]) | |
| self.layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| flattened_patches: torch.Tensor, # (sum_seq_lens, vision_patch_size) | |
| cu_seq_lens: torch.Tensor, # (seq_count) | |
| max_seq_len: int, | |
| ) -> torch.Tensor: | |
| hidden_states = self.embeddings(flattened_patches) | |
| for i, layer_module in enumerate(self.layers): | |
| hidden_states = layer_module(hidden_states, cu_seq_lens, max_seq_len) | |
| hidden_states = self.layer_norm(hidden_states) | |
| return hidden_states | |
| class Pix2StructVisionLayer(nn.Module): | |
| def __init__(self, config: Pix2StructConfig) -> None: | |
| super().__init__() | |
| self.attention = Pix2StructVisionAttention(config) | |
| self.mlp = Pix2StructVisionMLP(config) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, # (sum_seq_lens, vision_hidden_size) | |
| cu_seq_lens: torch.Tensor, # (seq_count) | |
| max_seq_len: int, | |
| ) -> torch.Tensor: # (sum_seq_lens, vision_hidden_size) | |
| hidden_states = hidden_states + self.attention( | |
| hidden_states=hidden_states, | |
| cu_seq_lens=cu_seq_lens, | |
| max_seq_len=max_seq_len, | |
| ) | |
| hidden_states = hidden_states + self.mlp(hidden_states) | |
| return hidden_states | |
| class Pix2StructVisionAttention(nn.Module): | |
| def __init__(self, config: Pix2StructConfig): | |
| super().__init__() | |
| self.config = config | |
| self.inner_dim = config.vision_attention_kv_size * config.vision_attention_heads | |
| self.pre_layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps) | |
| self.query = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) | |
| self.key = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) | |
| self.value = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) | |
| self.output = nn.Linear(self.inner_dim, config.vision_hidden_size, bias=False) | |
| def forward(self, hidden_states, cu_seq_lens, max_seq_len): | |
| sum_seq_lens = hidden_states.size(0) | |
| def to_projection_shape(states: torch.Tensor) -> torch.Tensor: # (sum_seq_lens, n_heads, dim_per_head) | |
| return states.contiguous().view( | |
| sum_seq_lens, | |
| self.config.vision_attention_heads, | |
| self.config.vision_attention_kv_size | |
| ) | |
| hidden_states = self.pre_layer_norm(hidden_states) | |
| query_states = to_projection_shape(self.query(hidden_states)) | |
| key_states = to_projection_shape(self.key(hidden_states)) | |
| value_states = to_projection_shape(self.value(hidden_states)) | |
| attn_output = flash_attn_varlen_func( | |
| q=query_states, | |
| k=key_states, | |
| v=value_states, | |
| cu_seqlens_q=cu_seq_lens, | |
| cu_seqlens_k=cu_seq_lens, | |
| max_seqlen_q=max_seq_len, | |
| max_seqlen_k=max_seq_len, | |
| dropout_p=self.config.dropout_rate if self.training else 0.0, | |
| causal=False, | |
| ) | |
| attn_output = attn_output.contiguous().view(sum_seq_lens, self.inner_dim) | |
| attn_output = self.output(attn_output) | |
| return attn_output | |
| class Pix2StructVisionMLP(nn.Module): | |
| def __init__(self, config: Pix2StructConfig): | |
| super().__init__() | |
| self.pre_layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps) | |
| self.wi_0 = nn.Linear(config.vision_hidden_size, config.vision_mlp_ff_size, bias=False) | |
| self.wi_1 = nn.Linear(config.vision_hidden_size, config.vision_mlp_ff_size, bias=False) | |
| self.wo = nn.Linear(config.vision_mlp_ff_size, config.vision_hidden_size, bias=False) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| self.act = nn.GELU(approximate='tanh') | |
| def forward(self, hidden_states): | |
| hidden_states = self.pre_layer_norm(hidden_states) | |
| hidden_gelu = self.wi_0(hidden_states) | |
| hidden_linear = self.wi_1(hidden_states) | |
| hidden_states = GELUMulFunction.apply(hidden_gelu, hidden_linear) | |
| # hidden_gelu = self.act(self.wi_0(hidden_states)) | |
| # hidden_linear = self.wi_1(hidden_states) | |
| # hidden_states = hidden_gelu * hidden_linear | |
| if self.training: | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.wo(hidden_states) | |
| return hidden_states | |
| class Pix2StructVisionEmbeddings(nn.Module): | |
| def __init__(self, config: Pix2StructConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.patch_projection = nn.Linear(config.vision_patch_size, config.vision_hidden_size) | |
| if config.vision_page_mode == 'index': | |
| self.page_embedder = nn.Embedding(config.vision_max_pages, config.vision_hidden_size) | |
| else: | |
| self.page_embedder = None | |
| self.row_embedder = nn.Embedding(config.vision_max_rows, config.vision_hidden_size) | |
| self.column_embedder = nn.Embedding(config.vision_max_columns, config.vision_hidden_size) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward( | |
| self, | |
| flattened_patches: torch.Tensor, # (sum_seq_lens, vision_patch_size) | |
| ) -> torch.Tensor: # (sum_seq_lens, vision_hidden_size) | |
| if self.config.vision_page_mode == 'index': | |
| page_indices = flattened_patches[:, 0].long() | |
| row_indices = flattened_patches[:, 1].long() | |
| col_indices = flattened_patches[:, 2].long() | |
| flattened_patches = flattened_patches[:, 3:] | |
| else: | |
| page_indices = None | |
| row_indices = flattened_patches[:, 0].long() | |
| col_indices = flattened_patches[:, 1].long() | |
| flattened_patches = flattened_patches[:, 2:] | |
| embeddings = self.patch_projection(flattened_patches) | |
| row_embeddings = self.row_embedder(row_indices) | |
| col_embeddings = self.column_embedder(col_indices) | |
| if self.config.vision_page_mode == 'index': | |
| page_embeddings = self.page_embedder(page_indices) | |
| embeddings = embeddings + page_embeddings + row_embeddings + col_embeddings | |
| else: | |
| embeddings = embeddings + row_embeddings + col_embeddings | |
| if self.training: | |
| embeddings = self.dropout(embeddings) | |
| return embeddings | |
| # TEXT | |
| class Pix2StructTextDecoder(nn.Module): | |
| def __init__(self, config: Pix2StructConfig): | |
| super().__init__() | |
| self.config = config | |
| self.embed_tokens = nn.Embedding(config.text_vocab_size, config.text_hidden_size) | |
| self.layers = nn.ModuleList( | |
| [Pix2StructTextLayer(config, alibi=bool(i == 0)) for i in range(config.text_layers)] | |
| ) | |
| self.layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps) | |
| self.lm_head = nn.Linear(config.text_hidden_size, config.text_vocab_size, bias=False) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| cu_seq_lens: torch.LongTensor, | |
| max_seq_len: int, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_cu_seq_lens: torch.LongTensor, | |
| encoder_max_seq_len: int, | |
| cross_cu_seq_lens: torch.LongTensor, | |
| cross_max_seq_len: int, | |
| labels: Optional[torch.LongTensor] = None, | |
| ): | |
| hidden_states = self.embed_tokens(input_ids) | |
| hidden_states = self.dropout(hidden_states) if self.training else hidden_states | |
| for i, layer_module in enumerate(self.layers): | |
| hidden_states = layer_module( | |
| hidden_states, | |
| cu_seq_lens, | |
| max_seq_len, | |
| encoder_hidden_states, | |
| encoder_cu_seq_lens, | |
| encoder_max_seq_len, | |
| cross_cu_seq_lens, | |
| cross_max_seq_len, | |
| ) | |
| hidden_states = self.layer_norm(hidden_states) | |
| if self.training: | |
| hidden_states = self.dropout(hidden_states) | |
| if labels is not None: | |
| logits = self.lm_head(hidden_states) | |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") | |
| loss = loss_fct(logits, labels) | |
| return loss | |
| else: | |
| logits = self.lm_head(hidden_states) | |
| return logits | |
| def predict( | |
| self, | |
| input_ids: torch.LongTensor, # (batch_size, max_seq_len), | |
| decoder_k_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)] | |
| decoder_v_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)] | |
| decoder_cache_seqlens: torch.IntTensor, # (batch_size_cache) | |
| encoder_k_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)] | |
| encoder_v_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)] | |
| encoder_cache_seqlens: torch.IntTensor, # (batch_size_cache) | |
| encoder_cache_batch_idx: torch.IntTensor, # (batch_size_cache) | |
| ): | |
| hidden_states = self.embed_tokens(input_ids) | |
| for i, layer_module in enumerate(self.layers): | |
| hidden_states = layer_module.predict( | |
| hidden_states, | |
| decoder_k_cache[i], | |
| decoder_v_cache[i], | |
| decoder_cache_seqlens, | |
| encoder_k_cache[i], | |
| encoder_v_cache[i], | |
| encoder_cache_seqlens, | |
| encoder_cache_batch_idx, | |
| ) | |
| hidden_states = self.layer_norm(hidden_states) | |
| logits = self.lm_head(hidden_states) | |
| return logits | |
| def get_decoder_kv_cache(self, device, batch_size, seq_len, dtype=torch.float32): | |
| k_cache_layers = [] | |
| v_cache_layers = [] | |
| for _ in self.layers: | |
| k_cache_layers.append( | |
| torch.empty( | |
| (batch_size, seq_len, self.config.text_attention_heads, self.config.text_attention_kv_size), | |
| device=device, dtype=dtype | |
| ) | |
| ) | |
| v_cache_layers.append( | |
| torch.empty( | |
| (batch_size, seq_len, self.config.text_attention_heads, self.config.text_attention_kv_size), | |
| device=device, dtype=dtype | |
| ) | |
| ) | |
| return k_cache_layers, v_cache_layers | |
| def get_encoder_kv_cache(self, encoder_hidden_states: torch.Tensor): | |
| k_cache_layers = [] | |
| v_cache_layers = [] | |
| for layer in self.layers: | |
| k, v = layer.get_encoder_kv_cache(encoder_hidden_states) | |
| k_cache_layers.append(k) | |
| v_cache_layers.append(v) | |
| return k_cache_layers, v_cache_layers | |
| class Pix2StructTextLayer(nn.Module): | |
| def __init__(self, config, alibi=False): | |
| super().__init__() | |
| self.self_attention = Pix2StructTextSelfAttention(config, alibi=alibi) | |
| self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) | |
| self.mlp = Pix2StructTextMLP(config) | |
| def forward( | |
| self, | |
| hidden_states, | |
| cu_seq_lens, | |
| max_seq_len, | |
| encoder_hidden_states, | |
| encoder_cu_seq_lens, | |
| encoder_max_seq_len, | |
| cross_cu_seq_lens, | |
| cross_max_seq_len, | |
| ): | |
| hidden_states = self.self_attention( | |
| hidden_states=hidden_states, | |
| cu_seq_lens=cu_seq_lens, | |
| max_seq_len=max_seq_len, | |
| ) | |
| do_cross_attention = encoder_hidden_states is not None | |
| if do_cross_attention: | |
| hidden_states = self.encoder_decoder_attention( | |
| hidden_states=hidden_states, | |
| cu_seq_lens=cross_cu_seq_lens, | |
| max_seq_len=cross_max_seq_len, | |
| kv_hidden_states=encoder_hidden_states, | |
| kv_cu_seq_lens=encoder_cu_seq_lens, | |
| kv_max_seq_len=encoder_max_seq_len, | |
| ) | |
| hidden_states = self.mlp(hidden_states) | |
| return hidden_states | |
| def predict( | |
| self, | |
| hidden_states, | |
| decoder_k_cache, | |
| decoder_v_cache, | |
| decoder_cache_seqlens, | |
| encoder_k_cache, | |
| encoder_v_cache, | |
| encoder_cache_seqlens, | |
| encoder_cache_batch_idx, | |
| ): | |
| hidden_states = self.self_attention.predict( | |
| hidden_states, | |
| decoder_k_cache, | |
| decoder_v_cache, | |
| decoder_cache_seqlens, | |
| ) | |
| do_cross_attention = encoder_k_cache is not None | |
| if do_cross_attention: | |
| hidden_states = self.encoder_decoder_attention.predict( | |
| hidden_states, | |
| encoder_k_cache, | |
| encoder_v_cache, | |
| encoder_cache_seqlens, | |
| encoder_cache_batch_idx, | |
| ) | |
| hidden_states = self.mlp(hidden_states) | |
| return hidden_states | |
| def get_encoder_kv_cache(self, encoder_hidden_states): | |
| return self.encoder_decoder_attention.get_kv_cache(encoder_hidden_states) | |
| class Pix2StructTextSelfAttention(nn.Module): | |
| def __init__(self, config: Pix2StructConfig, alibi=False): | |
| super().__init__() | |
| self.config = config | |
| self.alibi = alibi | |
| if alibi: | |
| self.alibi_slopes = generate_alibi_slopes(config.text_attention_heads) | |
| self.inner_dim = config.text_attention_kv_size * config.text_attention_heads | |
| self.dropout = config.dropout_rate | |
| self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps) | |
| self.query = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False) | |
| self.key = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False) | |
| self.value = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False) | |
| self.output = nn.Linear(self.inner_dim, config.text_hidden_size, bias=False) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward( | |
| self, | |
| hidden_states, | |
| cu_seq_lens, | |
| max_seq_len, | |
| ): | |
| normed_hidden_states = self.pre_layer_norm(hidden_states) | |
| query_states = self.to_projection_shape(self.query(normed_hidden_states)) | |
| key_states = self.to_projection_shape(self.key(normed_hidden_states)) | |
| value_states = self.to_projection_shape(self.value(normed_hidden_states)) | |
| attention_output = flash_attn_varlen_func( | |
| q=query_states, | |
| k=key_states, | |
| v=value_states, | |
| cu_seqlens_q=cu_seq_lens, | |
| cu_seqlens_k=cu_seq_lens, | |
| max_seqlen_q=max_seq_len, | |
| max_seqlen_k=max_seq_len, | |
| dropout_p=self.config.dropout_rate if self.training else 0.0, | |
| causal=True, | |
| alibi_slopes=self.alibi_slopes.to(query_states.device) if self.alibi else None, | |
| ) | |
| attention_output = attention_output.contiguous().view(-1, self.inner_dim) | |
| attention_output = self.output(attention_output) | |
| return hidden_states + (self.dropout(attention_output) if self.training else attention_output) | |
| def predict( | |
| self, | |
| hidden_states, | |
| decoder_k_cache, | |
| decoder_v_cache, | |
| decoder_cache_seqlens, | |
| ): | |
| normed_hidden_states = self.pre_layer_norm(hidden_states) | |
| query_states = self.to_projection_shape(self.query(normed_hidden_states)) | |
| key_states = self.to_projection_shape(self.key(normed_hidden_states)) | |
| value_states = self.to_projection_shape(self.value(normed_hidden_states)) | |
| attention_output = flash_attn_with_kvcache( | |
| q=query_states, | |
| k_cache=decoder_k_cache, | |
| v_cache=decoder_v_cache, | |
| k=key_states, | |
| v=value_states, | |
| cache_seqlens=decoder_cache_seqlens, | |
| causal=True, | |
| alibi_slopes=self.alibi_slopes.to(query_states.device) if self.alibi else None, | |
| ) | |
| attention_output = attention_output.contiguous().view(*hidden_states.shape) | |
| attention_output = self.output(attention_output) | |
| return hidden_states + attention_output | |
| def to_projection_shape(self, states): | |
| return states.contiguous().view( | |
| *states.shape[:-1], self.config.text_attention_heads, self.config.text_attention_kv_size | |
| ) | |
| class Pix2StructTextLayerCrossAttention(nn.Module): | |
| def __init__(self, config: Pix2StructConfig): | |
| super().__init__() | |
| self.config = config | |
| self.inner_dim = config.text_attention_kv_size * config.text_attention_heads | |
| self.dropout = config.dropout_rate | |
| self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps) | |
| self.query = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False) | |
| self.key = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) | |
| self.value = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False) | |
| self.output = nn.Linear(self.inner_dim, config.text_hidden_size, bias=False) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward( | |
| self, | |
| hidden_states, | |
| cu_seq_lens, | |
| max_seq_len, | |
| kv_hidden_states, | |
| kv_cu_seq_lens, | |
| kv_max_seq_len, | |
| ): | |
| normed_hidden_states = self.pre_layer_norm(hidden_states) | |
| query_states = self.to_projection_shape(self.query(normed_hidden_states)) | |
| key_states = self.to_projection_shape(self.key(kv_hidden_states)) | |
| value_states = self.to_projection_shape(self.value(kv_hidden_states)) | |
| attention_output = flash_attn_varlen_func( | |
| q=query_states, | |
| k=key_states, | |
| v=value_states, | |
| cu_seqlens_q=cu_seq_lens, | |
| cu_seqlens_k=kv_cu_seq_lens, | |
| max_seqlen_q=max_seq_len, | |
| max_seqlen_k=kv_max_seq_len, | |
| dropout_p=self.config.dropout_rate if self.training else 0.0, | |
| causal=False, | |
| ) | |
| attention_output = attention_output.contiguous().view(hidden_states.shape[0], self.inner_dim) | |
| attention_output = self.output(attention_output) | |
| hidden_states = hidden_states + (self.dropout(attention_output) if self.training else attention_output) | |
| return hidden_states | |
| def predict( | |
| self, | |
| hidden_states, | |
| encoder_k_cache, | |
| encoder_v_cache, | |
| encoder_cache_seqlens, | |
| encoder_cache_batch_idx, | |
| ): | |
| normed_hidden_states = self.pre_layer_norm(hidden_states) | |
| query_states = self.to_projection_shape(self.query(normed_hidden_states)) | |
| attention_output = flash_attn_with_kvcache( | |
| q=query_states, | |
| k_cache=encoder_k_cache, | |
| v_cache=encoder_v_cache, | |
| cache_seqlens=encoder_cache_seqlens, | |
| cache_batch_idx=encoder_cache_batch_idx, | |
| causal=False, | |
| ) | |
| attention_output = attention_output.contiguous().view(*hidden_states.shape) | |
| attention_output = self.output(attention_output) | |
| return hidden_states + attention_output | |
| def get_kv_cache(self, states): | |
| return self.to_projection_shape(self.key(states)), self.to_projection_shape(self.value(states)) | |
| def to_projection_shape(self, states: torch.Tensor) -> torch.Tensor: | |
| return states.contiguous().view( | |
| *states.shape[:-1], self.config.text_attention_heads, self.config.text_attention_kv_size | |
| ) | |
| class Pix2StructTextMLP(nn.Module): | |
| def __init__(self, config: Pix2StructConfig): | |
| super().__init__() | |
| self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps) | |
| self.wi_0 = nn.Linear(config.text_hidden_size, config.text_dense_act_ff_size, bias=False) | |
| self.wi_1 = nn.Linear(config.text_hidden_size, config.text_dense_act_ff_size, bias=False) | |
| self.wo = nn.Linear(config.text_dense_act_ff_size, config.text_hidden_size, bias=False) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| self.act = nn.GELU(approximate='tanh') | |
| def forward(self, hidden_states): | |
| residual = hidden_states | |
| hidden_states = self.pre_layer_norm(hidden_states) | |
| hidden_gelu = self.wi_0(hidden_states) | |
| hidden_linear = self.wi_1(hidden_states) | |
| hidden_states = GELUMulFunction.apply(hidden_gelu, hidden_linear) | |
| # hidden_gelu = self.act(self.wi_0(hidden_states)) | |
| # hidden_linear = self.wi_1(hidden_states) | |
| # hidden_states = hidden_gelu * hidden_linear | |
| if self.training: | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.wo(hidden_states) | |
| hidden_states = residual + (self.dropout(hidden_states) if self.training else hidden_states) | |
| return hidden_states | |
| def generate_alibi_slopes(num_heads): | |
| def get_slopes_power_of_two(n): | |
| start = (2 ** (-2 ** -(math.log2(n) - 3))) | |
| ratio = start | |
| return [start * ratio ** i for i in range(n)] | |
| if num_heads <= 8: | |
| slopes = get_slopes_power_of_two(num_heads) | |
| else: | |
| slopes = get_slopes_power_of_two(8) | |
| for i in range(8, num_heads): | |
| slopes.append(slopes[-1] * slopes[0]) | |
| return torch.tensor(slopes, dtype=torch.float32) | |