artyomxyz's picture
init
98159fd
import dataclasses
import json
import math
import os
from typing import Optional, Literal, Union
import safetensors.torch
import torch
from torch import nn
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
try:
from liger_kernel.transformers.rms_norm import LigerRMSNorm as RMSNorm
from liger_kernel.ops.geglu import LigerGELUMulFunction as GELUMulFunction
except ModuleNotFoundError:
from torch.nn import RMSNorm
class GELUMulFunction:
@staticmethod
def apply(hidden_gelu, hidden_linear):
return torch.nn.functional.gelu(hidden_gelu, approximate='tanh') * hidden_linear
@dataclasses.dataclass
class Pix2StructConfig:
vision_patch_size: int = 768
vision_hidden_size: int = 768
vision_mlp_ff_size: int = 2048
vision_layers: int = 12
vision_attention_kv_size: int = 64
vision_attention_heads: int = 12
vision_max_rows: int = 4096
vision_max_columns: int = 4096
vision_page_mode: Literal['concat', 'index'] = 'index'
vision_max_pages: int = 4096
text_vocab_size: int = 50265
text_layers: int = 12
text_hidden_size: int = 768
text_dense_act_ff_size: int = 2048
text_attention_kv_size: int = 64
text_attention_heads: int = 12
rms_norm_eps: float = 1e-6
dropout_rate: float = 0.1
@staticmethod
def from_transformers(config):
return Pix2StructConfig(
vision_patch_size=config.vision_config.patch_embed_hidden_size,
vision_hidden_size=config.vision_config.hidden_size,
vision_mlp_ff_size=config.vision_config.d_ff,
vision_layers=config.vision_config.num_hidden_layers,
vision_attention_kv_size=config.vision_config.d_kv,
vision_attention_heads=config.vision_config.num_attention_heads,
vision_max_rows=config.vision_config.seq_len,
vision_max_columns=config.vision_config.seq_len,
vision_page_mode='index',
vision_max_pages=4096,
text_vocab_size=config.text_config.vocab_size,
text_layers=config.text_config.num_layers,
text_hidden_size=config.text_config.hidden_size,
text_dense_act_ff_size=config.text_config.d_ff,
text_attention_kv_size=config.text_config.d_kv,
text_attention_heads=config.text_config.num_heads,
rms_norm_eps=config.vision_config.layer_norm_eps,
dropout_rate=config.vision_config.dropout_rate,
)
class Pix2StructModel(nn.Module):
def __init__(self, config: Pix2StructConfig):
super().__init__()
self.config = config
self.encoder = Pix2StructVisionEncoder(config)
self.decoder = Pix2StructTextDecoder(config)
def forward(
self,
flattened_patches: torch.Tensor,
flattened_patches_cu_seq_lens: torch.LongTensor,
flattened_patches_max_seq_len: int,
decoder_input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
decoder_cu_seq_lens: Optional[torch.LongTensor] = None,
decoder_max_seq_len: Optional[int] = None,
decoder_cross_cu_seq_lens: Optional[torch.LongTensor] = None,
decoder_cross_max_seq_len: Optional[int] = None,
):
encoder_hidden_states = self.encoder(
flattened_patches=flattened_patches,
cu_seq_lens=flattened_patches_cu_seq_lens,
max_seq_len=flattened_patches_max_seq_len,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
labels=labels,
cu_seq_lens=decoder_cu_seq_lens,
max_seq_len=decoder_max_seq_len,
encoder_hidden_states=encoder_hidden_states,
encoder_cu_seq_lens=flattened_patches_cu_seq_lens,
encoder_max_seq_len=flattened_patches_max_seq_len,
cross_cu_seq_lens=decoder_cross_cu_seq_lens,
cross_max_seq_len=decoder_cross_max_seq_len,
)
return decoder_outputs
def get_encoder_kv_cache(
self,
flattened_patches: torch.Tensor,
flattened_patches_cu_seq_lens: torch.LongTensor,
flattened_patches_max_seq_len: int,
):
encoder_hidden_states = self.encoder(
flattened_patches=flattened_patches,
cu_seq_lens=flattened_patches_cu_seq_lens,
max_seq_len=flattened_patches_max_seq_len,
)
encoder_hidden_states_batch = []
# print('flattened_patches', flattened_patches.shape)
# print('flattened_patches_cu_seq_lens', flattened_patches_cu_seq_lens)
for i in range(1, flattened_patches_cu_seq_lens.shape[0]):
sub_seq_len = flattened_patches_cu_seq_lens[i] - flattened_patches_cu_seq_lens[i - 1]
encoder_hidden_states_batch.append(
torch.nn.functional.pad(
encoder_hidden_states[flattened_patches_cu_seq_lens[i - 1]:flattened_patches_cu_seq_lens[i]],
(0, 0, 0, flattened_patches_max_seq_len - sub_seq_len),
)
)
# print('encoder_hidden_states_batch', i, encoder_hidden_states_batch[-1].shape, sub_seq_len, flattened_patches_max_seq_len)
encoder_hidden_states_batch = torch.stack(encoder_hidden_states_batch)
encoder_k_cache, encoder_v_cache = self.decoder.get_encoder_kv_cache(encoder_hidden_states_batch)
return {
'encoder_k_cache': encoder_k_cache,
'encoder_v_cache': encoder_v_cache,
'encoder_cache_seqlens': flattened_patches_cu_seq_lens[1:] - flattened_patches_cu_seq_lens[:-1],
}
def save(self, model_folder):
os.makedirs(model_folder, exist_ok=True)
safetensors.torch.save_file(self.state_dict(), model_folder + '/model.safetensors')
with open(model_folder + '/config.json', 'w') as f:
json.dump(dataclasses.asdict(self.config), f)
@staticmethod
def load(model_folder):
if model_folder.startswith('transformers/'):
return Pix2StructModel.from_transformers(model_folder.replace('transformers/', ''))
with open(model_folder + '/config.json', 'r') as f:
config = Pix2StructConfig(**json.load(f))
model = Pix2StructModel(config)
model.load_state_dict(safetensors.torch.load_file(model_folder + '/model.safetensors'))
return model
@staticmethod
def from_transformers(model_id):
import transformers
donor_config = transformers.Pix2StructConfig.from_pretrained(model_id)
donor_model = transformers.Pix2StructForConditionalGeneration.from_pretrained(model_id)
config = Pix2StructConfig.from_transformers(donor_config)
model = Pix2StructModel(config)
weights = donor_model.state_dict()
mapping = {
'encoder.embeddings.patch_projection.weight': 'encoder.embeddings.patch_projection.weight',
'encoder.embeddings.patch_projection.bias': 'encoder.embeddings.patch_projection.bias',
'encoder.embeddings.row_embedder.weight': 'encoder.embeddings.row_embedder.weight',
'encoder.embeddings.column_embedder.weight': 'encoder.embeddings.column_embedder.weight',
'encoder.layer_norm.weight': 'encoder.layernorm.weight',
'decoder.embed_tokens.weight': 'decoder.embed_tokens.weight',
'decoder.layer_norm.weight': 'decoder.final_layer_norm.weight',
'decoder.lm_head.weight': 'decoder.lm_head.weight'
}
for vision_layer_idx in range(config.vision_layers):
prefix = f'encoder.layers.{vision_layer_idx}'
s_prefix = f'encoder.encoder.layer.{vision_layer_idx}'
mapping[f'{prefix}.attention.pre_layer_norm.weight'] = f'{s_prefix}.pre_attention_layer_norm.weight'
mapping[f'{prefix}.attention.query.weight'] = f'{s_prefix}.attention.query.weight'
mapping[f'{prefix}.attention.key.weight'] = f'{s_prefix}.attention.key.weight'
mapping[f'{prefix}.attention.value.weight'] = f'{s_prefix}.attention.value.weight'
mapping[f'{prefix}.attention.output.weight'] = f'{s_prefix}.attention.output.weight'
mapping[f'{prefix}.mlp.pre_layer_norm.weight'] = f'{s_prefix}.pre_mlp_layer_norm.weight'
mapping[f'{prefix}.mlp.wi_0.weight'] = f'{s_prefix}.mlp.wi_0.weight'
mapping[f'{prefix}.mlp.wi_1.weight'] = f'{s_prefix}.mlp.wi_1.weight'
mapping[f'{prefix}.mlp.wo.weight'] = f'{s_prefix}.mlp.wo.weight'
for vision_layer_idx in range(config.text_layers):
prefix = f'decoder.layers.{vision_layer_idx}'
s_prefix = f'decoder.layer.{vision_layer_idx}'
mapping[f'{prefix}.encoder_decoder_attention.pre_layer_norm.weight'] = f'{s_prefix}.encoder_decoder_attention.layer_norm.weight'
mapping[f'{prefix}.encoder_decoder_attention.query.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.query.weight'
mapping[f'{prefix}.encoder_decoder_attention.key.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.key.weight'
mapping[f'{prefix}.encoder_decoder_attention.value.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.value.weight'
mapping[f'{prefix}.encoder_decoder_attention.output.weight'] = f'{s_prefix}.encoder_decoder_attention.attention.output.weight'
mapping[f'{prefix}.self_attention.pre_layer_norm.weight'] = f'{s_prefix}.self_attention.layer_norm.weight'
mapping[f'{prefix}.self_attention.query.weight'] = f'{s_prefix}.self_attention.attention.query.weight'
mapping[f'{prefix}.self_attention.key.weight'] = f'{s_prefix}.self_attention.attention.key.weight'
mapping[f'{prefix}.self_attention.value.weight'] = f'{s_prefix}.self_attention.attention.value.weight'
mapping[f'{prefix}.self_attention.output.weight'] = f'{s_prefix}.self_attention.attention.output.weight'
mapping[f'{prefix}.mlp.pre_layer_norm.weight'] = f'{s_prefix}.mlp.layer_norm.weight'
mapping[f'{prefix}.mlp.wi_0.weight'] = f'{s_prefix}.mlp.DenseReluDense.wi_0.weight'
mapping[f'{prefix}.mlp.wi_1.weight'] = f'{s_prefix}.mlp.DenseReluDense.wi_1.weight'
mapping[f'{prefix}.mlp.wo.weight'] = f'{s_prefix}.mlp.DenseReluDense.wo.weight'
state_dict = {}
for target_key, source_key in mapping.items():
state_dict[target_key] = weights[source_key]
del weights[source_key]
load_info = model.load_state_dict(state_dict, strict=False)
if len(load_info.missing_keys) > 0:
print('The following keys are missing', load_info.missing_keys)
if 'encoder.embeddings.page_embedder.weight' in load_info.missing_keys:
model.encoder.embeddings.page_embedder.weight.data.normal_(mean=0.0, std=0.02)
return model
# VISION
class Pix2StructVisionEncoder(nn.Module):
def __init__(self, config: Pix2StructConfig):
super().__init__()
self.config = config
self.embeddings = Pix2StructVisionEmbeddings(config)
self.layers = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.vision_layers)])
self.layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps)
def forward(
self,
flattened_patches: torch.Tensor, # (sum_seq_lens, vision_patch_size)
cu_seq_lens: torch.Tensor, # (seq_count)
max_seq_len: int,
) -> torch.Tensor:
hidden_states = self.embeddings(flattened_patches)
for i, layer_module in enumerate(self.layers):
hidden_states = layer_module(hidden_states, cu_seq_lens, max_seq_len)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class Pix2StructVisionLayer(nn.Module):
def __init__(self, config: Pix2StructConfig) -> None:
super().__init__()
self.attention = Pix2StructVisionAttention(config)
self.mlp = Pix2StructVisionMLP(config)
def forward(
self,
hidden_states: torch.Tensor, # (sum_seq_lens, vision_hidden_size)
cu_seq_lens: torch.Tensor, # (seq_count)
max_seq_len: int,
) -> torch.Tensor: # (sum_seq_lens, vision_hidden_size)
hidden_states = hidden_states + self.attention(
hidden_states=hidden_states,
cu_seq_lens=cu_seq_lens,
max_seq_len=max_seq_len,
)
hidden_states = hidden_states + self.mlp(hidden_states)
return hidden_states
class Pix2StructVisionAttention(nn.Module):
def __init__(self, config: Pix2StructConfig):
super().__init__()
self.config = config
self.inner_dim = config.vision_attention_kv_size * config.vision_attention_heads
self.pre_layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps)
self.query = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False)
self.key = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False)
self.value = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False)
self.output = nn.Linear(self.inner_dim, config.vision_hidden_size, bias=False)
def forward(self, hidden_states, cu_seq_lens, max_seq_len):
sum_seq_lens = hidden_states.size(0)
def to_projection_shape(states: torch.Tensor) -> torch.Tensor: # (sum_seq_lens, n_heads, dim_per_head)
return states.contiguous().view(
sum_seq_lens,
self.config.vision_attention_heads,
self.config.vision_attention_kv_size
)
hidden_states = self.pre_layer_norm(hidden_states)
query_states = to_projection_shape(self.query(hidden_states))
key_states = to_projection_shape(self.key(hidden_states))
value_states = to_projection_shape(self.value(hidden_states))
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seq_lens,
cu_seqlens_k=cu_seq_lens,
max_seqlen_q=max_seq_len,
max_seqlen_k=max_seq_len,
dropout_p=self.config.dropout_rate if self.training else 0.0,
causal=False,
)
attn_output = attn_output.contiguous().view(sum_seq_lens, self.inner_dim)
attn_output = self.output(attn_output)
return attn_output
class Pix2StructVisionMLP(nn.Module):
def __init__(self, config: Pix2StructConfig):
super().__init__()
self.pre_layer_norm = RMSNorm(config.vision_hidden_size, eps=config.rms_norm_eps)
self.wi_0 = nn.Linear(config.vision_hidden_size, config.vision_mlp_ff_size, bias=False)
self.wi_1 = nn.Linear(config.vision_hidden_size, config.vision_mlp_ff_size, bias=False)
self.wo = nn.Linear(config.vision_mlp_ff_size, config.vision_hidden_size, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = nn.GELU(approximate='tanh')
def forward(self, hidden_states):
hidden_states = self.pre_layer_norm(hidden_states)
hidden_gelu = self.wi_0(hidden_states)
hidden_linear = self.wi_1(hidden_states)
hidden_states = GELUMulFunction.apply(hidden_gelu, hidden_linear)
# hidden_gelu = self.act(self.wi_0(hidden_states))
# hidden_linear = self.wi_1(hidden_states)
# hidden_states = hidden_gelu * hidden_linear
if self.training:
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class Pix2StructVisionEmbeddings(nn.Module):
def __init__(self, config: Pix2StructConfig) -> None:
super().__init__()
self.config = config
self.patch_projection = nn.Linear(config.vision_patch_size, config.vision_hidden_size)
if config.vision_page_mode == 'index':
self.page_embedder = nn.Embedding(config.vision_max_pages, config.vision_hidden_size)
else:
self.page_embedder = None
self.row_embedder = nn.Embedding(config.vision_max_rows, config.vision_hidden_size)
self.column_embedder = nn.Embedding(config.vision_max_columns, config.vision_hidden_size)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
flattened_patches: torch.Tensor, # (sum_seq_lens, vision_patch_size)
) -> torch.Tensor: # (sum_seq_lens, vision_hidden_size)
if self.config.vision_page_mode == 'index':
page_indices = flattened_patches[:, 0].long()
row_indices = flattened_patches[:, 1].long()
col_indices = flattened_patches[:, 2].long()
flattened_patches = flattened_patches[:, 3:]
else:
page_indices = None
row_indices = flattened_patches[:, 0].long()
col_indices = flattened_patches[:, 1].long()
flattened_patches = flattened_patches[:, 2:]
embeddings = self.patch_projection(flattened_patches)
row_embeddings = self.row_embedder(row_indices)
col_embeddings = self.column_embedder(col_indices)
if self.config.vision_page_mode == 'index':
page_embeddings = self.page_embedder(page_indices)
embeddings = embeddings + page_embeddings + row_embeddings + col_embeddings
else:
embeddings = embeddings + row_embeddings + col_embeddings
if self.training:
embeddings = self.dropout(embeddings)
return embeddings
# TEXT
class Pix2StructTextDecoder(nn.Module):
def __init__(self, config: Pix2StructConfig):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.text_vocab_size, config.text_hidden_size)
self.layers = nn.ModuleList(
[Pix2StructTextLayer(config, alibi=bool(i == 0)) for i in range(config.text_layers)]
)
self.layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.text_hidden_size, config.text_vocab_size, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
input_ids: torch.LongTensor,
cu_seq_lens: torch.LongTensor,
max_seq_len: int,
encoder_hidden_states: torch.Tensor,
encoder_cu_seq_lens: torch.LongTensor,
encoder_max_seq_len: int,
cross_cu_seq_lens: torch.LongTensor,
cross_max_seq_len: int,
labels: Optional[torch.LongTensor] = None,
):
hidden_states = self.embed_tokens(input_ids)
hidden_states = self.dropout(hidden_states) if self.training else hidden_states
for i, layer_module in enumerate(self.layers):
hidden_states = layer_module(
hidden_states,
cu_seq_lens,
max_seq_len,
encoder_hidden_states,
encoder_cu_seq_lens,
encoder_max_seq_len,
cross_cu_seq_lens,
cross_max_seq_len,
)
hidden_states = self.layer_norm(hidden_states)
if self.training:
hidden_states = self.dropout(hidden_states)
if labels is not None:
logits = self.lm_head(hidden_states)
loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
loss = loss_fct(logits, labels)
return loss
else:
logits = self.lm_head(hidden_states)
return logits
def predict(
self,
input_ids: torch.LongTensor, # (batch_size, max_seq_len),
decoder_k_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)]
decoder_v_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)]
decoder_cache_seqlens: torch.IntTensor, # (batch_size_cache)
encoder_k_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)]
encoder_v_cache: torch.Tensor, # [(batch_size_cache, seqlen_cache, nheads_k, headdim)]
encoder_cache_seqlens: torch.IntTensor, # (batch_size_cache)
encoder_cache_batch_idx: torch.IntTensor, # (batch_size_cache)
):
hidden_states = self.embed_tokens(input_ids)
for i, layer_module in enumerate(self.layers):
hidden_states = layer_module.predict(
hidden_states,
decoder_k_cache[i],
decoder_v_cache[i],
decoder_cache_seqlens,
encoder_k_cache[i],
encoder_v_cache[i],
encoder_cache_seqlens,
encoder_cache_batch_idx,
)
hidden_states = self.layer_norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits
def get_decoder_kv_cache(self, device, batch_size, seq_len, dtype=torch.float32):
k_cache_layers = []
v_cache_layers = []
for _ in self.layers:
k_cache_layers.append(
torch.empty(
(batch_size, seq_len, self.config.text_attention_heads, self.config.text_attention_kv_size),
device=device, dtype=dtype
)
)
v_cache_layers.append(
torch.empty(
(batch_size, seq_len, self.config.text_attention_heads, self.config.text_attention_kv_size),
device=device, dtype=dtype
)
)
return k_cache_layers, v_cache_layers
def get_encoder_kv_cache(self, encoder_hidden_states: torch.Tensor):
k_cache_layers = []
v_cache_layers = []
for layer in self.layers:
k, v = layer.get_encoder_kv_cache(encoder_hidden_states)
k_cache_layers.append(k)
v_cache_layers.append(v)
return k_cache_layers, v_cache_layers
class Pix2StructTextLayer(nn.Module):
def __init__(self, config, alibi=False):
super().__init__()
self.self_attention = Pix2StructTextSelfAttention(config, alibi=alibi)
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
self.mlp = Pix2StructTextMLP(config)
def forward(
self,
hidden_states,
cu_seq_lens,
max_seq_len,
encoder_hidden_states,
encoder_cu_seq_lens,
encoder_max_seq_len,
cross_cu_seq_lens,
cross_max_seq_len,
):
hidden_states = self.self_attention(
hidden_states=hidden_states,
cu_seq_lens=cu_seq_lens,
max_seq_len=max_seq_len,
)
do_cross_attention = encoder_hidden_states is not None
if do_cross_attention:
hidden_states = self.encoder_decoder_attention(
hidden_states=hidden_states,
cu_seq_lens=cross_cu_seq_lens,
max_seq_len=cross_max_seq_len,
kv_hidden_states=encoder_hidden_states,
kv_cu_seq_lens=encoder_cu_seq_lens,
kv_max_seq_len=encoder_max_seq_len,
)
hidden_states = self.mlp(hidden_states)
return hidden_states
def predict(
self,
hidden_states,
decoder_k_cache,
decoder_v_cache,
decoder_cache_seqlens,
encoder_k_cache,
encoder_v_cache,
encoder_cache_seqlens,
encoder_cache_batch_idx,
):
hidden_states = self.self_attention.predict(
hidden_states,
decoder_k_cache,
decoder_v_cache,
decoder_cache_seqlens,
)
do_cross_attention = encoder_k_cache is not None
if do_cross_attention:
hidden_states = self.encoder_decoder_attention.predict(
hidden_states,
encoder_k_cache,
encoder_v_cache,
encoder_cache_seqlens,
encoder_cache_batch_idx,
)
hidden_states = self.mlp(hidden_states)
return hidden_states
def get_encoder_kv_cache(self, encoder_hidden_states):
return self.encoder_decoder_attention.get_kv_cache(encoder_hidden_states)
class Pix2StructTextSelfAttention(nn.Module):
def __init__(self, config: Pix2StructConfig, alibi=False):
super().__init__()
self.config = config
self.alibi = alibi
if alibi:
self.alibi_slopes = generate_alibi_slopes(config.text_attention_heads)
self.inner_dim = config.text_attention_kv_size * config.text_attention_heads
self.dropout = config.dropout_rate
self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps)
self.query = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False)
self.key = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False)
self.value = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False)
self.output = nn.Linear(self.inner_dim, config.text_hidden_size, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
cu_seq_lens,
max_seq_len,
):
normed_hidden_states = self.pre_layer_norm(hidden_states)
query_states = self.to_projection_shape(self.query(normed_hidden_states))
key_states = self.to_projection_shape(self.key(normed_hidden_states))
value_states = self.to_projection_shape(self.value(normed_hidden_states))
attention_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seq_lens,
cu_seqlens_k=cu_seq_lens,
max_seqlen_q=max_seq_len,
max_seqlen_k=max_seq_len,
dropout_p=self.config.dropout_rate if self.training else 0.0,
causal=True,
alibi_slopes=self.alibi_slopes.to(query_states.device) if self.alibi else None,
)
attention_output = attention_output.contiguous().view(-1, self.inner_dim)
attention_output = self.output(attention_output)
return hidden_states + (self.dropout(attention_output) if self.training else attention_output)
def predict(
self,
hidden_states,
decoder_k_cache,
decoder_v_cache,
decoder_cache_seqlens,
):
normed_hidden_states = self.pre_layer_norm(hidden_states)
query_states = self.to_projection_shape(self.query(normed_hidden_states))
key_states = self.to_projection_shape(self.key(normed_hidden_states))
value_states = self.to_projection_shape(self.value(normed_hidden_states))
attention_output = flash_attn_with_kvcache(
q=query_states,
k_cache=decoder_k_cache,
v_cache=decoder_v_cache,
k=key_states,
v=value_states,
cache_seqlens=decoder_cache_seqlens,
causal=True,
alibi_slopes=self.alibi_slopes.to(query_states.device) if self.alibi else None,
)
attention_output = attention_output.contiguous().view(*hidden_states.shape)
attention_output = self.output(attention_output)
return hidden_states + attention_output
def to_projection_shape(self, states):
return states.contiguous().view(
*states.shape[:-1], self.config.text_attention_heads, self.config.text_attention_kv_size
)
class Pix2StructTextLayerCrossAttention(nn.Module):
def __init__(self, config: Pix2StructConfig):
super().__init__()
self.config = config
self.inner_dim = config.text_attention_kv_size * config.text_attention_heads
self.dropout = config.dropout_rate
self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps)
self.query = nn.Linear(config.text_hidden_size, self.inner_dim, bias=False)
self.key = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False)
self.value = nn.Linear(config.vision_hidden_size, self.inner_dim, bias=False)
self.output = nn.Linear(self.inner_dim, config.text_hidden_size, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
cu_seq_lens,
max_seq_len,
kv_hidden_states,
kv_cu_seq_lens,
kv_max_seq_len,
):
normed_hidden_states = self.pre_layer_norm(hidden_states)
query_states = self.to_projection_shape(self.query(normed_hidden_states))
key_states = self.to_projection_shape(self.key(kv_hidden_states))
value_states = self.to_projection_shape(self.value(kv_hidden_states))
attention_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seq_lens,
cu_seqlens_k=kv_cu_seq_lens,
max_seqlen_q=max_seq_len,
max_seqlen_k=kv_max_seq_len,
dropout_p=self.config.dropout_rate if self.training else 0.0,
causal=False,
)
attention_output = attention_output.contiguous().view(hidden_states.shape[0], self.inner_dim)
attention_output = self.output(attention_output)
hidden_states = hidden_states + (self.dropout(attention_output) if self.training else attention_output)
return hidden_states
def predict(
self,
hidden_states,
encoder_k_cache,
encoder_v_cache,
encoder_cache_seqlens,
encoder_cache_batch_idx,
):
normed_hidden_states = self.pre_layer_norm(hidden_states)
query_states = self.to_projection_shape(self.query(normed_hidden_states))
attention_output = flash_attn_with_kvcache(
q=query_states,
k_cache=encoder_k_cache,
v_cache=encoder_v_cache,
cache_seqlens=encoder_cache_seqlens,
cache_batch_idx=encoder_cache_batch_idx,
causal=False,
)
attention_output = attention_output.contiguous().view(*hidden_states.shape)
attention_output = self.output(attention_output)
return hidden_states + attention_output
def get_kv_cache(self, states):
return self.to_projection_shape(self.key(states)), self.to_projection_shape(self.value(states))
def to_projection_shape(self, states: torch.Tensor) -> torch.Tensor:
return states.contiguous().view(
*states.shape[:-1], self.config.text_attention_heads, self.config.text_attention_kv_size
)
class Pix2StructTextMLP(nn.Module):
def __init__(self, config: Pix2StructConfig):
super().__init__()
self.pre_layer_norm = RMSNorm(config.text_hidden_size, eps=config.rms_norm_eps)
self.wi_0 = nn.Linear(config.text_hidden_size, config.text_dense_act_ff_size, bias=False)
self.wi_1 = nn.Linear(config.text_hidden_size, config.text_dense_act_ff_size, bias=False)
self.wo = nn.Linear(config.text_dense_act_ff_size, config.text_hidden_size, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = nn.GELU(approximate='tanh')
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.pre_layer_norm(hidden_states)
hidden_gelu = self.wi_0(hidden_states)
hidden_linear = self.wi_1(hidden_states)
hidden_states = GELUMulFunction.apply(hidden_gelu, hidden_linear)
# hidden_gelu = self.act(self.wi_0(hidden_states))
# hidden_linear = self.wi_1(hidden_states)
# hidden_states = hidden_gelu * hidden_linear
if self.training:
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
hidden_states = residual + (self.dropout(hidden_states) if self.training else hidden_states)
return hidden_states
def generate_alibi_slopes(num_heads):
def get_slopes_power_of_two(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if num_heads <= 8:
slopes = get_slopes_power_of_two(num_heads)
else:
slopes = get_slopes_power_of_two(8)
for i in range(8, num_heads):
slopes.append(slopes[-1] * slopes[0])
return torch.tensor(slopes, dtype=torch.float32)