H-Liu1997 commited on
Commit
eabfc69
·
verified ·
1 Parent(s): c785bc6

Upload models/tools/t5.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/tools/t5.py +595 -0
models/tools/t5.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .tokenizers import HuggingfaceTokenizer
11
+
12
+ __all__ = [
13
+ "T5Model",
14
+ "T5Encoder",
15
+ "T5Decoder",
16
+ "T5EncoderModel",
17
+ ]
18
+
19
+
20
+ def fp16_clamp(x):
21
+ if x.dtype == torch.float16 and torch.isinf(x).any():
22
+ clamp = torch.finfo(x.dtype).max - 1000
23
+ x = torch.clamp(x, min=-clamp, max=clamp)
24
+ return x
25
+
26
+
27
+ def init_weights(m):
28
+ if isinstance(m, T5LayerNorm):
29
+ nn.init.ones_(m.weight)
30
+ elif isinstance(m, T5Model):
31
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
32
+ elif isinstance(m, T5FeedForward):
33
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
+ elif isinstance(m, T5Attention):
37
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
38
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
41
+ elif isinstance(m, T5RelativeEmbedding):
42
+ nn.init.normal_(
43
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
44
+ )
45
+
46
+
47
+ class GELU(nn.Module):
48
+ def forward(self, x):
49
+ return (
50
+ 0.5
51
+ * x
52
+ * (
53
+ 1.0
54
+ + torch.tanh(
55
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
56
+ )
57
+ )
58
+ )
59
+
60
+
61
+ class T5LayerNorm(nn.Module):
62
+ def __init__(self, dim, eps=1e-6):
63
+ super(T5LayerNorm, self).__init__()
64
+ self.dim = dim
65
+ self.eps = eps
66
+ self.weight = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x):
69
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
70
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
71
+ x = x.type_as(self.weight)
72
+ return self.weight * x
73
+
74
+
75
+ class T5Attention(nn.Module):
76
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
77
+ assert dim_attn % num_heads == 0
78
+ super(T5Attention, self).__init__()
79
+ self.dim = dim
80
+ self.dim_attn = dim_attn
81
+ self.num_heads = num_heads
82
+ self.head_dim = dim_attn // num_heads
83
+
84
+ # layers
85
+ self.q = nn.Linear(dim, dim_attn, bias=False)
86
+ self.k = nn.Linear(dim, dim_attn, bias=False)
87
+ self.v = nn.Linear(dim, dim_attn, bias=False)
88
+ self.o = nn.Linear(dim_attn, dim, bias=False)
89
+ self.dropout = nn.Dropout(dropout)
90
+
91
+ def forward(self, x, context=None, mask=None, pos_bias=None):
92
+ """
93
+ x: [B, L1, C].
94
+ context: [B, L2, C] or None.
95
+ mask: [B, L2] or [B, L1, L2] or None.
96
+ """
97
+ # check inputs
98
+ context = x if context is None else context
99
+ b, n, c = x.size(0), self.num_heads, self.head_dim
100
+
101
+ # compute query, key, value
102
+ q = self.q(x).view(b, -1, n, c)
103
+ k = self.k(context).view(b, -1, n, c)
104
+ v = self.v(context).view(b, -1, n, c)
105
+
106
+ # attention bias
107
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
108
+ if pos_bias is not None:
109
+ attn_bias += pos_bias
110
+ if mask is not None:
111
+ assert mask.ndim in [2, 3]
112
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
113
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
114
+
115
+ # compute attention (T5 does not use scaling)
116
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
117
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
118
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
119
+
120
+ # output
121
+ x = x.reshape(b, -1, n * c)
122
+ x = self.o(x)
123
+ x = self.dropout(x)
124
+ return x
125
+
126
+
127
+ class T5FeedForward(nn.Module):
128
+ def __init__(self, dim, dim_ffn, dropout=0.1):
129
+ super(T5FeedForward, self).__init__()
130
+ self.dim = dim
131
+ self.dim_ffn = dim_ffn
132
+
133
+ # layers
134
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
135
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
136
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
137
+ self.dropout = nn.Dropout(dropout)
138
+
139
+ def forward(self, x):
140
+ x = self.fc1(x) * self.gate(x)
141
+ x = self.dropout(x)
142
+ x = self.fc2(x)
143
+ x = self.dropout(x)
144
+ return x
145
+
146
+
147
+ class T5SelfAttention(nn.Module):
148
+ def __init__(
149
+ self,
150
+ dim,
151
+ dim_attn,
152
+ dim_ffn,
153
+ num_heads,
154
+ num_buckets,
155
+ shared_pos=True,
156
+ dropout=0.1,
157
+ ):
158
+ super(T5SelfAttention, self).__init__()
159
+ self.dim = dim
160
+ self.dim_attn = dim_attn
161
+ self.dim_ffn = dim_ffn
162
+ self.num_heads = num_heads
163
+ self.num_buckets = num_buckets
164
+ self.shared_pos = shared_pos
165
+
166
+ # layers
167
+ self.norm1 = T5LayerNorm(dim)
168
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
169
+ self.norm2 = T5LayerNorm(dim)
170
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
171
+ self.pos_embedding = (
172
+ None
173
+ if shared_pos
174
+ else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
175
+ )
176
+
177
+ def forward(self, x, mask=None, pos_bias=None):
178
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
179
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
180
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
181
+ return x
182
+
183
+
184
+ class T5CrossAttention(nn.Module):
185
+ def __init__(
186
+ self,
187
+ dim,
188
+ dim_attn,
189
+ dim_ffn,
190
+ num_heads,
191
+ num_buckets,
192
+ shared_pos=True,
193
+ dropout=0.1,
194
+ ):
195
+ super(T5CrossAttention, self).__init__()
196
+ self.dim = dim
197
+ self.dim_attn = dim_attn
198
+ self.dim_ffn = dim_ffn
199
+ self.num_heads = num_heads
200
+ self.num_buckets = num_buckets
201
+ self.shared_pos = shared_pos
202
+
203
+ # layers
204
+ self.norm1 = T5LayerNorm(dim)
205
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
206
+ self.norm2 = T5LayerNorm(dim)
207
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
208
+ self.norm3 = T5LayerNorm(dim)
209
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
210
+ self.pos_embedding = (
211
+ None
212
+ if shared_pos
213
+ else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
214
+ )
215
+
216
+ def forward(
217
+ self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
218
+ ):
219
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
220
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
221
+ x = fp16_clamp(
222
+ x
223
+ + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
224
+ )
225
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
226
+ return x
227
+
228
+
229
+ class T5RelativeEmbedding(nn.Module):
230
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
231
+ super(T5RelativeEmbedding, self).__init__()
232
+ self.num_buckets = num_buckets
233
+ self.num_heads = num_heads
234
+ self.bidirectional = bidirectional
235
+ self.max_dist = max_dist
236
+
237
+ # layers
238
+ self.embedding = nn.Embedding(num_buckets, num_heads)
239
+
240
+ def forward(self, lq, lk):
241
+ device = self.embedding.weight.device
242
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
243
+ # torch.arange(lq).unsqueeze(1).to(device)
244
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
245
+ lq, device=device
246
+ ).unsqueeze(1)
247
+ rel_pos = self._relative_position_bucket(rel_pos)
248
+ rel_pos_embeds = self.embedding(rel_pos)
249
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
250
+ return rel_pos_embeds.contiguous()
251
+
252
+ def _relative_position_bucket(self, rel_pos):
253
+ # preprocess
254
+ if self.bidirectional:
255
+ num_buckets = self.num_buckets // 2
256
+ rel_buckets = (rel_pos > 0).long() * num_buckets
257
+ rel_pos = torch.abs(rel_pos)
258
+ else:
259
+ num_buckets = self.num_buckets
260
+ rel_buckets = 0
261
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
262
+
263
+ # embeddings for small and large positions
264
+ max_exact = num_buckets // 2
265
+ rel_pos_large = (
266
+ max_exact
267
+ + (
268
+ torch.log(rel_pos.float() / max_exact)
269
+ / math.log(self.max_dist / max_exact)
270
+ * (num_buckets - max_exact)
271
+ ).long()
272
+ )
273
+ rel_pos_large = torch.min(
274
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
275
+ )
276
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
277
+ return rel_buckets
278
+
279
+
280
+ class T5Encoder(nn.Module):
281
+ def __init__(
282
+ self,
283
+ vocab,
284
+ dim,
285
+ dim_attn,
286
+ dim_ffn,
287
+ num_heads,
288
+ num_layers,
289
+ num_buckets,
290
+ shared_pos=True,
291
+ dropout=0.1,
292
+ ):
293
+ super(T5Encoder, self).__init__()
294
+ self.dim = dim
295
+ self.dim_attn = dim_attn
296
+ self.dim_ffn = dim_ffn
297
+ self.num_heads = num_heads
298
+ self.num_layers = num_layers
299
+ self.num_buckets = num_buckets
300
+ self.shared_pos = shared_pos
301
+
302
+ # layers
303
+ self.token_embedding = (
304
+ vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
305
+ )
306
+ self.pos_embedding = (
307
+ T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
308
+ if shared_pos
309
+ else None
310
+ )
311
+ self.dropout = nn.Dropout(dropout)
312
+ self.blocks = nn.ModuleList(
313
+ [
314
+ T5SelfAttention(
315
+ dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
316
+ )
317
+ for _ in range(num_layers)
318
+ ]
319
+ )
320
+ self.norm = T5LayerNorm(dim)
321
+
322
+ # initialize weights
323
+ self.apply(init_weights)
324
+
325
+ def forward(self, ids, mask=None):
326
+ x = self.token_embedding(ids)
327
+ x = self.dropout(x)
328
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
329
+ for block in self.blocks:
330
+ x = block(x, mask, pos_bias=e)
331
+ x = self.norm(x)
332
+ x = self.dropout(x)
333
+ return x
334
+
335
+
336
+ class T5Decoder(nn.Module):
337
+ def __init__(
338
+ self,
339
+ vocab,
340
+ dim,
341
+ dim_attn,
342
+ dim_ffn,
343
+ num_heads,
344
+ num_layers,
345
+ num_buckets,
346
+ shared_pos=True,
347
+ dropout=0.1,
348
+ ):
349
+ super(T5Decoder, self).__init__()
350
+ self.dim = dim
351
+ self.dim_attn = dim_attn
352
+ self.dim_ffn = dim_ffn
353
+ self.num_heads = num_heads
354
+ self.num_layers = num_layers
355
+ self.num_buckets = num_buckets
356
+ self.shared_pos = shared_pos
357
+
358
+ # layers
359
+ self.token_embedding = (
360
+ vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
361
+ )
362
+ self.pos_embedding = (
363
+ T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
364
+ if shared_pos
365
+ else None
366
+ )
367
+ self.dropout = nn.Dropout(dropout)
368
+ self.blocks = nn.ModuleList(
369
+ [
370
+ T5CrossAttention(
371
+ dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
372
+ )
373
+ for _ in range(num_layers)
374
+ ]
375
+ )
376
+ self.norm = T5LayerNorm(dim)
377
+
378
+ # initialize weights
379
+ self.apply(init_weights)
380
+
381
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
382
+ b, s = ids.size()
383
+
384
+ # causal mask
385
+ if mask is None:
386
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
387
+ elif mask.ndim == 2:
388
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
389
+
390
+ # layers
391
+ x = self.token_embedding(ids)
392
+ x = self.dropout(x)
393
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
394
+ for block in self.blocks:
395
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
396
+ x = self.norm(x)
397
+ x = self.dropout(x)
398
+ return x
399
+
400
+
401
+ class T5Model(nn.Module):
402
+ def __init__(
403
+ self,
404
+ vocab_size,
405
+ dim,
406
+ dim_attn,
407
+ dim_ffn,
408
+ num_heads,
409
+ encoder_layers,
410
+ decoder_layers,
411
+ num_buckets,
412
+ shared_pos=True,
413
+ dropout=0.1,
414
+ ):
415
+ super(T5Model, self).__init__()
416
+ self.vocab_size = vocab_size
417
+ self.dim = dim
418
+ self.dim_attn = dim_attn
419
+ self.dim_ffn = dim_ffn
420
+ self.num_heads = num_heads
421
+ self.encoder_layers = encoder_layers
422
+ self.decoder_layers = decoder_layers
423
+ self.num_buckets = num_buckets
424
+
425
+ # layers
426
+ self.token_embedding = nn.Embedding(vocab_size, dim)
427
+ self.encoder = T5Encoder(
428
+ self.token_embedding,
429
+ dim,
430
+ dim_attn,
431
+ dim_ffn,
432
+ num_heads,
433
+ encoder_layers,
434
+ num_buckets,
435
+ shared_pos,
436
+ dropout,
437
+ )
438
+ self.decoder = T5Decoder(
439
+ self.token_embedding,
440
+ dim,
441
+ dim_attn,
442
+ dim_ffn,
443
+ num_heads,
444
+ decoder_layers,
445
+ num_buckets,
446
+ shared_pos,
447
+ dropout,
448
+ )
449
+ self.head = nn.Linear(dim, vocab_size, bias=False)
450
+
451
+ # initialize weights
452
+ self.apply(init_weights)
453
+
454
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
455
+ x = self.encoder(encoder_ids, encoder_mask)
456
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
457
+ x = self.head(x)
458
+ return x
459
+
460
+
461
+ def _t5(
462
+ name,
463
+ encoder_only=False,
464
+ decoder_only=False,
465
+ return_tokenizer=False,
466
+ tokenizer_kwargs={},
467
+ dtype=torch.float32,
468
+ device="cpu",
469
+ **kwargs,
470
+ ):
471
+ # sanity check
472
+ assert not (encoder_only and decoder_only)
473
+
474
+ # params
475
+ if encoder_only:
476
+ model_cls = T5Encoder
477
+ kwargs["vocab"] = kwargs.pop("vocab_size")
478
+ kwargs["num_layers"] = kwargs.pop("encoder_layers")
479
+ _ = kwargs.pop("decoder_layers")
480
+ elif decoder_only:
481
+ model_cls = T5Decoder
482
+ kwargs["vocab"] = kwargs.pop("vocab_size")
483
+ kwargs["num_layers"] = kwargs.pop("decoder_layers")
484
+ _ = kwargs.pop("encoder_layers")
485
+ else:
486
+ model_cls = T5Model
487
+
488
+ # init model
489
+ with torch.device(device):
490
+ model = model_cls(**kwargs)
491
+
492
+ # set device
493
+ model = model.to(dtype=dtype, device=device)
494
+
495
+ # init tokenizer
496
+ if return_tokenizer:
497
+ from .tokenizers import HuggingfaceTokenizer
498
+
499
+ tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
500
+ return model, tokenizer
501
+ else:
502
+ return model
503
+
504
+
505
+ def umt5_xxl(**kwargs):
506
+ cfg = dict(
507
+ vocab_size=256384,
508
+ dim=4096,
509
+ dim_attn=4096,
510
+ dim_ffn=10240,
511
+ num_heads=64,
512
+ encoder_layers=24,
513
+ decoder_layers=24,
514
+ num_buckets=32,
515
+ shared_pos=False,
516
+ dropout=0.1,
517
+ )
518
+ cfg.update(**kwargs)
519
+ return _t5("umt5-xxl", **cfg)
520
+
521
+
522
+ def umt5_base(**kwargs):
523
+ cfg = dict(
524
+ vocab_size=256384,
525
+ dim=768,
526
+ dim_attn=768,
527
+ dim_ffn=2048,
528
+ num_heads=12,
529
+ encoder_layers=12,
530
+ decoder_layers=12,
531
+ num_buckets=32,
532
+ shared_pos=False,
533
+ dropout=0.1,
534
+ )
535
+ cfg.update(**kwargs)
536
+ return _t5("umt5-base", **cfg)
537
+
538
+
539
+ # Model factory mapping
540
+ _T5_MODEL_FACTORY = {
541
+ "xxl": umt5_xxl,
542
+ "base": umt5_base,
543
+ }
544
+
545
+
546
+ class T5EncoderModel:
547
+ def __init__(
548
+ self,
549
+ text_len,
550
+ dtype=torch.bfloat16,
551
+ device=torch.cuda.current_device(),
552
+ checkpoint_path=None,
553
+ tokenizer_path=None,
554
+ shard_fn=None,
555
+ t5_size="xxl",
556
+ ):
557
+ self.text_len = text_len
558
+ self.dtype = dtype
559
+ self.device = device
560
+ self.checkpoint_path = checkpoint_path
561
+ self.tokenizer_path = tokenizer_path
562
+ self.t5_size = t5_size
563
+
564
+ # init model
565
+ if t5_size not in _T5_MODEL_FACTORY:
566
+ raise ValueError(
567
+ f"Unknown t5_size: {t5_size}. Available: {list(_T5_MODEL_FACTORY.keys())}"
568
+ )
569
+ model_fn = _T5_MODEL_FACTORY[t5_size]
570
+ model = (
571
+ model_fn(
572
+ encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
573
+ )
574
+ .eval()
575
+ .requires_grad_(False)
576
+ )
577
+ logging.info(f"loading {checkpoint_path}")
578
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
579
+ self.model = model
580
+ if shard_fn is not None:
581
+ self.model = shard_fn(self.model, sync_module_states=False)
582
+ else:
583
+ self.model.to(self.device)
584
+ # init tokenizer
585
+ self.tokenizer = HuggingfaceTokenizer(
586
+ name=tokenizer_path, seq_len=text_len, clean="whitespace"
587
+ )
588
+
589
+ def __call__(self, texts, device):
590
+ ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
591
+ ids = ids.to(device)
592
+ mask = mask.to(device)
593
+ seq_lens = mask.gt(0).sum(dim=1).long()
594
+ context = self.model(ids, mask)
595
+ return [u[:v] for u, v in zip(context, seq_lens)]