shahafw commited on
Commit
bfc1785
·
verified ·
1 Parent(s): 881e9ba

Upload token_compression.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. token_compression.py +66 -0
token_compression.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers.activations import ACT2FN
4
+
5
+
6
+ class TokenCompressionAdapter(nn.Module):
7
+
8
+ def __init__(
9
+ self,
10
+ num_compressed_tokens: int,
11
+ hidden_size: int,
12
+ intermediate_size: int,
13
+ output_size: int,
14
+ hidden_act: str,
15
+ num_attention_heads: int,
16
+ layer_norm_eps: float
17
+ ):
18
+ super().__init__()
19
+ self.query = nn.Parameter(torch.randn(1, num_compressed_tokens, hidden_size))
20
+ self.key = nn.Linear(hidden_size, hidden_size)
21
+ self.value = nn.Linear(hidden_size, hidden_size)
22
+ self.attention = torch.nn.MultiheadAttention(
23
+ embed_dim=hidden_size,
24
+ num_heads=num_attention_heads,
25
+ batch_first=True
26
+ )
27
+ self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
28
+ self.mlp = MLP(
29
+ hidden_size=hidden_size,
30
+ intermediate_size=intermediate_size,
31
+ hidden_act=hidden_act
32
+ )
33
+ self.projection = nn.Linear(hidden_size, output_size)
34
+
35
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
36
+ batch_size = hidden_state.shape[0]
37
+ query = self.query.repeat(batch_size, 1, 1)
38
+ key = self.key(hidden_state)
39
+ value = self.value(hidden_state)
40
+ hidden_state = self.attention(query, key, value)[0]
41
+
42
+ residual = hidden_state
43
+ hidden_state = self.layernorm(hidden_state)
44
+ hidden_state = self.mlp(hidden_state)
45
+ hidden_state = residual + hidden_state
46
+
47
+ hidden_state = self.projection(hidden_state)
48
+ return hidden_state
49
+
50
+
51
+ class MLP(nn.Module):
52
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
53
+ super().__init__()
54
+ self.activation_fn = ACT2FN[hidden_act]
55
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
56
+ self.fc1_5 = nn.Linear(intermediate_size, intermediate_size)
57
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
58
+
59
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
60
+ hidden_states = self.fc1(hidden_states)
61
+ hidden_states = self.activation_fn(hidden_states)
62
+ hidden_states = self.fc1_5(hidden_states)
63
+ hidden_states = self.activation_fn(hidden_states)
64
+ hidden_states = self.fc2(hidden_states)
65
+ return hidden_states
66
+