gberton commited on
Commit
6a109bf
·
verified ·
1 Parent(s): f4ff2b3

Upload text_encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. text_encoder.py +344 -0
text_encoder.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Text encoder implementation in PyTorch."""
17
+
18
+ import typing as t
19
+
20
+ import numpy as np
21
+ import sentencepiece as spm
22
+ import torch
23
+ from torch import nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ class Tokenizer(object):
28
+ """A simple tokenizer using SentencePiece."""
29
+
30
+ def __init__(self, tokenizer_path: str):
31
+ self.sp = spm.SentencePieceProcessor(model_file=tokenizer_path)
32
+ # Match tensorflow_text.SentencepieceTokenizer(add_bos=False, add_eos=False)
33
+ self.sp.SetEncodeExtraOptions("")
34
+ # Explicitly disable BOS/EOS to match the reference Colab implementation.
35
+ self._add_bos = False
36
+ self._add_eos = False
37
+
38
+ def tokenize(self, input_texts, max_len=64):
39
+ if isinstance(input_texts, str):
40
+ input_texts = [input_texts]
41
+ batch_ids = [
42
+ self.sp.encode(t.lower(), add_bos=self._add_bos, add_eos=self._add_eos)
43
+ for t in input_texts
44
+ ]
45
+ tokens = np.zeros((len(batch_ids), max_len), dtype=np.int64)
46
+ for i, ids in enumerate(batch_ids):
47
+ length = min(len(ids), max_len)
48
+ tokens[i, :length] = ids[:length]
49
+ is_padding = (tokens == 0).astype(np.int32)
50
+ return tokens, is_padding
51
+
52
+
53
+ class PositionalEmbedding(nn.Module):
54
+ """Generates position embedding for a given 1-d sequence.
55
+
56
+ Attributes:
57
+ min_timescale: Start of the geometric index. Determines the periodicity of
58
+ the added signal.
59
+ max_timescale: End of the geometric index. Determines the frequency of the
60
+ added signal.
61
+ embedding_dim: Dimension of the embedding to be generated.
62
+ """
63
+
64
+ min_timescale: int = 1
65
+ max_timescale: int = 10_000
66
+ embedding_dim: int = 0
67
+
68
+ def __init__(self, embedding_dim: int):
69
+ super().__init__()
70
+ self.embedding_dim = embedding_dim
71
+
72
+ def __call__(self, seq_length: int = None, position: torch.tensor = None):
73
+ """Generates a torch.tensor of sinusoids with different frequencies.
74
+
75
+ Args:
76
+ seq_length: an optional Python int defining the output sequence length.
77
+ if the `position` argument is specified.
78
+ position: [B, seq_length], optional position for each token in the
79
+ sequence, only required when the sequence is packed.
80
+
81
+ Returns:
82
+ [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
83
+ """
84
+ if position is None:
85
+ assert seq_length is not None
86
+ # [1, seqlen]
87
+ position = torch.arange(seq_length, dtype=torch.float32)[None, :]
88
+ else:
89
+ assert position.ndim == 2, position.shape
90
+
91
+ num_timescales = self.embedding_dim // 2
92
+ log_timescale_increment = torch.log(
93
+ torch.tensor(float(self.max_timescale) / float(self.min_timescale))
94
+ ) / torch.maximum(
95
+ torch.tensor(num_timescales, dtype=torch.float32) - 1, torch.tensor(1)
96
+ )
97
+ inv_timescales = self.min_timescale * torch.exp(
98
+ torch.arange(num_timescales, dtype=torch.float32)
99
+ * -log_timescale_increment
100
+ )
101
+ scaled_time = position[:, :, None] * inv_timescales[None, None, :]
102
+ signal = torch.cat((torch.sin(scaled_time), torch.cos(scaled_time)), dim=2)
103
+ # Force usage of `np` rather than `jnp` to compute static values at trace
104
+ # time.
105
+ signal = F.pad(signal, (0, self.embedding_dim % 2, 0, 0, 0, 0))
106
+ return signal
107
+
108
+
109
+ class MlpBlockWithMask(nn.Module):
110
+ """Transformer MLP / feed-forward block that supports masking."""
111
+
112
+ def __init__(
113
+ self,
114
+ mlp_dim: int,
115
+ d_model: int,
116
+ use_bias: bool = True,
117
+ dtype: torch.dtype = torch.float32,
118
+ activation_fn: nn.Module = nn.GELU,
119
+ ):
120
+ super().__init__()
121
+
122
+ self.mlp_dim = mlp_dim
123
+ self.d_model = d_model
124
+ self.use_bias = use_bias
125
+ self.dtype = dtype
126
+ self.activation_fn = activation_fn
127
+
128
+ self.c_fc = nn.Linear(
129
+ in_features=self.d_model,
130
+ out_features=self.mlp_dim,
131
+ dtype=self.dtype,
132
+ bias=self.use_bias,
133
+ )
134
+ self.c_proj = nn.Linear(
135
+ in_features=self.mlp_dim,
136
+ out_features=self.d_model,
137
+ dtype=self.dtype,
138
+ bias=self.use_bias,
139
+ )
140
+
141
+ def __call__(
142
+ self, inputs: torch.Tensor, mlp_mask: torch.Tensor
143
+ ) -> torch.Tensor:
144
+ """Applies Transformer MlpBlock with mask module."""
145
+ x = self.c_fc(inputs)
146
+ x = self.activation_fn()(x)
147
+ x = x * mlp_mask[..., None] # First masking.
148
+ x = self.c_proj(x)
149
+ x = x * mlp_mask[..., None] # Second masking.
150
+ return x
151
+
152
+
153
+ class ResidualAttentionBlock(nn.Module):
154
+ """Transformer residual attention block."""
155
+
156
+ def __init__(
157
+ self,
158
+ d_model: int,
159
+ n_head: int,
160
+ mlp_dim: int,
161
+ dtype: torch.dtype = torch.float32,
162
+ ):
163
+ super().__init__()
164
+ self.d_model = d_model
165
+ self.n_head = n_head
166
+ self.mlp_dim = mlp_dim
167
+ self.dtype = dtype
168
+
169
+ self.attn = nn.MultiheadAttention(d_model, n_head, dtype=self.dtype)
170
+ self.ln_1 = nn.LayerNorm(d_model, dtype=self.dtype)
171
+ self.mlp = MlpBlockWithMask(
172
+ self.mlp_dim,
173
+ d_model,
174
+ use_bias=True,
175
+ dtype=self.dtype,
176
+ activation_fn=nn.ReLU,
177
+ )
178
+ self.ln_2 = nn.LayerNorm(d_model, dtype=self.dtype)
179
+
180
+ def attention(self, x: torch.Tensor, mask: torch.Tensor):
181
+ attn_mask = (
182
+ mask[:, None, None, :]
183
+ .repeat(1, self.n_head, x.shape[0], 1)
184
+ .flatten(0, 1)
185
+ )
186
+ attn_mask[attn_mask == 0] = float('-inf')
187
+ attn_mask[attn_mask == 1] = 0
188
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
189
+
190
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
191
+ x = x + self.attention(self.ln_1(x), mask.permute(1, 0))
192
+ x = x + self.mlp(self.ln_2(x), mask)
193
+ return x, mask
194
+
195
+
196
+ class SequentialMultiInput(nn.Sequential):
197
+ """Sequential module that can take multiple inputs."""
198
+
199
+ def forward(self, *inputs):
200
+ for module in self._modules.values():
201
+ if isinstance(inputs, tuple):
202
+ inputs = module(*inputs)
203
+ else:
204
+ inputs = module(inputs)
205
+ return inputs
206
+
207
+
208
+ class Transformer(nn.Module):
209
+ """Transformer implementation."""
210
+
211
+ def __init__(
212
+ self,
213
+ width: int,
214
+ layers: int,
215
+ heads: int,
216
+ mlp_dim: int,
217
+ dtype: torch.dtype = torch.float32,
218
+ ):
219
+ super().__init__()
220
+ self.width = width
221
+ self.layers = layers
222
+ self.heads = heads
223
+ self.mlp_dim = mlp_dim
224
+ self.dtype = dtype
225
+
226
+ self.resblocks = SequentialMultiInput(*[
227
+ ResidualAttentionBlock(self.width, self.heads, self.mlp_dim, self.dtype)
228
+ for _ in range(self.layers)
229
+ ])
230
+
231
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
232
+ return self.resblocks(x, mask)[0]
233
+
234
+
235
+ class GlobalAvgPooling(nn.Module):
236
+ """Performs a simple global pooling over the input with optional paddings.
237
+
238
+ Attributes:
239
+ pooling_dims: A list of dims to perform pooling over.
240
+ keepdims: If True, keep dimension of inputs after pooling.
241
+ """
242
+
243
+ pooling_dims: t.Sequence[int]
244
+ epsilon: float = 1e-8
245
+
246
+ def __init__(
247
+ self, pooling_dims: t.Sequence[int], epsilon: float = 1e-8
248
+ ):
249
+ super().__init__()
250
+ self.pooling_dims = pooling_dims
251
+ self.epsilon = epsilon
252
+
253
+ if not all([p_dims >= 0 for p_dims in self.pooling_dims]):
254
+ raise ValueError('pooling_dims must be non-negative integers.')
255
+
256
+ def __call__(
257
+ self,
258
+ inputs: torch.tensor,
259
+ compatible_paddings: torch.tensor,
260
+ ):
261
+ """Applies global average spatial pooling to inputs.
262
+
263
+ Args:
264
+ inputs: An input tensor.
265
+ compatible_paddings: paddings of inputs with shapes compatible with
266
+ inputs, e.g. compatible_paddings with shape [B, 1] for inputs with shape
267
+ [B, D].
268
+
269
+ Returns:
270
+ Output tensor with global pooling applied.
271
+ """
272
+ padded_value = torch.zeros_like(inputs)
273
+ padded_value = torch.ones_like(inputs) * padded_value
274
+ inputs = torch.where(compatible_paddings > 0, padded_value, inputs)
275
+ valid_inputs = (
276
+ torch.sum(
277
+ 1.0 - compatible_paddings,
278
+ self.pooling_dims,
279
+ keepdims=True,
280
+ dtype=inputs.dtype,
281
+ )
282
+ + self.epsilon
283
+ )
284
+ inputs_sum = torch.sum(inputs, self.pooling_dims, keepdims=True)
285
+ outputs = torch.divide(inputs_sum, valid_inputs).type(inputs.dtype)
286
+ outputs = torch.squeeze(outputs, axis=self.pooling_dims)
287
+ return outputs
288
+
289
+
290
+ class TextEncoder(nn.Module):
291
+ """Text encoder implementation."""
292
+
293
+ def __init__(
294
+ self,
295
+ config: t.Dict[str, int],
296
+ vocab_size: int,
297
+ dtype: torch.dtype = torch.float32,
298
+ scale_sqrt_depth: bool = True,
299
+ ):
300
+ super().__init__()
301
+ self.vocab_size = vocab_size
302
+ self.dtype = dtype
303
+ self.scale_sqrt_depth = scale_sqrt_depth
304
+
305
+ # The text tower layers are fixed independent of vision tower size.
306
+ self.transformer_layers = config['num_layers']
307
+ self.embedding_dim = config['hidden_size']
308
+ self.transformer_width = config['hidden_size']
309
+ self.mlp_dim = config['mlp_dim']
310
+ self.transformer_heads = config['num_heads']
311
+
312
+ self.token_embedding = nn.Embedding(
313
+ self.vocab_size, self.embedding_dim, dtype=self.dtype
314
+ )
315
+ self.pos_embedder = PositionalEmbedding(embedding_dim=self.embedding_dim)
316
+ self.transformer = Transformer(
317
+ width=self.transformer_width,
318
+ layers=self.transformer_layers,
319
+ heads=self.transformer_heads,
320
+ mlp_dim=self.mlp_dim,
321
+ dtype=self.dtype,
322
+ )
323
+ self.pooling = GlobalAvgPooling(pooling_dims=[1])
324
+ self.ln_final = nn.LayerNorm(self.transformer_width, dtype=self.dtype)
325
+
326
+ def __call__(
327
+ self,
328
+ ids: torch.tensor,
329
+ paddings: torch.tensor,
330
+ ):
331
+ """Applies TextEncoder module."""
332
+ _, seq_length = ids.shape
333
+ mask = (paddings == 0).type(torch.float32)
334
+ mask = mask.permute(1, 0) # NL -> LN
335
+ x = self.token_embedding(ids)
336
+ if self.scale_sqrt_depth:
337
+ x = x * (self.embedding_dim**0.5)
338
+ x = x + self.pos_embedder(seq_length=seq_length).to(x.device)
339
+ x = x.permute(1, 0, 2) # NLD -> LND
340
+ x = self.transformer(x, mask)
341
+ x = x.permute(1, 0, 2) # LND -> NLD
342
+ x = self.ln_final(x)
343
+ x = self.pooling(x, compatible_paddings=paddings[:, :, None])
344
+ return x