zongzhex commited on
Commit
06acd95
·
verified ·
1 Parent(s): 1c0ea5c

Add source code

Browse files
src/open_clip/.ipynb_checkpoints/tokenizer-checkpoint.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ import random
9
+ import string
10
+ from functools import lru_cache, partial
11
+ from typing import Callable, List, Optional, Union, Dict
12
+ import warnings
13
+
14
+ import ftfy
15
+ import numpy as np
16
+ import regex as re
17
+ import torch
18
+
19
+ # https://stackoverflow.com/q/62691279
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ _nltk_init = False
22
+
23
+ DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
24
+
25
+
26
+ @lru_cache()
27
+ def default_bpe():
28
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
29
+
30
+
31
+ @lru_cache()
32
+ def bytes_to_unicode():
33
+ """
34
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
35
+ The reversible bpe codes work on unicode strings.
36
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
37
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
38
+ This is a significant percentage of your normal, say, 32K bpe vocab.
39
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
40
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
41
+ """
42
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
43
+ cs = bs[:]
44
+ n = 0
45
+ for b in range(2**8):
46
+ if b not in bs:
47
+ bs.append(b)
48
+ cs.append(2**8+n)
49
+ n += 1
50
+ cs = [chr(n) for n in cs]
51
+ return dict(zip(bs, cs))
52
+
53
+
54
+ def get_pairs(word):
55
+ """Return set of symbol pairs in a word.
56
+ Word is represented as tuple of symbols (symbols being variable-length strings).
57
+ """
58
+ pairs = set()
59
+ prev_char = word[0]
60
+ for char in word[1:]:
61
+ pairs.add((prev_char, char))
62
+ prev_char = char
63
+ return pairs
64
+
65
+
66
+ def basic_clean(text):
67
+ text = ftfy.fix_text(text)
68
+ text = html.unescape(html.unescape(text))
69
+ return text.strip()
70
+
71
+
72
+ def whitespace_clean(text):
73
+ text = " ".join(text.split())
74
+ text = text.strip()
75
+ return text
76
+
77
+
78
+ def _clean_canonicalize(x):
79
+ # basic, remove whitespace, remove punctuation, lower case
80
+ return canonicalize_text(basic_clean(x))
81
+
82
+
83
+ def _clean_lower(x):
84
+ # basic, remove whitespace, lower case
85
+ return whitespace_clean(basic_clean(x)).lower()
86
+
87
+
88
+ def _clean_whitespace(x):
89
+ # basic, remove whitespace
90
+ return whitespace_clean(basic_clean(x))
91
+
92
+
93
+ def get_clean_fn(type: str):
94
+ if type == 'canonicalize':
95
+ return _clean_canonicalize
96
+ elif type == 'lower':
97
+ return _clean_lower
98
+ elif type == 'whitespace':
99
+ return _clean_whitespace
100
+ else:
101
+ assert False, f"Invalid clean function ({type})."
102
+
103
+
104
+ def canonicalize_text(
105
+ text,
106
+ *,
107
+ keep_punctuation_exact_string=None,
108
+ trans_punctuation: dict = str.maketrans("", "", string.punctuation),
109
+ ):
110
+ """Returns canonicalized `text` (lowercase and punctuation removed).
111
+
112
+ From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
113
+
114
+ Args:
115
+ text: string to be canonicalized.
116
+ keep_punctuation_exact_string: If provided, then this exact string kept.
117
+ For example providing '{}' will keep any occurrences of '{}' (but will
118
+ still remove '{' and '}' that appear separately).
119
+ """
120
+ text = text.replace("_", " ")
121
+ if keep_punctuation_exact_string:
122
+ text = keep_punctuation_exact_string.join(
123
+ part.translate(trans_punctuation)
124
+ for part in text.split(keep_punctuation_exact_string)
125
+ )
126
+ else:
127
+ text = text.translate(trans_punctuation)
128
+ text = text.lower()
129
+ text = " ".join(text.split())
130
+ return text.strip()
131
+
132
+
133
+ class SimpleTokenizer(object):
134
+ def __init__(
135
+ self,
136
+ bpe_path: str = default_bpe(),
137
+ additional_special_tokens: Optional[List[str]] = None,
138
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
139
+ clean: str = 'lower',
140
+ reduction_mask: str = ''
141
+ ):
142
+ self.byte_encoder = bytes_to_unicode()
143
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
144
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
145
+ merges = merges[1:49152-256-2+1]
146
+ merges = [tuple(merge.split()) for merge in merges]
147
+ vocab = list(bytes_to_unicode().values())
148
+ vocab = vocab + [v+'</w>' for v in vocab]
149
+ for merge in merges:
150
+ vocab.append(''.join(merge))
151
+ special_tokens = ['<start_of_text>', '<end_of_text>']
152
+ if additional_special_tokens:
153
+ special_tokens += additional_special_tokens
154
+ vocab.extend(special_tokens)
155
+ self.encoder = dict(zip(vocab, range(len(vocab))))
156
+ self.decoder = {v: k for k, v in self.encoder.items()}
157
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
158
+ self.cache = {t:t for t in special_tokens}
159
+ special = "|".join(special_tokens)
160
+ self.pat = re.compile(
161
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
162
+ re.IGNORECASE,
163
+ )
164
+ self.vocab_size = len(self.encoder)
165
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
166
+ self.sot_token_id = self.all_special_ids[0]
167
+ self.eot_token_id = self.all_special_ids[1]
168
+ self.context_length = context_length
169
+ self.clean_fn = get_clean_fn(clean)
170
+ self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None
171
+
172
+ def bpe(self, token):
173
+ if token in self.cache:
174
+ return self.cache[token]
175
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
176
+ pairs = get_pairs(word)
177
+
178
+ if not pairs:
179
+ return token+'</w>'
180
+
181
+ while True:
182
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
183
+ if bigram not in self.bpe_ranks:
184
+ break
185
+ first, second = bigram
186
+ new_word = []
187
+ i = 0
188
+ while i < len(word):
189
+ try:
190
+ j = word.index(first, i)
191
+ new_word.extend(word[i:j])
192
+ i = j
193
+ except Exception:
194
+ new_word.extend(word[i:])
195
+ break
196
+
197
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
198
+ new_word.append(first+second)
199
+ i += 2
200
+ else:
201
+ new_word.append(word[i])
202
+ i += 1
203
+ new_word = tuple(new_word)
204
+ word = new_word
205
+ if len(word) == 1:
206
+ break
207
+ else:
208
+ pairs = get_pairs(word)
209
+ word = ' '.join(word)
210
+ self.cache[token] = word
211
+ return word
212
+
213
+ def encode(self, text):
214
+ bpe_tokens = []
215
+ text = self.clean_fn(text)
216
+ for token in re.findall(self.pat, text):
217
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
218
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
219
+ return bpe_tokens
220
+
221
+ def decode(self, tokens):
222
+ text = ''.join([self.decoder[token] for token in tokens])
223
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
224
+ return text
225
+
226
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:
227
+ """ Returns the tokenized representation of given input string(s)
228
+
229
+ Parameters
230
+ ----------
231
+ texts : Union[str, List[str]]
232
+ An input string or a list of input strings to tokenize
233
+ context_length : int
234
+ The context length to use; all CLIP models use 77 as the context length
235
+
236
+ Returns
237
+ -------
238
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
239
+ """
240
+ if isinstance(texts, str):
241
+ texts = [texts]
242
+
243
+ context_length = context_length or self.context_length
244
+ assert context_length, 'Please set a valid context length'
245
+
246
+ if self.reduction_fn is not None:
247
+ # use reduction strategy for tokenize if set, otherwise default to truncation below
248
+ return self.reduction_fn(
249
+ texts,
250
+ context_length=context_length,
251
+ sot_token_id=self.sot_token_id,
252
+ eot_token_id=self.eot_token_id,
253
+ encode_fn=self.encode,
254
+ )
255
+
256
+ all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]
257
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
258
+
259
+ for i, tokens in enumerate(all_tokens):
260
+ if len(tokens) > context_length:
261
+ tokens = tokens[:context_length] # Truncate
262
+ tokens[-1] = self.eot_token_id
263
+ result[i, :len(tokens)] = torch.tensor(tokens)
264
+
265
+ return result
266
+
267
+
268
+ _tokenizer = SimpleTokenizer()
269
+
270
+
271
+ def decode(output_ids: torch.Tensor):
272
+ output_ids = output_ids.cpu().numpy()
273
+ return _tokenizer.decode(output_ids)
274
+
275
+
276
+ def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:
277
+ return _tokenizer(texts, context_length=context_length)
278
+
279
+
280
+ def random_mask_tokenize(
281
+ texts: Union[str, List[str]],
282
+ context_length: int,
283
+ sot_token_id: int,
284
+ eot_token_id: int,
285
+ encode_fn: Callable,
286
+ shuffle: bool = False,
287
+ ):
288
+ all_tokens = [encode_fn(text) for text in texts]
289
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
290
+
291
+ for i, tokens in enumerate(all_tokens):
292
+ tokens = torch.tensor(tokens)
293
+ num_tokens = len(tokens)
294
+ if num_tokens > context_length - 2: # 2 for sot and eot token
295
+ num_keep = context_length - 2
296
+ indices = torch.randperm(len(tokens))
297
+ indices = indices[:num_keep]
298
+ if not shuffle:
299
+ indices = indices.msort()
300
+ tokens = tokens[indices]
301
+ num_tokens = num_keep
302
+ result[i, 0] = sot_token_id
303
+ result[i, 1:num_tokens + 1] = tokens
304
+ result[i, num_tokens + 1] = eot_token_id
305
+
306
+ return result
307
+
308
+
309
+ def simple_mask_tokenize(
310
+ texts: Union[str, List[str]],
311
+ context_length: int,
312
+ sot_token_id: int,
313
+ eot_token_id: int,
314
+ encode_fn: Callable,
315
+ ):
316
+ all_tokens = [encode_fn(text) for text in texts]
317
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
318
+
319
+ for i, tokens in enumerate(all_tokens):
320
+ num_tokens = len(tokens)
321
+ if num_tokens > context_length - 2: # 2 for sot and eot token
322
+ num_keep = context_length - 2
323
+ start_index = random.randint(0, num_tokens - num_keep) # high is incl
324
+ tokens = tokens[start_index: start_index + num_keep]
325
+ tokens = [sot_token_id] + tokens + [eot_token_id]
326
+ result[i, :len(tokens)] = torch.tensor(tokens)
327
+
328
+ return result
329
+
330
+
331
+ def syntax_mask_tokenize(
332
+ texts: Union[str, List[str]],
333
+ context_length: int,
334
+ sot_token_id: int,
335
+ eot_token_id: int,
336
+ encode_fn: Callable,
337
+ ) -> torch.LongTensor:
338
+ """ Returns the tokenized representation of given input string(s).
339
+ Apply syntax masking before tokenize.
340
+ """
341
+ import nltk
342
+ global _nltk_init
343
+ if not _nltk_init:
344
+ # run them for the first time
345
+ nltk.download('punkt')
346
+ nltk.download('averaged_perceptron_tagger')
347
+ _nltk_init = True
348
+
349
+ def get_order(x):
350
+ if x.startswith('NN'):
351
+ return 1
352
+ elif x.startswith('JJ'):
353
+ return 2
354
+ elif x.startswith('VB'):
355
+ return 3
356
+ else:
357
+ return 4
358
+
359
+ # syntax masking
360
+ new_texts = []
361
+ for text in texts:
362
+ list_tokens = nltk.tokenize.word_tokenize(text)
363
+ pos_tags = nltk.pos_tag(list_tokens)
364
+ # sample the words by get_order method
365
+ order_list = [get_order(tag) for _, tag in pos_tags]
366
+ sorted_ids = np.argsort(np.array(order_list))
367
+ sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens
368
+ sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens
369
+
370
+ new_text = ''
371
+ for token in sampled_tokens:
372
+ new_text = new_text + str(token) + ' '
373
+ new_text = new_text.strip()
374
+ new_texts.append(new_text)
375
+ texts = new_texts
376
+
377
+ all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]
378
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
379
+
380
+ for i, tokens in enumerate(all_tokens):
381
+ # still need first truncate because some words produces two tokens
382
+ if len(tokens) > context_length:
383
+ tokens = tokens[:context_length] # Truncate
384
+ tokens[-1] = eot_token_id
385
+ result[i, :len(tokens)] = torch.tensor(tokens)
386
+
387
+ return result
388
+
389
+
390
+ def get_reduction_mask_fn(type: str):
391
+ """ Choose strategy for dropping (masking) tokens to achieve target context length"""
392
+ assert type in ('simple', 'random', 'shuffle', 'syntax')
393
+ if type == 'simple':
394
+ return simple_mask_tokenize # randomly select block [start:end]
395
+ elif type == 'random':
396
+ return random_mask_tokenize # randomly drop tokens (keep order)
397
+ elif type == 'shuffle':
398
+ return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order)
399
+ elif type == 'syntax':
400
+ return syntax_mask_tokenize # randomly drop prioritized by syntax
401
+ else:
402
+ assert False, F'Unknown type {type}.'
403
+
404
+
405
+ class HFTokenizer:
406
+ """HuggingFace tokenizer wrapper with support for custom tokenization modes"""
407
+
408
+ def __init__(
409
+ self,
410
+ tokenizer_name: str,
411
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
412
+ clean: str = 'whitespace',
413
+ strip_sep_token: bool = False,
414
+ language: Optional[str] = None,
415
+ cache_dir: Optional[str] = None,
416
+ tokenizer_mode: Optional[str] = None, # None, 'clips'
417
+ **kwargs
418
+ ):
419
+ self.tokenizer_mode = tokenizer_mode or ''
420
+ self.context_length = context_length
421
+ self.clean_fn = get_clean_fn(clean)
422
+ self.strip_sep_token = strip_sep_token
423
+
424
+ # NOTE: Left as example of loading custom tokenizer from file for experimentation
425
+ # if self.tokenizer_mode == 'bert_clips':
426
+ # self.special_tokens = {
427
+ # "bos_token": 1,
428
+ # "eos_token": 2,
429
+ # "cls_token": 101,
430
+ # "pad_token": 0
431
+ # }
432
+ #
433
+ # # For BERT CLIPS mode with vocab file
434
+ # from tokenizers import BertWordPieceTokenizer
435
+ # if tokenizer_name.startswith('hf-hub:'):
436
+ # from huggingface_hub import hf_hub_download
437
+ # # Format: hf-hub:repo_id/filename
438
+ # repo_url = tokenizer_name[7:]
439
+ # parts = repo_url.split('/')
440
+ # filename = parts[-1]
441
+ # repo_id = '/'.join(parts[:-1])
442
+ # vocab_file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
443
+ # self.tokenizer = BertWordPieceTokenizer(lowercase=True)
444
+ # self.tokenizer = self.tokenizer.from_file(vocab_file)
445
+ # else:
446
+ # # Assume tokenizer_name is a local path to a vocab file
447
+ # self.tokenizer = BertWordPieceTokenizer(lowercase=True)
448
+ # self.tokenizer = self.tokenizer.from_file(tokenizer_name)
449
+
450
+ # Standard HuggingFace tokenizer initialization
451
+ from transformers import AutoTokenizer
452
+ self.tokenizer = AutoTokenizer.from_pretrained(
453
+ tokenizer_name,
454
+ cache_dir=cache_dir,
455
+ **kwargs
456
+ )
457
+
458
+ # Set language function if available
459
+ set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
460
+ if callable(set_lang_fn):
461
+ self.set_lang_fn = set_lang_fn
462
+ if language is not None:
463
+ self.set_language(language)
464
+
465
+ def save_pretrained(self, dest):
466
+ self.tokenizer.save_pretrained(dest)
467
+
468
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
469
+ # same cleaning as for default tokenizer, except lowercasing
470
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
471
+ if isinstance(texts, str):
472
+ texts = [texts]
473
+
474
+ context_length = context_length or self.context_length
475
+ assert context_length, 'Please set a valid context length in class init or call.'
476
+
477
+ texts = [self.clean_fn(text) for text in texts]
478
+
479
+ # Handle different tokenization modes
480
+ if self.tokenizer_mode == 'clips':
481
+ return self._clips_tokenize(texts, context_length)
482
+ else:
483
+ # Standard tokenization
484
+ input_ids = self.tokenizer.batch_encode_plus(
485
+ texts,
486
+ return_tensors='pt',
487
+ max_length=context_length,
488
+ padding='max_length',
489
+ truncation=True,
490
+ ).input_ids
491
+
492
+ if self.strip_sep_token:
493
+ input_ids = torch.where(
494
+ input_ids == self.tokenizer.sep_token_id,
495
+ torch.zeros_like(input_ids),
496
+ input_ids,
497
+ )
498
+
499
+ return input_ids
500
+
501
+ def set_language(self, src_lang):
502
+ if hasattr(self, 'set_lang_fn'):
503
+ self.set_lang_fn(src_lang)
504
+ else:
505
+ warnings.warn('Cannot set language for the tokenizer.')
506
+
507
+ def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
508
+ """Use standard HF tokenizer but apply custom post-processing"""
509
+ # Use standard tokenizer without special tokens - we'll add our own
510
+ encoded_outputs = self.tokenizer.batch_encode_plus(
511
+ texts,
512
+ add_special_tokens=False,
513
+ padding=False,
514
+ truncation=False,
515
+ return_tensors=None
516
+ )
517
+
518
+ encoded = []
519
+ for tokens in encoded_outputs["input_ids"]:
520
+ tokens = tokens[:context_length - 3] # Leave room for special tokens
521
+ tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
522
+ encoded.append(tokens)
523
+
524
+ # Create result tensor and handle padding + class token
525
+ result = torch.zeros(len(encoded), context_length, dtype=torch.long)
526
+ for i, tokens in enumerate(encoded):
527
+ padded_tokens = self._pad_and_add_class_token(
528
+ tokens,
529
+ max_length=context_length,
530
+ pad_token_id=self.tokenizer.pad_token_id,
531
+ cls_token_id=self.tokenizer.cls_token_id,
532
+ )
533
+ result[i, :len(padded_tokens)] = torch.tensor(padded_tokens)
534
+
535
+ return result
536
+
537
+ def _pad_and_add_class_token(
538
+ self,
539
+ tokens: List[int],
540
+ max_length: int,
541
+ pad_token_id: int = 0,
542
+ cls_token_id: int = 101,
543
+ ) -> List[int]:
544
+ """ Add padding with class token at the end """
545
+ if len(tokens) > max_length - 1:
546
+ tokens = tokens[:max_length - 1]
547
+
548
+ # Add padding to reach max_length-1
549
+ if len(tokens) < max_length - 1:
550
+ tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
551
+
552
+ # Add class token at the end
553
+ tokens = tokens + [cls_token_id]
554
+ return tokens
555
+
556
+
557
+ class SigLipTokenizer:
558
+ """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
559
+
560
+ NOTE: this is not needed in normal library use, but is used to import new sentencepiece tokenizers
561
+ into OpenCLIP. Leaving code here in case future models use new tokenizers.
562
+ """
563
+ VOCAB_FILES = {
564
+ # english, vocab_size=32_000
565
+ "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model",
566
+ # used in multilingual models (mT5, PaLI), vocab_size=250_000
567
+ "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
568
+ # used in SigLIP2 models, vocab_size=256000
569
+ "gemma": "http://storage.googleapis.com/big_vision/gemma_tokenizer.model",
570
+ }
571
+
572
+ def __init__(
573
+ self,
574
+ tokenizer_name: str,
575
+ context_length: Optional[int] = 64,
576
+ ):
577
+ if 'gemma' in tokenizer_name:
578
+ from transformers import GemmaTokenizerFast
579
+ tokenizer_cls = partial(
580
+ GemmaTokenizerFast, padding_side='right', add_bos_token=False, add_eos_token=True)
581
+ else:
582
+ from transformers import T5TokenizerFast
583
+ tokenizer_cls = partial(T5TokenizerFast, extra_ids=0)
584
+
585
+ if tokenizer_name in self.VOCAB_FILES:
586
+ # FIXME temporary hack?
587
+ import tempfile
588
+ import fsspec
589
+ vocab_file = self.VOCAB_FILES[tokenizer_name]
590
+ with tempfile.NamedTemporaryFile('wb') as dst:
591
+ with fsspec.open(vocab_file, 'rb') as src:
592
+ dst.write(src.read())
593
+ self.tokenizer = tokenizer_cls(dst.name, legacy=False)
594
+ else:
595
+ self.tokenizer = tokenizer_cls(tokenizer_name, legacy=False)
596
+
597
+ self.tokenizer.pad_token_id = 0 if 'gemma' in tokenizer_name else 1
598
+ self.tokenizer.eos_token_id = 1
599
+ self.context_length = context_length
600
+
601
+ def save_pretrained(self, dest):
602
+ self.tokenizer.save_pretrained(dest)
603
+
604
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
605
+ # same cleaning as for default tokenizer, except lowercasing
606
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
607
+ if isinstance(texts, str):
608
+ texts = [texts]
609
+
610
+ context_length = context_length or self.context_length
611
+ assert context_length, 'Please set a valid context length in class init or call.'
612
+
613
+ texts = [canonicalize_text(basic_clean(text)) for text in texts]
614
+ output = self.tokenizer(
615
+ texts,
616
+ return_tensors='pt',
617
+ max_length=context_length,
618
+ padding='max_length',
619
+ truncation=True,
620
+ )
621
+ return output.input_ids
src/open_clip/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .factory import create_model, load_checkpoint, get_tokenizer, get_input_dtype
2
+ from .tokenizer import SimpleTokenizer
src/open_clip/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (319 Bytes). View file
 
src/open_clip/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (330 Bytes). View file
 
src/open_clip/__pycache__/biosignals_coca_model.cpython-310.pyc ADDED
Binary file (44.3 kB). View file
 
src/open_clip/__pycache__/biosignals_coca_model.cpython-313.pyc ADDED
Binary file (70.3 kB). View file
 
src/open_clip/__pycache__/coca_model.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
src/open_clip/__pycache__/coca_model.cpython-313.pyc ADDED
Binary file (21.2 kB). View file
 
src/open_clip/__pycache__/factory.cpython-310.pyc ADDED
Binary file (3.26 kB). View file
 
src/open_clip/__pycache__/factory.cpython-313.pyc ADDED
Binary file (5.05 kB). View file
 
src/open_clip/__pycache__/model.cpython-310.pyc ADDED
Binary file (24.5 kB). View file
 
src/open_clip/__pycache__/model.cpython-313.pyc ADDED
Binary file (42.6 kB). View file
 
src/open_clip/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
src/open_clip/__pycache__/tokenizer.cpython-313.pyc ADDED
Binary file (28.6 kB). View file
 
src/open_clip/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (44.1 kB). View file
 
src/open_clip/__pycache__/transformer.cpython-313.pyc ADDED
Binary file (79.7 kB). View file
 
src/open_clip/biosignals_coca_model.py ADDED
@@ -0,0 +1,1807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Biosignals-Text CoCa Model
3
+
4
+ Adapted from the original CoCa model to work with biosignals (time series) data
5
+ instead of images. This model is designed for biosignals-text contrastive learning.
6
+ """
7
+
8
+ from typing import Dict, List, Optional, Union, Tuple
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ import numpy as np
13
+ import math
14
+ from dataclasses import dataclass, field
15
+
16
+ from .transformer import (
17
+ LayerNormFp32,
18
+ LayerNorm,
19
+ QuickGELU,
20
+ MultimodalTransformer,
21
+ ConcatMultimodalTransformer,
22
+ )
23
+ from .model import CLIPTextCfg, _build_text_tower
24
+ from .coca_model import MultimodalCfg, _build_text_decoder_tower, _token_to_tensor
25
+
26
+ try:
27
+ from transformers.generation.beam_search import BeamSearchScorer
28
+ from transformers.generation.logits_process import (
29
+ LogitsProcessorList,
30
+ TopPLogitsWarper,
31
+ TopKLogitsWarper,
32
+ RepetitionPenaltyLogitsProcessor,
33
+ MinLengthLogitsProcessor,
34
+ )
35
+ from transformers.generation.stopping_criteria import (
36
+ MaxLengthCriteria,
37
+ EosTokenCriteria,
38
+ StoppingCriteriaList,
39
+ )
40
+
41
+ GENERATION_TYPES = {
42
+ "top_k": TopKLogitsWarper,
43
+ "top_p": TopPLogitsWarper,
44
+ "beam_search": "beam_search"
45
+ }
46
+ _has_transformers = True
47
+ except ImportError as e:
48
+ GENERATION_TYPES = {
49
+ "top_k": None,
50
+ "top_p": None,
51
+ "beam_search": "beam_search"
52
+ }
53
+ _has_transformers = False
54
+
55
+
56
+ # ============================================================================
57
+ # Pure Transformer Architecture Components (from PureTransformerMAE)
58
+ # ============================================================================
59
+
60
+ class RotaryEmbedding(nn.Module):
61
+ """Rotary Position Embedding (RoPE)"""
62
+ def __init__(self, dim: int, theta: float = 10000.0, learned_freq: bool = False):
63
+ super().__init__()
64
+ self.dim = dim
65
+ self.theta = theta
66
+ self.learned_freq = learned_freq
67
+
68
+ if learned_freq:
69
+ # Learnable frequencies for channel attention
70
+ self.freqs = nn.Parameter(torch.randn(dim // 2) * 0.02)
71
+ else:
72
+ # Fixed frequencies for temporal attention
73
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
74
+ self.register_buffer('freqs', freqs)
75
+
76
+ def rotate_queries_or_keys(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None):
77
+ """
78
+ Apply rotary embeddings to queries or keys
79
+
80
+ Args:
81
+ x: (batch_size, num_heads, seq_len, head_dim)
82
+ position_ids: (seq_len,) or (batch_size, seq_len) - position indices
83
+ Returns:
84
+ Rotated tensor of same shape
85
+ """
86
+ batch_size, num_heads, seq_len, head_dim = x.shape
87
+ assert head_dim == self.dim, f"head_dim {head_dim} != self.dim {self.dim}"
88
+
89
+ # Generate position indices if not provided
90
+ if position_ids is None:
91
+ position_ids = torch.arange(seq_len, device=x.device, dtype=torch.float)
92
+ elif position_ids.ndim == 2:
93
+ # If 2D, take the first batch (assuming all batches have same pattern)
94
+ position_ids = position_ids[0].float()
95
+ else:
96
+ position_ids = position_ids.float()
97
+
98
+ # Compute angles: position_ids * freqs
99
+ # position_ids: (seq_len,), freqs: (dim // 2,)
100
+ # angles: (seq_len, dim // 2)
101
+ angles = torch.einsum('s,d->sd', position_ids, self.freqs)
102
+
103
+ # Duplicate for cos and sin
104
+ # cos/sin: (seq_len, dim)
105
+ cos = torch.cos(angles).repeat_interleave(2, dim=-1)
106
+ sin = torch.sin(angles).repeat_interleave(2, dim=-1)
107
+
108
+ # Reshape for broadcasting: (1, 1, seq_len, dim)
109
+ cos = cos.unsqueeze(0).unsqueeze(0)
110
+ sin = sin.unsqueeze(0).unsqueeze(0)
111
+
112
+ # Apply rotation
113
+ # Split x into even and odd dimensions
114
+ x1 = x[..., 0::2] # Even dimensions
115
+ x2 = x[..., 1::2] # Odd dimensions
116
+
117
+ # Apply rotation: [x1, x2] @ [[cos, -sin], [sin, cos]]
118
+ x_rotated = torch.empty_like(x)
119
+ x_rotated[..., 0::2] = x1 * cos[..., 0::2] - x2 * sin[..., 0::2]
120
+ x_rotated[..., 1::2] = x1 * sin[..., 1::2] + x2 * cos[..., 1::2]
121
+
122
+ return x_rotated
123
+
124
+
125
+ class RMSNorm(nn.Module):
126
+ """Root Mean Square Layer Normalization"""
127
+ def __init__(self, dim: int, eps: float = 1e-6):
128
+ super().__init__()
129
+ self.eps = eps
130
+ self.weight = nn.Parameter(torch.ones(dim))
131
+
132
+ def _norm(self, x):
133
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
134
+
135
+ def forward(self, x):
136
+ output = self._norm(x.float()).type_as(x)
137
+ return output * self.weight
138
+
139
+
140
+ class SwiGLU(nn.Module):
141
+ """SwiGLU activation function: SiLU(x * W1) * (x * W2)"""
142
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = False):
143
+ super().__init__()
144
+ self.w1 = nn.Linear(dim_in, dim_out, bias=bias)
145
+ self.w2 = nn.Linear(dim_in, dim_out, bias=bias)
146
+
147
+ def forward(self, x):
148
+ return F.silu(self.w1(x)) * self.w2(x)
149
+
150
+
151
+ class MLP(nn.Module):
152
+ """MLP with configurable activation and normalization"""
153
+ def __init__(self,
154
+ dim: int,
155
+ hidden_dim: int,
156
+ dropout: float = 0.0,
157
+ activation: str = "swiglu", # "swiglu", "gelu", "relu"
158
+ bias: bool = False):
159
+ super().__init__()
160
+ self.activation = activation
161
+
162
+ if activation == "swiglu":
163
+ # SwiGLU requires different structure: two parallel linear layers
164
+ self.gate_proj = SwiGLU(dim, hidden_dim, bias=bias)
165
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
166
+ else:
167
+ # Standard MLP structure
168
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
169
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
170
+
171
+ if activation == "gelu":
172
+ self.act_fn = nn.GELU()
173
+ elif activation == "relu":
174
+ self.act_fn = nn.ReLU()
175
+ else:
176
+ raise ValueError(f"Unknown activation: {activation}")
177
+
178
+ self.dropout = nn.Dropout(dropout)
179
+
180
+ def forward(self, x):
181
+ if self.activation == "swiglu":
182
+ x = self.gate_proj(x)
183
+ x = self.dropout(x)
184
+ x = self.down_proj(x)
185
+ else:
186
+ x = self.up_proj(x)
187
+ x = self.act_fn(x)
188
+ x = self.dropout(x)
189
+ x = self.down_proj(x)
190
+
191
+ return self.dropout(x)
192
+
193
+
194
+ class ChannelPatching(nn.Module):
195
+ """Patching layer that operates independently on each channel"""
196
+ def __init__(self,
197
+ patch_size: int = 32,
198
+ conv_embed_dim: int = 256,
199
+ num_channels: int = 21):
200
+ super().__init__()
201
+ self.patch_size = patch_size
202
+ self.conv_embed_dim = conv_embed_dim
203
+ self.num_channels = num_channels
204
+
205
+ # Single conv layer applied to all channels (kernel_size=patch_size, stride=patch_size)
206
+ self.conv_patching = nn.Conv1d(
207
+ in_channels=1,
208
+ out_channels=conv_embed_dim,
209
+ kernel_size=patch_size,
210
+ stride=patch_size,
211
+ padding=0 # No padding for clean non-overlapping patches
212
+ )
213
+
214
+ def forward(self, x):
215
+ """
216
+ Args:
217
+ x: (batch_size, num_channels, signal_length) - multi-channel signal
218
+ Returns:
219
+ (batch_size, num_channels, num_patches, conv_embed_dim) - patched representations
220
+ """
221
+ batch_size, num_channels, seq_len = x.shape
222
+
223
+ # Reshape to process all channels independently: (batch_size * num_channels, 1, seq_len)
224
+ x_reshaped = x.reshape(batch_size * num_channels, 1, seq_len)
225
+
226
+ # Apply conv patching to all channels
227
+ patched = self.conv_patching(x_reshaped) # (batch_size * num_channels, conv_embed_dim, num_patches)
228
+
229
+ # Reshape back to separate batch and channel dimensions
230
+ _, conv_embed_dim, num_patches = patched.shape
231
+ patched = patched.reshape(batch_size, num_channels, conv_embed_dim, num_patches)
232
+
233
+ # Transpose to get (batch_size, num_channels, num_patches, conv_embed_dim)
234
+ patched = patched.transpose(2, 3)
235
+
236
+ return patched
237
+
238
+
239
+ class DualRoPEAttention(nn.Module):
240
+ """Multi-head attention with separate RoPE for temporal and learnable RoPE for channels"""
241
+ def __init__(self,
242
+ embed_dim: int = 256,
243
+ num_heads: int = 8,
244
+ dropout: float = 0.1,
245
+ attention_type: str = "temporal", # "temporal" or "channel"
246
+ num_channels: int = 21,
247
+ shared_channel_rope: Optional[nn.Module] = None):
248
+ super().__init__()
249
+ self.embed_dim = embed_dim
250
+ self.num_heads = num_heads
251
+ self.head_dim = embed_dim // num_heads
252
+ self.attention_type = attention_type
253
+
254
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
255
+
256
+ # Linear projections
257
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
258
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
259
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
260
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
261
+
262
+ # RoPE embeddings - different for temporal vs channel
263
+ if attention_type == "temporal":
264
+ # Standard RoPE for temporal attention
265
+ self.rotary_emb = RotaryEmbedding(
266
+ dim=self.head_dim,
267
+ theta=10000,
268
+ learned_freq=False
269
+ )
270
+ elif attention_type == "channel":
271
+ # Use shared learnable RoPE for channel attention if provided
272
+ if shared_channel_rope is not None:
273
+ self.rotary_emb = shared_channel_rope
274
+ else:
275
+ # Fallback to creating own RoPE
276
+ self.rotary_emb = RotaryEmbedding(
277
+ dim=self.head_dim,
278
+ theta=10000,
279
+ learned_freq=True # Learnable frequencies for channels
280
+ )
281
+ else:
282
+ raise ValueError(f"Unknown attention_type: {attention_type}")
283
+
284
+ self.dropout = nn.Dropout(dropout)
285
+ self.scale = self.head_dim ** -0.5
286
+
287
+ def forward(self, x, position_ids=None):
288
+ """
289
+ Args:
290
+ x: (batch_size, seq_len, embed_dim)
291
+ position_ids: (batch_size, seq_len) or (seq_len,) - custom position indices for RoPE
292
+ Returns:
293
+ (batch_size, seq_len, embed_dim)
294
+ """
295
+ batch_size, seq_len, embed_dim = x.shape
296
+
297
+ # Linear projections
298
+ q = self.q_proj(x)
299
+ k = self.k_proj(x)
300
+ v = self.v_proj(x)
301
+
302
+ # Reshape for multi-head attention
303
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
304
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
305
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
306
+
307
+ # Apply RoPE
308
+ q = self.rotary_emb.rotate_queries_or_keys(q, position_ids=position_ids)
309
+ k = self.rotary_emb.rotate_queries_or_keys(k, position_ids=position_ids)
310
+
311
+ # Scaled dot-product attention
312
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
313
+ attn_weights = F.softmax(attn_weights, dim=-1)
314
+ attn_weights = self.dropout(attn_weights)
315
+
316
+ # Apply attention to values
317
+ attn_output = torch.matmul(attn_weights, v)
318
+
319
+ # Reshape and project output
320
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
321
+ output = self.out_proj(attn_output)
322
+
323
+ return output
324
+
325
+
326
+ class DualTransformerBlock(nn.Module):
327
+ """Biosignal transformer block with channel and temporal attention using dual RoPE"""
328
+ def __init__(self,
329
+ embed_dim: int = 256,
330
+ num_heads: int = 8,
331
+ num_temporal_layers: int = 2,
332
+ dropout: float = 0.1,
333
+ mlp_ratio: float = 4.0,
334
+ num_channels: int = 21,
335
+ activation: str = "swiglu",
336
+ norm_type: str = "rmsnorm",
337
+ mlp_bias: bool = False,
338
+ shared_channel_rope: Optional[nn.Module] = None):
339
+ super().__init__()
340
+ self.embed_dim = embed_dim
341
+ self.num_temporal_layers = num_temporal_layers
342
+
343
+ # Helper function to create normalization layer
344
+ def create_norm(dim):
345
+ if norm_type == "rmsnorm":
346
+ return RMSNorm(dim)
347
+ elif norm_type == "layernorm":
348
+ return nn.LayerNorm(dim)
349
+ else:
350
+ raise ValueError(f"Unknown norm_type: {norm_type}")
351
+
352
+ # Channel-wise attention with shared learnable RoPE
353
+ self.channel_attention = DualRoPEAttention(
354
+ embed_dim, num_heads, dropout,
355
+ attention_type="channel", num_channels=num_channels,
356
+ shared_channel_rope=shared_channel_rope
357
+ )
358
+ self.channel_norm = create_norm(embed_dim)
359
+
360
+ # Temporal attention layers with standard RoPE
361
+ self.temporal_attention_layers = nn.ModuleList([
362
+ DualRoPEAttention(embed_dim, num_heads, dropout, attention_type="temporal")
363
+ for _ in range(num_temporal_layers)
364
+ ])
365
+ self.temporal_norms = nn.ModuleList([
366
+ create_norm(embed_dim)
367
+ for _ in range(num_temporal_layers)
368
+ ])
369
+
370
+ # MLP layers
371
+ mlp_hidden_dim = int(embed_dim * mlp_ratio)
372
+ self.channel_mlp = MLP(
373
+ dim=embed_dim,
374
+ hidden_dim=mlp_hidden_dim,
375
+ dropout=dropout,
376
+ activation=activation,
377
+ bias=mlp_bias
378
+ )
379
+
380
+ self.temporal_mlps = nn.ModuleList([
381
+ MLP(
382
+ dim=embed_dim,
383
+ hidden_dim=mlp_hidden_dim,
384
+ dropout=dropout,
385
+ activation=activation,
386
+ bias=mlp_bias
387
+ ) for _ in range(num_temporal_layers)
388
+ ])
389
+
390
+ self.channel_mlp_norm = create_norm(embed_dim)
391
+ self.temporal_mlp_norms = nn.ModuleList([
392
+ create_norm(embed_dim)
393
+ for _ in range(num_temporal_layers)
394
+ ])
395
+
396
+ def forward(self, x, temporal_position_ids=None):
397
+ """
398
+ Args:
399
+ x: (batch_size, num_channels, num_patches, embed_dim)
400
+ temporal_position_ids: (batch_size, num_patches) or (num_patches,) - position indices for temporal RoPE
401
+ Returns:
402
+ (batch_size, num_channels, num_patches, embed_dim)
403
+ """
404
+ batch_size, num_channels, num_patches, embed_dim = x.shape
405
+
406
+ # 1. Channel-wise attention on each patch independently
407
+ x_for_channel_attn = x.permute(0, 2, 1, 3).contiguous().reshape(batch_size * num_patches, num_channels, embed_dim)
408
+
409
+ # Apply channel attention with learnable RoPE
410
+ channel_attn_out = self.channel_attention(x_for_channel_attn)
411
+
412
+ # Residual connection and layer norm
413
+ x_for_channel_attn = self.channel_norm(x_for_channel_attn + channel_attn_out)
414
+
415
+ # MLP
416
+ channel_mlp_out = self.channel_mlp(x_for_channel_attn)
417
+ x_for_channel_attn = self.channel_mlp_norm(x_for_channel_attn + channel_mlp_out)
418
+
419
+ # Reshape back
420
+ x = x_for_channel_attn.reshape(batch_size, num_patches, num_channels, embed_dim).permute(0, 2, 1, 3)
421
+
422
+ # 2. Temporal attention on patches for each channel
423
+ x_for_temporal_attn = x.reshape(batch_size * num_channels, num_patches, embed_dim)
424
+
425
+ # Prepare temporal position IDs
426
+ if temporal_position_ids is not None:
427
+ if temporal_position_ids.ndim == 2:
428
+ temporal_pos_ids_expanded = temporal_position_ids[0]
429
+ else:
430
+ temporal_pos_ids_expanded = temporal_position_ids
431
+ else:
432
+ temporal_pos_ids_expanded = None
433
+
434
+ # Apply multiple temporal attention layers
435
+ for i in range(self.num_temporal_layers):
436
+ temporal_attn_out = self.temporal_attention_layers[i](x_for_temporal_attn, position_ids=temporal_pos_ids_expanded)
437
+ x_for_temporal_attn = self.temporal_norms[i](x_for_temporal_attn + temporal_attn_out)
438
+
439
+ temporal_mlp_out = self.temporal_mlps[i](x_for_temporal_attn)
440
+ x_for_temporal_attn = self.temporal_mlp_norms[i](x_for_temporal_attn + temporal_mlp_out)
441
+
442
+ # Reshape back
443
+ x = x_for_temporal_attn.reshape(batch_size, num_channels, num_patches, embed_dim)
444
+
445
+ return x
446
+
447
+
448
+ # ============================================================================
449
+ # End of Pure Transformer Architecture Components
450
+ # ============================================================================
451
+
452
+
453
+ def _build_signal_tower(
454
+ embed_dim: int,
455
+ signal_cfg,
456
+ output_tokens: bool = False,
457
+ cast_dtype: Optional[torch.dtype] = None,
458
+ ):
459
+ """Build a biosignals encoder tower
460
+
461
+ Args:
462
+ embed_dim: Output embedding dimension
463
+ signal_cfg: BiosignalsCfg or dict with configuration
464
+ output_tokens: Whether to output tokens for multimodal decoder
465
+ cast_dtype: Optional dtype for casting
466
+
467
+ Returns:
468
+ Biosignals encoder (either BiosignalsEncoder or PureTransformerBiosignalsEncoder)
469
+ """
470
+ if isinstance(signal_cfg, dict):
471
+ signal_cfg = BiosignalsCfg(**signal_cfg)
472
+
473
+ import logging
474
+ architecture = getattr(signal_cfg, 'architecture', 'conv_transformer')
475
+ logging.info(f"Building biosignals encoder with architecture: {architecture}")
476
+
477
+ if architecture == "pure_transformer":
478
+ signal_encoder = PureTransformerBiosignalsEncoder(
479
+ biosignals_cfg=signal_cfg,
480
+ embed_dim=embed_dim,
481
+ output_tokens=output_tokens,
482
+ cast_dtype=cast_dtype
483
+ )
484
+ logging.info(f"Pure Transformer architecture:")
485
+ logging.info(f" Patch size: {signal_cfg.patch_size}")
486
+ logging.info(f" Conv embed dim: {signal_cfg.conv_embed_dim}")
487
+ logging.info(f" Transformer blocks: {signal_cfg.transformer_layers}")
488
+ logging.info(f" Temporal layers per block: {signal_cfg.num_temporal_layers}")
489
+ logging.info(f" Activation: {signal_cfg.activation}")
490
+ logging.info(f" Norm type: {signal_cfg.norm_type}")
491
+ logging.info(f" Share channel RoPE: {signal_cfg.share_channel_rope}")
492
+ elif architecture == "conv_transformer":
493
+ signal_encoder = BiosignalsEncoder(
494
+ biosignals_cfg=signal_cfg,
495
+ embed_dim=embed_dim,
496
+ output_tokens=output_tokens,
497
+ cast_dtype=cast_dtype
498
+ )
499
+ logging.info(f"Conv-Transformer architecture:")
500
+ logging.info(f" Conv layers: {signal_cfg.conv_layers}")
501
+ logging.info(f" Kernel sizes: {signal_cfg.kernel_sizes}")
502
+ logging.info(f" Strides: {signal_cfg.strides}")
503
+ logging.info(f" Transformer layers: {signal_cfg.transformer_layers}")
504
+ else:
505
+ raise ValueError(f"Unknown architecture: {architecture}. Must be 'conv_transformer' or 'pure_transformer'")
506
+
507
+ return signal_encoder
508
+
509
+
510
+ def _build_text_decoder_tower_v2(
511
+ embed_dim,
512
+ multimodal_cfg,
513
+ quick_gelu: bool = False,
514
+ cast_dtype: Optional[torch.dtype] = None,
515
+ decoder_type: str = "cross_attention",
516
+ prefix_len: int = 0,
517
+ ):
518
+ """Build text decoder tower with support for different decoder types.
519
+
520
+ Args:
521
+ embed_dim: Embedding dimension
522
+ multimodal_cfg: MultimodalCfg config
523
+ quick_gelu: Whether to use QuickGELU
524
+ cast_dtype: Optional dtype for casting
525
+ decoder_type: "cross_attention" or "concat"
526
+ - "cross_attention": Uses separate cross-attention layers (default CoCa)
527
+ - "concat": Concatenates image/biosignals and text tokens
528
+ prefix_len: Number of prefix tokens (condition embeddings) prepended to text
529
+ Used to pre-build prefix-causal attention mask
530
+ """
531
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
532
+ act_layer = QuickGELU if quick_gelu else nn.GELU
533
+ norm_layer = (
534
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
535
+ )
536
+
537
+ if decoder_type == "cross_attention":
538
+ decoder = MultimodalTransformer(
539
+ context_length=multimodal_cfg.context_length,
540
+ width=multimodal_cfg.width,
541
+ heads=multimodal_cfg.heads,
542
+ layers=multimodal_cfg.layers,
543
+ mlp_ratio=multimodal_cfg.mlp_ratio,
544
+ ls_init_value=multimodal_cfg.ls_init_value,
545
+ output_dim=embed_dim,
546
+ act_layer=act_layer,
547
+ norm_layer=norm_layer,
548
+ prefix_len=prefix_len,
549
+ )
550
+ elif decoder_type == "concat":
551
+ decoder = ConcatMultimodalTransformer(
552
+ context_length=multimodal_cfg.context_length,
553
+ width=multimodal_cfg.width,
554
+ heads=multimodal_cfg.heads,
555
+ layers=multimodal_cfg.layers,
556
+ mlp_ratio=multimodal_cfg.mlp_ratio,
557
+ ls_init_value=multimodal_cfg.ls_init_value,
558
+ output_dim=embed_dim,
559
+ act_layer=act_layer,
560
+ norm_layer=norm_layer,
561
+ prefix_len=prefix_len,
562
+ )
563
+ else:
564
+ raise ValueError(f"Unknown decoder_type: {decoder_type}. Must be 'cross_attention' or 'concat'")
565
+
566
+ return decoder
567
+
568
+
569
+ @dataclass
570
+ class BiosignalsCfg:
571
+ """Configuration for biosignals encoder"""
572
+ input_channels: int = 12 # Number of input channels (e.g., 12-lead ECG)
573
+ signal_length: int = 1000 # Length of input time series
574
+ sampling_rate: int = 500 # Sampling rate in Hz
575
+
576
+ # Architecture selection
577
+ architecture: str = "conv_transformer" # "conv_transformer" or "pure_transformer"
578
+
579
+ # Architecture parameters for conv_transformer
580
+ conv_layers: List[int] = None # Conv layer dimensions
581
+ kernel_sizes: List[int] = None # Kernel sizes for conv layers
582
+ strides: List[int] = None # Strides for conv layers
583
+
584
+ # Architecture parameters for pure_transformer
585
+ patch_size: int = 32 # Patch size for pure_transformer
586
+ conv_embed_dim: int = 256 # Conv embedding dimension for pure_transformer
587
+ num_temporal_layers: int = 2 # Number of temporal attention layers per block
588
+ activation: str = "swiglu" # "swiglu", "gelu", "relu" (for pure_transformer)
589
+ norm_type: str = "rmsnorm" # "rmsnorm", "layernorm" (for pure_transformer)
590
+ mlp_bias: bool = False # Whether to use bias in MLP layers (for pure_transformer)
591
+ share_channel_rope: bool = True # Share channel RoPE across blocks (for pure_transformer)
592
+ decoder_tokens: int = 32 # Number of decoder tokens for dual-axis transformer (pure_transformer)
593
+
594
+ # Transformer parameters (shared)
595
+ transformer_layers: int = 6 # Number of transformer layers/blocks
596
+ transformer_width: int = 768 # Transformer width
597
+ transformer_heads: int = 12 # Number of attention heads
598
+ mlp_ratio: float = 4.0 # MLP expansion ratio
599
+
600
+ # Pooling and output
601
+ pool_type: str = 'attn' # 'avg', 'max', 'cls', 'attn'
602
+ dropout: float = 0.1
603
+
604
+ def __post_init__(self):
605
+ if self.architecture == "conv_transformer":
606
+ if self.conv_layers is None:
607
+ # Default conv layers for processing time series
608
+ self.conv_layers = [64, 128, 256, 512]
609
+ if self.kernel_sizes is None:
610
+ # Default kernel sizes
611
+ self.kernel_sizes = [7, 5, 3, 3]
612
+ if self.strides is None:
613
+ # Default strides
614
+ self.strides = [2, 2, 2, 2]
615
+
616
+
617
+ class BaseBiosignalsEncoder(nn.Module):
618
+ """
619
+ Base class for biosignals encoders that handles common pooling and projection logic.
620
+ Child classes should implement _encode() to return features before pooling.
621
+ """
622
+
623
+ def __init__(
624
+ self,
625
+ biosignals_cfg: BiosignalsCfg,
626
+ embed_dim: int,
627
+ output_tokens: bool,
628
+ transformer_width: int,
629
+ cast_dtype: Optional[torch.dtype] = None
630
+ ):
631
+ super().__init__()
632
+ self.biosignals_cfg = biosignals_cfg
633
+ self.embed_dim = embed_dim
634
+ self.output_tokens = output_tokens
635
+ self.transformer_width = transformer_width
636
+ self.pool_type = biosignals_cfg.pool_type
637
+
638
+ # Projection to output embedding dimension
639
+ self.proj_to_embed = nn.Linear(transformer_width, embed_dim)
640
+
641
+ # Attention pooling if needed
642
+ if self.pool_type == 'attn':
643
+ self.attn_pool = nn.MultiheadAttention(
644
+ transformer_width,
645
+ biosignals_cfg.transformer_heads,
646
+ batch_first=True
647
+ )
648
+
649
+ def _pool_features(self, x: torch.Tensor, has_cls_token: bool) -> torch.Tensor:
650
+ """
651
+ Pool features using the configured pooling method.
652
+
653
+ Args:
654
+ x: Features of shape (batch_size, seq_len, width)
655
+ has_cls_token: Whether the sequence includes a CLS token at the last position
656
+
657
+ Returns:
658
+ pooled: Pooled features of shape (batch_size, width)
659
+ """
660
+ if self.pool_type == 'cls':
661
+ # Use class token (last position)
662
+ pooled = x[:, -1]
663
+ elif self.pool_type == 'avg':
664
+ # Average pooling over sequence
665
+ if has_cls_token:
666
+ pooled = x[:, :-1].mean(dim=1)
667
+ else:
668
+ pooled = x.mean(dim=1)
669
+ elif self.pool_type == 'max':
670
+ # Max pooling over sequence
671
+ if has_cls_token:
672
+ pooled = x[:, :-1].max(dim=1)[0]
673
+ else:
674
+ pooled = x.max(dim=1)[0]
675
+ elif self.pool_type == 'attn':
676
+ # Attention pooling using cls token as query
677
+ query = x[:, -1:] # CLS token as query
678
+ # CLS attends to content tokens
679
+ pooled, _ = self.attn_pool(query, x[:, :-1], x[:, :-1])
680
+ pooled = pooled.squeeze(1)
681
+ else:
682
+ raise ValueError(f"Unknown pool_type: {self.pool_type}")
683
+
684
+ return pooled
685
+
686
+ def _encode(self, biosignals: torch.Tensor) -> Tuple[torch.Tensor, bool]:
687
+ """
688
+ Encode biosignals to features. Must be implemented by child classes.
689
+
690
+ Args:
691
+ biosignals: Input biosignals tensor
692
+
693
+ Returns:
694
+ features: Encoded features of shape (batch_size, seq_len, transformer_width)
695
+ has_cls_token: Whether the sequence includes a CLS token at the last position
696
+ """
697
+ raise NotImplementedError("Child classes must implement _encode()")
698
+
699
+ def forward(self, biosignals: torch.Tensor):
700
+ """
701
+ Forward pass with encoding, pooling, and projection.
702
+
703
+ Args:
704
+ biosignals: Input biosignals tensor
705
+
706
+ Returns:
707
+ embedding: Global embedding (batch_size, embed_dim)
708
+ tokens_for_decoder: Optional tokens for decoder (batch_size, seq_len, transformer_width)
709
+ """
710
+ # Encode to features
711
+ features, has_cls_token = self._encode(biosignals)
712
+
713
+ # Pool features
714
+ pooled = self._pool_features(features, has_cls_token)
715
+
716
+ # Project to final embedding dimension
717
+ embedding = self.proj_to_embed(pooled)
718
+
719
+ if self.output_tokens:
720
+ # Return tokens for multimodal decoder
721
+ if has_cls_token:
722
+ # Exclude CLS token from tokens for decoder
723
+ tokens_for_decoder = features[:, :-1]
724
+ else:
725
+ tokens_for_decoder = features
726
+ return embedding, tokens_for_decoder
727
+ else:
728
+ return embedding
729
+
730
+ def set_grad_checkpointing(self, enable=True):
731
+ # For compatibility with other models
732
+ pass
733
+
734
+
735
+ class Conv1dBlock(nn.Module):
736
+ """1D Convolutional block with normalization and activation"""
737
+
738
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
739
+ norm_layer=nn.BatchNorm1d, act_layer=nn.ReLU):
740
+ super().__init__()
741
+ self.conv = nn.Conv1d(
742
+ in_channels, out_channels, kernel_size,
743
+ stride=stride, padding=kernel_size//2
744
+ )
745
+ self.norm = norm_layer(out_channels)
746
+ self.act = act_layer()
747
+ self.dropout = nn.Dropout(0.1)
748
+
749
+ def forward(self, x):
750
+ x = self.conv(x)
751
+ x = self.norm(x)
752
+ x = self.act(x)
753
+ x = self.dropout(x)
754
+ return x
755
+
756
+
757
+ class BiosignalsEncoder(BaseBiosignalsEncoder):
758
+ """
759
+ Biosignals encoder that converts time series data to embeddings.
760
+ Uses a combination of 1D convolutions and transformers.
761
+ """
762
+
763
+ def __init__(
764
+ self,
765
+ biosignals_cfg: BiosignalsCfg,
766
+ embed_dim: int = 512,
767
+ output_tokens: bool = False,
768
+ cast_dtype: Optional[torch.dtype] = None
769
+ ):
770
+ # Initialize base class with common pooling/projection logic
771
+ super().__init__(
772
+ biosignals_cfg=biosignals_cfg,
773
+ embed_dim=embed_dim,
774
+ output_tokens=output_tokens,
775
+ transformer_width=biosignals_cfg.transformer_width,
776
+ cast_dtype=cast_dtype
777
+ )
778
+
779
+ # Convolutional feature extraction
780
+ conv_layers = []
781
+ in_channels = biosignals_cfg.input_channels
782
+
783
+ for i, (out_channels, kernel_size, stride) in enumerate(
784
+ zip(biosignals_cfg.conv_layers, biosignals_cfg.kernel_sizes, biosignals_cfg.strides)
785
+ ):
786
+ conv_layers.append(
787
+ Conv1dBlock(in_channels, out_channels, kernel_size, stride)
788
+ )
789
+ in_channels = out_channels
790
+
791
+ self.conv_layers = nn.Sequential(*conv_layers)
792
+
793
+ # Calculate the length after convolutions with padding - we'll use a dummy forward pass
794
+ # to get the exact dimensions
795
+ with torch.no_grad():
796
+ dummy_input = torch.randn(1, biosignals_cfg.input_channels, biosignals_cfg.signal_length)
797
+ dummy_output = self.conv_layers(dummy_input)
798
+ conv_output_length = dummy_output.shape[2]
799
+
800
+ self.conv_output_length = conv_output_length
801
+ self.conv_output_dim = biosignals_cfg.conv_layers[-1]
802
+
803
+ # Projection to transformer dimension
804
+ self.proj_conv_to_transformer = nn.Linear(
805
+ self.conv_output_dim, biosignals_cfg.transformer_width
806
+ )
807
+
808
+ # Positional embeddings for sequence positions (excluding CLS token)
809
+ # CLS token gets no positional embedding as it represents global context
810
+ self.pos_embed = nn.Parameter(
811
+ torch.randn(1, conv_output_length, biosignals_cfg.transformer_width)
812
+ )
813
+
814
+ # Add a class token for global representation (only used for 'cls' and 'attn' pooling)
815
+ self.cls_token = nn.Parameter(
816
+ torch.randn(1, 1, biosignals_cfg.transformer_width)
817
+ )
818
+
819
+ # Transformer layers
820
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
821
+ act_layer = QuickGELU
822
+
823
+ self.transformer_layers = nn.ModuleList([
824
+ TransformerBlock(
825
+ biosignals_cfg.transformer_width,
826
+ biosignals_cfg.transformer_heads,
827
+ biosignals_cfg.mlp_ratio,
828
+ act_layer=act_layer,
829
+ norm_layer=norm_layer,
830
+ dropout=biosignals_cfg.dropout
831
+ )
832
+ for _ in range(biosignals_cfg.transformer_layers)
833
+ ])
834
+
835
+ # Final layer norm
836
+ self.ln_final = norm_layer(biosignals_cfg.transformer_width)
837
+
838
+ def _encode(self, biosignals):
839
+ """
840
+ Encode biosignals to features before pooling.
841
+
842
+ Args:
843
+ biosignals: Tensor of shape (batch_size, channels, signal_length)
844
+ Returns:
845
+ features: Encoded features of shape (batch_size, seq_len, transformer_width)
846
+ has_cls_token: Whether the sequence includes a CLS token at the last position
847
+ """
848
+ batch_size = biosignals.shape[0]
849
+
850
+ # Apply convolutional layers
851
+ x = self.conv_layers(biosignals) # (batch_size, conv_dim, conv_length)
852
+
853
+ # Transpose to (batch_size, conv_length, conv_dim)
854
+ x = x.transpose(1, 2)
855
+
856
+ # Project to transformer dimension
857
+ x = self.proj_conv_to_transformer(x) # (batch_size, conv_length, transformer_width)
858
+
859
+ # Add positional embeddings
860
+ x = x + self.pos_embed
861
+
862
+ # Add class token only if needed for pooling
863
+ # For consistency with causal text encoder, append CLS token (not prepend)
864
+ if self.pool_type in ['cls', 'attn']:
865
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
866
+ x = torch.cat([x, cls_tokens], dim=1) # (batch_size, conv_length + 1, transformer_width)
867
+ has_cls_token = True
868
+ else:
869
+ has_cls_token = False
870
+
871
+ # Apply transformer layers
872
+ for layer in self.transformer_layers:
873
+ x = layer(x)
874
+
875
+ # Apply final layer norm
876
+ x = self.ln_final(x)
877
+
878
+ return x, has_cls_token
879
+
880
+
881
+ class TransformerBlock(nn.Module):
882
+ """Transformer block with self-attention and MLP"""
883
+
884
+ def __init__(
885
+ self,
886
+ width: int,
887
+ heads: int,
888
+ mlp_ratio: float = 4.0,
889
+ act_layer=QuickGELU,
890
+ norm_layer=LayerNorm,
891
+ dropout: float = 0.1
892
+ ):
893
+ super().__init__()
894
+ self.attention = nn.MultiheadAttention(width, heads, dropout=dropout, batch_first=True)
895
+ self.ln_1 = norm_layer(width)
896
+ self.mlp = nn.Sequential(
897
+ nn.Linear(width, int(width * mlp_ratio)),
898
+ act_layer(),
899
+ nn.Dropout(dropout),
900
+ nn.Linear(int(width * mlp_ratio), width),
901
+ nn.Dropout(dropout)
902
+ )
903
+ self.ln_2 = norm_layer(width)
904
+
905
+ def forward(self, x):
906
+ # Self-attention
907
+ attn_out, _ = self.attention(x, x, x)
908
+ x = x + attn_out
909
+ x = self.ln_1(x)
910
+
911
+ # MLP
912
+ mlp_out = self.mlp(x)
913
+ x = x + mlp_out
914
+ x = self.ln_2(x)
915
+
916
+ return x
917
+
918
+
919
+ class AttnPooler(nn.Module):
920
+ """
921
+ CoCa-style attentional pooler.
922
+ A small multi-head attention layer with n_query learned queries (Q),
923
+ and the encoder sequence as both K and V. This lets us:
924
+ - n_query = 1 => global embedding for contrastive loss
925
+ - n_query = N => compressed token set for decoder cross-attention
926
+ Ref: CoCa uses task-specific attentional pooling with nquery=1 for contrastive
927
+ and nquery=256 for generative objectives. [oai_citation:2‡Medium](https://medium.com/%40arithmancylabs/coca-contrastive-captioners-are-image-textfoundation-models-324022377630?utm_source=chatgpt.com)
928
+ """
929
+ def __init__(self, dim: int, num_heads: int, n_query: int):
930
+ super().__init__()
931
+ self.n_query = n_query
932
+ self.query_tokens = nn.Parameter(torch.randn(1, n_query, dim) * 0.02)
933
+ self.attn = nn.MultiheadAttention(
934
+ embed_dim=dim,
935
+ num_heads=num_heads,
936
+ batch_first=True
937
+ )
938
+
939
+ def forward(self, x_seq: torch.Tensor) -> torch.Tensor:
940
+ """
941
+ x_seq: (B, L, D)
942
+ returns:
943
+ pooled: (B, n_query, D)
944
+ """
945
+ B = x_seq.size(0)
946
+ q = self.query_tokens.expand(B, -1, -1) # (B, n_query, D)
947
+ pooled, _ = self.attn(q, x_seq, x_seq) # pooled attends over all tokens
948
+ return pooled # (B, n_query, D)
949
+
950
+
951
+ class PureTransformerBiosignalsEncoder(BaseBiosignalsEncoder):
952
+ """
953
+ Pure Transformer encoder for biosignals with channel+temporal attention.
954
+
955
+ Updated to use CoCa-style task-specific attentional pooling:
956
+ - contrastive_pooler (n_query=1) → 1 global token for contrastive / CLS
957
+ - decoder_pooler (n_query=N_dec) → small set of summary tokens for text decoder
958
+
959
+ We still:
960
+ 1. Patch each channel independently
961
+ 2. Alternate channel-attn and temporal-attn in DualTransformerBlocks (factorized attention)
962
+ 3. Keep (B, C, T, D) internally (cheap attention along channel or time separately)
963
+ 4. Flatten to (B, C*T, D) only at the end
964
+ 5. Run two poolers:
965
+ - 1-query pooler -> global token
966
+ - multi-query pooler -> decoder tokens
967
+ 6. Append the 1-query pooled token to the end of x_seq so BaseBiosignalsEncoder
968
+ can keep using pool_type='cls' or 'attn' the same way.
969
+ 7. Save the multi-query pooled tokens so, when output_tokens=True, we can hand
970
+ them to the text decoder instead of the full ~C*T sequence.
971
+
972
+ This mirrors CoCa's "task-specific attentional pooling," where the same encoder
973
+ supports both contrastive global alignment and caption-style generation with
974
+ minimal extra cost. [oai_citation:3‡Medium](https://medium.com/%40arithmancylabs/coca-contrastive-captioners-are-image-textfoundation-models-324022377630?utm_source=chatgpt.com)
975
+ """
976
+
977
+ def __init__(
978
+ self,
979
+ biosignals_cfg: BiosignalsCfg,
980
+ embed_dim: int = 512,
981
+ output_tokens: bool = False,
982
+ cast_dtype: Optional[torch.dtype] = None
983
+ ):
984
+ super().__init__(
985
+ biosignals_cfg=biosignals_cfg,
986
+ embed_dim=embed_dim,
987
+ output_tokens=output_tokens,
988
+ transformer_width=biosignals_cfg.transformer_width,
989
+ cast_dtype=cast_dtype
990
+ )
991
+
992
+ # --- Sanity checks for RoPE dimensions ---
993
+ assert biosignals_cfg.transformer_width % biosignals_cfg.transformer_heads == 0, (
994
+ f"transformer_width ({biosignals_cfg.transformer_width}) must be divisible by "
995
+ f"transformer_heads ({biosignals_cfg.transformer_heads})"
996
+ )
997
+ head_dim = biosignals_cfg.transformer_width // biosignals_cfg.transformer_heads
998
+ assert head_dim % 2 == 0, (
999
+ f"head_dim ({head_dim}) must be even for RoPE. "
1000
+ f"Got transformer_width={biosignals_cfg.transformer_width}, "
1001
+ f"transformer_heads={biosignals_cfg.transformer_heads}"
1002
+ )
1003
+
1004
+ # 1. Channel patching (Conv1d tokenizer per channel)
1005
+ self.patching = ChannelPatching(
1006
+ patch_size=biosignals_cfg.patch_size,
1007
+ conv_embed_dim=biosignals_cfg.conv_embed_dim,
1008
+ num_channels=biosignals_cfg.input_channels
1009
+ )
1010
+
1011
+ # number of temporal patches per channel
1012
+ self.num_patches = biosignals_cfg.signal_length // biosignals_cfg.patch_size
1013
+
1014
+ # 2. Project patch embeddings to transformer_width
1015
+ self.embed_projection = nn.Linear(
1016
+ biosignals_cfg.conv_embed_dim,
1017
+ biosignals_cfg.transformer_width
1018
+ )
1019
+
1020
+ # 2a. Channel ID embedding (categorical channel identity)
1021
+ self.channel_id_embed = nn.Embedding(
1022
+ num_embeddings=biosignals_cfg.input_channels,
1023
+ embedding_dim=biosignals_cfg.transformer_width,
1024
+ )
1025
+
1026
+ # 3. Shared learnable RoPE for channel attention (optional)
1027
+ if biosignals_cfg.share_channel_rope:
1028
+ shared_head_dim = biosignals_cfg.transformer_width // biosignals_cfg.transformer_heads
1029
+ self.shared_channel_rope = RotaryEmbedding(
1030
+ dim=shared_head_dim,
1031
+ theta=10000,
1032
+ learned_freq=True # learnable for channel axis
1033
+ )
1034
+ else:
1035
+ self.shared_channel_rope = None
1036
+
1037
+ # 4. Dual-axis Transformer blocks (channel attention + temporal attention)
1038
+ self.transformer_blocks = nn.ModuleList([
1039
+ DualTransformerBlock(
1040
+ embed_dim=biosignals_cfg.transformer_width,
1041
+ num_heads=biosignals_cfg.transformer_heads,
1042
+ num_temporal_layers=biosignals_cfg.num_temporal_layers,
1043
+ dropout=biosignals_cfg.dropout,
1044
+ mlp_ratio=biosignals_cfg.mlp_ratio,
1045
+ num_channels=biosignals_cfg.input_channels,
1046
+ activation=biosignals_cfg.activation,
1047
+ norm_type=biosignals_cfg.norm_type,
1048
+ mlp_bias=biosignals_cfg.mlp_bias,
1049
+ shared_channel_rope=self.shared_channel_rope if biosignals_cfg.share_channel_rope else None
1050
+ ) for _ in range(biosignals_cfg.transformer_layers)
1051
+ ])
1052
+
1053
+ # 5. Final norm
1054
+ norm_layer = (
1055
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
1056
+ )
1057
+ if biosignals_cfg.norm_type == "rmsnorm":
1058
+ self.ln_final = RMSNorm(biosignals_cfg.transformer_width)
1059
+ else:
1060
+ self.ln_final = norm_layer(biosignals_cfg.transformer_width)
1061
+
1062
+ # 6. CoCa-style attentional poolers
1063
+ # - contrastive_pooler: n_query = 1 for global CLS token (contrastive head)
1064
+ # - decoder_pooler: n_query = decoder_tokens (e.g. 32) for compressed memory
1065
+ #
1066
+ # We'll add a new config field on BiosignalsCfg: decoder_tokens (int, default 32).
1067
+ n_decoder_tokens = getattr(biosignals_cfg, "decoder_tokens", 32)
1068
+
1069
+ self.contrastive_pooler = AttnPooler(
1070
+ dim=biosignals_cfg.transformer_width,
1071
+ num_heads=biosignals_cfg.transformer_heads,
1072
+ n_query=1
1073
+ )
1074
+
1075
+ self.decoder_pooler = AttnPooler(
1076
+ dim=biosignals_cfg.transformer_width,
1077
+ num_heads=biosignals_cfg.transformer_heads,
1078
+ n_query=n_decoder_tokens
1079
+ )
1080
+
1081
+
1082
+ def _encode(self, biosignals: torch.Tensor):
1083
+ """
1084
+ Returns:
1085
+ features: (B, N_dec + 1, D)
1086
+ first N_dec tokens = pooled decoder tokens
1087
+ last token = global pooled token (contrastive CLS)
1088
+ has_cls_token: True
1089
+ """
1090
+ B = biosignals.shape[0]
1091
+ device = biosignals.device
1092
+
1093
+ # 1. Patch per channel -> (B, C, T, conv_dim)
1094
+ x = self.patching(biosignals)
1095
+
1096
+ # 2. Project to model dim -> (B, C, T, D)
1097
+ x = self.embed_projection(x)
1098
+
1099
+ # 2a. Add channel ID embedding
1100
+ _, C, T, D = x.shape
1101
+ channel_ids = torch.arange(C, device=device) # (C,)
1102
+ channel_bias = self.channel_id_embed(channel_ids) # (C, D)
1103
+ channel_bias = channel_bias.view(1, C, 1, D).expand(B, C, T, D)
1104
+ x = x + channel_bias
1105
+
1106
+ # 3. Temporal RoPE positions
1107
+ pos_ids = torch.arange(self.num_patches, device=device) # (T,)
1108
+
1109
+ # 4. Dual-axis transformer blocks (channel-attn + temporal-attn)
1110
+ for block in self.transformer_blocks:
1111
+ x = block(x, temporal_position_ids=pos_ids) # stays (B, C, T, D)
1112
+
1113
+ # 5. Final norm
1114
+ x = self.ln_final(x) # (B, C, T, D)
1115
+
1116
+ # 6. Flatten channels×time to a sequence for pooling (not for decoder!)
1117
+ x_seq = x.reshape(B, C * T, D) # (B, L, D) with L = C*T
1118
+
1119
+ # 7. Task-specific attentional pooling (CoCa-style)
1120
+ # contrastive_pooler: n_query=1 -> global_token (B,1,D)
1121
+ # decoder_pooler: n_query=Nd -> dec_tokens (B,Nd,D)
1122
+ global_token = self.contrastive_pooler(x_seq) # (B, 1, D)
1123
+ dec_tokens = self.decoder_pooler(x_seq) # (B, N_dec, D)
1124
+
1125
+ # 8. Build final feature sequence:
1126
+ # [decoder tokens..., global token] so that:
1127
+ # - features[:, :-1] = dec_tokens (for decoder cross-attn)
1128
+ # - features[:, -1] = global_token (for contrastive / CLS pooling)
1129
+ features = torch.cat([dec_tokens, global_token], dim=1) # (B, N_dec+1, D)
1130
+
1131
+ has_cls_token = True
1132
+ return features, has_cls_token
1133
+
1134
+
1135
+ class SignalReconstructionDecoder(nn.Module):
1136
+ """
1137
+ Lightweight transformer decoder for signal reconstruction.
1138
+ Uses 2-3 transformer encoder layers + final MLP to reconstruct biosignals.
1139
+ Note: Uses TransformerEncoder (self-attention only) since we don't need cross-attention.
1140
+ """
1141
+
1142
+ def __init__(
1143
+ self,
1144
+ input_dim: int = 768,
1145
+ num_layers: int = 2,
1146
+ num_heads: int = 4, # Reduced from 8 for efficiency
1147
+ output_channels: int = 10,
1148
+ output_length: int = 1920,
1149
+ ):
1150
+ super().__init__()
1151
+
1152
+ # Transformer encoder layers (self-attention + FFN)
1153
+ # Using 2x feedforward (instead of 4x) for lighter decoder
1154
+ encoder_layer = nn.TransformerEncoderLayer(
1155
+ d_model=input_dim,
1156
+ nhead=num_heads,
1157
+ dim_feedforward=input_dim * 2, # 1536 for input_dim=768
1158
+ batch_first=True,
1159
+ norm_first=True,
1160
+ )
1161
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
1162
+
1163
+ # Final MLP to project to signal space
1164
+ # Reduced intermediate dimension for efficiency
1165
+ self.to_signal = nn.Sequential(
1166
+ nn.Linear(input_dim, input_dim // 2),
1167
+ nn.ReLU(),
1168
+ nn.Linear(input_dim // 2, output_channels * output_length),
1169
+ )
1170
+
1171
+ self.output_channels = output_channels
1172
+ self.output_length = output_length
1173
+
1174
+ def forward(self, encoder_features):
1175
+ """
1176
+ Args:
1177
+ encoder_features: (B, seq_len, input_dim) - unprojected encoder features
1178
+ Returns:
1179
+ reconstructed: (B, output_channels, output_length)
1180
+ """
1181
+ B = encoder_features.shape[0]
1182
+
1183
+ # Self-attention on encoder features
1184
+ decoded = self.transformer(encoder_features) # (B, seq_len, dim)
1185
+
1186
+ # Global average pooling
1187
+ pooled = decoded.mean(dim=1) # (B, dim)
1188
+
1189
+ # Project to signal space
1190
+ signal_flat = self.to_signal(pooled) # (B, output_channels * output_length)
1191
+
1192
+ # Reshape to signal format
1193
+ signal = signal_flat.reshape(B, self.output_channels, self.output_length)
1194
+
1195
+ return signal
1196
+
1197
+
1198
+ class BiosignalsCoCa(nn.Module):
1199
+ """
1200
+ CoCa model adapted for biosignals-text contrastive learning.
1201
+ Replaces the vision tower with a biosignals encoder.
1202
+
1203
+ Supports two decoder types:
1204
+ - "cross_attention": Separate cross-attention between text and biosignals (default CoCa)
1205
+ - "concat": Concatenate biosignals and text tokens with prefix-causal masking
1206
+ """
1207
+
1208
+ def __init__(
1209
+ self,
1210
+ embed_dim,
1211
+ multimodal_cfg: MultimodalCfg,
1212
+ text_cfg: CLIPTextCfg,
1213
+ biosignals_cfg: BiosignalsCfg,
1214
+ quick_gelu: bool = False,
1215
+ init_logit_scale: float = np.log(1 / 0.07),
1216
+ init_logit_bias: Optional[float] = None,
1217
+ nonscalar_logit_scale: bool = False,
1218
+ cast_dtype: Optional[torch.dtype] = None,
1219
+ pad_id: int = 0,
1220
+ decoder_type: str = "cross_attention",
1221
+ num_caption_channels: int = 12, # Number of channel/modality embeddings (22 for channels, 4 for modalities)
1222
+ prefix_len: int = 0,
1223
+ use_signal_decoder: bool = False, # NEW: Enable signal reconstruction
1224
+ ):
1225
+ super().__init__()
1226
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
1227
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
1228
+ biosignals_cfg = BiosignalsCfg(**biosignals_cfg) if isinstance(biosignals_cfg, dict) else biosignals_cfg
1229
+
1230
+ self.decoder_type = decoder_type
1231
+ self.num_channels = num_caption_channels
1232
+ self.use_signal_decoder = use_signal_decoder
1233
+
1234
+ # Debug logging for channel configuration
1235
+ import logging
1236
+ logging.info(f"BiosignalsCoCa initialized with num_caption_channels={num_caption_channels}, prefix_len={prefix_len}")
1237
+ if use_signal_decoder:
1238
+ logging.info(f"Signal reconstruction decoder enabled")
1239
+
1240
+ self.text = _build_text_tower(
1241
+ embed_dim=embed_dim,
1242
+ text_cfg=text_cfg,
1243
+ quick_gelu=quick_gelu,
1244
+ cast_dtype=cast_dtype,
1245
+ )
1246
+
1247
+ vocab_size = (
1248
+ self.text.vocab_size # for hf models
1249
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
1250
+ else text_cfg.vocab_size
1251
+ )
1252
+
1253
+ # Replace visual tower with biosignals tower
1254
+ self.biosignals = _build_signal_tower(
1255
+ embed_dim=embed_dim,
1256
+ signal_cfg=biosignals_cfg,
1257
+ output_tokens=True, # Need tokens for multimodal decoder
1258
+ cast_dtype=cast_dtype,
1259
+ )
1260
+
1261
+ self.text_decoder = _build_text_decoder_tower_v2(
1262
+ vocab_size,
1263
+ multimodal_cfg=multimodal_cfg,
1264
+ quick_gelu=quick_gelu,
1265
+ cast_dtype=cast_dtype,
1266
+ decoder_type=decoder_type,
1267
+ prefix_len=prefix_len,
1268
+ )
1269
+
1270
+ lshape = [1] if nonscalar_logit_scale else []
1271
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
1272
+ if init_logit_bias is not None:
1273
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
1274
+ else:
1275
+ self.logit_bias = None
1276
+ self.pad_id = pad_id
1277
+
1278
+ self.context_length = multimodal_cfg.context_length
1279
+
1280
+ # Learnable channel/modality embeddings
1281
+ # num_caption_channels will be 23 for individual channel mode or 5 for modality mode
1282
+ # Dimension should match the decoder width (multimodal_cfg.width for text decoder input)
1283
+ self.channel_embeddings = nn.Parameter(
1284
+ torch.randn(num_caption_channels, multimodal_cfg.width) * 0.02
1285
+ )
1286
+
1287
+ # Learnable padding embedding for -1 positions
1288
+ # This learns to be "neutral" or ignored during training (similar to [PAD] tokens)
1289
+ self.padding_embedding = nn.Parameter(
1290
+ torch.randn(multimodal_cfg.width) * 0.02
1291
+ )
1292
+
1293
+ self.decoder_width = multimodal_cfg.width
1294
+
1295
+ # Optional signal reconstruction decoder
1296
+ if use_signal_decoder:
1297
+ self.signal_decoder = SignalReconstructionDecoder(
1298
+ input_dim=biosignals_cfg.transformer_width,
1299
+ num_layers=2, # Lightweight: 2 transformer layers
1300
+ num_heads=biosignals_cfg.transformer_heads,
1301
+ output_channels=biosignals_cfg.input_channels,
1302
+ output_length=biosignals_cfg.signal_length,
1303
+ )
1304
+
1305
+ @torch.jit.ignore
1306
+ def set_grad_checkpointing(self, enable: bool = True):
1307
+ self.biosignals.set_grad_checkpointing(enable)
1308
+ self.text.set_grad_checkpointing(enable)
1309
+ self.text_decoder.set_grad_checkpointing(enable)
1310
+
1311
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
1312
+ """Lock the text encoder, optionally leaving the last N layers unlocked.
1313
+
1314
+ Args:
1315
+ unlocked_layers: Number of layers to leave unlocked (from the end)
1316
+ freeze_layer_norm: Whether to freeze LayerNorm parameters in locked layers
1317
+ """
1318
+ if hasattr(self.text, 'lock'):
1319
+ # For HFTextEncoder (Pythia, etc.)
1320
+ self.text.lock(unlocked_layers, freeze_layer_norm)
1321
+
1322
+ # IMPORTANT: Unfreeze newly added token embeddings (e.g., <pad>, <coca_cls>)
1323
+ # These were randomly initialized and need to be trained
1324
+ if hasattr(self.text, 'original_vocab_size'):
1325
+ import logging
1326
+ embedding_module = self.text.transformer.get_input_embeddings()
1327
+ original_size = self.text.original_vocab_size
1328
+ current_size = embedding_module.weight.shape[0]
1329
+
1330
+ if current_size > original_size:
1331
+ # Enable gradients for the embedding layer
1332
+ embedding_module.weight.requires_grad = True
1333
+
1334
+ # Store metadata for optimizer configuration (zero weight decay)
1335
+ self.text._new_token_start_idx = original_size
1336
+
1337
+ # Get actual embedding size (may be padded for Tensor Cores)
1338
+ actual_embedding_size = embedding_module.weight.shape[0]
1339
+ new_vocab_size = self.text.vocab_size # Actual number of tokens (not padded)
1340
+
1341
+ # Register parameter-level hook to mask frozen token gradients
1342
+ # IMPORTANT: This is registered BEFORE DDP wrapping to ensure it persists
1343
+ def _zero_grad_frozen_tokens(grad):
1344
+ """Zero out gradients for old (frozen) tokens and padding, keep only new tokens."""
1345
+ if grad is not None:
1346
+ # Zero out pretrained tokens [0:original_size]
1347
+ grad[:original_size] = 0
1348
+ # Zero out padding tokens [new_vocab_size:actual_embedding_size]
1349
+ if actual_embedding_size > new_vocab_size:
1350
+ grad[new_vocab_size:] = 0
1351
+ return grad
1352
+
1353
+ embedding_module.weight.register_hook(_zero_grad_frozen_tokens)
1354
+
1355
+ num_new_tokens = new_vocab_size - original_size
1356
+ num_padding_tokens = actual_embedding_size - new_vocab_size
1357
+ logging.info(f"Embedding layer configuration:")
1358
+ logging.info(f" Trainable new tokens: {num_new_tokens} (indices {original_size}:{new_vocab_size})")
1359
+ logging.info(f" Frozen pretrained tokens: {original_size} (indices 0:{original_size})")
1360
+ if num_padding_tokens > 0:
1361
+ logging.info(f" Frozen padding tokens: {num_padding_tokens} (indices {new_vocab_size}:{actual_embedding_size})")
1362
+ logging.info(f" Total embedding size: {actual_embedding_size}")
1363
+ logging.info(f"Registered gradient masking hook before DDP wrapping")
1364
+ logging.info(f"NOTE: Optimizer uses weight_decay=0 for embedding layer")
1365
+ else:
1366
+ # For standard TextTransformer
1367
+ assert False, "BiosignalsCoCa does not support locking standard TextTransformer"
1368
+ from .transformer import lock_text_tower
1369
+ lock_text_tower(self, unlocked_layers)
1370
+
1371
+ def _encode_biosignals(self, biosignals, normalize: bool = True):
1372
+ biosignals_latent, tokens_embs = self.biosignals(biosignals)
1373
+ biosignals_latent = F.normalize(biosignals_latent, dim=-1) if normalize else biosignals_latent
1374
+ return biosignals_latent, tokens_embs
1375
+
1376
+ def _encode_text(self, text, normalize: bool = True):
1377
+ text_latent, token_emb = self.text(text)
1378
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
1379
+ return text_latent, token_emb
1380
+
1381
+ def encode_image(self, biosignals, normalize: bool = True):
1382
+ biosignals_latent, _ = self._encode_biosignals(biosignals, normalize=normalize)
1383
+ return biosignals_latent
1384
+
1385
+ def encode_text(self, text, normalize: bool = True):
1386
+ text_latent, _ = self._encode_text(text, normalize=normalize)
1387
+ return text_latent
1388
+
1389
+ def _get_channel_condition_embs(self, channel_indices: torch.Tensor) -> torch.Tensor:
1390
+ """Convert channel/modality indices to embeddings with learnable padding.
1391
+
1392
+ Args:
1393
+ channel_indices: (batch_size, prefix_len) tensor of indices
1394
+ - Individual mode: indices into 23 channel embeddings (22 channels + 1 stage_event)
1395
+ - Modality mode: indices into 5 modality embeddings (4 modalities + 1 stage_event)
1396
+ - Padded with -1 for variable length (uses learnable padding_embedding for -1)
1397
+
1398
+ Returns:
1399
+ condition_embs: (batch_size, prefix_len, decoder_width)
1400
+ Embeddings for all positions. -1 positions use learnable padding_embedding
1401
+ that learns to be neutral/ignored during training.
1402
+ """
1403
+ batch_size, prefix_len = channel_indices.shape
1404
+
1405
+ # Create output tensor
1406
+ condition_embs = torch.zeros(batch_size, prefix_len, self.decoder_width,
1407
+ dtype=self.channel_embeddings.dtype,
1408
+ device=self.channel_embeddings.device)
1409
+
1410
+ # Create mask for valid (non-padding) indices
1411
+ valid_mask = channel_indices >= 0 # (batch_size, prefix_len)
1412
+ padding_mask = channel_indices == -1 # (batch_size, prefix_len)
1413
+
1414
+ # Gather channel embeddings for valid indices
1415
+ # Clamp to 0 for safe indexing (will be overwritten by padding where needed)
1416
+ indices_safe = channel_indices.clamp(min=0)
1417
+
1418
+ # Expand embeddings for batching
1419
+ expanded_embeddings = self.channel_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
1420
+
1421
+ # Gather embeddings
1422
+ indices_expanded = indices_safe.unsqueeze(-1).expand(-1, -1, self.decoder_width)
1423
+ gathered_embs = torch.gather(expanded_embeddings, 1, indices_expanded)
1424
+
1425
+ # Fill in valid positions with gathered embeddings
1426
+ condition_embs[valid_mask] = gathered_embs[valid_mask]
1427
+
1428
+ # Fill in padding positions with learnable padding embedding
1429
+ if padding_mask.any():
1430
+ # Broadcast padding_embedding to all padding positions
1431
+ condition_embs[padding_mask] = self.padding_embedding
1432
+
1433
+ return condition_embs
1434
+
1435
+ def forward(
1436
+ self,
1437
+ biosignals,
1438
+ text: Optional[torch.Tensor] = None,
1439
+ biosignals_latent: Optional[torch.Tensor] = None,
1440
+ biosignals_embs: Optional[torch.Tensor] = None,
1441
+
1442
+ channel_indices: Optional[torch.Tensor] = None,
1443
+ output_labels: bool = True,
1444
+ ):
1445
+ """Forward pass for BiosignalsCoCa model.
1446
+
1447
+ Args:
1448
+ biosignals: Input biosignals tensor
1449
+ text: Optional text token ids
1450
+ biosignals_latent: Optional pre-computed biosignals latent features
1451
+ biosignals_embs: Optional pre-computed biosignals token embeddings
1452
+
1453
+ channel_indices: Optional (batch_size, num_selected_channels) tensor of channel indices
1454
+ Used to select channel-specific condition embeddings. If provided, overrides condition_embs.
1455
+ output_labels: Whether to output labels for loss computation
1456
+ """
1457
+ if biosignals_latent is None or biosignals_embs is None:
1458
+ biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals)
1459
+
1460
+ if text is None:
1461
+ return {"image_features": biosignals_latent, "image_embs": biosignals_embs}
1462
+
1463
+ text_latent, token_embs = self._encode_text(text)
1464
+
1465
+ # FIXME this isn't an ideal solution, would like to improve -RW
1466
+ labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
1467
+ if output_labels:
1468
+ # align text_embs and thus logits with labels for teacher-forcing caption loss
1469
+ token_embs = token_embs[:, :-1]
1470
+
1471
+ # Convert channel indices to condition embeddings if provided
1472
+ if channel_indices is not None:
1473
+ condition_embs = self._get_channel_condition_embs(channel_indices)
1474
+ else:
1475
+ condition_embs = None
1476
+
1477
+ logits = self.text_decoder(biosignals_embs, token_embs, condition_embs=condition_embs)
1478
+ out_dict = {
1479
+ "image_features": biosignals_latent,
1480
+ "text_features": text_latent,
1481
+ "logits": logits,
1482
+ "logit_scale": self.logit_scale.exp()
1483
+ }
1484
+ if labels is not None:
1485
+ out_dict["labels"] = labels
1486
+ if self.logit_bias is not None:
1487
+ out_dict["logit_bias"] = self.logit_bias
1488
+
1489
+ # Optional signal reconstruction
1490
+ if self.use_signal_decoder:
1491
+ reconstructed_signal = self.signal_decoder(biosignals_embs)
1492
+ out_dict["reconstructed_signal"] = reconstructed_signal
1493
+ out_dict["original_signal"] = biosignals
1494
+
1495
+ return out_dict
1496
+
1497
+ def generate(
1498
+ self,
1499
+ biosignals,
1500
+ text=None,
1501
+ seq_len=30,
1502
+ max_seq_len=256,
1503
+ temperature=1.,
1504
+ generation_type="beam_search",
1505
+ top_p=0.1,
1506
+ top_k=1,
1507
+ pad_token_id=None,
1508
+ eos_token_id=None,
1509
+ sot_token_id=None,
1510
+ num_beams=6,
1511
+ num_beam_groups=3,
1512
+ min_seq_len=5,
1513
+ stopping_criteria=None,
1514
+ repetition_penalty=1.0,
1515
+ fixed_output_length=False,
1516
+ condition_embs=None,
1517
+ channel_indices=None,
1518
+ ):
1519
+ # taking many ideas and components from HuggingFace GenerationMixin
1520
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
1521
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
1522
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
1523
+ device = biosignals.device
1524
+
1525
+ # Note: condition_embs parameter is for backward compatibility
1526
+ # We pass channel_indices directly to forward(), which handles the conversion internally
1527
+
1528
+ with torch.no_grad():
1529
+ sot_token_id = _token_to_tensor(sot_token_id, device=device)
1530
+ eos_token_id = _token_to_tensor(eos_token_id, device=device)
1531
+ pad_token_id = pad_token_id
1532
+ logit_processor = LogitsProcessorList(
1533
+ [
1534
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
1535
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
1536
+ ]
1537
+ )
1538
+
1539
+ if stopping_criteria is None:
1540
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
1541
+ stopping_criteria = StoppingCriteriaList(stopping_criteria)
1542
+
1543
+ if generation_type == "beam_search":
1544
+ output = self._generate_beamsearch(
1545
+ biosignals_inputs=biosignals,
1546
+ pad_token_id=pad_token_id,
1547
+ eos_token_id=eos_token_id,
1548
+ sot_token_id=sot_token_id,
1549
+ num_beams=num_beams,
1550
+ num_beam_groups=num_beam_groups,
1551
+ min_seq_len=min_seq_len,
1552
+ stopping_criteria=stopping_criteria,
1553
+ logit_processor=logit_processor,
1554
+ channel_indices=channel_indices,
1555
+ )
1556
+ if fixed_output_length and output.shape[1] < seq_len:
1557
+ pad_len = seq_len - output.shape[1]
1558
+ return torch.cat((
1559
+ output,
1560
+ torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
1561
+ ),
1562
+ dim=1
1563
+ )
1564
+ return output
1565
+
1566
+ elif generation_type == "top_p":
1567
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
1568
+ elif generation_type == "top_k":
1569
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
1570
+ else:
1571
+ raise ValueError(
1572
+ f"generation_type has to be one of "
1573
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
1574
+ )
1575
+
1576
+ biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals)
1577
+
1578
+ if text is None:
1579
+ text = torch.ones((biosignals.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
1580
+
1581
+ was_training = self.training
1582
+ num_dims = len(text.shape)
1583
+
1584
+ if num_dims == 1:
1585
+ text = text[None, :]
1586
+
1587
+ self.eval()
1588
+ out = text
1589
+
1590
+ while True:
1591
+ x = out[:, -max_seq_len:]
1592
+ cur_len = x.shape[1]
1593
+ logits = self(
1594
+ biosignals,
1595
+ x,
1596
+ biosignals_latent=biosignals_latent,
1597
+ biosignals_embs=biosignals_embs,
1598
+ channel_indices=channel_indices,
1599
+ output_labels=False,
1600
+ )["logits"][:, -1]
1601
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
1602
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
1603
+
1604
+ if mask.all():
1605
+ if not fixed_output_length:
1606
+ break
1607
+ else:
1608
+ logits = logits[~mask, :]
1609
+ filtered_logits = logit_processor(x[~mask, :], logits)
1610
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
1611
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
1612
+
1613
+ if (cur_len + 1 == seq_len):
1614
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
1615
+ else:
1616
+ sample[~mask, :] = torch.multinomial(probs, 1)
1617
+
1618
+ out = torch.cat((out, sample), dim=-1)
1619
+
1620
+ cur_len += 1
1621
+
1622
+ if all(stopping_criteria(out, None)):
1623
+ break
1624
+
1625
+ if num_dims == 1:
1626
+ out = out.squeeze(0)
1627
+
1628
+ self.train(was_training)
1629
+ return out
1630
+
1631
+ def _generate_beamsearch(
1632
+ self,
1633
+ biosignals_inputs,
1634
+ pad_token_id=None,
1635
+ eos_token_id=None,
1636
+ sot_token_id=None,
1637
+ num_beams=6,
1638
+ num_beam_groups=3,
1639
+ min_seq_len=5,
1640
+ stopping_criteria=None,
1641
+ logit_processor=None,
1642
+ logit_warper=None,
1643
+ channel_indices=None,
1644
+ ):
1645
+ device = biosignals_inputs.device
1646
+ batch_size = biosignals_inputs.shape[0]
1647
+ biosignals_inputs = torch.repeat_interleave(biosignals_inputs, num_beams, dim=0)
1648
+ biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals_inputs)
1649
+
1650
+ # Repeat channel indices for beam search if provided
1651
+ # forward() will convert them to condition embeddings internally
1652
+ if channel_indices is not None:
1653
+ channel_indices = torch.repeat_interleave(channel_indices, num_beams, dim=0)
1654
+
1655
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
1656
+ input_ids = input_ids * sot_token_id
1657
+ beam_scorer = BeamSearchScorer(
1658
+ batch_size=batch_size,
1659
+ num_beams=num_beams,
1660
+ device=device,
1661
+ num_beam_groups=num_beam_groups,
1662
+ )
1663
+ # instantiate logits processors
1664
+ logits_processor = (
1665
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
1666
+ if logit_processor is None
1667
+ else logit_processor
1668
+ )
1669
+
1670
+ num_beams = beam_scorer.num_beams
1671
+ num_beam_groups = beam_scorer.num_beam_groups
1672
+ num_sub_beams = num_beams // num_beam_groups
1673
+ batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
1674
+ batch_beam_size, cur_len = input_ids.shape
1675
+ beam_indices = None
1676
+
1677
+ if num_beams * batch_size != batch_beam_size:
1678
+ raise ValueError(
1679
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
1680
+ )
1681
+
1682
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
1683
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
1684
+ # the same group don't produce same tokens everytime.
1685
+ beam_scores[:, ::num_sub_beams] = 0
1686
+ beam_scores = beam_scores.view((batch_size * num_beams,))
1687
+
1688
+ while True:
1689
+
1690
+ # predicted tokens in cur_len step
1691
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
1692
+
1693
+ # indices which will form the beams in the next time step
1694
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
1695
+
1696
+ # do one decoder step on all beams of all sentences in batch
1697
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, biosignals_inputs=biosignals_inputs)
1698
+ outputs = self(
1699
+ model_inputs['biosignals'],
1700
+ model_inputs['text'],
1701
+ biosignals_latent=biosignals_latent,
1702
+ biosignals_embs=biosignals_embs,
1703
+ channel_indices=channel_indices,
1704
+ output_labels=False,
1705
+ )
1706
+
1707
+ for beam_group_idx in range(num_beam_groups):
1708
+ group_start_idx = beam_group_idx * num_sub_beams
1709
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
1710
+ group_size = group_end_idx - group_start_idx
1711
+
1712
+ # indices of beams of current group among all sentences in batch
1713
+ batch_group_indices = []
1714
+
1715
+ for batch_idx in range(batch_size):
1716
+ batch_group_indices.extend(
1717
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
1718
+ )
1719
+ group_input_ids = input_ids[batch_group_indices]
1720
+
1721
+ # select outputs of beams of currentg group only
1722
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
1723
+ vocab_size = next_token_logits.shape[-1]
1724
+
1725
+ next_token_scores_processed = logits_processor(
1726
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
1727
+ )
1728
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
1729
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
1730
+
1731
+ # reshape for beam search
1732
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
1733
+
1734
+ next_token_scores, next_tokens = torch.topk(
1735
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
1736
+ )
1737
+
1738
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
1739
+ next_tokens = next_tokens % vocab_size
1740
+
1741
+ # stateless
1742
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
1743
+ beam_outputs = beam_scorer.process(
1744
+ group_input_ids,
1745
+ next_token_scores,
1746
+ next_tokens,
1747
+ next_indices,
1748
+ pad_token_id=pad_token_id,
1749
+ eos_token_id=eos_token_id,
1750
+ beam_indices=process_beam_indices,
1751
+ group_index=beam_group_idx,
1752
+ )
1753
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
1754
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
1755
+ beam_idx = beam_outputs["next_beam_indices"]
1756
+
1757
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
1758
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
1759
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
1760
+
1761
+ # (beam_idx // group_size) -> batch_idx
1762
+ # (beam_idx % group_size) -> offset of idx inside the group
1763
+ reordering_indices[batch_group_indices] = (
1764
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
1765
+ )
1766
+
1767
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
1768
+
1769
+ # increase cur_len
1770
+ cur_len = cur_len + 1
1771
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
1772
+ break
1773
+
1774
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
1775
+ sequence_outputs = beam_scorer.finalize(
1776
+ input_ids,
1777
+ beam_scores,
1778
+ next_tokens,
1779
+ next_indices,
1780
+ pad_token_id=pad_token_id,
1781
+ eos_token_id=eos_token_id,
1782
+ max_length=stopping_criteria.max_length,
1783
+ beam_indices=final_beam_indices,
1784
+ )
1785
+ return sequence_outputs['sequences']
1786
+
1787
+
1788
+ def prepare_inputs_for_generation(input_ids, biosignals_inputs, past=None, **kwargs):
1789
+ if past:
1790
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1791
+
1792
+ attention_mask = kwargs.get("attention_mask", None)
1793
+ position_ids = kwargs.get("position_ids", None)
1794
+
1795
+ if attention_mask is not None and position_ids is None:
1796
+ # create position_ids on the fly for batch generation
1797
+ position_ids = attention_mask.long().cumsum(-1) - 1
1798
+ position_ids.masked_fill_(attention_mask == 0, 1)
1799
+ else:
1800
+ position_ids = None
1801
+ return {
1802
+ "text": input_ids,
1803
+ "biosignals": biosignals_inputs,
1804
+ "past_key_values": past,
1805
+ "position_ids": position_ids,
1806
+ "attention_mask": attention_mask,
1807
+ }
src/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
src/open_clip/coca_model.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StopStringCriteria,
27
+ EosTokenCriteria,
28
+ StoppingCriteriaList
29
+ )
30
+
31
+ GENERATION_TYPES = {
32
+ "top_k": TopKLogitsWarper,
33
+ "top_p": TopPLogitsWarper,
34
+ "beam_search": "beam_search"
35
+ }
36
+ _has_transformers = True
37
+ except ImportError as e:
38
+ GENERATION_TYPES = {
39
+ "top_k": None,
40
+ "top_p": None,
41
+ "beam_search": "beam_search"
42
+ }
43
+ _has_transformers = False
44
+
45
+
46
+ @dataclass
47
+ class MultimodalCfg(CLIPTextCfg):
48
+ mlp_ratio: int = 4
49
+ dim_head: int = 64
50
+ heads: int = 8
51
+ n_queries: int = 256
52
+ attn_pooler_heads: int = 8
53
+
54
+
55
+ def _build_text_decoder_tower(
56
+ embed_dim,
57
+ multimodal_cfg,
58
+ quick_gelu: bool = False,
59
+ cast_dtype: Optional[torch.dtype] = None,
60
+ ):
61
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
62
+ act_layer = QuickGELU if quick_gelu else nn.GELU
63
+ norm_layer = (
64
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
65
+ )
66
+
67
+ decoder = MultimodalTransformer(
68
+ context_length=multimodal_cfg.context_length,
69
+ width=multimodal_cfg.width,
70
+ heads=multimodal_cfg.heads,
71
+ layers=multimodal_cfg.layers,
72
+ ls_init_value=multimodal_cfg.ls_init_value,
73
+ output_dim=embed_dim,
74
+ act_layer=act_layer,
75
+ norm_layer=norm_layer,
76
+ )
77
+
78
+ return decoder
79
+
80
+
81
+ def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
82
+ if not isinstance(token_id, torch.Tensor):
83
+ if isinstance(token_id, int):
84
+ token_id = [token_id]
85
+ token_id = torch.tensor(token_id, device=device)
86
+ return token_id
87
+
88
+
89
+ class CoCa(nn.Module):
90
+ def __init__(
91
+ self,
92
+ embed_dim,
93
+ multimodal_cfg: MultimodalCfg,
94
+ text_cfg: CLIPTextCfg,
95
+ vision_cfg=None,
96
+ quick_gelu: bool = False,
97
+ init_logit_scale: float = np.log(1 / 0.07),
98
+ init_logit_bias: Optional[float] = None,
99
+ nonscalar_logit_scale: bool = False,
100
+ cast_dtype: Optional[torch.dtype] = None,
101
+ pad_id: int = 0,
102
+ ):
103
+ super().__init__()
104
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
105
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
106
+
107
+ self.text = _build_text_tower(
108
+ embed_dim=embed_dim,
109
+ text_cfg=text_cfg,
110
+ quick_gelu=quick_gelu,
111
+ cast_dtype=cast_dtype,
112
+ )
113
+
114
+ vocab_size = (
115
+ self.text.vocab_size
116
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
117
+ else text_cfg.vocab_size
118
+ )
119
+
120
+ if vision_cfg is not None:
121
+ from .model import CLIPVisionCfg, _build_vision_tower
122
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
123
+ self.visual = _build_vision_tower(
124
+ embed_dim=embed_dim,
125
+ vision_cfg=vision_cfg,
126
+ quick_gelu=quick_gelu,
127
+ cast_dtype=cast_dtype,
128
+ )
129
+ else:
130
+ self.visual = None
131
+
132
+ self.text_decoder = _build_text_decoder_tower(
133
+ vocab_size,
134
+ multimodal_cfg=multimodal_cfg,
135
+ quick_gelu=quick_gelu,
136
+ cast_dtype=cast_dtype,
137
+ )
138
+
139
+ lshape = [1] if nonscalar_logit_scale else []
140
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
141
+ if init_logit_bias is not None:
142
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
143
+ else:
144
+ self.logit_bias = None
145
+ self.pad_id = pad_id
146
+
147
+ self.context_length = multimodal_cfg.context_length
148
+
149
+ @torch.jit.ignore
150
+ def set_grad_checkpointing(self, enable: bool = True):
151
+ self.visual.set_grad_checkpointing(enable)
152
+ self.text.set_grad_checkpointing(enable)
153
+ self.text_decoder.set_grad_checkpointing(enable)
154
+
155
+ def _encode_image(self, images, normalize: bool = True):
156
+ image_latent, tokens_embs = self.visual(images)
157
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
158
+ return image_latent, tokens_embs
159
+
160
+ def _encode_text(self, text, normalize: bool = True):
161
+ text_latent, token_emb = self.text(text)
162
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
163
+ return text_latent, token_emb
164
+
165
+ def encode_image(self, images, normalize: bool = True):
166
+ image_latent, _ = self._encode_image(images, normalize=normalize)
167
+ return image_latent
168
+
169
+ def encode_text(self, text, normalize: bool = True):
170
+ text_latent, _ = self._encode_text(text, normalize=normalize)
171
+ return text_latent
172
+
173
+ def forward_intermediates(
174
+ self,
175
+ image: Optional[torch.Tensor] = None,
176
+ text: Optional[torch.Tensor] = None,
177
+ image_indices: Optional[Union[int, List[int]]] = None,
178
+ text_indices: Optional[Union[int, List[int]]] = None,
179
+ stop_early: bool = False,
180
+ normalize: bool = True,
181
+ normalize_intermediates: bool = False,
182
+ intermediates_only: bool = False,
183
+ image_output_fmt: str = 'NCHW',
184
+ image_output_extra_tokens: bool = False,
185
+ text_output_fmt: str = 'NLC',
186
+ text_output_extra_tokens: bool = False,
187
+ output_logits: bool = False,
188
+ output_logit_scale_bias: bool = False,
189
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
190
+ """ Forward features that returns intermediates.
191
+
192
+ Args:
193
+ image: Input image tensor
194
+ text: Input text tensor
195
+ image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
196
+ text_indices: Take last n blocks if int, all if None, select matching indices if sequence
197
+ stop_early: Stop iterating over blocks when last desired intermediate hit
198
+ normalize: L2 Normalize final image and text features (if present)
199
+ normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
200
+ intermediates_only: Only return intermediate features, do not return final features
201
+ image_output_fmt: Shape of intermediate image feature outputs
202
+ image_output_extra_tokens: Return both prefix and spatial intermediate tokens
203
+ text_output_fmt: Shape of intermediate text feature outputs
204
+ text_output_extra_tokens: Return both prefix and spatial intermediate tokens
205
+ output_logits: Include logits in output
206
+ output_logit_scale_bias: Include the logit scale bias in the output
207
+ Returns:
208
+
209
+ """
210
+ output = {}
211
+ if intermediates_only:
212
+ # intermediates only disables final feature normalization, and include logits
213
+ normalize = False
214
+ output_logits = False
215
+ if output_logits:
216
+ assert False, 'FIXME, needs implementing'
217
+
218
+ if image is not None:
219
+ image_output = self.visual.forward_intermediates(
220
+ image,
221
+ indices=image_indices,
222
+ stop_early=stop_early,
223
+ normalize_intermediates=normalize_intermediates,
224
+ intermediates_only=intermediates_only,
225
+ output_fmt=image_output_fmt,
226
+ output_extra_tokens=image_output_extra_tokens,
227
+ )
228
+ if normalize and "image_features" in image_output:
229
+ image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
230
+ output.update(image_output)
231
+
232
+ if text is not None:
233
+ text_output = self.text.forward_intermediates(
234
+ text,
235
+ indices=text_indices,
236
+ stop_early=stop_early,
237
+ normalize_intermediates=normalize_intermediates,
238
+ intermediates_only=intermediates_only,
239
+ output_fmt=text_output_fmt,
240
+ output_extra_tokens=text_output_extra_tokens,
241
+ )
242
+ if normalize and "text_features" in text_output:
243
+ text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
244
+ output.update(text_output)
245
+
246
+ # FIXME text decoder
247
+ logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
248
+ if output_logit_scale_bias:
249
+ output["logit_scale"] = logit_scale_exp
250
+ if self.logit_bias is not None:
251
+ output['logit_bias'] = self.logit_bias
252
+
253
+ return output
254
+
255
+ def forward(
256
+ self,
257
+ image,
258
+ text: Optional[torch.Tensor] = None,
259
+ image_latent: Optional[torch.Tensor] = None,
260
+ image_embs: Optional[torch.Tensor] = None,
261
+ output_labels: bool = True,
262
+ ):
263
+ if image_latent is None or image_embs is None:
264
+ image_latent, image_embs = self._encode_image(image)
265
+
266
+ if text is None:
267
+ return {"image_features": image_latent, "image_embs": image_embs}
268
+
269
+ text_latent, token_embs = self._encode_text(text)
270
+
271
+ # FIXME this isn't an ideal solution, would like to improve -RW
272
+ labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
273
+ if output_labels:
274
+ # align text_embs and thus logits with labels for teacher-forcing caption loss
275
+ token_embs = token_embs[:, :-1]
276
+
277
+ logits = self.text_decoder(image_embs, token_embs)
278
+ out_dict = {
279
+ "image_features": image_latent,
280
+ "text_features": text_latent,
281
+ "logits": logits,
282
+ "logit_scale": self.logit_scale.exp()
283
+ }
284
+ if labels is not None:
285
+ out_dict["labels"] = labels
286
+ if self.logit_bias is not None:
287
+ out_dict["logit_bias"] = self.logit_bias
288
+ return out_dict
289
+
290
+ def generate(
291
+ self,
292
+ image,
293
+ text=None,
294
+ seq_len=30,
295
+ max_seq_len=77,
296
+ temperature=1.,
297
+ generation_type="beam_search",
298
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
299
+ top_k=1, # keeps the top_k most probable tokens
300
+ pad_token_id=None,
301
+ eos_token_id=None,
302
+ sot_token_id=None,
303
+ num_beams=6,
304
+ num_beam_groups=3,
305
+ min_seq_len=5,
306
+ stopping_criteria=None,
307
+ repetition_penalty=1.0,
308
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
309
+ ):
310
+ # taking many ideas and components from HuggingFace GenerationMixin
311
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
312
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
313
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
314
+ device = image.device
315
+
316
+ with torch.no_grad():
317
+ sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
318
+ eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
319
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
320
+ logit_processor = LogitsProcessorList(
321
+ [
322
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
323
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
324
+ ]
325
+ )
326
+
327
+ if stopping_criteria is None:
328
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
329
+ stopping_criteria = StoppingCriteriaList(stopping_criteria)
330
+
331
+ if generation_type == "beam_search":
332
+ output = self._generate_beamsearch(
333
+ image_inputs=image,
334
+ pad_token_id=pad_token_id,
335
+ eos_token_id=eos_token_id,
336
+ sot_token_id=sot_token_id,
337
+ num_beams=num_beams,
338
+ num_beam_groups=num_beam_groups,
339
+ min_seq_len=min_seq_len,
340
+ stopping_criteria=stopping_criteria,
341
+ logit_processor=logit_processor,
342
+ )
343
+ if fixed_output_length and output.shape[1] < seq_len:
344
+ pad_len = seq_len - output.shape[1]
345
+ return torch.cat((
346
+ output,
347
+ torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
348
+ ),
349
+ dim=1
350
+ )
351
+ return output
352
+
353
+ elif generation_type == "top_p":
354
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
355
+ elif generation_type == "top_k":
356
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
357
+ else:
358
+ raise ValueError(
359
+ f"generation_type has to be one of "
360
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
361
+ )
362
+
363
+ image_latent, image_embs = self._encode_image(image)
364
+
365
+ if text is None:
366
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
367
+
368
+ was_training = self.training
369
+ num_dims = len(text.shape)
370
+
371
+ if num_dims == 1:
372
+ text = text[None, :]
373
+
374
+ self.eval()
375
+ out = text
376
+
377
+ while True:
378
+ x = out[:, -max_seq_len:]
379
+ cur_len = x.shape[1]
380
+ logits = self(
381
+ image,
382
+ x,
383
+ image_latent=image_latent,
384
+ image_embs=image_embs,
385
+ output_labels=False,
386
+ )["logits"][:, -1]
387
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
388
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
389
+
390
+ if mask.all():
391
+ if not fixed_output_length:
392
+ break
393
+ else:
394
+ logits = logits[~mask, :]
395
+ filtered_logits = logit_processor(x[~mask, :], logits)
396
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
397
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
398
+
399
+ if (cur_len + 1 == seq_len):
400
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
401
+ else:
402
+ sample[~mask, :] = torch.multinomial(probs, 1)
403
+
404
+ out = torch.cat((out, sample), dim=-1)
405
+
406
+ cur_len += 1
407
+
408
+ if all(stopping_criteria(out, None)):
409
+ break
410
+
411
+ if num_dims == 1:
412
+ out = out.squeeze(0)
413
+
414
+ self.train(was_training)
415
+ return out
416
+
417
+ def _generate_beamsearch(
418
+ self,
419
+ image_inputs,
420
+ pad_token_id=None,
421
+ eos_token_id=None,
422
+ sot_token_id=None,
423
+ num_beams=6,
424
+ num_beam_groups=3,
425
+ min_seq_len=5,
426
+ stopping_criteria=None,
427
+ logit_processor=None,
428
+ logit_warper=None,
429
+ ):
430
+ device = image_inputs.device
431
+ batch_size = image_inputs.shape[0]
432
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
433
+ image_latent, image_embs = self._encode_image(image_inputs)
434
+
435
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
436
+ input_ids = input_ids * sot_token_id
437
+ beam_scorer = BeamSearchScorer(
438
+ batch_size=batch_size,
439
+ num_beams=num_beams,
440
+ device=device,
441
+ num_beam_groups=num_beam_groups,
442
+ )
443
+ # instantiate logits processors
444
+ logits_processor = (
445
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
446
+ if logit_processor is None
447
+ else logit_processor
448
+ )
449
+
450
+ num_beams = beam_scorer.num_beams
451
+ num_beam_groups = beam_scorer.num_beam_groups
452
+ num_sub_beams = num_beams // num_beam_groups
453
+ batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
454
+ batch_beam_size, cur_len = input_ids.shape
455
+ beam_indices = None
456
+
457
+ if num_beams * batch_size != batch_beam_size:
458
+ raise ValueError(
459
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
460
+ )
461
+
462
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
463
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
464
+ # the same group don't produce same tokens everytime.
465
+ beam_scores[:, ::num_sub_beams] = 0
466
+ beam_scores = beam_scores.view((batch_size * num_beams,))
467
+
468
+ while True:
469
+
470
+ # predicted tokens in cur_len step
471
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
472
+
473
+ # indices which will form the beams in the next time step
474
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
475
+
476
+ # do one decoder step on all beams of all sentences in batch
477
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
478
+ outputs = self(
479
+ model_inputs['images'],
480
+ model_inputs['text'],
481
+ image_latent=image_latent,
482
+ image_embs=image_embs,
483
+ output_labels=False,
484
+ )
485
+
486
+ for beam_group_idx in range(num_beam_groups):
487
+ group_start_idx = beam_group_idx * num_sub_beams
488
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
489
+ group_size = group_end_idx - group_start_idx
490
+
491
+ # indices of beams of current group among all sentences in batch
492
+ batch_group_indices = []
493
+
494
+ for batch_idx in range(batch_size):
495
+ batch_group_indices.extend(
496
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
497
+ )
498
+ group_input_ids = input_ids[batch_group_indices]
499
+
500
+ # select outputs of beams of currentg group only
501
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
502
+ vocab_size = next_token_logits.shape[-1]
503
+
504
+ next_token_scores_processed = logits_processor(
505
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
506
+ )
507
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
508
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
509
+
510
+ # reshape for beam search
511
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
512
+
513
+ next_token_scores, next_tokens = torch.topk(
514
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
515
+ )
516
+
517
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
518
+ next_tokens = next_tokens % vocab_size
519
+
520
+ # stateless
521
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
522
+ beam_outputs = beam_scorer.process(
523
+ group_input_ids,
524
+ next_token_scores,
525
+ next_tokens,
526
+ next_indices,
527
+ pad_token_id=pad_token_id,
528
+ eos_token_id=eos_token_id,
529
+ beam_indices=process_beam_indices,
530
+ group_index=beam_group_idx,
531
+ )
532
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
533
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
534
+ beam_idx = beam_outputs["next_beam_indices"]
535
+
536
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
537
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
538
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
539
+
540
+ # (beam_idx // group_size) -> batch_idx
541
+ # (beam_idx % group_size) -> offset of idx inside the group
542
+ reordering_indices[batch_group_indices] = (
543
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
544
+ )
545
+
546
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
547
+
548
+ # increase cur_len
549
+ cur_len = cur_len + 1
550
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
551
+ break
552
+
553
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
554
+ sequence_outputs = beam_scorer.finalize(
555
+ input_ids,
556
+ beam_scores,
557
+ next_tokens,
558
+ next_indices,
559
+ pad_token_id=pad_token_id,
560
+ eos_token_id=eos_token_id,
561
+ max_length=stopping_criteria.max_length,
562
+ beam_indices=final_beam_indices,
563
+ )
564
+ return sequence_outputs['sequences']
565
+
566
+
567
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
568
+ if past:
569
+ input_ids = input_ids[:, -1].unsqueeze(-1)
570
+
571
+ attention_mask = kwargs.get("attention_mask", None)
572
+ position_ids = kwargs.get("position_ids", None)
573
+
574
+ if attention_mask is not None and position_ids is None:
575
+ # create position_ids on the fly for batch generation
576
+ position_ids = attention_mask.long().cumsum(-1) - 1
577
+ position_ids.masked_fill_(attention_mask == 0, 1)
578
+ else:
579
+ position_ids = None
580
+ return {
581
+ "text": input_ids,
582
+ "images": image_inputs,
583
+ "past_key_values": past,
584
+ "position_ids": position_ids,
585
+ "attention_mask": attention_mask,
586
+ }
src/open_clip/factory.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+
9
+ from .biosignals_coca_model import BiosignalsCoCa
10
+ from .model import get_cast_dtype, convert_weights_to_lp
11
+ from .tokenizer import SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
12
+
13
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / "model_configs/"]
14
+ _MODEL_CONFIGS = {}
15
+
16
+
17
+ def _rescan_model_configs():
18
+ global _MODEL_CONFIGS
19
+ config_files = []
20
+ for config_path in _MODEL_CONFIG_PATHS:
21
+ if config_path.is_dir():
22
+ config_files.extend(config_path.glob("*.json"))
23
+ for cf in config_files:
24
+ with open(cf, "r") as f:
25
+ model_cfg = json.load(f)
26
+ if all(a in model_cfg for a in ("embed_dim", "biosignals_cfg", "text_cfg")):
27
+ _MODEL_CONFIGS[cf.stem] = model_cfg
28
+
29
+
30
+ _rescan_model_configs()
31
+
32
+
33
+ def get_model_config(model_name: str):
34
+ return deepcopy(_MODEL_CONFIGS.get(model_name))
35
+
36
+
37
+ def create_model(
38
+ model_name: str,
39
+ precision: str = "fp32",
40
+ device: Union[str, torch.device] = "cpu",
41
+ **model_kwargs,
42
+ ) -> BiosignalsCoCa:
43
+ if isinstance(device, str):
44
+ device = torch.device(device)
45
+
46
+ model_cfg = get_model_config(model_name)
47
+ if model_cfg is None:
48
+ raise RuntimeError(f"Model config for '{model_name}' not found. Available: {list(_MODEL_CONFIGS.keys())}")
49
+
50
+ model_cfg.pop("custom_text", None)
51
+ model_cfg.update(model_kwargs)
52
+
53
+ cast_dtype = get_cast_dtype(precision)
54
+ model = BiosignalsCoCa(**model_cfg, cast_dtype=cast_dtype)
55
+
56
+ if precision in ("fp16", "bf16"):
57
+ dtype = torch.float16 if "fp16" in precision else torch.bfloat16
58
+ model.to(device=device)
59
+ convert_weights_to_lp(model, dtype=dtype)
60
+ elif precision in ("pure_fp16", "pure_bf16"):
61
+ dtype = torch.float16 if "fp16" in precision else torch.bfloat16
62
+ model.to(device=device, dtype=dtype)
63
+ else:
64
+ model.to(device=device)
65
+
66
+ model.output_dict = True
67
+ return model
68
+
69
+
70
+ def load_checkpoint(model, checkpoint_path: str, device="cpu"):
71
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
72
+ state_dict = checkpoint.get("state_dict", checkpoint)
73
+ if next(iter(state_dict)).startswith("module."):
74
+ state_dict = {k[len("module."):]: v for k, v in state_dict.items()}
75
+ incompatible = model.load_state_dict(state_dict, strict=False)
76
+ return incompatible
77
+
78
+
79
+ def get_tokenizer(model_name: str = "", context_length: Optional[int] = None, **kwargs):
80
+ config = get_model_config(model_name) or {}
81
+ text_cfg = config.get("text_cfg", {})
82
+ if context_length is None:
83
+ context_length = text_cfg.get("context_length", DEFAULT_CONTEXT_LENGTH)
84
+ return SimpleTokenizer(context_length=context_length, **kwargs)
85
+
86
+
87
+ def get_input_dtype(precision: str):
88
+ input_dtype = None
89
+ if precision in ("bf16", "pure_bf16"):
90
+ input_dtype = torch.bfloat16
91
+ elif precision in ("fp16", "pure_fp16"):
92
+ input_dtype = torch.float16
93
+ return input_dtype
src/open_clip/model.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import copy
6
+ import logging
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+ from torch.utils.checkpoint import checkpoint
16
+ from functools import partial
17
+
18
+ from .transformer import (
19
+ LayerNormFp32,
20
+ LayerNorm,
21
+ QuickGELU,
22
+ Attention,
23
+ VisionTransformer,
24
+ TextTransformer,
25
+ text_global_pool,
26
+ lock_text_tower,
27
+ to_2tuple,
28
+ )
29
+
30
+
31
+ @dataclass
32
+ class CLIPVisionCfg:
33
+ layers: Union[Tuple[int, int, int, int], int] = 12
34
+ width: int = 768
35
+ head_width: int = 64
36
+ mlp_ratio: float = 4.0
37
+ patch_size: int = 16
38
+ image_size: Union[Tuple[int, int], int] = 224
39
+
40
+ ls_init_value: Optional[float] = None # layer scale initial value
41
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
42
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
43
+ attn_pooler_queries: int = 256 # n_queries for attentional pooler
44
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
45
+ no_ln_pre: bool = False # disable pre transformer LayerNorm
46
+ pos_embed_type: str = 'learnable'
47
+ final_ln_after_pool: bool = False # apply final LayerNorm after pooling
48
+ pool_type: str = 'tok'
49
+ output_tokens: bool = False
50
+ act_kwargs: Optional[dict] = None
51
+ norm_kwargs: Optional[dict] = None
52
+
53
+ # Custom attention block settings
54
+ block_type: Optional[str] = None # attention block type ('default', 'custom'), auto-selects 'custom' if any below features enabled
55
+ qk_norm: bool = False # apply layer norm to q and k in attention
56
+ scaled_cosine_attn: bool = False # use scaled cosine attention
57
+ scale_heads: bool = False # learnable head-specific scale applied to attention logits
58
+ scale_attn_inner: bool = False # apply layer norm on attention context, before output projection
59
+ scale_attn: bool = False # apply layer norm after full attention block
60
+ scale_fc: bool = False # apply layer norm in MLP block
61
+
62
+ timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
63
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
64
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
65
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
66
+ timm_proj_bias: bool = False # enable bias final projection
67
+ timm_drop: float = 0. # head dropout
68
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
69
+
70
+
71
+ @dataclass
72
+ class CLIPTextCfg:
73
+ context_length: int = 77
74
+ vocab_size: int = 49408
75
+ hf_tokenizer_name: Optional[str] = None
76
+ tokenizer_mode: Optional[str] = None
77
+ tokenizer_kwargs: Optional[dict] = None
78
+
79
+ width: int = 512
80
+ heads: int = 8
81
+ layers: int = 12
82
+ mlp_ratio: float = 4.0
83
+ ls_init_value: Optional[float] = None # layer scale initial value
84
+ embed_cls: bool = False
85
+ pad_id: int = 0
86
+ eos_id: int = 2 # only used for when pool_type == 'eos', must match tokenizer eos
87
+ no_causal_mask: bool = False # disable causal masking
88
+ final_ln_after_pool: bool = False # apply final LayerNorm after pooling
89
+ pool_type: str = 'argmax'
90
+ proj_bias: bool = False
91
+ proj_type: str = 'linear' # control final text projection, 'none' forces no projection
92
+ output_tokens: bool = False
93
+ act_kwargs: dict = None
94
+ norm_kwargs: dict = None
95
+
96
+ # Custom attention block settings
97
+ block_type: Optional[str] = None # attention block type ('default', 'custom'), auto-selects 'custom' if any custom features enabled
98
+ qk_norm: bool = False # apply layer norm to q and k in attention
99
+ scaled_cosine_attn: bool = False # use scaled cosine attention
100
+ scale_heads: bool = False # learnable head-specific scale applied to attention logits
101
+ scale_attn_inner: bool = False # apply layer norm on attention context, before output projection
102
+ scale_attn: bool = False # apply layer norm after full attention block
103
+ scale_fc: bool = False # apply layer norm in MLP block
104
+
105
+ # HuggingFace specific text tower config
106
+ hf_model_name: Optional[str] = None
107
+ hf_model_pretrained: bool = True
108
+ hf_proj_type: str = 'mlp'
109
+ hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
110
+ special_tokens_to_add: Optional[dict] = None # special tokens to add to tokenizer (e.g., for Pythia)
111
+
112
+
113
+ def get_cast_dtype(precision: str):
114
+ cast_dtype = None
115
+ if precision == 'bf16':
116
+ cast_dtype = torch.bfloat16
117
+ elif precision == 'fp16':
118
+ cast_dtype = torch.float16
119
+ return cast_dtype
120
+
121
+
122
+ def get_input_dtype(precision: str):
123
+ input_dtype = None
124
+ if precision in ('bf16', 'pure_bf16'):
125
+ input_dtype = torch.bfloat16
126
+ elif precision in ('fp16', 'pure_fp16'):
127
+ input_dtype = torch.float16
128
+ return input_dtype
129
+
130
+
131
+ def _build_vision_tower(
132
+ embed_dim: int,
133
+ vision_cfg: CLIPVisionCfg,
134
+ quick_gelu: bool = False,
135
+ cast_dtype: Optional[torch.dtype] = None
136
+ ):
137
+ if isinstance(vision_cfg, dict):
138
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
139
+
140
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
141
+ # memory efficient in recent PyTorch releases (>= 1.10).
142
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
143
+ act_layer = QuickGELU if quick_gelu else nn.GELU
144
+
145
+ if vision_cfg.timm_model_name:
146
+ from .timm_model import TimmModel
147
+ visual = TimmModel(
148
+ vision_cfg.timm_model_name,
149
+ pretrained=vision_cfg.timm_model_pretrained,
150
+ pool=vision_cfg.timm_pool,
151
+ proj=vision_cfg.timm_proj,
152
+ proj_bias=vision_cfg.timm_proj_bias,
153
+ drop=vision_cfg.timm_drop,
154
+ drop_path=vision_cfg.timm_drop_path,
155
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
156
+ embed_dim=embed_dim,
157
+ image_size=vision_cfg.image_size,
158
+ )
159
+ elif isinstance(vision_cfg.layers, (tuple, list)):
160
+ from .modified_resnet import ModifiedResNet
161
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
162
+ visual = ModifiedResNet(
163
+ layers=vision_cfg.layers,
164
+ output_dim=embed_dim,
165
+ heads=vision_heads,
166
+ image_size=vision_cfg.image_size,
167
+ width=vision_cfg.width,
168
+ )
169
+ else:
170
+ vision_heads = vision_cfg.width // vision_cfg.head_width
171
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
172
+ if vision_cfg.norm_kwargs:
173
+ norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
174
+ if vision_cfg.act_kwargs is not None:
175
+ act_layer = partial(act_layer, **vision_cfg.act_kwargs)
176
+
177
+ visual = VisionTransformer(
178
+ image_size=vision_cfg.image_size,
179
+ patch_size=vision_cfg.patch_size,
180
+ width=vision_cfg.width,
181
+ layers=vision_cfg.layers,
182
+ heads=vision_heads,
183
+ mlp_ratio=vision_cfg.mlp_ratio,
184
+ ls_init_value=vision_cfg.ls_init_value,
185
+ patch_dropout=vision_cfg.patch_dropout,
186
+ attentional_pool=vision_cfg.attentional_pool,
187
+ attn_pooler_queries=vision_cfg.attn_pooler_queries,
188
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
189
+ pos_embed_type=vision_cfg.pos_embed_type,
190
+ no_ln_pre=vision_cfg.no_ln_pre,
191
+ final_ln_after_pool=vision_cfg.final_ln_after_pool,
192
+ pool_type=vision_cfg.pool_type,
193
+ output_tokens=vision_cfg.output_tokens,
194
+ output_dim=embed_dim,
195
+ act_layer=act_layer,
196
+ norm_layer=norm_layer,
197
+ block_type=vision_cfg.block_type,
198
+ qk_norm=vision_cfg.qk_norm,
199
+ scaled_cosine_attn=vision_cfg.scaled_cosine_attn,
200
+ scale_heads=vision_cfg.scale_heads,
201
+ scale_attn_inner=vision_cfg.scale_attn_inner,
202
+ scale_attn=vision_cfg.scale_attn,
203
+ scale_fc=vision_cfg.scale_fc,
204
+ )
205
+
206
+ return visual
207
+
208
+
209
+
210
+
211
+
212
+ def _build_text_tower(
213
+ embed_dim: int,
214
+ text_cfg: CLIPTextCfg,
215
+ quick_gelu: bool = False,
216
+ cast_dtype: Optional[torch.dtype] = None,
217
+ ):
218
+ if isinstance(text_cfg, dict):
219
+ text_cfg = CLIPTextCfg(**text_cfg)
220
+
221
+ if text_cfg.hf_model_name:
222
+ from .hf_model import HFTextEncoder
223
+ text = HFTextEncoder(
224
+ text_cfg.hf_model_name,
225
+ output_dim=embed_dim,
226
+ proj_type=text_cfg.hf_proj_type,
227
+ pooler_type=text_cfg.hf_pooler_type,
228
+ pretrained=text_cfg.hf_model_pretrained,
229
+ output_tokens=text_cfg.output_tokens,
230
+ )
231
+
232
+ # Handle special tokens if configured (e.g., for Pythia)
233
+ special_tokens_cfg = getattr(text_cfg, 'special_tokens_to_add', None)
234
+ if special_tokens_cfg:
235
+ from transformers import AutoTokenizer
236
+ import logging
237
+
238
+ # Load tokenizer from local cache only (ensures consistency with get_tokenizer())
239
+ # get_tokenizer() is called first and downloads/caches, we just reuse that exact version
240
+ tokenizer = AutoTokenizer.from_pretrained(
241
+ text_cfg.hf_model_name,
242
+ local_files_only=True
243
+ )
244
+
245
+ # Store original vocab size before adding new tokens
246
+ # This is needed to unfreeze new token embeddings after locking
247
+ original_vocab_size = len(tokenizer)
248
+ text.original_vocab_size = original_vocab_size
249
+
250
+ tokenizer.add_special_tokens(special_tokens_cfg)
251
+
252
+ # Resize model embeddings to accommodate new tokens
253
+ # pad_to_multiple_of=64 ensures optimal Tensor Core performance for embedding lookups
254
+ new_vocab_size = len(tokenizer)
255
+ text.transformer.resize_token_embeddings(new_vocab_size, pad_to_multiple_of=64)
256
+
257
+ # Store token IDs for use in forward pass
258
+ if 'additional_special_tokens' in special_tokens_cfg:
259
+ for token in special_tokens_cfg['additional_special_tokens']:
260
+ if token == '<coca_cls>':
261
+ text.coca_cls_token_id = tokenizer.convert_tokens_to_ids(token)
262
+
263
+ if 'pad_token' in special_tokens_cfg:
264
+ text.config.pad_token_id = tokenizer.pad_token_id
265
+ text.pad_token_id = tokenizer.pad_token_id
266
+
267
+ text.config.vocab_size = new_vocab_size
268
+ text.vocab_size = new_vocab_size
269
+
270
+ logging.info(f"Added special tokens to {text_cfg.hf_model_name}:")
271
+ logging.info(f" Original vocab size: {original_vocab_size}")
272
+ logging.info(f" New vocab size: {new_vocab_size}")
273
+ logging.info(f" Added {new_vocab_size - original_vocab_size} new tokens")
274
+ if text.coca_cls_token_id is not None:
275
+ logging.info(f" CoCa CLS token ID: {text.coca_cls_token_id}")
276
+ if text.pad_token_id is not None:
277
+ logging.info(f" Pad token ID: {text.pad_token_id}")
278
+ else:
279
+ act_layer = QuickGELU if quick_gelu else nn.GELU
280
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
281
+ if text_cfg.norm_kwargs:
282
+ norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
283
+ if text_cfg.act_kwargs is not None:
284
+ act_layer = partial(act_layer, **text_cfg.act_kwargs)
285
+
286
+ text = TextTransformer(
287
+ context_length=text_cfg.context_length,
288
+ vocab_size=text_cfg.vocab_size,
289
+ width=text_cfg.width,
290
+ heads=text_cfg.heads,
291
+ layers=text_cfg.layers,
292
+ mlp_ratio=text_cfg.mlp_ratio,
293
+ ls_init_value=text_cfg.ls_init_value,
294
+ output_dim=embed_dim,
295
+ embed_cls=text_cfg.embed_cls,
296
+ no_causal_mask=text_cfg.no_causal_mask,
297
+ pad_id=text_cfg.pad_id,
298
+ eos_id=text_cfg.eos_id,
299
+ pool_type=text_cfg.pool_type,
300
+ proj_type=text_cfg.proj_type,
301
+ proj_bias=text_cfg.proj_bias,
302
+ output_tokens=text_cfg.output_tokens,
303
+ act_layer=act_layer,
304
+ norm_layer=norm_layer,
305
+ block_type=text_cfg.block_type,
306
+ qk_norm=text_cfg.qk_norm,
307
+ scaled_cosine_attn=text_cfg.scaled_cosine_attn,
308
+ scale_heads=text_cfg.scale_heads,
309
+ scale_attn_inner=text_cfg.scale_attn_inner,
310
+ scale_attn=text_cfg.scale_attn,
311
+ scale_fc=text_cfg.scale_fc,
312
+ )
313
+ return text
314
+
315
+
316
+ class CLIP(nn.Module):
317
+ output_dict: torch.jit.Final[bool]
318
+
319
+ def __init__(
320
+ self,
321
+ embed_dim: int,
322
+ vision_cfg: CLIPVisionCfg,
323
+ text_cfg: CLIPTextCfg,
324
+ quick_gelu: bool = False,
325
+ init_logit_scale: float = np.log(1 / 0.07),
326
+ init_logit_bias: Optional[float] = None,
327
+ nonscalar_logit_scale: bool = False,
328
+ cast_dtype: Optional[torch.dtype] = None,
329
+ output_dict: bool = False,
330
+ ):
331
+ super().__init__()
332
+ self.output_dict = output_dict
333
+
334
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
335
+
336
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
337
+ self.transformer = text.transformer
338
+ self.context_length = text.context_length
339
+ self.vocab_size = text.vocab_size
340
+ self.token_embedding = text.token_embedding
341
+ self.positional_embedding = text.positional_embedding
342
+ self.ln_final = text.ln_final
343
+ self.text_projection = text.text_projection
344
+ self.text_pool_type = text.pool_type
345
+ self.text_eos_id = text.eos_id
346
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
347
+
348
+ lshape = [1] if nonscalar_logit_scale else []
349
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
350
+ if init_logit_bias is not None:
351
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
352
+ else:
353
+ self.logit_bias = None
354
+
355
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
356
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
357
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
358
+
359
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
360
+ assert freeze_layer_norm, 'Unfreezing LayerNorm is not supported. LayerNorm treated like other weights.'
361
+ lock_text_tower(self, unlocked_layers)
362
+
363
+ @torch.jit.ignore
364
+ def set_grad_checkpointing(self, enable=True):
365
+ self.visual.set_grad_checkpointing(enable)
366
+ self.transformer.grad_checkpointing = enable
367
+
368
+ @torch.jit.ignore
369
+ def no_weight_decay(self):
370
+ # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
371
+ no_wd = {'positional_embedding'}
372
+ if hasattr(self.visual, 'no_weight_decay'):
373
+ for n in self.visual.no_weight_decay():
374
+ no_wd.add('visual.' + n)
375
+ return no_wd
376
+
377
+ def encode_image(self, image, normalize: bool = False):
378
+ features = self.visual(image)
379
+ return F.normalize(features, dim=-1) if normalize else features
380
+
381
+ def encode_text(self, text, normalize: bool = False):
382
+ cast_dtype = self.transformer.get_cast_dtype()
383
+
384
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
385
+
386
+ x = x + self.positional_embedding.to(cast_dtype)
387
+ x = self.transformer(x, attn_mask=self.attn_mask)
388
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
389
+ x = text_global_pool(x, text, self.text_pool_type, eos_token_id=getattr(self, "text_eos_id", None))
390
+ if self.text_projection is not None:
391
+ if isinstance(self.text_projection, nn.Linear):
392
+ x = self.text_projection(x)
393
+ else:
394
+ x = x @ self.text_projection
395
+
396
+ return F.normalize(x, dim=-1) if normalize else x
397
+
398
+ def get_logits(self, image, text):
399
+ image_features = self.encode_image(image, normalize=True)
400
+ text_features = self.encode_text(text, normalize=True)
401
+ image_logits = self.logit_scale.exp() * image_features @ text_features.T
402
+ if self.logit_bias is not None:
403
+ image_logits += self.logit_bias
404
+ text_logits = image_logits.T
405
+ return image_logits, text_logits
406
+
407
+ def forward_intermediates(
408
+ self,
409
+ image: Optional[torch.Tensor] = None,
410
+ text: Optional[torch.Tensor] = None,
411
+ image_indices: Optional[Union[int, List[int]]] = None,
412
+ text_indices: Optional[Union[int, List[int]]] = None,
413
+ stop_early: bool = False,
414
+ normalize: bool = True,
415
+ normalize_intermediates: bool = False,
416
+ intermediates_only: bool = False,
417
+ image_output_fmt: str = 'NCHW',
418
+ image_output_extra_tokens: bool = False,
419
+ text_output_fmt: str = 'NLC',
420
+ text_output_extra_tokens: bool = False,
421
+ output_logits: bool = False,
422
+ output_logit_scale_bias: bool = False,
423
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
424
+ """ Forward features that returns intermediates.
425
+
426
+ Args:
427
+ image: Input image tensor
428
+ text: Input text tensor
429
+ image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
430
+ text_indices: Take last n blocks if int, all if None, select matching indices if sequence
431
+ stop_early: Stop iterating over blocks when last desired intermediate hit
432
+ normalize_intermediates: Apply final norm layer to all intermediates
433
+ normalize: L2 Normalize final features
434
+ intermediates_only: Only return intermediate features, do not return final features
435
+ image_output_fmt: Shape of intermediate image feature outputs
436
+ image_output_extra_tokens: Return both prefix and spatial intermediate tokens
437
+ text_output_fmt: Shape of intermediate text feature outputs (ignored for this model)
438
+ text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model)
439
+ output_logits: Include logits in output
440
+ output_logit_scale_bias: Include the logit scale bias in the output
441
+ Returns:
442
+
443
+ """
444
+ output = {}
445
+ if intermediates_only:
446
+ # intermediates only disables final feature normalization, and include logits
447
+ normalize = False
448
+ output_logits = False
449
+ if output_logits:
450
+ assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
451
+
452
+ if image is not None:
453
+ image_output = self.visual.forward_intermediates(
454
+ image,
455
+ indices=image_indices,
456
+ stop_early=stop_early,
457
+ normalize_intermediates=normalize_intermediates,
458
+ intermediates_only=intermediates_only,
459
+ output_fmt=image_output_fmt,
460
+ output_extra_tokens=image_output_extra_tokens,
461
+ )
462
+ if normalize and "image_features" in image_output:
463
+ image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
464
+ output.update(image_output)
465
+
466
+ if text is not None:
467
+ cast_dtype = self.transformer.get_cast_dtype()
468
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
469
+ x = x + self.positional_embedding.to(cast_dtype)
470
+ x, intermediates = self.transformer.forward_intermediates(
471
+ x,
472
+ attn_mask=self.attn_mask,
473
+ indices=text_indices
474
+ )
475
+ if normalize_intermediates:
476
+ intermediates = [self.ln_final(xi) for xi in intermediates]
477
+
478
+ # NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens
479
+ output["text_intermediates"] = intermediates
480
+
481
+ if not intermediates_only:
482
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
483
+ x = text_global_pool(x, text, self.text_pool_type, eos_token_id=getattr(self, "text_eos_id", None))
484
+ if self.text_projection is not None:
485
+ if isinstance(self.text_projection, nn.Linear):
486
+ x = self.text_projection(x)
487
+ else:
488
+ x = x @ self.text_projection
489
+ if normalize:
490
+ x = F.normalize(x, dim=-1)
491
+ output["text_features"] = x
492
+
493
+ logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
494
+
495
+ if output_logits:
496
+ image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
497
+ if self.logit_bias is not None:
498
+ image_logits += self.logit_bias
499
+ text_logits = image_logits.T
500
+ output["image_logits"] = image_logits
501
+ output["text_logits"] = text_logits
502
+
503
+ if output_logit_scale_bias:
504
+ output["logit_scale"] = logit_scale_exp
505
+ if self.logit_bias is not None:
506
+ output['logit_bias'] = self.logit_bias
507
+
508
+ return output
509
+
510
+ def forward(
511
+ self,
512
+ image: Optional[torch.Tensor] = None,
513
+ text: Optional[torch.Tensor] = None,
514
+ ):
515
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
516
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
517
+
518
+ if self.output_dict:
519
+ out_dict = {
520
+ "image_features": image_features,
521
+ "text_features": text_features,
522
+ "logit_scale": self.logit_scale.exp()
523
+ }
524
+ if self.logit_bias is not None:
525
+ out_dict['logit_bias'] = self.logit_bias
526
+ return out_dict
527
+
528
+ if self.logit_bias is not None:
529
+ return image_features, text_features, self.logit_scale.exp(), self.logit_bias
530
+ return image_features, text_features, self.logit_scale.exp()
531
+
532
+
533
+ class CustomTextCLIP(nn.Module):
534
+ output_dict: torch.jit.Final[bool]
535
+
536
+ def __init__(
537
+ self,
538
+ embed_dim: int,
539
+ vision_cfg: CLIPVisionCfg,
540
+ text_cfg: CLIPTextCfg,
541
+ quick_gelu: bool = False,
542
+ init_logit_scale: float = np.log(1 / 0.07),
543
+ init_logit_bias: Optional[float] = None,
544
+ nonscalar_logit_scale: bool = False,
545
+ cast_dtype: Optional[torch.dtype] = None,
546
+ output_dict: bool = False,
547
+ ):
548
+ super().__init__()
549
+ self.output_dict = output_dict
550
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
551
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
552
+ self.context_length = self.text.context_length
553
+ self.vocab_size = self.text.vocab_size
554
+
555
+ lshape = [1] if nonscalar_logit_scale else []
556
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
557
+ if init_logit_bias is not None:
558
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
559
+ else:
560
+ self.logit_bias = None
561
+
562
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
563
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
564
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
565
+
566
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
567
+ self.text.lock(unlocked_layers, freeze_layer_norm)
568
+
569
+ @torch.jit.ignore
570
+ def set_grad_checkpointing(self, enable=True):
571
+ self.visual.set_grad_checkpointing(enable)
572
+ self.text.set_grad_checkpointing(enable)
573
+
574
+ @torch.jit.ignore
575
+ def no_weight_decay(self):
576
+ # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
577
+ no_wd = set()
578
+ if hasattr(self.visual, 'no_weight_decay'):
579
+ for n in self.visual.no_weight_decay():
580
+ no_wd.add('visual.' + n)
581
+ if hasattr(self.text, 'no_weight_decay'):
582
+ for n in self.text.no_weight_decay():
583
+ no_wd.add('text.' + n)
584
+ return no_wd
585
+
586
+ def encode_image(self, image, normalize: bool = False):
587
+ features = self.visual(image)
588
+ return F.normalize(features, dim=-1) if normalize else features
589
+
590
+ def encode_text(self, text, normalize: bool = False):
591
+ features = self.text(text)
592
+ return F.normalize(features, dim=-1) if normalize else features
593
+
594
+ def get_logits(self, image, text):
595
+ image_features = self.encode_image(image, normalize=True)
596
+ text_features = self.encode_text(text, normalize=True)
597
+ image_logits = self.logit_scale.exp() * image_features @ text_features.T
598
+ if self.logit_bias is not None:
599
+ image_logits += self.logit_bias
600
+ text_logits = image_logits.T
601
+ return image_logits, text_logits
602
+
603
+ def forward_intermediates(
604
+ self,
605
+ image: Optional[torch.Tensor] = None,
606
+ text: Optional[torch.Tensor] = None,
607
+ image_indices: Optional[Union[int, List[int]]] = None,
608
+ text_indices: Optional[Union[int, List[int]]] = None,
609
+ stop_early: bool = False,
610
+ normalize: bool = True,
611
+ normalize_intermediates: bool = False,
612
+ intermediates_only: bool = False,
613
+ image_output_fmt: str = 'NCHW',
614
+ image_output_extra_tokens: bool = False,
615
+ text_output_fmt: str = 'NLC',
616
+ text_output_extra_tokens: bool = False,
617
+ output_logits: bool = False,
618
+ output_logit_scale_bias: bool = False,
619
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
620
+ """ Forward features that returns intermediates.
621
+
622
+ Args:
623
+ image: Input image tensor
624
+ text: Input text tensor
625
+ image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
626
+ text_indices: Take last n blocks if int, all if None, select matching indices if sequence
627
+ stop_early: Stop iterating over blocks when last desired intermediate hit
628
+ normalize: L2 Normalize final image and text features (if present)
629
+ normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
630
+ intermediates_only: Only return intermediate features, do not return final features
631
+ image_output_fmt: Shape of intermediate image feature outputs
632
+ image_output_extra_tokens: Return both prefix and spatial intermediate tokens
633
+ text_output_fmt: Shape of intermediate text feature outputs
634
+ text_output_extra_tokens: Return both prefix and spatial intermediate tokens
635
+ output_logits: Include logits in output
636
+ output_logit_scale_bias: Include the logit scale bias in the output
637
+ Returns:
638
+
639
+ """
640
+ output = {}
641
+ if intermediates_only:
642
+ # intermediates only disables final feature normalization, and include logits
643
+ normalize = False
644
+ output_logits = False
645
+ if output_logits:
646
+ assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
647
+
648
+ if image is not None:
649
+ image_output = self.visual.forward_intermediates(
650
+ image,
651
+ indices=image_indices,
652
+ stop_early=stop_early,
653
+ normalize_intermediates=normalize_intermediates,
654
+ intermediates_only=intermediates_only,
655
+ output_fmt=image_output_fmt,
656
+ output_extra_tokens=image_output_extra_tokens,
657
+ )
658
+ if normalize and "image_features" in image_output:
659
+ image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
660
+ output.update(image_output)
661
+
662
+ if text is not None:
663
+ text_output = self.text.forward_intermediates(
664
+ text,
665
+ indices=text_indices,
666
+ stop_early=stop_early,
667
+ normalize_intermediates=normalize_intermediates,
668
+ intermediates_only=intermediates_only,
669
+ output_fmt=text_output_fmt,
670
+ output_extra_tokens=text_output_extra_tokens,
671
+ )
672
+ if normalize and "text_features" in text_output:
673
+ text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
674
+ output.update(text_output)
675
+
676
+ logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
677
+
678
+ if output_logits:
679
+ image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
680
+ if self.logit_bias is not None:
681
+ image_logits += self.logit_bias
682
+ text_logits = image_logits.T
683
+ output["image_logits"] = image_logits
684
+ output["text_logits"] = text_logits
685
+
686
+ if output_logit_scale_bias:
687
+ output["logit_scale"] = logit_scale_exp
688
+ if self.logit_bias is not None:
689
+ output['logit_bias'] = self.logit_bias
690
+
691
+ return output
692
+
693
+ def forward(
694
+ self,
695
+ image: Optional[torch.Tensor] = None,
696
+ text: Optional[torch.Tensor] = None,
697
+ ):
698
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
699
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
700
+
701
+ if self.output_dict:
702
+ out_dict = {
703
+ "image_features": image_features,
704
+ "text_features": text_features,
705
+ "logit_scale": self.logit_scale.exp()
706
+ }
707
+ if self.logit_bias is not None:
708
+ out_dict['logit_bias'] = self.logit_bias
709
+ return out_dict
710
+
711
+ if self.logit_bias is not None:
712
+ return image_features, text_features, self.logit_scale.exp(), self.logit_bias
713
+ return image_features, text_features, self.logit_scale.exp()
714
+
715
+
716
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
717
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
718
+
719
+ def _convert_weights(l):
720
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
721
+ l.weight.data = l.weight.data.to(dtype)
722
+ if l.bias is not None:
723
+ l.bias.data = l.bias.data.to(dtype)
724
+
725
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
726
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
727
+ tensor = getattr(l, attr, None)
728
+ if tensor is not None:
729
+ tensor.data = tensor.data.to(dtype)
730
+
731
+ if isinstance(l, (CLIP, TextTransformer)):
732
+ # convert text nn.Parameter projections
733
+ attr = getattr(l, "text_projection", None)
734
+ if attr is not None:
735
+ attr.data = attr.data.to(dtype)
736
+
737
+ if isinstance(l, VisionTransformer):
738
+ # convert vision nn.Parameter projections
739
+ attr = getattr(l, "proj", None)
740
+ if attr is not None:
741
+ attr.data = attr.data.to(dtype)
742
+
743
+ model.apply(_convert_weights)
744
+
745
+
746
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
747
+
748
+
749
+ # used to maintain checkpoint compatibility
750
+ def convert_to_custom_text_state_dict(state_dict: dict):
751
+ if 'text_projection' in state_dict:
752
+ # old format state_dict, move text tower -> .text
753
+ new_state_dict = {}
754
+ for k, v in state_dict.items():
755
+ if any(k.startswith(p) for p in (
756
+ 'text_projection',
757
+ 'positional_embedding',
758
+ 'token_embedding',
759
+ 'transformer',
760
+ 'ln_final',
761
+ )):
762
+ k = 'text.' + k
763
+ new_state_dict[k] = v
764
+ return new_state_dict
765
+ return state_dict
766
+
767
+
768
+ def build_model_from_openai_state_dict(
769
+ state_dict: dict,
770
+ quick_gelu=True,
771
+ cast_dtype=torch.float16,
772
+ ):
773
+ vit = "visual.proj" in state_dict
774
+
775
+ if vit:
776
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
777
+ vision_layers = len(
778
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
779
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
780
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
781
+ image_size = vision_patch_size * grid_size
782
+ else:
783
+ counts: list = [
784
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
785
+ vision_layers = tuple(counts)
786
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
787
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
788
+ vision_patch_size = None
789
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
790
+ image_size = output_width * 32
791
+
792
+ embed_dim = state_dict["text_projection"].shape[1]
793
+ context_length = state_dict["positional_embedding"].shape[0]
794
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
795
+ transformer_width = state_dict["ln_final.weight"].shape[0]
796
+ transformer_heads = transformer_width // 64
797
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
798
+
799
+ vision_cfg = CLIPVisionCfg(
800
+ layers=vision_layers,
801
+ width=vision_width,
802
+ patch_size=vision_patch_size,
803
+ image_size=image_size,
804
+ )
805
+ text_cfg = CLIPTextCfg(
806
+ context_length=context_length,
807
+ vocab_size=vocab_size,
808
+ width=transformer_width,
809
+ heads=transformer_heads,
810
+ layers=transformer_layers,
811
+ )
812
+ model = CLIP(
813
+ embed_dim,
814
+ vision_cfg=vision_cfg,
815
+ text_cfg=text_cfg,
816
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
817
+ cast_dtype=cast_dtype,
818
+ )
819
+
820
+ for key in ["input_resolution", "context_length", "vocab_size"]:
821
+ state_dict.pop(key, None)
822
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
823
+ model.load_state_dict(state_dict)
824
+ return model.eval()
825
+
826
+
827
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
828
+ model.eval()
829
+ image_size = model.visual.image_size
830
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
831
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
832
+ model = torch.jit.trace_module(
833
+ model,
834
+ inputs=dict(
835
+ forward=(example_images, example_text),
836
+ encode_text=(example_text,),
837
+ encode_image=(example_images,)
838
+ ))
839
+ model.visual.image_size = image_size
840
+ return model
841
+
842
+
843
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
844
+ # Rescale the grid of position embeddings when loading from state_dict
845
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
846
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
847
+ return
848
+ grid_size = to_2tuple(model.visual.grid_size)
849
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
850
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
851
+ if new_seq_len == old_pos_embed.shape[0]:
852
+ return
853
+
854
+ if extra_tokens:
855
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
856
+ else:
857
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
858
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
859
+
860
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
861
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
862
+ pos_emb_img = F.interpolate(
863
+ pos_emb_img,
864
+ size=grid_size,
865
+ mode=interpolation,
866
+ antialias=antialias,
867
+ align_corners=False,
868
+ )
869
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
870
+ if pos_emb_tok is not None:
871
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
872
+ else:
873
+ new_pos_embed = pos_emb_img
874
+ state_dict['visual.positional_embedding'] = new_pos_embed
875
+
876
+
877
+ def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
878
+ pos_embed_key = 'positional_embedding' if 'positional_embedding' in state_dict else 'text.positional_embedding'
879
+ old_pos_embed = state_dict.get(pos_embed_key, None)
880
+ if old_pos_embed is None:
881
+ return
882
+ # FIXME add support for text cls_token
883
+ model_pos_embed = getattr(model, 'positional_embedding', None)
884
+ if model_pos_embed is None:
885
+ model_pos_embed = getattr(model.text, 'positional_embedding', None)
886
+
887
+ old_num_pos = old_pos_embed.shape[0]
888
+ old_width = old_pos_embed.shape[1]
889
+ num_pos = model_pos_embed.shape[0]
890
+ width = model_pos_embed.shape[1]
891
+ assert old_width == width, 'text pos_embed width changed!'
892
+ if old_num_pos == num_pos:
893
+ return
894
+
895
+ logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
896
+ old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
897
+ old_pos_embed = F.interpolate(
898
+ old_pos_embed,
899
+ size=num_pos,
900
+ mode=interpolation,
901
+ antialias=antialias,
902
+ align_corners=False,
903
+ )
904
+ old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
905
+ new_pos_embed = old_pos_embed
906
+
907
+ state_dict[pos_embed_key] = new_pos_embed
908
+
909
+
910
+ def get_model_preprocess_cfg(model):
911
+ module = getattr(model, 'visual', model)
912
+ preprocess_cfg = getattr(module, 'preprocess_cfg', {})
913
+ if not preprocess_cfg:
914
+ # use separate legacy attributes if preprocess_cfg dict not found
915
+ size = getattr(module, 'image_size')
916
+ if size is not None:
917
+ preprocess_cfg['size'] = size
918
+ mean = getattr(module, 'image_mean', None)
919
+ if mean is not None:
920
+ preprocess_cfg['mean'] = mean
921
+ std = getattr(module, 'image_std', None)
922
+ if std is not None:
923
+ preprocess_cfg['std'] = std
924
+ return preprocess_cfg
925
+
926
+
927
+ def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
928
+ module = getattr(model, 'visual', model)
929
+ module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
930
+ module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
931
+ module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
932
+
933
+
934
+ def get_model_tokenize_cfg(model):
935
+ module = getattr(model, 'text', model)
936
+ cfg = {}
937
+ context_length = getattr(module, 'context_length', None)
938
+ if context_length is not None:
939
+ cfg['context_length'] = context_length
940
+ vocab_size = getattr(module, 'vocab_size', None)
941
+ if vocab_size is not None:
942
+ cfg['vocab_size'] = vocab_size
943
+ return cfg
src/open_clip/model_configs/sleep_coca_base_dualtransformer.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "multimodal_cfg": {
4
+ "width": 768,
5
+ "context_length": 256,
6
+ "mlp_ratio": 4,
7
+ "layers": 12,
8
+ "heads": 12
9
+ },
10
+ "biosignals_cfg": {
11
+ "architecture": "pure_transformer",
12
+ "input_channels": 10,
13
+ "signal_length": 1920,
14
+ "sampling_rate": 64,
15
+ "patch_size": 16,
16
+ "conv_embed_dim": 256,
17
+ "num_temporal_layers": 1,
18
+ "activation": "swiglu",
19
+ "norm_type": "rmsnorm",
20
+ "mlp_bias": false,
21
+ "share_channel_rope": true,
22
+ "transformer_layers": 3,
23
+ "transformer_width": 768,
24
+ "transformer_heads": 12,
25
+ "mlp_ratio": 3.0,
26
+ "pool_type": "attn",
27
+ "dropout": 0.1,
28
+ "decoder_tokens": 32
29
+ },
30
+ "text_cfg": {
31
+ "context_length": 256,
32
+ "vocab_size": 49408,
33
+ "layers": 12,
34
+ "heads": 12,
35
+ "width": 768,
36
+ "embed_cls": true,
37
+ "output_tokens": true
38
+ },
39
+ "custom_text": true,
40
+ "prefix_len": 1,
41
+ "num_caption_channels": 12,
42
+ "decoder_type": "cross_attention"
43
+ }
44
+
src/open_clip/tokenizer.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ import random
9
+ import string
10
+ from functools import lru_cache, partial
11
+ from typing import Callable, List, Optional, Union, Dict
12
+ import warnings
13
+
14
+ import ftfy
15
+ import numpy as np
16
+ import regex as re
17
+ import torch
18
+
19
+ # https://stackoverflow.com/q/62691279
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ _nltk_init = False
22
+
23
+ DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
24
+
25
+
26
+ @lru_cache()
27
+ def default_bpe():
28
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
29
+
30
+
31
+ @lru_cache()
32
+ def bytes_to_unicode():
33
+ """
34
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
35
+ The reversible bpe codes work on unicode strings.
36
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
37
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
38
+ This is a significant percentage of your normal, say, 32K bpe vocab.
39
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
40
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
41
+ """
42
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
43
+ cs = bs[:]
44
+ n = 0
45
+ for b in range(2**8):
46
+ if b not in bs:
47
+ bs.append(b)
48
+ cs.append(2**8+n)
49
+ n += 1
50
+ cs = [chr(n) for n in cs]
51
+ return dict(zip(bs, cs))
52
+
53
+
54
+ def get_pairs(word):
55
+ """Return set of symbol pairs in a word.
56
+ Word is represented as tuple of symbols (symbols being variable-length strings).
57
+ """
58
+ pairs = set()
59
+ prev_char = word[0]
60
+ for char in word[1:]:
61
+ pairs.add((prev_char, char))
62
+ prev_char = char
63
+ return pairs
64
+
65
+
66
+ def basic_clean(text):
67
+ text = ftfy.fix_text(text)
68
+ text = html.unescape(html.unescape(text))
69
+ return text.strip()
70
+
71
+
72
+ def whitespace_clean(text):
73
+ text = " ".join(text.split())
74
+ text = text.strip()
75
+ return text
76
+
77
+
78
+ def _clean_canonicalize(x):
79
+ # basic, remove whitespace, remove punctuation, lower case
80
+ return canonicalize_text(basic_clean(x))
81
+
82
+
83
+ def _clean_lower(x):
84
+ # basic, remove whitespace, lower case
85
+ return whitespace_clean(basic_clean(x)).lower()
86
+
87
+
88
+ def _clean_whitespace(x):
89
+ # basic, remove whitespace
90
+ return whitespace_clean(basic_clean(x))
91
+
92
+
93
+ def get_clean_fn(type: str):
94
+ if type == 'canonicalize':
95
+ return _clean_canonicalize
96
+ elif type == 'lower':
97
+ return _clean_lower
98
+ elif type == 'whitespace':
99
+ return _clean_whitespace
100
+ else:
101
+ assert False, f"Invalid clean function ({type})."
102
+
103
+
104
+ def canonicalize_text(
105
+ text,
106
+ *,
107
+ keep_punctuation_exact_string=None,
108
+ trans_punctuation: dict = str.maketrans("", "", string.punctuation),
109
+ ):
110
+ """Returns canonicalized `text` (lowercase and punctuation removed).
111
+
112
+ From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
113
+
114
+ Args:
115
+ text: string to be canonicalized.
116
+ keep_punctuation_exact_string: If provided, then this exact string kept.
117
+ For example providing '{}' will keep any occurrences of '{}' (but will
118
+ still remove '{' and '}' that appear separately).
119
+ """
120
+ text = text.replace("_", " ")
121
+ if keep_punctuation_exact_string:
122
+ text = keep_punctuation_exact_string.join(
123
+ part.translate(trans_punctuation)
124
+ for part in text.split(keep_punctuation_exact_string)
125
+ )
126
+ else:
127
+ text = text.translate(trans_punctuation)
128
+ text = text.lower()
129
+ text = " ".join(text.split())
130
+ return text.strip()
131
+
132
+
133
+ class SimpleTokenizer(object):
134
+ def __init__(
135
+ self,
136
+ bpe_path: str = default_bpe(),
137
+ additional_special_tokens: Optional[List[str]] = None,
138
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
139
+ clean: str = 'lower',
140
+ reduction_mask: str = ''
141
+ ):
142
+ self.byte_encoder = bytes_to_unicode()
143
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
144
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
145
+ merges = merges[1:49152-256-2+1]
146
+ merges = [tuple(merge.split()) for merge in merges]
147
+ vocab = list(bytes_to_unicode().values())
148
+ vocab = vocab + [v+'</w>' for v in vocab]
149
+ for merge in merges:
150
+ vocab.append(''.join(merge))
151
+ special_tokens = ['<start_of_text>', '<end_of_text>']
152
+ if additional_special_tokens:
153
+ special_tokens += additional_special_tokens
154
+ vocab.extend(special_tokens)
155
+ self.encoder = dict(zip(vocab, range(len(vocab))))
156
+ self.decoder = {v: k for k, v in self.encoder.items()}
157
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
158
+ self.cache = {t:t for t in special_tokens}
159
+ special = "|".join(special_tokens)
160
+ self.pat = re.compile(
161
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
162
+ re.IGNORECASE,
163
+ )
164
+ self.vocab_size = len(self.encoder)
165
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
166
+ self.sot_token_id = self.all_special_ids[0]
167
+ self.eot_token_id = self.all_special_ids[1]
168
+ self.context_length = context_length
169
+ self.clean_fn = get_clean_fn(clean)
170
+ self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None
171
+
172
+ def bpe(self, token):
173
+ if token in self.cache:
174
+ return self.cache[token]
175
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
176
+ pairs = get_pairs(word)
177
+
178
+ if not pairs:
179
+ return token+'</w>'
180
+
181
+ while True:
182
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
183
+ if bigram not in self.bpe_ranks:
184
+ break
185
+ first, second = bigram
186
+ new_word = []
187
+ i = 0
188
+ while i < len(word):
189
+ try:
190
+ j = word.index(first, i)
191
+ new_word.extend(word[i:j])
192
+ i = j
193
+ except Exception:
194
+ new_word.extend(word[i:])
195
+ break
196
+
197
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
198
+ new_word.append(first+second)
199
+ i += 2
200
+ else:
201
+ new_word.append(word[i])
202
+ i += 1
203
+ new_word = tuple(new_word)
204
+ word = new_word
205
+ if len(word) == 1:
206
+ break
207
+ else:
208
+ pairs = get_pairs(word)
209
+ word = ' '.join(word)
210
+ self.cache[token] = word
211
+ return word
212
+
213
+ def encode(self, text):
214
+ bpe_tokens = []
215
+ text = self.clean_fn(text)
216
+ for token in re.findall(self.pat, text):
217
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
218
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
219
+ return bpe_tokens
220
+
221
+ def decode(self, tokens):
222
+ text = ''.join([self.decoder[token] for token in tokens])
223
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
224
+ return text
225
+
226
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:
227
+ """ Returns the tokenized representation of given input string(s)
228
+
229
+ Parameters
230
+ ----------
231
+ texts : Union[str, List[str]]
232
+ An input string or a list of input strings to tokenize
233
+ context_length : int
234
+ The context length to use; all CLIP models use 77 as the context length
235
+
236
+ Returns
237
+ -------
238
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
239
+ """
240
+ if isinstance(texts, str):
241
+ texts = [texts]
242
+
243
+ context_length = context_length or self.context_length
244
+ assert context_length, 'Please set a valid context length'
245
+
246
+ if self.reduction_fn is not None:
247
+ # use reduction strategy for tokenize if set, otherwise default to truncation below
248
+ return self.reduction_fn(
249
+ texts,
250
+ context_length=context_length,
251
+ sot_token_id=self.sot_token_id,
252
+ eot_token_id=self.eot_token_id,
253
+ encode_fn=self.encode,
254
+ )
255
+
256
+ all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]
257
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
258
+
259
+ for i, tokens in enumerate(all_tokens):
260
+ if len(tokens) > context_length:
261
+ tokens = tokens[:context_length] # Truncate
262
+ tokens[-1] = self.eot_token_id
263
+ result[i, :len(tokens)] = torch.tensor(tokens)
264
+
265
+ return result
266
+
267
+
268
+ _tokenizer = SimpleTokenizer()
269
+
270
+
271
+ def decode(output_ids: torch.Tensor):
272
+ output_ids = output_ids.cpu().numpy()
273
+ return _tokenizer.decode(output_ids)
274
+
275
+
276
+ def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:
277
+ return _tokenizer(texts, context_length=context_length)
278
+
279
+
280
+ def random_mask_tokenize(
281
+ texts: Union[str, List[str]],
282
+ context_length: int,
283
+ sot_token_id: int,
284
+ eot_token_id: int,
285
+ encode_fn: Callable,
286
+ shuffle: bool = False,
287
+ ):
288
+ all_tokens = [encode_fn(text) for text in texts]
289
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
290
+
291
+ for i, tokens in enumerate(all_tokens):
292
+ tokens = torch.tensor(tokens)
293
+ num_tokens = len(tokens)
294
+ if num_tokens > context_length - 2: # 2 for sot and eot token
295
+ num_keep = context_length - 2
296
+ indices = torch.randperm(len(tokens))
297
+ indices = indices[:num_keep]
298
+ if not shuffle:
299
+ indices = indices.msort()
300
+ tokens = tokens[indices]
301
+ num_tokens = num_keep
302
+ result[i, 0] = sot_token_id
303
+ result[i, 1:num_tokens + 1] = tokens
304
+ result[i, num_tokens + 1] = eot_token_id
305
+
306
+ return result
307
+
308
+
309
+ def simple_mask_tokenize(
310
+ texts: Union[str, List[str]],
311
+ context_length: int,
312
+ sot_token_id: int,
313
+ eot_token_id: int,
314
+ encode_fn: Callable,
315
+ ):
316
+ all_tokens = [encode_fn(text) for text in texts]
317
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
318
+
319
+ for i, tokens in enumerate(all_tokens):
320
+ num_tokens = len(tokens)
321
+ if num_tokens > context_length - 2: # 2 for sot and eot token
322
+ num_keep = context_length - 2
323
+ start_index = random.randint(0, num_tokens - num_keep) # high is incl
324
+ tokens = tokens[start_index: start_index + num_keep]
325
+ tokens = [sot_token_id] + tokens + [eot_token_id]
326
+ result[i, :len(tokens)] = torch.tensor(tokens)
327
+
328
+ return result
329
+
330
+
331
+ def syntax_mask_tokenize(
332
+ texts: Union[str, List[str]],
333
+ context_length: int,
334
+ sot_token_id: int,
335
+ eot_token_id: int,
336
+ encode_fn: Callable,
337
+ ) -> torch.LongTensor:
338
+ """ Returns the tokenized representation of given input string(s).
339
+ Apply syntax masking before tokenize.
340
+ """
341
+ import nltk
342
+ global _nltk_init
343
+ if not _nltk_init:
344
+ # run them for the first time
345
+ nltk.download('punkt')
346
+ nltk.download('averaged_perceptron_tagger')
347
+ _nltk_init = True
348
+
349
+ def get_order(x):
350
+ if x.startswith('NN'):
351
+ return 1
352
+ elif x.startswith('JJ'):
353
+ return 2
354
+ elif x.startswith('VB'):
355
+ return 3
356
+ else:
357
+ return 4
358
+
359
+ # syntax masking
360
+ new_texts = []
361
+ for text in texts:
362
+ list_tokens = nltk.tokenize.word_tokenize(text)
363
+ pos_tags = nltk.pos_tag(list_tokens)
364
+ # sample the words by get_order method
365
+ order_list = [get_order(tag) for _, tag in pos_tags]
366
+ sorted_ids = np.argsort(np.array(order_list))
367
+ sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens
368
+ sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens
369
+
370
+ new_text = ''
371
+ for token in sampled_tokens:
372
+ new_text = new_text + str(token) + ' '
373
+ new_text = new_text.strip()
374
+ new_texts.append(new_text)
375
+ texts = new_texts
376
+
377
+ all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]
378
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
379
+
380
+ for i, tokens in enumerate(all_tokens):
381
+ # still need first truncate because some words produces two tokens
382
+ if len(tokens) > context_length:
383
+ tokens = tokens[:context_length] # Truncate
384
+ tokens[-1] = eot_token_id
385
+ result[i, :len(tokens)] = torch.tensor(tokens)
386
+
387
+ return result
388
+
389
+
390
+ def get_reduction_mask_fn(type: str):
391
+ """ Choose strategy for dropping (masking) tokens to achieve target context length"""
392
+ assert type in ('simple', 'random', 'shuffle', 'syntax')
393
+ if type == 'simple':
394
+ return simple_mask_tokenize # randomly select block [start:end]
395
+ elif type == 'random':
396
+ return random_mask_tokenize # randomly drop tokens (keep order)
397
+ elif type == 'shuffle':
398
+ return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order)
399
+ elif type == 'syntax':
400
+ return syntax_mask_tokenize # randomly drop prioritized by syntax
401
+ else:
402
+ assert False, F'Unknown type {type}.'
403
+
404
+
405
+ class HFTokenizer:
406
+ """HuggingFace tokenizer wrapper with support for custom tokenization modes"""
407
+
408
+ def __init__(
409
+ self,
410
+ tokenizer_name: str,
411
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
412
+ clean: str = 'whitespace',
413
+ strip_sep_token: bool = False,
414
+ language: Optional[str] = None,
415
+ cache_dir: Optional[str] = None,
416
+ tokenizer_mode: Optional[str] = None, # None, 'clips'
417
+ **kwargs
418
+ ):
419
+ self.tokenizer_mode = tokenizer_mode or ''
420
+ self.context_length = context_length
421
+ self.clean_fn = get_clean_fn(clean)
422
+ self.strip_sep_token = strip_sep_token
423
+
424
+ # NOTE: Left as example of loading custom tokenizer from file for experimentation
425
+ # if self.tokenizer_mode == 'bert_clips':
426
+ # self.special_tokens = {
427
+ # "bos_token": 1,
428
+ # "eos_token": 2,
429
+ # "cls_token": 101,
430
+ # "pad_token": 0
431
+ # }
432
+ #
433
+ # # For BERT CLIPS mode with vocab file
434
+ # from tokenizers import BertWordPieceTokenizer
435
+ # if tokenizer_name.startswith('hf-hub:'):
436
+ # from huggingface_hub import hf_hub_download
437
+ # # Format: hf-hub:repo_id/filename
438
+ # repo_url = tokenizer_name[7:]
439
+ # parts = repo_url.split('/')
440
+ # filename = parts[-1]
441
+ # repo_id = '/'.join(parts[:-1])
442
+ # vocab_file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
443
+ # self.tokenizer = BertWordPieceTokenizer(lowercase=True)
444
+ # self.tokenizer = self.tokenizer.from_file(vocab_file)
445
+ # else:
446
+ # # Assume tokenizer_name is a local path to a vocab file
447
+ # self.tokenizer = BertWordPieceTokenizer(lowercase=True)
448
+ # self.tokenizer = self.tokenizer.from_file(tokenizer_name)
449
+
450
+ # Standard HuggingFace tokenizer initialization
451
+ from transformers import AutoTokenizer
452
+ self.tokenizer = AutoTokenizer.from_pretrained(
453
+ tokenizer_name,
454
+ cache_dir=cache_dir,
455
+ **kwargs
456
+ )
457
+
458
+ # Set language function if available
459
+ set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
460
+ if callable(set_lang_fn):
461
+ self.set_lang_fn = set_lang_fn
462
+ if language is not None:
463
+ self.set_language(language)
464
+
465
+ def save_pretrained(self, dest):
466
+ self.tokenizer.save_pretrained(dest)
467
+
468
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
469
+ # same cleaning as for default tokenizer, except lowercasing
470
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
471
+ if isinstance(texts, str):
472
+ texts = [texts]
473
+
474
+ context_length = context_length or self.context_length
475
+ assert context_length, 'Please set a valid context length in class init or call.'
476
+
477
+ texts = [self.clean_fn(text) for text in texts]
478
+
479
+ # Handle different tokenization modes
480
+ if self.tokenizer_mode == 'clips':
481
+ return self._clips_tokenize(texts, context_length)
482
+ else:
483
+ # Standard tokenization
484
+ input_ids = self.tokenizer.batch_encode_plus(
485
+ texts,
486
+ return_tensors='pt',
487
+ max_length=context_length,
488
+ padding='max_length',
489
+ truncation=True,
490
+ ).input_ids
491
+
492
+ if self.strip_sep_token:
493
+ input_ids = torch.where(
494
+ input_ids == self.tokenizer.sep_token_id,
495
+ torch.zeros_like(input_ids),
496
+ input_ids,
497
+ )
498
+
499
+ return input_ids
500
+
501
+ def set_language(self, src_lang):
502
+ if hasattr(self, 'set_lang_fn'):
503
+ self.set_lang_fn(src_lang)
504
+ else:
505
+ warnings.warn('Cannot set language for the tokenizer.')
506
+
507
+ def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
508
+ """Use standard HF tokenizer but apply custom post-processing"""
509
+ # Use standard tokenizer without special tokens - we'll add our own
510
+ encoded_outputs = self.tokenizer.batch_encode_plus(
511
+ texts,
512
+ add_special_tokens=False,
513
+ padding=False,
514
+ truncation=False,
515
+ return_tensors=None
516
+ )
517
+
518
+ encoded = []
519
+ for tokens in encoded_outputs["input_ids"]:
520
+ tokens = tokens[:context_length - 3] # Leave room for special tokens
521
+ tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
522
+ encoded.append(tokens)
523
+
524
+ # Create result tensor and handle padding + class token
525
+ result = torch.zeros(len(encoded), context_length, dtype=torch.long)
526
+ for i, tokens in enumerate(encoded):
527
+ padded_tokens = self._pad_and_add_class_token(
528
+ tokens,
529
+ max_length=context_length,
530
+ pad_token_id=self.tokenizer.pad_token_id,
531
+ cls_token_id=self.tokenizer.cls_token_id,
532
+ )
533
+ result[i, :len(padded_tokens)] = torch.tensor(padded_tokens)
534
+
535
+ return result
536
+
537
+ def _pad_and_add_class_token(
538
+ self,
539
+ tokens: List[int],
540
+ max_length: int,
541
+ pad_token_id: int = 0,
542
+ cls_token_id: int = 101,
543
+ ) -> List[int]:
544
+ """ Add padding with class token at the end """
545
+ if len(tokens) > max_length - 1:
546
+ tokens = tokens[:max_length - 1]
547
+
548
+ # Add padding to reach max_length-1
549
+ if len(tokens) < max_length - 1:
550
+ tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
551
+
552
+ # Add class token at the end
553
+ tokens = tokens + [cls_token_id]
554
+ return tokens
555
+
556
+
557
+ class SigLipTokenizer:
558
+ """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
559
+
560
+ NOTE: this is not needed in normal library use, but is used to import new sentencepiece tokenizers
561
+ into OpenCLIP. Leaving code here in case future models use new tokenizers.
562
+ """
563
+ VOCAB_FILES = {
564
+ # english, vocab_size=32_000
565
+ "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model",
566
+ # used in multilingual models (mT5, PaLI), vocab_size=250_000
567
+ "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
568
+ # used in SigLIP2 models, vocab_size=256000
569
+ "gemma": "http://storage.googleapis.com/big_vision/gemma_tokenizer.model",
570
+ }
571
+
572
+ def __init__(
573
+ self,
574
+ tokenizer_name: str,
575
+ context_length: Optional[int] = 64,
576
+ ):
577
+ if 'gemma' in tokenizer_name:
578
+ from transformers import GemmaTokenizerFast
579
+ tokenizer_cls = partial(
580
+ GemmaTokenizerFast, padding_side='right', add_bos_token=False, add_eos_token=True)
581
+ else:
582
+ from transformers import T5TokenizerFast
583
+ tokenizer_cls = partial(T5TokenizerFast, extra_ids=0)
584
+
585
+ if tokenizer_name in self.VOCAB_FILES:
586
+ # FIXME temporary hack?
587
+ import tempfile
588
+ import fsspec
589
+ vocab_file = self.VOCAB_FILES[tokenizer_name]
590
+ with tempfile.NamedTemporaryFile('wb') as dst:
591
+ with fsspec.open(vocab_file, 'rb') as src:
592
+ dst.write(src.read())
593
+ self.tokenizer = tokenizer_cls(dst.name, legacy=False)
594
+ else:
595
+ self.tokenizer = tokenizer_cls(tokenizer_name, legacy=False)
596
+
597
+ self.tokenizer.pad_token_id = 0 if 'gemma' in tokenizer_name else 1
598
+ self.tokenizer.eos_token_id = 1
599
+ self.context_length = context_length
600
+
601
+ def save_pretrained(self, dest):
602
+ self.tokenizer.save_pretrained(dest)
603
+
604
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
605
+ # same cleaning as for default tokenizer, except lowercasing
606
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
607
+ if isinstance(texts, str):
608
+ texts = [texts]
609
+
610
+ context_length = context_length or self.context_length
611
+ assert context_length, 'Please set a valid context length in class init or call.'
612
+
613
+ texts = [canonicalize_text(basic_clean(text)) for text in texts]
614
+ output = self.tokenizer(
615
+ texts,
616
+ return_tensors='pt',
617
+ max_length=context_length,
618
+ padding='max_length',
619
+ truncation=True,
620
+ )
621
+ return output.input_ids
src/open_clip/transformer.py ADDED
@@ -0,0 +1,1823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import math
3
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ import warnings
11
+ import numpy as np
12
+
13
+
14
+ def to_2tuple(x):
15
+ if isinstance(x, (tuple, list)):
16
+ return x
17
+ return (x, x)
18
+
19
+
20
+ def feature_take_indices(num_blocks, indices):
21
+ take_indices = [i if i >= 0 else num_blocks + i for i in indices]
22
+ max_index = max(take_indices)
23
+ return take_indices, max_index
24
+
25
+
26
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
27
+ grid_h = np.arange(grid_size, dtype=np.float32)
28
+ grid_w = np.arange(grid_size, dtype=np.float32)
29
+ grid = np.meshgrid(grid_w, grid_h)
30
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
31
+ pos_embed = _get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
32
+ if cls_token:
33
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
34
+ return pos_embed
35
+
36
+
37
+ def _get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
38
+ assert embed_dim % 2 == 0
39
+ emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
40
+ emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
41
+ return np.concatenate([emb_h, emb_w], axis=1)
42
+
43
+
44
+ def _get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
45
+ assert embed_dim % 2 == 0
46
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
47
+ omega /= embed_dim / 2.
48
+ omega = 1. / 10000**omega
49
+ pos = pos.reshape(-1)
50
+ out = np.einsum('m,d->md', pos, omega)
51
+ return np.concatenate([np.sin(out), np.cos(out)], axis=1)
52
+
53
+
54
+ class LayerNormFp32(nn.LayerNorm):
55
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
56
+
57
+ def forward(self, x: torch.Tensor):
58
+ orig_type = x.dtype
59
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
60
+ return x.to(orig_type)
61
+
62
+
63
+ class LayerNorm(nn.LayerNorm):
64
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
65
+
66
+ def forward(self, x: torch.Tensor):
67
+ orig_type = x.dtype
68
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
69
+ return x.to(orig_type)
70
+
71
+
72
+ class QuickGELU(nn.Module):
73
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
74
+ def forward(self, x: torch.Tensor):
75
+ return x * torch.sigmoid(1.702 * x)
76
+
77
+
78
+ class LayerScale(nn.Module):
79
+ def __init__(self, dim, init_values=1e-5, inplace=False):
80
+ super().__init__()
81
+ self.inplace = inplace
82
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
83
+
84
+ def forward(self, x):
85
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
86
+
87
+
88
+ class PatchDropout(nn.Module):
89
+ """
90
+ https://arxiv.org/abs/2212.00794
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ prob: float = 0.5,
96
+ exclude_first_token: bool = True
97
+ ):
98
+ super().__init__()
99
+ assert 0 <= prob < 1.
100
+ self.prob = prob
101
+ self.exclude_first_token = exclude_first_token # exclude CLS token
102
+
103
+ def forward(self, x):
104
+ if not self.training or self.prob == 0.:
105
+ return x
106
+
107
+ if self.exclude_first_token:
108
+ cls_tokens, x = x[:, :1], x[:, 1:]
109
+ else:
110
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
111
+
112
+ batch = x.size()[0]
113
+ num_tokens = x.size()[1]
114
+
115
+ batch_indices = torch.arange(batch)
116
+ batch_indices = batch_indices[..., None]
117
+
118
+ keep_prob = 1 - self.prob
119
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
120
+
121
+ rand = torch.randn(batch, num_tokens)
122
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
123
+
124
+ x = x[batch_indices, patch_indices_keep]
125
+
126
+ if self.exclude_first_token:
127
+ x = torch.cat((cls_tokens, x), dim=1)
128
+
129
+ return x
130
+
131
+
132
+ class Attention(nn.Module):
133
+ def __init__(
134
+ self,
135
+ dim: int,
136
+ num_heads: int = 8,
137
+ qkv_bias: bool = True,
138
+ qk_norm: bool = False,
139
+ scaled_cosine: bool = False,
140
+ scale_heads: bool = False,
141
+ inner_norm: bool = False,
142
+ logit_scale_max: float = math.log(1. / 0.01),
143
+ norm_layer: Type[nn.Module] = LayerNormFp32,
144
+ attn_drop: float = 0.,
145
+ proj_drop: float = 0.
146
+ ):
147
+ super().__init__()
148
+ assert not (scaled_cosine and qk_norm), "Cannot activate both scaled cosine and QK normalization"
149
+ self.scaled_cosine = scaled_cosine
150
+ self.scale_heads = scale_heads
151
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
152
+ self.num_heads = num_heads
153
+ self.head_dim = dim // num_heads
154
+ self.scale = self.head_dim ** -0.5
155
+ self.logit_scale_max = logit_scale_max
156
+ self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')
157
+
158
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
159
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
160
+ if qkv_bias:
161
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
162
+ else:
163
+ self.in_proj_bias = None
164
+
165
+ # QK normalization (with LN) from https://arxiv.org/abs/2106.04560 and related to other QK Norm ideas
166
+ if qk_norm:
167
+ self.ln_q = norm_layer(self.head_dim)
168
+ self.ln_k = norm_layer(self.head_dim)
169
+ else:
170
+ self.ln_q = nn.Identity()
171
+ self.ln_k = nn.Identity()
172
+
173
+ # Scaled cosine attention (from Swin Transformer V2, https://arxiv.org/abs/2111.09883)
174
+ if self.scaled_cosine:
175
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
176
+ else:
177
+ self.logit_scale = None
178
+
179
+ self.attn_drop = nn.Dropout(attn_drop)
180
+
181
+ # Per-head attention logit scaling (from NormFormer, https://arxiv.org/abs/2110.09456)
182
+ if self.scale_heads:
183
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
184
+ else:
185
+ self.head_scale = None
186
+
187
+ # Normalization of attention logits, before final projection.
188
+ # Origin likely Sub-LN in (Foundation Transformers, https://arxiv.org/abs/2210.06423)
189
+ if inner_norm:
190
+ self.ln_inner = norm_layer(dim)
191
+ else:
192
+ self.ln_inner = nn.Identity()
193
+
194
+ self.out_proj = nn.Linear(dim, dim)
195
+ self.out_drop = nn.Dropout(proj_drop)
196
+
197
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
198
+ N, L, C = x.shape
199
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
200
+ q = q.reshape(N, L, self.num_heads, -1).transpose(1, 2)
201
+ k = k.reshape(N, L, self.num_heads, -1).transpose(1, 2)
202
+ v = v.reshape(N, L, self.num_heads, -1).transpose(1, 2)
203
+
204
+ if attn_mask is not None:
205
+ if attn_mask.ndim == 3:
206
+ # this module works with (L, L), or (N, num_heads, L, L) masks
207
+ attn_mask = attn_mask.reshape(N, self.num_heads, L, L)
208
+ if attn_mask.dtype == torch.bool:
209
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
210
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
211
+ attn_mask = new_attn_mask
212
+ else:
213
+ attn_mask = attn_mask.to(dtype=q.dtype)
214
+
215
+ if self.logit_scale is not None:
216
+ attn = torch.bmm(
217
+ F.normalize(q, dim=-1),
218
+ F.normalize(k, dim=-1).transpose(-1, -2)
219
+ )
220
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
221
+ attn = attn * logit_scale
222
+ if attn_mask is not None:
223
+ attn = attn + attn_mask
224
+ attn = attn.softmax(dim=-1)
225
+ attn = self.attn_drop(attn)
226
+ x = torch.bmm(attn, v)
227
+ else:
228
+ q = self.ln_q(q)
229
+ k = self.ln_k(k)
230
+ if self.use_fsdpa:
231
+ x = F.scaled_dot_product_attention(
232
+ q, k, v,
233
+ attn_mask=attn_mask,
234
+ dropout_p=self.attn_drop.p if self.training else 0.,
235
+ )
236
+ else:
237
+ q = q * self.scale
238
+ attn = torch.bmm(q, k.transpose(-1, -2))
239
+ if attn_mask is not None:
240
+ attn += attn_mask
241
+ attn = attn.softmax(dim=-1)
242
+ attn = self.attn_drop(attn)
243
+ x = torch.bmm(attn, v)
244
+
245
+ # N, num_heads, L, head_dim
246
+ if self.head_scale is not None:
247
+ x = x * self.head_scale
248
+ x = x.transpose(1, 2).reshape(N, L, C)
249
+ x = self.ln_inner(x)
250
+ x = self.out_proj(x)
251
+ x = self.out_drop(x)
252
+ return x
253
+
254
+
255
+ class AttentionalPooler(nn.Module):
256
+ def __init__(
257
+ self,
258
+ d_model: int,
259
+ context_dim: int,
260
+ n_head: int = 8,
261
+ n_queries: int = 256,
262
+ norm_layer: Callable = LayerNorm,
263
+ ):
264
+ super().__init__()
265
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
266
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
267
+ self.ln_q = norm_layer(d_model)
268
+ self.ln_k = norm_layer(context_dim)
269
+
270
+ def forward(self, x: torch.Tensor):
271
+ N = x.shape[0]
272
+ x = self.ln_k(x)
273
+ q = self.ln_q(self.query)
274
+ out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
275
+ return out
276
+
277
+
278
+ class ResidualAttentionBlock(nn.Module):
279
+ def __init__(
280
+ self,
281
+ d_model: int,
282
+ n_head: int,
283
+ mlp_ratio: float = 4.0,
284
+ ls_init_value: float = None,
285
+ act_layer: Callable = nn.GELU,
286
+ norm_layer: Callable = LayerNorm,
287
+ is_cross_attention: bool = False,
288
+ batch_first: bool = True,
289
+ ):
290
+ super().__init__()
291
+
292
+ self.ln_1 = norm_layer(d_model)
293
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
294
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
295
+ if is_cross_attention:
296
+ self.ln_1_kv = norm_layer(d_model)
297
+
298
+ self.ln_2 = norm_layer(d_model)
299
+ mlp_width = int(d_model * mlp_ratio)
300
+ self.mlp = nn.Sequential(OrderedDict([
301
+ ("c_fc", nn.Linear(d_model, mlp_width)),
302
+ ("gelu", act_layer()),
303
+ ("c_proj", nn.Linear(mlp_width, d_model))
304
+ ]))
305
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
306
+
307
+ def get_weight_dtype(self) -> torch.dtype:
308
+ if hasattr(self.mlp.c_fc, 'int8_original_dtype'):
309
+ return self.mlp.c_fc.int8_original_dtype
310
+ return self.mlp.c_fc.weight.dtype
311
+
312
+ def attention(
313
+ self,
314
+ q_x: torch.Tensor,
315
+ k_x: Optional[torch.Tensor] = None,
316
+ v_x: Optional[torch.Tensor] = None,
317
+ attn_mask: Optional[torch.Tensor] = None,
318
+ ):
319
+ k_x = k_x if k_x is not None else q_x
320
+ v_x = v_x if v_x is not None else q_x
321
+
322
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
323
+ return self.attn(
324
+ q_x, k_x, v_x,
325
+ need_weights=False,
326
+ attn_mask=attn_mask
327
+ )[0]
328
+
329
+ def forward(
330
+ self,
331
+ q_x: torch.Tensor,
332
+ k_x: Optional[torch.Tensor] = None,
333
+ v_x: Optional[torch.Tensor] = None,
334
+ attn_mask: Optional[torch.Tensor] = None,
335
+ ):
336
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
337
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
338
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
339
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
340
+ return x
341
+
342
+
343
+ class CustomResidualAttentionBlock(nn.Module):
344
+ def __init__(
345
+ self,
346
+ d_model: int,
347
+ n_head: int,
348
+ mlp_ratio: float = 4.0,
349
+ ls_init_value: float = None,
350
+ act_layer: Type[nn.Module] = nn.GELU,
351
+ norm_layer: Type[nn.Module] = LayerNorm,
352
+ qk_norm: bool = False,
353
+ scale_cosine_attn: bool = False,
354
+ scale_heads: bool = False,
355
+ scale_attn_inner: bool = False,
356
+ scale_attn: bool = False,
357
+ scale_fc: bool = False,
358
+ batch_first: bool = True,
359
+ ):
360
+ super().__init__()
361
+ assert batch_first, 'batch_first must be True for CustomResidualAttentionBlock'
362
+
363
+ self.ln_1 = norm_layer(d_model)
364
+ self.attn = Attention(
365
+ d_model,
366
+ n_head,
367
+ qk_norm=qk_norm,
368
+ scaled_cosine=scale_cosine_attn,
369
+ scale_heads=scale_heads,
370
+ inner_norm=scale_attn_inner,
371
+ norm_layer=norm_layer,
372
+ )
373
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
374
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
375
+
376
+ self.ln_2 = norm_layer(d_model)
377
+ mlp_width = int(d_model * mlp_ratio)
378
+ self.mlp = nn.Sequential(OrderedDict([
379
+ ("c_fc", nn.Linear(d_model, mlp_width)),
380
+ ("gelu", act_layer()),
381
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), # from NormFormer / Foundation Transformers
382
+ ("c_proj", nn.Linear(mlp_width, d_model))
383
+ ]))
384
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
385
+
386
+ def get_weight_dtype(self) -> torch.dtype:
387
+ if hasattr(self.mlp.c_fc, 'int8_original_dtype'):
388
+ return self.mlp.c_fc.int8_original_dtype
389
+ return self.mlp.c_fc.weight.dtype
390
+
391
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
392
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
393
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
394
+ return x
395
+
396
+
397
+ class CustomTransformer(nn.Module):
398
+ """ A custom transformer that can use different block types. """
399
+ def __init__(
400
+ self,
401
+ width: int,
402
+ layers: int,
403
+ heads: int,
404
+ mlp_ratio: float = 4.0,
405
+ ls_init_value: float = None,
406
+ act_layer: Type[nn.Module] = nn.GELU,
407
+ norm_layer: Type[nn.Module] = LayerNorm,
408
+ batch_first: bool = True,
409
+ block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',
410
+ ):
411
+ super().__init__()
412
+ self.width = width
413
+ self.layers = layers
414
+ self.batch_first = batch_first # run transformer stack in batch first (N, L, D)
415
+ self.grad_checkpointing = False
416
+
417
+ if isinstance(block_types, str):
418
+ block_types = [block_types] * layers
419
+ assert len(block_types) == layers
420
+
421
+ def _create_block(bt: str):
422
+ if bt == 'CustomResidualAttentionBlock':
423
+ return CustomResidualAttentionBlock(
424
+ width,
425
+ heads,
426
+ mlp_ratio=mlp_ratio,
427
+ ls_init_value=ls_init_value,
428
+ act_layer=act_layer,
429
+ norm_layer=norm_layer,
430
+ batch_first=batch_first,
431
+ )
432
+ else:
433
+ assert False
434
+
435
+ self.resblocks = nn.ModuleList([
436
+ _create_block(bt)
437
+ for bt in block_types
438
+ ])
439
+
440
+ def get_cast_dtype(self) -> torch.dtype:
441
+ return self.resblocks[0].get_weight_dtype()
442
+
443
+ def forward_intermediates(
444
+ self,
445
+ x: torch.Tensor,
446
+ attn_mask: Optional[torch.Tensor] = None,
447
+ indices: Optional[Union[int, List[int]]] = None,
448
+ stop_early: bool = False,
449
+ ):
450
+ take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
451
+
452
+ if not self.batch_first:
453
+ x = x.transpose(0, 1).contiguous() # NLD -> LND
454
+
455
+ intermediates = []
456
+ if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
457
+ blocks = self.resblocks
458
+ else:
459
+ blocks = self.resblocks[:max_index + 1]
460
+ for i, blk in enumerate(blocks):
461
+ if self.grad_checkpointing and not torch.jit.is_scripting():
462
+ x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)
463
+ else:
464
+ x = blk(x, attn_mask=attn_mask)
465
+
466
+ if i in take_indices:
467
+ intermediates.append(x.transpose(0, 1) if not self.batch_first else x)
468
+
469
+ if not self.batch_first:
470
+ x = x.transpose(0, 1) # LND -> NLD
471
+
472
+ return x, intermediates
473
+
474
+ def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):
475
+ """ Prune layers not required for specified intermediates.
476
+ """
477
+ take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
478
+ self.resblocks = self.resblocks[:max_index + 1] # truncate blocks
479
+ return take_indices
480
+
481
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
482
+ if not self.batch_first:
483
+ x = x.transpose(0, 1) # NLD -> LND
484
+
485
+ for r in self.resblocks:
486
+ if self.grad_checkpointing and not torch.jit.is_scripting():
487
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
488
+ x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
489
+ else:
490
+ x = r(x, attn_mask=attn_mask)
491
+
492
+ if not self.batch_first:
493
+ x = x.transpose(0, 1) # NLD -> LND
494
+ return x
495
+
496
+
497
+ class Transformer(nn.Module):
498
+ def __init__(
499
+ self,
500
+ width: int,
501
+ layers: int,
502
+ heads: int,
503
+ mlp_ratio: float = 4.0,
504
+ ls_init_value: float = None,
505
+ act_layer: Type[nn.Module] = nn.GELU,
506
+ norm_layer: Type[nn.Module] = LayerNorm,
507
+ batch_first: bool = True,
508
+ block_type: Optional[str] = None,
509
+ qk_norm: bool = False,
510
+ scaled_cosine_attn: bool = False,
511
+ scale_heads: bool = False,
512
+ scale_attn_inner: bool = False,
513
+ scale_attn: bool = False,
514
+ scale_fc: bool = False,
515
+ ):
516
+ super().__init__()
517
+ self.width = width
518
+ self.layers = layers
519
+ self.batch_first = batch_first
520
+ self.grad_checkpointing = False
521
+
522
+ # Auto-select custom block if any custom features are enabled
523
+ if block_type is None:
524
+ if any([qk_norm, scaled_cosine_attn, scale_heads, scale_attn_inner, scale_attn, scale_fc]):
525
+ block_type = 'custom'
526
+ else:
527
+ block_type = 'default'
528
+
529
+ if block_type == 'custom':
530
+ self.resblocks = nn.ModuleList([
531
+ CustomResidualAttentionBlock(
532
+ width,
533
+ heads,
534
+ mlp_ratio,
535
+ ls_init_value=ls_init_value,
536
+ act_layer=act_layer,
537
+ norm_layer=norm_layer,
538
+ qk_norm=qk_norm,
539
+ scale_cosine_attn=scaled_cosine_attn,
540
+ scale_heads=scale_heads,
541
+ scale_attn_inner=scale_attn_inner,
542
+ scale_attn=scale_attn,
543
+ scale_fc=scale_fc,
544
+ batch_first=batch_first,
545
+ )
546
+ for _ in range(layers)
547
+ ])
548
+ else:
549
+ self.resblocks = nn.ModuleList([
550
+ ResidualAttentionBlock(
551
+ width,
552
+ heads,
553
+ mlp_ratio,
554
+ ls_init_value=ls_init_value,
555
+ act_layer=act_layer,
556
+ norm_layer=norm_layer,
557
+ batch_first=batch_first,
558
+ )
559
+ for _ in range(layers)
560
+ ])
561
+
562
+ def get_cast_dtype(self) -> torch.dtype:
563
+ return self.resblocks[0].get_weight_dtype()
564
+
565
+ def forward_intermediates(
566
+ self,
567
+ x: torch.Tensor,
568
+ attn_mask: Optional[torch.Tensor] = None,
569
+ indices: Optional[Union[int, List[int]]] = None,
570
+ stop_early: bool = False,
571
+ ):
572
+ take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
573
+
574
+ if not self.batch_first:
575
+ x = x.transpose(0, 1).contiguous() # NLD -> LND
576
+
577
+ intermediates = []
578
+ if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
579
+ blocks = self.resblocks
580
+ else:
581
+ blocks = self.resblocks[:max_index + 1]
582
+ for i, blk in enumerate(blocks):
583
+ if self.grad_checkpointing and not torch.jit.is_scripting():
584
+ x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)
585
+ else:
586
+ x = blk(x, attn_mask=attn_mask)
587
+
588
+ if i in take_indices:
589
+ intermediates.append(x.transpose(0, 1) if not self.batch_first else x)
590
+
591
+ if not self.batch_first:
592
+ x = x.transpose(0, 1) # LND -> NLD
593
+
594
+ return x, intermediates
595
+
596
+ def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):
597
+ """ Prune layers not required for specified intermediates.
598
+ """
599
+ take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
600
+ self.resblocks = self.resblocks[:max_index + 1] # truncate blocks
601
+ return take_indices
602
+
603
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
604
+ if not self.batch_first:
605
+ x = x.transpose(0, 1).contiguous() # NLD -> LND
606
+
607
+ for r in self.resblocks:
608
+ if self.grad_checkpointing and not torch.jit.is_scripting():
609
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
610
+ x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
611
+ else:
612
+ x = r(x, attn_mask=attn_mask)
613
+
614
+ if not self.batch_first:
615
+ x = x.transpose(0, 1) # LND -> NLD
616
+ return x
617
+
618
+
619
+ def _expand_token(token, batch_size: int):
620
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
621
+
622
+
623
+ class VisionTransformer(nn.Module):
624
+ output_tokens: torch.jit.Final[bool]
625
+
626
+ def __init__(
627
+ self,
628
+ image_size: int,
629
+ patch_size: int,
630
+ width: int,
631
+ layers: int,
632
+ heads: int,
633
+ mlp_ratio: float,
634
+ ls_init_value: float = None,
635
+ attentional_pool: bool = False,
636
+ attn_pooler_queries: int = 256,
637
+ attn_pooler_heads: int = 8,
638
+ output_dim: int = 512,
639
+ patch_dropout: float = 0.,
640
+ no_ln_pre: bool = False,
641
+ pos_embed_type: str = 'learnable',
642
+ pool_type: str = 'tok',
643
+ final_ln_after_pool: bool = False,
644
+ act_layer: Callable = nn.GELU,
645
+ norm_layer: Callable = LayerNorm,
646
+ output_tokens: bool = False,
647
+ block_type: Optional[str] = None,
648
+ qk_norm: bool = False,
649
+ scaled_cosine_attn: bool = False,
650
+ scale_heads: bool = False,
651
+ scale_attn_inner: bool = False,
652
+ scale_attn: bool = False,
653
+ scale_fc: bool = False,
654
+ ):
655
+ super().__init__()
656
+ assert pool_type in ('tok', 'avg', 'none')
657
+ self.output_tokens = output_tokens
658
+ image_height, image_width = self.image_size = to_2tuple(image_size)
659
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
660
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
661
+ self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
662
+ self.output_dim = output_dim
663
+
664
+ self.conv1 = nn.Conv2d(
665
+ in_channels=3,
666
+ out_channels=width,
667
+ kernel_size=patch_size,
668
+ stride=patch_size,
669
+ bias=False,
670
+ )
671
+
672
+ # class embeddings and positional embeddings
673
+ scale = width ** -0.5
674
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
675
+ if pos_embed_type == 'learnable':
676
+ self.positional_embedding = nn.Parameter(
677
+ scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
678
+ elif pos_embed_type == 'sin_cos_2d':
679
+ # fixed sin-cos embedding
680
+ assert self.grid_size[0] == self.grid_size[1],\
681
+ 'currently sin cos 2d pos embedding only supports square input'
682
+ self.positional_embedding = nn.Parameter(
683
+ torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)
684
+ pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)
685
+ self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())
686
+ else:
687
+ raise ValueError
688
+
689
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
690
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
691
+
692
+ self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
693
+ self.transformer = Transformer(
694
+ width,
695
+ layers,
696
+ heads,
697
+ mlp_ratio,
698
+ ls_init_value=ls_init_value,
699
+ act_layer=act_layer,
700
+ norm_layer=norm_layer,
701
+ block_type=block_type,
702
+ qk_norm=qk_norm,
703
+ scaled_cosine_attn=scaled_cosine_attn,
704
+ scale_heads=scale_heads,
705
+ scale_attn_inner=scale_attn_inner,
706
+ scale_attn=scale_attn,
707
+ scale_fc=scale_fc,
708
+ )
709
+
710
+ if attentional_pool:
711
+ if isinstance(attentional_pool, str):
712
+ self.attn_pool_type = attentional_pool
713
+ self.pool_type = 'none'
714
+ if attentional_pool in ('parallel', 'cascade'):
715
+ self.attn_pool = AttentionalPooler(
716
+ output_dim,
717
+ width,
718
+ n_head=attn_pooler_heads,
719
+ n_queries=attn_pooler_queries,
720
+ )
721
+ self.attn_pool_contrastive = AttentionalPooler(
722
+ output_dim,
723
+ width,
724
+ n_head=attn_pooler_heads,
725
+ n_queries=1,
726
+ )
727
+ else:
728
+ assert False
729
+ else:
730
+ self.attn_pool_type = ''
731
+ self.pool_type = pool_type
732
+ self.attn_pool = AttentionalPooler(
733
+ output_dim,
734
+ width,
735
+ n_head=attn_pooler_heads,
736
+ n_queries=attn_pooler_queries,
737
+ )
738
+ self.attn_pool_contrastive = None
739
+ pool_dim = output_dim
740
+ else:
741
+ self.attn_pool = None
742
+ pool_dim = width
743
+ self.pool_type = pool_type
744
+
745
+ self.ln_post = norm_layer(pool_dim)
746
+ self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
747
+
748
+ self.init_parameters()
749
+
750
+ def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
751
+ for param in self.parameters():
752
+ param.requires_grad = False
753
+
754
+ if unlocked_groups != 0:
755
+ groups = [
756
+ [
757
+ self.conv1,
758
+ self.class_embedding,
759
+ self.positional_embedding,
760
+ self.ln_pre,
761
+ ],
762
+ *self.transformer.resblocks[:-1],
763
+ [
764
+ self.transformer.resblocks[-1],
765
+ self.ln_post,
766
+ ],
767
+ self.proj,
768
+ ]
769
+
770
+ def _unlock(x):
771
+ if isinstance(x, Sequence):
772
+ for g in x:
773
+ _unlock(g)
774
+ else:
775
+ if isinstance(x, torch.nn.Parameter):
776
+ x.requires_grad = True
777
+ else:
778
+ for p in x.parameters():
779
+ p.requires_grad = True
780
+
781
+ _unlock(groups[-unlocked_groups:])
782
+
783
+ def init_parameters(self):
784
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
785
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
786
+
787
+ # nn.init.normal_(self.class_embedding, std=self.scale)
788
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
789
+ #
790
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
791
+ # attn_std = self.transformer.width ** -0.5
792
+ # fc_std = (2 * self.transformer.width) ** -0.5
793
+ # for block in self.transformer.resblocks:
794
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
795
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
796
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
797
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
798
+ #
799
+ # if self.text_projection is not None:
800
+ # nn.init.normal_(self.text_projection, std=self.scale)
801
+ pass
802
+
803
+ @torch.jit.ignore
804
+ def set_grad_checkpointing(self, enable: bool = True):
805
+ self.transformer.grad_checkpointing = enable
806
+
807
+ @torch.jit.ignore
808
+ def no_weight_decay(self):
809
+ # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
810
+ no_wd = {'positional_embedding', 'class_embedding'}
811
+ return no_wd
812
+
813
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
814
+ if self.pool_type == 'avg':
815
+ pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
816
+ elif self.pool_type == 'tok':
817
+ pooled, tokens = x[:, 0], x[:, 1:]
818
+ else:
819
+ pooled = tokens = x
820
+
821
+ return pooled, tokens
822
+
823
+ def _embeds(self, x:torch.Tensor) -> torch.Tensor:
824
+ x = self.conv1(x) # shape = [*, dim, grid, grid]
825
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
826
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
827
+
828
+ # class embeddings and positional embeddings
829
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
830
+ # shape = [*, grid ** 2 + 1, width]
831
+ x = x + self.positional_embedding.to(x.dtype)
832
+
833
+ # patch dropout (if active)
834
+ x = self.patch_dropout(x)
835
+
836
+ # apply norm before transformer
837
+ x = self.ln_pre(x)
838
+ return x
839
+
840
+ def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
841
+ if self.attn_pool is not None:
842
+ if self.attn_pool_contrastive is not None:
843
+ # This is untested, WIP pooling that should match paper
844
+ x = self.ln_post(x) # TBD LN first or separate one after each pool?
845
+ tokens = self.attn_pool(x)
846
+ if self.attn_pool_type == 'parallel':
847
+ pooled = self.attn_pool_contrastive(x)
848
+ else:
849
+ assert self.attn_pool_type == 'cascade'
850
+ pooled = self.attn_pool_contrastive(tokens)
851
+ else:
852
+ # this is the original OpenCLIP CoCa setup, does not match paper
853
+ x = self.attn_pool(x)
854
+ x = self.ln_post(x)
855
+ pooled, tokens = self._global_pool(x)
856
+ elif self.final_ln_after_pool:
857
+ pooled, tokens = self._global_pool(x)
858
+ pooled = self.ln_post(pooled)
859
+ else:
860
+ x = self.ln_post(x)
861
+ pooled, tokens = self._global_pool(x)
862
+
863
+ return pooled, tokens
864
+
865
+ def forward_intermediates(
866
+ self,
867
+ x: torch.Tensor,
868
+ indices: Optional[Union[int, List[int]]] = None,
869
+ stop_early: bool = False,
870
+ normalize_intermediates: bool = False,
871
+ intermediates_only: bool = False,
872
+ output_fmt: str = 'NCHW',
873
+ output_extra_tokens: bool = False,
874
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
875
+ """ Forward features that returns intermediates.
876
+
877
+ Args:
878
+ x: Input image tensor
879
+ indices: Take last n blocks if int, all if None, select matching indices if sequence
880
+ stop_early: Stop iterating over blocks when last desired intermediate hit
881
+ intermediates_only: Only return intermediate features
882
+ normalize_intermediates: Apply final norm layer to all intermediates
883
+ output_fmt: Shape of intermediate feature outputs
884
+ output_extra_tokens: Return both extra prefix class tokens
885
+ Returns:
886
+
887
+ """
888
+ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
889
+ reshape = output_fmt == 'NCHW'
890
+
891
+ # forward pass
892
+ B, _, height, width = x.shape
893
+ x = self._embeds(x)
894
+ x, intermediates = self.transformer.forward_intermediates(
895
+ x,
896
+ indices=indices,
897
+ stop_early=stop_early,
898
+ )
899
+
900
+ # process intermediates
901
+ if normalize_intermediates:
902
+ # apply final norm to all intermediates
903
+ intermediates = [self.ln_post(xi) for xi in intermediates]
904
+ num_prefix_tokens = 1 # one class token that's always there (as of now)
905
+ if num_prefix_tokens:
906
+ # split prefix (e.g. class, distill) and spatial feature tokens
907
+ prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates]
908
+ intermediates = [y[:, num_prefix_tokens:] for y in intermediates]
909
+ else:
910
+ prefix_tokens = None
911
+ if reshape:
912
+ # reshape to BCHW output format
913
+ H, W = height // self.patch_size[0], width // self.patch_size[1]
914
+ intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
915
+
916
+ output = {'image_intermediates': intermediates}
917
+ if prefix_tokens is not None and output_extra_tokens:
918
+ output['image_intermediates_prefix'] = prefix_tokens
919
+
920
+ if intermediates_only:
921
+ return output
922
+
923
+ pooled, _ = self._pool(x)
924
+
925
+ if self.proj is not None:
926
+ pooled = pooled @ self.proj
927
+
928
+ output['image_features'] = pooled
929
+
930
+ return output
931
+
932
+ def prune_intermediate_layers(
933
+ self,
934
+ indices: Union[int, List[int]] = 1,
935
+ prune_norm: bool = False,
936
+ prune_head: bool = True,
937
+ ):
938
+ """ Prune layers not required for specified intermediates.
939
+ """
940
+ take_indices = self.transformer.prune_intermediate_layers(indices)
941
+ if prune_norm:
942
+ self.ln_post = nn.Identity()
943
+ if prune_head:
944
+ self.proj = None
945
+ return take_indices
946
+
947
+ def forward(self, x: torch.Tensor):
948
+ x = self._embeds(x)
949
+ x = self.transformer(x)
950
+ pooled, tokens = self._pool(x)
951
+
952
+ if self.proj is not None:
953
+ pooled = pooled @ self.proj
954
+
955
+ if self.output_tokens:
956
+ return pooled, tokens
957
+
958
+ return pooled
959
+
960
+
961
+ def text_global_pool(
962
+ x: torch.Tensor,
963
+ text: Optional[torch.Tensor] = None,
964
+ pool_type: str = 'argmax',
965
+ eos_token_id: Optional[int] = None,
966
+ ) -> torch.Tensor:
967
+ if pool_type == 'first':
968
+ pooled = x[:, 0]
969
+ elif pool_type == 'last':
970
+ pooled = x[:, -1]
971
+ elif pool_type == 'argmax':
972
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
973
+ assert text is not None
974
+ pooled = x[torch.arange(x.shape[0], device=x.device), text.argmax(dim=-1)]
975
+ elif pool_type == 'eos':
976
+ # take features from tokenizer specific eos
977
+ assert text is not None
978
+ assert eos_token_id is not None
979
+ idx = (text == eos_token_id).int().argmax(dim=-1)
980
+ pooled = x[torch.arange(x.shape[0], device=x.device), idx]
981
+ else:
982
+ pooled = x
983
+
984
+ return pooled
985
+
986
+
987
+ class TextTransformer(nn.Module):
988
+ output_tokens: torch.jit.Final[bool]
989
+
990
+ def __init__(
991
+ self,
992
+ context_length: int = 77,
993
+ vocab_size: int = 49408,
994
+ width: int = 512,
995
+ heads: int = 8,
996
+ layers: int = 12,
997
+ mlp_ratio: float = 4.0,
998
+ ls_init_value: float = None,
999
+ output_dim: Optional[int] = 512,
1000
+ embed_cls: bool = False,
1001
+ no_causal_mask: bool = False,
1002
+ use_pad_mask: bool = False,
1003
+ correct_cls_mask: bool = False,
1004
+ pad_id: int = 0,
1005
+ eos_id: int = 2,
1006
+ pool_type: str = 'argmax',
1007
+ proj_type: str = 'linear',
1008
+ proj_bias: bool = False,
1009
+ act_layer: Type[nn.Module] = nn.GELU,
1010
+ norm_layer: Type[nn.Module] = LayerNorm,
1011
+ output_tokens: bool = False,
1012
+ block_type: Optional[str] = None,
1013
+ qk_norm: bool = False,
1014
+ scaled_cosine_attn: bool = False,
1015
+ scale_heads: bool = False,
1016
+ scale_attn_inner: bool = False,
1017
+ scale_attn: bool = False,
1018
+ scale_fc: bool = False,
1019
+ ):
1020
+ super().__init__()
1021
+ assert pool_type in ('first', 'last', 'argmax', 'eos', 'none')
1022
+ self.output_tokens = output_tokens
1023
+ self.num_pos = self.context_length = context_length
1024
+ self.vocab_size = vocab_size
1025
+ self.width = width
1026
+ self.output_dim = output_dim
1027
+ self.heads = heads
1028
+ self.pad_id = pad_id
1029
+ self.eos_id = eos_id
1030
+ self.pool_type = pool_type
1031
+ self.use_pad_mask = use_pad_mask and no_causal_mask # only use in bi‑dir mode
1032
+ self.correct_cls_mask = correct_cls_mask # use the correct cls mask for CoCa (original is wrong)
1033
+
1034
+ self.token_embedding = nn.Embedding(vocab_size, width)
1035
+ if embed_cls:
1036
+ self.cls_emb = nn.Parameter(torch.empty(width))
1037
+ self.num_pos += 1
1038
+ else:
1039
+ self.cls_emb = None
1040
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
1041
+ self.transformer = Transformer(
1042
+ width=width,
1043
+ layers=layers,
1044
+ heads=heads,
1045
+ mlp_ratio=mlp_ratio,
1046
+ ls_init_value=ls_init_value,
1047
+ act_layer=act_layer,
1048
+ norm_layer=norm_layer,
1049
+ block_type=block_type,
1050
+ qk_norm=qk_norm,
1051
+ scaled_cosine_attn=scaled_cosine_attn,
1052
+ scale_heads=scale_heads,
1053
+ scale_attn_inner=scale_attn_inner,
1054
+ scale_attn=scale_attn,
1055
+ scale_fc=scale_fc,
1056
+ )
1057
+ self.ln_final = norm_layer(width)
1058
+
1059
+ if no_causal_mask:
1060
+ self.attn_mask = None # bi‑directional
1061
+ else:
1062
+ self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)
1063
+
1064
+ if proj_type == 'none' or not output_dim:
1065
+ self.text_projection = None
1066
+ else:
1067
+ if proj_bias:
1068
+ self.text_projection = nn.Linear(width, output_dim)
1069
+ else:
1070
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
1071
+
1072
+ self.init_parameters()
1073
+
1074
+ def init_parameters(self):
1075
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
1076
+ nn.init.normal_(self.positional_embedding, std=0.01)
1077
+ if self.cls_emb is not None:
1078
+ nn.init.normal_(self.cls_emb, std=0.01)
1079
+
1080
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
1081
+ attn_std = self.transformer.width ** -0.5
1082
+ fc_std = (2 * self.transformer.width) ** -0.5
1083
+ for block in self.transformer.resblocks:
1084
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
1085
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
1086
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
1087
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
1088
+
1089
+ if self.text_projection is not None:
1090
+ if isinstance(self.text_projection, nn.Linear):
1091
+ nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)
1092
+ if self.text_projection.bias is not None:
1093
+ nn.init.zeros_(self.text_projection.bias)
1094
+ else:
1095
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
1096
+
1097
+ @torch.jit.ignore
1098
+ def set_grad_checkpointing(self, enable=True):
1099
+ self.transformer.grad_checkpointing = enable
1100
+
1101
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
1102
+ """
1103
+ Lock the text transformer layers, optionally leaving some layers unlocked.
1104
+
1105
+ Args:
1106
+ unlocked_layers: Number of layers to leave unlocked (from the end).
1107
+ freeze_layer_norm: LayerNorm freeze (only for API compatibility, not functional)
1108
+ """
1109
+ assert freeze_layer_norm, 'Unfreezing LayerNorm is not supported. LayerNorm treated like other weights.'
1110
+ lock_text_tower(self, unlocked_layers)
1111
+
1112
+ @torch.jit.ignore
1113
+ def no_weight_decay(self):
1114
+ # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
1115
+ no_wd = {'positional_embedding'}
1116
+ if self.cls_emb is not None:
1117
+ no_wd.add('cls_emb')
1118
+ return no_wd
1119
+
1120
+ def build_causal_mask(self):
1121
+ # lazily create causal attention mask, with full attention between the tokens
1122
+ # pytorch uses additive attention mask; fill with -inf
1123
+ mask = torch.empty(self.num_pos, self.num_pos)
1124
+ mask.fill_(float("-inf"))
1125
+ mask.triu_(1) # zero out the lower diagonal
1126
+ return mask
1127
+
1128
+ def _build_additive_mask(
1129
+ self,
1130
+ text: torch.Tensor, # [B, L] – original text ids without CLS yet
1131
+ seq_len: int, # L (+1 if CLS added)
1132
+ dtype: torch.dtype,
1133
+ ) -> torch.Tensor:
1134
+ """
1135
+ Returns an additive (-inf) mask of shape [B*heads, seq_len, seq_len] that
1136
+ simultaneously masks padding tokens and (optionally) the CLS token.
1137
+ """
1138
+ valid = text != self.pad_id # [B, L] (True = keep)
1139
+
1140
+ if self.cls_emb is not None:
1141
+ cls_valid = valid.new_ones(valid.size(0), 1) # [B, 1]
1142
+ # cls mask pos at end if correct or front for incorrect legacy mode in existing CoCa weights
1143
+ valid = torch.cat([valid, cls_valid] if self.correct_cls_mask else [cls_valid, valid], 1)
1144
+
1145
+ # broadcast over query dimension
1146
+ key_mask = valid.unsqueeze(1).expand(-1, seq_len, -1) # [B, Q, K]
1147
+ additive = torch.zeros_like(key_mask, dtype=dtype)
1148
+ additive.masked_fill_(~key_mask, float("-inf"))
1149
+ additive = additive.repeat_interleave(self.heads, 0) # [B*H, Q, K]
1150
+ return additive
1151
+
1152
+ def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1153
+ cast_dtype = self.transformer.get_cast_dtype()
1154
+ B, seq_len = text.shape
1155
+
1156
+ x = self.token_embedding(text).to(cast_dtype)
1157
+
1158
+ # Optional class token (always appended ala CoCa)
1159
+ if self.cls_emb is not None:
1160
+ x = torch.cat([x, _expand_token(self.cls_emb, x.size(0))], 1)
1161
+ seq_len += 1
1162
+
1163
+ attn_mask = self.attn_mask # Base causal mask (if any)
1164
+
1165
+ # Class + padding additive mask
1166
+ if self.use_pad_mask or self.cls_emb is not None:
1167
+ add_mask = self._build_additive_mask(text, seq_len, x.dtype)
1168
+ if attn_mask is not None:
1169
+ # Slice the causal mask to match current sequence length
1170
+ attn_mask = attn_mask[:seq_len, :seq_len].unsqueeze(0) + add_mask
1171
+ else:
1172
+ attn_mask = add_mask
1173
+
1174
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
1175
+ return x, attn_mask
1176
+
1177
+ def forward_intermediates(
1178
+ self,
1179
+ text: torch.Tensor,
1180
+ indices: Optional[Union[int, List[int]]] = None,
1181
+ stop_early: bool = False,
1182
+ normalize_intermediates: bool = False,
1183
+ intermediates_only: bool = False,
1184
+ output_fmt: str = 'NCHW',
1185
+ output_extra_tokens: bool = False,
1186
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
1187
+ """ Forward features that returns intermediates.
1188
+
1189
+ Args:
1190
+ text: Input text ids
1191
+ indices: Take last n blocks if int, all if None, select matching indices if sequence
1192
+ stop_early: Stop iterating over blocks when last desired intermediate hit
1193
+ normalize_intermediates: Apply norm layer to all intermediates
1194
+ intermediates_only: Only return intermediate features
1195
+ output_fmt: Shape of intermediate feature outputs
1196
+ output_extra_tokens: Return both prefix and intermediate tokens
1197
+ Returns:
1198
+
1199
+ """
1200
+ assert output_fmt in ('NLC',), 'Output format must be NLC.'
1201
+ # forward pass
1202
+ x, attn_mask = self._embeds(text)
1203
+ x, intermediates = self.transformer.forward_intermediates(
1204
+ x,
1205
+ attn_mask=attn_mask,
1206
+ indices=indices,
1207
+ stop_early=stop_early,
1208
+ )
1209
+
1210
+ # process intermediates
1211
+ if normalize_intermediates:
1212
+ # apply final norm to all intermediates
1213
+ intermediates = [self.ln_final(xi) for xi in intermediates]
1214
+
1215
+ output = {}
1216
+
1217
+ if self.cls_emb is not None:
1218
+ seq_intermediates = [xi[:, :-1] for xi in intermediates] # separate concat'd class token from sequence
1219
+ if output_extra_tokens:
1220
+ # return suffix class tokens separately
1221
+ cls_intermediates = [xi[:, -1:] for xi in intermediates]
1222
+ output['text_intermediates_suffix'] = cls_intermediates
1223
+ intermediates = seq_intermediates
1224
+ output['text_intermediates'] = intermediates
1225
+
1226
+ if intermediates_only:
1227
+ return output
1228
+
1229
+ if self.cls_emb is not None:
1230
+ # presence of appended cls embed (CoCa) overrides pool_type, always take last token
1231
+ pooled = text_global_pool(x, pool_type='last')
1232
+ pooled = self.ln_final(pooled) # final LN applied after pooling in this case
1233
+ else:
1234
+ x = self.ln_final(x)
1235
+ pooled = text_global_pool(x, text, pool_type=self.pool_type, eos_token_id=getattr(self, "eos_id", None))
1236
+
1237
+ if self.text_projection is not None:
1238
+ if isinstance(self.text_projection, nn.Linear):
1239
+ pooled = self.text_projection(pooled)
1240
+ else:
1241
+ pooled = pooled @ self.text_projection
1242
+
1243
+ output['text_features'] = pooled
1244
+
1245
+ return output
1246
+
1247
+ def prune_intermediate_layers(
1248
+ self,
1249
+ indices: Union[int, List[int]] = 1,
1250
+ prune_norm: bool = False,
1251
+ prune_head: bool = True,
1252
+ ):
1253
+ """ Prune layers not required for specified intermediates.
1254
+ """
1255
+ take_indices = self.transformer.prune_intermediate_layers(indices)
1256
+ if prune_norm:
1257
+ self.ln_final = nn.Identity()
1258
+ if prune_head:
1259
+ self.text_projection = None
1260
+ return take_indices
1261
+
1262
+ def forward(self, text):
1263
+ x, attn_mask = self._embeds(text)
1264
+
1265
+ x = self.transformer(x, attn_mask=attn_mask)
1266
+
1267
+ # x.shape = [batch_size, n_ctx, transformer.width]
1268
+ if self.cls_emb is not None:
1269
+ # presence of appended cls embed (CoCa) overrides pool_type, always take last token
1270
+ pooled = text_global_pool(x, pool_type='last')
1271
+ pooled = self.ln_final(pooled) # final LN applied after pooling in this case
1272
+ tokens = x[:, :-1]
1273
+ else:
1274
+ x = self.ln_final(x)
1275
+ pooled = text_global_pool(x, text, pool_type=self.pool_type, eos_token_id=getattr(self, "eos_id", None))
1276
+ tokens = x
1277
+
1278
+ if self.text_projection is not None:
1279
+ if isinstance(self.text_projection, nn.Linear):
1280
+ pooled = self.text_projection(pooled)
1281
+ else:
1282
+ pooled = pooled @ self.text_projection
1283
+
1284
+ if self.output_tokens:
1285
+ return pooled, tokens
1286
+
1287
+ return pooled
1288
+
1289
+
1290
+ class MultimodalTransformer(Transformer):
1291
+ """Cross-attention based multimodal decoder.
1292
+
1293
+ Text and image/biosignals embeddings are kept separate.
1294
+ Each layer has:
1295
+ 1. Self-attention on text (causal)
1296
+ 2. Cross-attention from text to image/biosignals
1297
+ """
1298
+ def __init__(
1299
+ self,
1300
+ width: int,
1301
+ layers: int,
1302
+ heads: int,
1303
+ context_length: int = 77,
1304
+ mlp_ratio: float = 4.0,
1305
+ ls_init_value: float = None,
1306
+ act_layer: Type[nn.Module] = nn.GELU,
1307
+ norm_layer: Type[nn.Module] = LayerNorm,
1308
+ output_dim: int = 512,
1309
+ batch_first: bool = True,
1310
+ prefix_len: int = 0,
1311
+ ):
1312
+ super().__init__(
1313
+ width=width,
1314
+ layers=layers,
1315
+ heads=heads,
1316
+ mlp_ratio=mlp_ratio,
1317
+ ls_init_value=ls_init_value,
1318
+ act_layer=act_layer,
1319
+ norm_layer=norm_layer,
1320
+ batch_first=batch_first,
1321
+ )
1322
+ self.context_length = context_length
1323
+ self.cross_attn = nn.ModuleList([
1324
+ ResidualAttentionBlock(
1325
+ width,
1326
+ heads,
1327
+ mlp_ratio,
1328
+ ls_init_value=ls_init_value,
1329
+ act_layer=act_layer,
1330
+ norm_layer=norm_layer,
1331
+ is_cross_attention=True,
1332
+ batch_first=batch_first,
1333
+ )
1334
+ for _ in range(layers)
1335
+ ])
1336
+
1337
+ # Register attention masks based on prefix configuration
1338
+ self.prefix_len = prefix_len
1339
+ if prefix_len > 0:
1340
+ # Pre-build prefix-causal mask for condition tokens + text
1341
+ prefix_causal_mask = self.build_prefix_causal_mask(prefix_len, context_length)
1342
+ self.register_buffer('prefix_causal_mask', prefix_causal_mask, persistent=False)
1343
+ else:
1344
+ # Only register standard causal mask when not using prefix tokens
1345
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
1346
+
1347
+ self.ln_final = norm_layer(width)
1348
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
1349
+
1350
+ self.init_parameters()
1351
+
1352
+ def init_parameters(self):
1353
+ proj_std = (self.width ** -0.5) * ((2 * self.layers) ** -0.5)
1354
+ attn_std = self.width ** -0.5
1355
+ fc_std = (2 * self.width) ** -0.5
1356
+ for block in self.resblocks:
1357
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
1358
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
1359
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
1360
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
1361
+ for block in self.cross_attn:
1362
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
1363
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
1364
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
1365
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
1366
+
1367
+ if self.text_projection is not None:
1368
+ nn.init.normal_(self.text_projection, std=self.width ** -0.5)
1369
+
1370
+ def build_attention_mask(self):
1371
+ # lazily create causal attention mask, with full attention between the tokens
1372
+ # pytorch uses additive attention mask; fill with -inf
1373
+ mask = torch.empty(self.context_length, self.context_length)
1374
+ mask.fill_(float("-inf"))
1375
+ mask.triu_(1) # zero out the lower diagonal
1376
+ return mask
1377
+
1378
+ # def build_prefix_causal_mask(self, prefix_len: int, text_len: int):
1379
+ # """Build a prefix-causal attention mask for condition tokens + text.
1380
+
1381
+ # Args:
1382
+ # prefix_len: Length of prefix (condition tokens)
1383
+ # These tokens receive full bidirectional attention among themselves.
1384
+ # text_len: Length of text sequence
1385
+ # These tokens receive causal attention.
1386
+
1387
+ # Returns:
1388
+ # Additive mask of shape (prefix_len + text_len, prefix_len + text_len)
1389
+ # Where -inf = cannot attend, 0 = can attend
1390
+
1391
+ # Attention pattern:
1392
+ # - Prefix tokens ↔ Prefix tokens: Full bidirectional (can attend)
1393
+ # - Text tokens → Prefix tokens: Full attention (can attend)
1394
+ # - Text tokens → Text tokens: Causal attention (only previous tokens)
1395
+ # - Prefix tokens → Text tokens: Cannot attend (masked)
1396
+ # """
1397
+ # total_len = prefix_len + text_len
1398
+ # mask = torch.zeros(total_len, total_len)
1399
+
1400
+ # # Prefix tokens can attend to all prefix tokens (bidirectional)
1401
+ # # mask[:prefix_len, :prefix_len] remains 0 (can attend)
1402
+
1403
+ # # Prefix tokens cannot attend to text tokens
1404
+ # mask[:prefix_len, prefix_len:] = float("-inf")
1405
+
1406
+ # # Text tokens can attend to all prefix tokens
1407
+ # # mask[prefix_len:, :prefix_len] remains 0 (can attend)
1408
+
1409
+ # # Text tokens attend to previous text tokens only (causal)
1410
+ # text_causal_mask = torch.triu(torch.ones(text_len, text_len), diagonal=1) * float("-inf")
1411
+ # mask[prefix_len:, prefix_len:] = text_causal_mask
1412
+
1413
+ # return mask
1414
+
1415
+ def build_prefix_causal_mask(self, prefix_len: int, text_len: int):
1416
+ """Additive mask; 0 = attend, NEG = block (fp32 for stability)."""
1417
+ total_len = prefix_len + text_len
1418
+ # fp32 on CPU; we'll .to(device) later without changing dtype
1419
+ mask = torch.zeros(total_len, total_len, dtype=torch.float32)
1420
+
1421
+ # large finite negative (safer than -inf for fp16/bf16 kernels)
1422
+ NEG = -torch.finfo(mask.dtype).max
1423
+
1424
+ # Prefix → Text: block
1425
+ mask[:prefix_len, prefix_len:] = NEG
1426
+
1427
+ # Text → Text: causal (block future). Use masked_fill, not 0 * -inf.
1428
+ tri = torch.triu(torch.ones(text_len, text_len, dtype=torch.bool), diagonal=1)
1429
+ mask[prefix_len:, prefix_len:].masked_fill_(tri, NEG)
1430
+ return mask
1431
+
1432
+ def forward_intermediates(
1433
+ self,
1434
+ x: torch.Tensor,
1435
+ attn_mask: Optional[torch.Tensor] = None,
1436
+ indices: Optional[Union[int, List[int]]] = None,
1437
+ stop_early: bool = False,
1438
+ ):
1439
+ assert False, "Not currently implemented for MultimodalTransformer w/ xattn"
1440
+
1441
+ def forward(self, image_embs, text_embs, condition_embs=None):
1442
+ """Forward pass with cross-attention between text and image.
1443
+
1444
+ Args:
1445
+ image_embs: (batch_size, num_image_tokens, width)
1446
+ text_embs: (batch_size, num_text_tokens, width)
1447
+ condition_embs: Optional (batch_size, num_condition_tokens, width)
1448
+ Additional conditioning tokens that will be prepended to text.
1449
+ These tokens get full bidirectional attention among themselves,
1450
+ then cross-attend to image embeddings.
1451
+
1452
+ Returns:
1453
+ Text decoder outputs: (batch_size, num_text_tokens, output_dim)
1454
+ Note: Only text token outputs are returned (condition token outputs are excluded)
1455
+ """
1456
+ # Determine text length before prepending conditions
1457
+ original_text_len = text_embs.shape[1]
1458
+ assert original_text_len <= self.context_length, "original_text_len must be less than or equal to context_length"
1459
+
1460
+ # Prepend condition tokens to text if provided
1461
+ if condition_embs is not None:
1462
+ condition_len = condition_embs.shape[1]
1463
+
1464
+ # Safety check: condition_len must not exceed the pre-configured prefix_len
1465
+ assert condition_len <= self.prefix_len, \
1466
+ f"condition_len ({condition_len}) exceeds prefix_len ({self.prefix_len})"
1467
+
1468
+ text_embs = torch.cat([condition_embs, text_embs], dim=1) # (batch, cond_len + text_len, width)
1469
+ else:
1470
+ condition_len = 0
1471
+
1472
+ # Get attention mask based on prefix configuration
1473
+ if self.prefix_len > 0:
1474
+ # Slice the pre-built prefix-causal mask based on actual condition_len
1475
+ # The mask is built for (prefix_len + context_length)
1476
+ # When condition_len < prefix_len, we slice from offset to get the right structure
1477
+ offset = self.prefix_len - condition_len # How many prefix positions to skip
1478
+ seq_len = condition_len + original_text_len # Total sequence length
1479
+ attn_mask = self.prefix_causal_mask[offset:offset+seq_len, offset:offset+seq_len].to(device=text_embs.device)
1480
+ else:
1481
+ # Use standard causal mask when prefix_len == 0
1482
+ seq_len = original_text_len
1483
+ attn_mask = self.attn_mask[:seq_len, :seq_len].to(device=text_embs.device)
1484
+
1485
+ if not self.batch_first:
1486
+ image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
1487
+ text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
1488
+
1489
+ for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
1490
+ if self.grad_checkpointing and not torch.jit.is_scripting():
1491
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
1492
+ text_embs = checkpoint(
1493
+ resblock, text_embs, None, None, attn_mask, use_reentrant=False)
1494
+ text_embs = checkpoint(
1495
+ cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False)
1496
+ else:
1497
+ text_embs = resblock(text_embs, attn_mask=attn_mask)
1498
+ text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
1499
+
1500
+ if not self.batch_first:
1501
+ text_embs = text_embs.permute(1, 0, 2) # LND -> NLD
1502
+
1503
+ out = self.ln_final(text_embs)
1504
+ if self.text_projection is not None:
1505
+ out = out @ self.text_projection
1506
+
1507
+ # Extract only the text portion (skip condition tokens if present)
1508
+ if condition_len > 0:
1509
+ out = out[:, condition_len:, :] # (batch, text_len, output_dim)
1510
+
1511
+ return out
1512
+
1513
+ @torch.jit.ignore
1514
+ def set_grad_checkpointing(self, enable=True):
1515
+ self.grad_checkpointing = enable
1516
+
1517
+
1518
+ class ConcatMultimodalTransformer(Transformer):
1519
+ """Concatenation-based multimodal decoder.
1520
+
1521
+ Concatenates [condition_tokens (optional), image/biosignals_tokens, text_tokens] into a single sequence.
1522
+ Uses unified self-attention with a prefix-causal mask that allows:
1523
+ - Condition tokens attend to all condition + image tokens (full bidirectional)
1524
+ - Image/biosignals tokens attend to all condition + image tokens (full bidirectional)
1525
+ - Text tokens attend to all condition + image tokens (full attention to prefix)
1526
+ - Text tokens attend to all previous text tokens (causal)
1527
+
1528
+ This enables flexible conditioning where any prefix tokens (condition + image) get full
1529
+ bidirectional attention, while text maintains causal generation properties.
1530
+ """
1531
+ def __init__(
1532
+ self,
1533
+ width: int,
1534
+ layers: int,
1535
+ heads: int,
1536
+ context_length: int = 77,
1537
+ mlp_ratio: float = 4.0,
1538
+ ls_init_value: float = None,
1539
+ act_layer: Type[nn.Module] = nn.GELU,
1540
+ norm_layer: Type[nn.Module] = LayerNorm,
1541
+ output_dim: int = 512,
1542
+ batch_first: bool = True,
1543
+ prefix_len: int = 0,
1544
+ ):
1545
+ super().__init__(
1546
+ width=width,
1547
+ layers=layers,
1548
+ heads=heads,
1549
+ mlp_ratio=mlp_ratio,
1550
+ ls_init_value=ls_init_value,
1551
+ act_layer=act_layer,
1552
+ norm_layer=norm_layer,
1553
+ batch_first=batch_first,
1554
+ )
1555
+ self.context_length = context_length
1556
+ self.condition_prefix_len = prefix_len # Number of condition tokens (0, 1, or N)
1557
+
1558
+ # Pre-register an empty buffer for the attention mask
1559
+ # Will be populated on first forward pass when image token count is known
1560
+ self.register_buffer('_cached_attn_mask', torch.empty(0), persistent=False)
1561
+ self._cached_prefix_len = None # Track the prefix length used to build the cache
1562
+
1563
+ # No cross-attention layers needed - uses self-attention only
1564
+ self.ln_final = norm_layer(width)
1565
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
1566
+
1567
+ # self.init_parameters()
1568
+
1569
+ def init_parameters(self):
1570
+ proj_std = (self.width ** -0.5) * ((2 * self.layers) ** -0.5)
1571
+ attn_std = self.width ** -0.5
1572
+ fc_std = (2 * self.width) ** -0.5
1573
+ for block in self.resblocks:
1574
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
1575
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
1576
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
1577
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
1578
+
1579
+ if self.text_projection is not None:
1580
+ nn.init.normal_(self.text_projection, std=self.width ** -0.5)
1581
+
1582
+ # def build_prefix_causal_mask(self, prefix_len: int, text_len: int):
1583
+ # """Build a prefix-causal attention mask.
1584
+
1585
+ # Args:
1586
+ # prefix_len: Length of the prefix (condition + image/biosignals tokens)
1587
+ # All prefix tokens receive full bidirectional attention among themselves.
1588
+ # text_len: Length of text sequence
1589
+
1590
+ # Returns:
1591
+ # Additive mask of shape (prefix_len + text_len, prefix_len + text_len)
1592
+ # Where -inf = cannot attend, 0 = can attend
1593
+
1594
+ # Attention pattern:
1595
+ # - Prefix tokens ↔ Prefix tokens: Full bidirectional (can attend)
1596
+ # - Text tokens → Prefix tokens: Full attention (can attend)
1597
+ # - Text tokens → Text tokens: Causal attention (only previous tokens)
1598
+ # - Prefix tokens → Text tokens: Cannot attend (masked)
1599
+ # """
1600
+ # total_len = prefix_len + text_len
1601
+ # # Start with a float mask of zeros (all positions can attend)
1602
+ # mask = torch.zeros(total_len, total_len, dtype=torch.float32)
1603
+
1604
+ # # Prefix tokens can attend to all prefix tokens (bidirectional)
1605
+ # # mask[:prefix_len, :prefix_len] remains 0 (can attend)
1606
+
1607
+ # # Prefix tokens CANNOT attend to text tokens (CRITICAL FIX)
1608
+ # mask[:prefix_len, prefix_len:] = float("-inf")
1609
+
1610
+ # # Text tokens can attend to all prefix tokens
1611
+ # # mask[prefix_len:, :prefix_len] remains 0 (can attend)
1612
+
1613
+ # # Text tokens attend to previous text tokens only (causal)
1614
+ # text_causal_mask = torch.triu(torch.ones(text_len, text_len), diagonal=1) * float("-inf")
1615
+ # mask[prefix_len:, prefix_len:] = text_causal_mask
1616
+
1617
+ # return mask
1618
+
1619
+ def build_prefix_causal_mask(self, prefix_len: int, text_len: int):
1620
+ """Additive mask; 0 = attend, NEG = block (fp32 for stability)."""
1621
+ total_len = prefix_len + text_len
1622
+ # build in fp32; move to GPU later with .to(device=...) but DON'T cast dtype
1623
+ mask = torch.zeros(total_len, total_len, dtype=torch.float32)
1624
+
1625
+ # large finite negative (safer than -inf with fp16/bf16 + fused kernels)
1626
+ NEG = -torch.finfo(mask.dtype).max
1627
+
1628
+ # Prefix → Text: block
1629
+ mask[:prefix_len, prefix_len:] = NEG
1630
+
1631
+ # Text → Text: causal (block future). Use masked_fill, not multiply by -inf.
1632
+ tri = torch.triu(torch.ones(text_len, text_len, dtype=torch.bool), diagonal=1)
1633
+ mask[prefix_len:, prefix_len:].masked_fill_(tri, NEG)
1634
+ return mask
1635
+
1636
+ def forward_intermediates(
1637
+ self,
1638
+ x: torch.Tensor,
1639
+ attn_mask: Optional[torch.Tensor] = None,
1640
+ indices: Optional[Union[int, List[int]]] = None,
1641
+ stop_early: bool = False,
1642
+ ):
1643
+ assert False, "Not currently implemented for ConcatMultimodalTransformer"
1644
+
1645
+ def forward(self, image_embs, text_embs, condition_embs=None):
1646
+ """Forward pass with concatenated embeddings.
1647
+
1648
+ Args:
1649
+ image_embs: (batch_size, num_image_tokens, width)
1650
+ text_embs: (batch_size, num_text_tokens, width)
1651
+ condition_embs: Optional (batch_size, num_condition_tokens, width)
1652
+ Additional conditioning tokens that will be prepended before image tokens.
1653
+ These tokens receive full bidirectional attention like image tokens.
1654
+
1655
+ Returns:
1656
+ Text decoder outputs: (batch_size, num_text_tokens, output_dim)
1657
+ """
1658
+ batch_size = text_embs.shape[0]
1659
+ text_len = text_embs.shape[1]
1660
+
1661
+ # Guard: text length must not exceed context length
1662
+ assert text_len <= self.context_length, \
1663
+ f"text_len ({text_len}) must be <= context_length ({self.context_length})"
1664
+
1665
+ # Build prefix: [condition_tokens (optional), image_tokens]
1666
+ # All prefix tokens get full bidirectional attention
1667
+ if condition_embs is not None:
1668
+ condition_len = condition_embs.shape[1]
1669
+
1670
+ # Safety check: condition_len must not exceed the pre-configured condition_prefix_len
1671
+ assert condition_len <= self.condition_prefix_len, \
1672
+ f"condition_len ({condition_len}) exceeds condition_prefix_len ({self.condition_prefix_len})"
1673
+
1674
+ prefix = torch.cat([condition_embs, image_embs], dim=1) # (batch, cond_len + img_len, width)
1675
+ else:
1676
+ condition_len = 0
1677
+ prefix = image_embs
1678
+
1679
+ prefix_len = prefix.shape[1] # Total prefix length (condition + image tokens)
1680
+
1681
+ # Concatenate prefix and text embeddings
1682
+ x = torch.cat([prefix, text_embs], dim=1) # (batch, prefix_len + text_len, width)
1683
+
1684
+ if not self.batch_first:
1685
+ x = x.permute(1, 0, 2) # NLD -> LND
1686
+
1687
+ # Build or retrieve cached prefix-causal attention mask
1688
+ # Dynamically rebuilds when prefix_len changes (handles variable condition_len or image_len)
1689
+ if self._cached_prefix_len != prefix_len or self._cached_attn_mask.numel() == 0:
1690
+ # Build mask for max possible text length (context_length)
1691
+ mask = self.build_prefix_causal_mask(prefix_len, self.context_length)
1692
+
1693
+ # Directly update the buffer (already registered in __init__)
1694
+ self._cached_attn_mask = mask
1695
+ self._cached_prefix_len = prefix_len
1696
+
1697
+ # Slice cached mask to actual sequence length
1698
+ seq_len = prefix_len + text_len
1699
+ attn_mask = self._cached_attn_mask[:seq_len, :seq_len].to(device=x.device)
1700
+
1701
+ # Apply transformer layers with unified self-attention
1702
+ for resblock in self.resblocks:
1703
+ if self.grad_checkpointing and not torch.jit.is_scripting():
1704
+ x = checkpoint(resblock, x, None, None, attn_mask, use_reentrant=False)
1705
+ else:
1706
+ x = resblock(x, attn_mask=attn_mask)
1707
+
1708
+ if not self.batch_first:
1709
+ x = x.permute(1, 0, 2) # LND -> NLD
1710
+
1711
+ # Apply final layer norm
1712
+ x = self.ln_final(x)
1713
+
1714
+ # Extract only the text portion (skip image prefix)
1715
+ text_output = x[:, prefix_len:, :] # (batch, text_len, width)
1716
+
1717
+ # Project to output dimension
1718
+ if self.text_projection is not None:
1719
+ text_output = text_output @ self.text_projection
1720
+
1721
+ return text_output
1722
+
1723
+ @torch.jit.ignore
1724
+ def set_grad_checkpointing(self, enable=True):
1725
+ self.grad_checkpointing = enable
1726
+
1727
+
1728
+ def lock_text_tower(
1729
+ model: nn.Module,
1730
+ unlocked_layers: int = 0,
1731
+ ):
1732
+ """
1733
+ Lock text tower layers for CLIP models.
1734
+
1735
+ Works with both model architectures:
1736
+ - CustomTextCLIP where text components are in self.text
1737
+ - Standard CLIP where text components are unpacked as attributes
1738
+
1739
+ Args:
1740
+ model: The CLIP model or TextTransformer module
1741
+ unlocked_layers: Number of layers to leave unlocked (from the end)
1742
+ """
1743
+ # Determine where to look for text components
1744
+ if hasattr(model, 'text'):
1745
+ # CustomTextCLIP or already a TextTransformer with nested structure
1746
+ text_module = model.text
1747
+ else:
1748
+ # Standard CLIP or direct TextTransformer
1749
+ text_module = model
1750
+
1751
+ # Collect text components
1752
+ text_params = {}
1753
+ text_params['token_embedding'] = getattr(text_module, 'token_embedding', None)
1754
+ text_params['positional_embedding'] = getattr(text_module, 'positional_embedding', None)
1755
+ text_params['cls_emb'] = getattr(text_module, 'cls_emb', None)
1756
+ text_params['transformer'] = getattr(text_module, 'transformer', None)
1757
+ text_params['ln_final'] = getattr(text_module, 'ln_final', None)
1758
+ text_params['text_projection'] = getattr(text_module, 'text_projection', None)
1759
+
1760
+ # Filter out None values
1761
+ text_params = {k: v for k, v in text_params.items() if v is not None}
1762
+
1763
+ # Freeze all text parameters first
1764
+ for module in text_params.values():
1765
+ if isinstance(module, nn.Parameter):
1766
+ module.requires_grad = False
1767
+ elif isinstance(module, nn.Module):
1768
+ for param in module.parameters():
1769
+ param.requires_grad = False
1770
+
1771
+ if unlocked_layers == 0:
1772
+ return
1773
+
1774
+ # Check if we have transformer blocks to work with
1775
+ transformer = text_params['transformer']
1776
+ if not transformer or not hasattr(transformer, 'resblocks'):
1777
+ return
1778
+
1779
+ total_layers = len(transformer.resblocks)
1780
+ if total_layers == 0:
1781
+ return
1782
+
1783
+ # Build groups for selective unlocking
1784
+ groups = []
1785
+
1786
+ # Group 1: Embeddings
1787
+ embedding_group = []
1788
+ for key in ['token_embedding', 'positional_embedding', 'cls_emb']:
1789
+ if key in text_params:
1790
+ embedding_group.append(text_params[key])
1791
+ if embedding_group:
1792
+ groups.append(embedding_group)
1793
+
1794
+ # Group 2-N: Individual transformer blocks (except last)
1795
+ if total_layers > 1:
1796
+ for block in transformer.resblocks[:-1]:
1797
+ groups.append([block])
1798
+
1799
+ # Combine last transformer block + final ln as the penultimate group
1800
+ last_block = [transformer.resblocks[-1]]
1801
+ if 'ln_final' in text_params:
1802
+ last_block.append(text_params['ln_final'])
1803
+ groups.append(last_block)
1804
+
1805
+ # The final group is the projection only
1806
+ if 'text_projection' in text_params:
1807
+ groups.append([text_params['text_projection']])
1808
+
1809
+ # Helper function to unlock parameters
1810
+ def _unlock(module):
1811
+ if isinstance(module, Sequence):
1812
+ for m in module:
1813
+ _unlock(m)
1814
+ elif isinstance(module, nn.Parameter):
1815
+ module.requires_grad = True
1816
+ elif isinstance(module, nn.Module):
1817
+ for name, param in module.named_parameters():
1818
+ param.requires_grad = True
1819
+
1820
+ # Unlock the specified number of layer groups from the end
1821
+ num_groups_to_unlock = min(unlocked_layers, len(groups))
1822
+ for group in groups[-num_groups_to_unlock:]:
1823
+ _unlock(group)