import torch from torch import nn from transformers.activations import ACT2FN class TokenCompressionAdapter(nn.Module): def __init__( self, num_compressed_tokens: int, hidden_size: int, intermediate_size: int, output_size: int, hidden_act: str, num_attention_heads: int, layer_norm_eps: float ): super().__init__() self.query = nn.Parameter(torch.randn(1, num_compressed_tokens, hidden_size)) self.key = nn.Linear(hidden_size, hidden_size) self.value = nn.Linear(hidden_size, hidden_size) self.attention = torch.nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_attention_heads, batch_first=True ) self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.mlp = MLP( hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act ) self.projection = nn.Linear(hidden_size, output_size) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size = hidden_state.shape[0] query = self.query.repeat(batch_size, 1, 1) key = self.key(hidden_state) value = self.value(hidden_state) hidden_state = self.attention(query, key, value)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = self.mlp(hidden_state) hidden_state = residual + hidden_state hidden_state = self.projection(hidden_state) return hidden_state class MLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): super().__init__() self.activation_fn = ACT2FN[hidden_act] self.fc1 = nn.Linear(hidden_size, intermediate_size) self.fc1_5 = nn.Linear(intermediate_size, intermediate_size) self.fc2 = nn.Linear(intermediate_size, hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc1_5(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states