Eli181927 commited on
Commit
b0c70d9
·
verified ·
1 Parent(s): 4faf173

Upload 2 files

Browse files
Files changed (2) hide show
  1. encode.py +216 -0
  2. requirements.txt +4 -0
encode.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encode.py
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ @dataclass
12
+ class EncoderConfig:
13
+ # Vocabulary size for source language (set from tokenizer)
14
+ src_vocab_size: int
15
+ # Model dimensions
16
+ embed_dim: int = 512
17
+ ff_hidden_dim: int = 2048
18
+ num_heads: int = 8
19
+ num_layers: int = 6
20
+ # Regularization
21
+ dropout: float = 0.1
22
+ # Max sequence length for positional embeddings
23
+ max_position_embeddings: int = 1024
24
+ # Special tokens
25
+ pad_token_id: int = 0
26
+ # Initialization scale (optional, small init helps stability)
27
+ init_range: float = 0.02
28
+
29
+
30
+ class TokenPositionalEmbedding(nn.Module):
31
+ """
32
+ Token embedding + learned positional embedding.
33
+ Shapes:
34
+ - input_ids: [B, S]
35
+ - return: [B, S, D]
36
+ """
37
+ def __init__(self, vocab_size: int, embed_dim: int,
38
+ max_position_embeddings: int, pad_token_id: int, dropout: float):
39
+ super().__init__()
40
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
41
+ self.pos_embedding = nn.Embedding(max_position_embeddings, embed_dim)
42
+ self.dropout = nn.Dropout(dropout)
43
+
44
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
45
+ batch_size, seq_len = input_ids.shape
46
+ device = input_ids.device
47
+ # [S] absolute positions 0..S-1
48
+ positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, seq_len)
49
+ x = self.token_embedding(input_ids) + self.pos_embedding(positions)
50
+ return self.dropout(x) # [B, S, D]
51
+
52
+
53
+ class MultiHeadSelfAttention(nn.Module):
54
+ """
55
+ Standard MHA (Q=K=V) with padding mask support.
56
+ Shapes:
57
+ - x: [B, S, D]
58
+ - key_padding_mask: [B, S] with True for tokens to keep OR 1/0; we convert to bool keep mask
59
+ - return: [B, S, D]
60
+ """
61
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float):
62
+ super().__init__()
63
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
64
+ self.embed_dim = embed_dim
65
+ self.num_heads = num_heads
66
+ self.head_dim = embed_dim // num_heads
67
+ self.scale = 1.0 / math.sqrt(self.head_dim)
68
+
69
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
70
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
71
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
72
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
73
+ self.attn_dropout = nn.Dropout(dropout)
74
+
75
+ def forward(self, x: torch.FloatTensor, key_padding_mask: torch.Tensor) -> torch.FloatTensor:
76
+ B, S, D = x.shape
77
+
78
+ # Project to multihead Q, K, V: [B, S, H*Hd] -> [B, H, S, Hd]
79
+ q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
80
+ k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
81
+ v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
82
+
83
+ # Attention scores: [B, H, S, Hd] @ [B, H, Hd, S] -> [B, H, S, S]
84
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
85
+
86
+ # Build broadcastable mask over keys dimension: [B, 1, 1, S]
87
+ # key_padding_mask is 1/True for valid tokens; 0/False for PADs.
88
+ if key_padding_mask.dtype != torch.bool:
89
+ keep_mask = key_padding_mask != 0
90
+ else:
91
+ keep_mask = key_padding_mask
92
+ keep_mask = keep_mask.unsqueeze(1).unsqueeze(1) # [B,1,1,S]
93
+
94
+ # Mask PAD keys by setting scores to a large negative value (excluded after softmax)
95
+ attn_scores = attn_scores.masked_fill(~keep_mask, float("-inf"))
96
+
97
+ attn_weights = F.softmax(attn_scores, dim=-1)
98
+ attn_weights = self.attn_dropout(attn_weights)
99
+
100
+ # Weighted sum of values: [B, H, S, S] @ [B, H, S, Hd] -> [B, H, S, Hd]
101
+ attn_output = torch.matmul(attn_weights, v)
102
+
103
+ # Merge heads: [B, H, S, Hd] -> [B, S, H*Hd=D]
104
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, D)
105
+ return self.out_proj(attn_output)
106
+
107
+
108
+ class FeedForward(nn.Module):
109
+ """
110
+ Position-wise MLP applied to each position independently.
111
+ Shapes:
112
+ - x: [B, S, D] -> [B, S, D]
113
+ """
114
+ def __init__(self, embed_dim: int, hidden_dim: int, dropout: float):
115
+ super().__init__()
116
+ self.fc1 = nn.Linear(embed_dim, hidden_dim)
117
+ self.fc2 = nn.Linear(hidden_dim, embed_dim)
118
+ self.dropout = nn.Dropout(dropout)
119
+ self.activation = nn.GELU()
120
+
121
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
122
+ x = self.fc1(x)
123
+ x = self.activation(x)
124
+ x = self.dropout(x)
125
+ x = self.fc2(x)
126
+ return self.dropout(x)
127
+
128
+
129
+ class EncoderBlock(nn.Module):
130
+ """
131
+ One Pre-LN encoder block: LN -> MHA -> resid, then LN -> FFN -> resid.
132
+ """
133
+ def __init__(self, embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float):
134
+ super().__init__()
135
+ self.ln1 = nn.LayerNorm(embed_dim)
136
+ self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
137
+ self.dropout1 = nn.Dropout(dropout)
138
+
139
+ self.ln2 = nn.LayerNorm(embed_dim)
140
+ self.ff = FeedForward(embed_dim, ff_hidden_dim, dropout)
141
+ self.dropout2 = nn.Dropout(dropout)
142
+
143
+ def forward(self, x: torch.FloatTensor, key_padding_mask: torch.Tensor) -> torch.FloatTensor:
144
+ # Self-attention sub-layer (Pre-LN)
145
+ attn_out = self.self_attn(self.ln1(x), key_padding_mask=key_padding_mask)
146
+ x = x + self.dropout1(attn_out)
147
+
148
+ # Feedforward sub-layer (Pre-LN)
149
+ ff_out = self.ff(self.ln2(x))
150
+ x = x + self.dropout2(ff_out)
151
+ return x
152
+
153
+
154
+ class Encoder(nn.Module):
155
+ """
156
+ Full encoder: embeddings -> N blocks -> final LayerNorm.
157
+ Forward signature:
158
+ encoder_hidden_states = Encoder(config)(src_input_ids, src_attention_mask)
159
+ """
160
+ def __init__(self, config: EncoderConfig):
161
+ super().__init__()
162
+ self.config = config
163
+ assert config.embed_dim % config.num_heads == 0, "embed_dim must be divisible by num_heads"
164
+
165
+ self.embeddings = TokenPositionalEmbedding(
166
+ vocab_size=config.src_vocab_size,
167
+ embed_dim=config.embed_dim,
168
+ max_position_embeddings=config.max_position_embeddings,
169
+ pad_token_id=config.pad_token_id,
170
+ dropout=config.dropout,
171
+ )
172
+
173
+ self.layers = nn.ModuleList([
174
+ EncoderBlock(
175
+ embed_dim=config.embed_dim,
176
+ num_heads=config.num_heads,
177
+ ff_hidden_dim=config.ff_hidden_dim,
178
+ dropout=config.dropout,
179
+ )
180
+ for _ in range(config.num_layers)
181
+ ])
182
+ self.final_ln = nn.LayerNorm(config.embed_dim)
183
+
184
+ self.apply(self._init_weights)
185
+
186
+ def _init_weights(self, module: nn.Module) -> None:
187
+ if isinstance(module, nn.Linear):
188
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.init_range)
189
+ if module.bias is not None:
190
+ nn.init.zeros_(module.bias)
191
+ elif isinstance(module, nn.Embedding):
192
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.init_range)
193
+ # Respect padding index: keep pad vectors near zero
194
+ if module.padding_idx is not None:
195
+ with torch.no_grad():
196
+ module.weight[module.padding_idx].fill_(0.0)
197
+
198
+ @torch.no_grad()
199
+ def _ensure_mask_dtype(self, mask: torch.Tensor) -> torch.Tensor:
200
+ # Accept bool or 0/1. Return bool where True means "keep".
201
+ return mask.bool() if mask.dtype != torch.bool else mask
202
+
203
+ def forward(
204
+ self,
205
+ src_input_ids: torch.LongTensor, # [B, S]
206
+ src_attention_mask: torch.Tensor, # [B, S] (1/True=token, 0/False=PAD)
207
+ ) -> torch.FloatTensor:
208
+ x = self.embeddings(src_input_ids) # [B, S, D]
209
+ keep_mask = self._ensure_mask_dtype(src_attention_mask)
210
+
211
+ for layer in self.layers:
212
+ x = layer(x, key_padding_mask=keep_mask)
213
+
214
+ x = self.final_ln(x)
215
+ x = x * keep_mask.unsqueeze(-1)
216
+ return x # [B, S, D]
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=4.0.0
3
+ pandas>=1.5.0
4
+ numpy>=1.21.0