Add source code
Browse files- src/open_clip/.ipynb_checkpoints/tokenizer-checkpoint.py +621 -0
- src/open_clip/__init__.py +2 -0
- src/open_clip/__pycache__/__init__.cpython-310.pyc +0 -0
- src/open_clip/__pycache__/__init__.cpython-313.pyc +0 -0
- src/open_clip/__pycache__/biosignals_coca_model.cpython-310.pyc +0 -0
- src/open_clip/__pycache__/biosignals_coca_model.cpython-313.pyc +0 -0
- src/open_clip/__pycache__/coca_model.cpython-310.pyc +0 -0
- src/open_clip/__pycache__/coca_model.cpython-313.pyc +0 -0
- src/open_clip/__pycache__/factory.cpython-310.pyc +0 -0
- src/open_clip/__pycache__/factory.cpython-313.pyc +0 -0
- src/open_clip/__pycache__/model.cpython-310.pyc +0 -0
- src/open_clip/__pycache__/model.cpython-313.pyc +0 -0
- src/open_clip/__pycache__/tokenizer.cpython-310.pyc +0 -0
- src/open_clip/__pycache__/tokenizer.cpython-313.pyc +0 -0
- src/open_clip/__pycache__/transformer.cpython-310.pyc +0 -0
- src/open_clip/__pycache__/transformer.cpython-313.pyc +0 -0
- src/open_clip/biosignals_coca_model.py +1807 -0
- src/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- src/open_clip/coca_model.py +586 -0
- src/open_clip/factory.py +93 -0
- src/open_clip/model.py +943 -0
- src/open_clip/model_configs/sleep_coca_base_dualtransformer.json +44 -0
- src/open_clip/tokenizer.py +621 -0
- src/open_clip/transformer.py +1823 -0
src/open_clip/.ipynb_checkpoints/tokenizer-checkpoint.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" CLIP tokenizer
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
| 4 |
+
"""
|
| 5 |
+
import gzip
|
| 6 |
+
import html
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import string
|
| 10 |
+
from functools import lru_cache, partial
|
| 11 |
+
from typing import Callable, List, Optional, Union, Dict
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import ftfy
|
| 15 |
+
import numpy as np
|
| 16 |
+
import regex as re
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
# https://stackoverflow.com/q/62691279
|
| 20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
+
_nltk_init = False
|
| 22 |
+
|
| 23 |
+
DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@lru_cache()
|
| 27 |
+
def default_bpe():
|
| 28 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@lru_cache()
|
| 32 |
+
def bytes_to_unicode():
|
| 33 |
+
"""
|
| 34 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 35 |
+
The reversible bpe codes work on unicode strings.
|
| 36 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 37 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 38 |
+
This is a significant percentage of your normal, say, 32K bpe vocab.
|
| 39 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 40 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 41 |
+
"""
|
| 42 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 43 |
+
cs = bs[:]
|
| 44 |
+
n = 0
|
| 45 |
+
for b in range(2**8):
|
| 46 |
+
if b not in bs:
|
| 47 |
+
bs.append(b)
|
| 48 |
+
cs.append(2**8+n)
|
| 49 |
+
n += 1
|
| 50 |
+
cs = [chr(n) for n in cs]
|
| 51 |
+
return dict(zip(bs, cs))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_pairs(word):
|
| 55 |
+
"""Return set of symbol pairs in a word.
|
| 56 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 57 |
+
"""
|
| 58 |
+
pairs = set()
|
| 59 |
+
prev_char = word[0]
|
| 60 |
+
for char in word[1:]:
|
| 61 |
+
pairs.add((prev_char, char))
|
| 62 |
+
prev_char = char
|
| 63 |
+
return pairs
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def basic_clean(text):
|
| 67 |
+
text = ftfy.fix_text(text)
|
| 68 |
+
text = html.unescape(html.unescape(text))
|
| 69 |
+
return text.strip()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def whitespace_clean(text):
|
| 73 |
+
text = " ".join(text.split())
|
| 74 |
+
text = text.strip()
|
| 75 |
+
return text
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _clean_canonicalize(x):
|
| 79 |
+
# basic, remove whitespace, remove punctuation, lower case
|
| 80 |
+
return canonicalize_text(basic_clean(x))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _clean_lower(x):
|
| 84 |
+
# basic, remove whitespace, lower case
|
| 85 |
+
return whitespace_clean(basic_clean(x)).lower()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _clean_whitespace(x):
|
| 89 |
+
# basic, remove whitespace
|
| 90 |
+
return whitespace_clean(basic_clean(x))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_clean_fn(type: str):
|
| 94 |
+
if type == 'canonicalize':
|
| 95 |
+
return _clean_canonicalize
|
| 96 |
+
elif type == 'lower':
|
| 97 |
+
return _clean_lower
|
| 98 |
+
elif type == 'whitespace':
|
| 99 |
+
return _clean_whitespace
|
| 100 |
+
else:
|
| 101 |
+
assert False, f"Invalid clean function ({type})."
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def canonicalize_text(
|
| 105 |
+
text,
|
| 106 |
+
*,
|
| 107 |
+
keep_punctuation_exact_string=None,
|
| 108 |
+
trans_punctuation: dict = str.maketrans("", "", string.punctuation),
|
| 109 |
+
):
|
| 110 |
+
"""Returns canonicalized `text` (lowercase and punctuation removed).
|
| 111 |
+
|
| 112 |
+
From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
text: string to be canonicalized.
|
| 116 |
+
keep_punctuation_exact_string: If provided, then this exact string kept.
|
| 117 |
+
For example providing '{}' will keep any occurrences of '{}' (but will
|
| 118 |
+
still remove '{' and '}' that appear separately).
|
| 119 |
+
"""
|
| 120 |
+
text = text.replace("_", " ")
|
| 121 |
+
if keep_punctuation_exact_string:
|
| 122 |
+
text = keep_punctuation_exact_string.join(
|
| 123 |
+
part.translate(trans_punctuation)
|
| 124 |
+
for part in text.split(keep_punctuation_exact_string)
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
text = text.translate(trans_punctuation)
|
| 128 |
+
text = text.lower()
|
| 129 |
+
text = " ".join(text.split())
|
| 130 |
+
return text.strip()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class SimpleTokenizer(object):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
bpe_path: str = default_bpe(),
|
| 137 |
+
additional_special_tokens: Optional[List[str]] = None,
|
| 138 |
+
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
|
| 139 |
+
clean: str = 'lower',
|
| 140 |
+
reduction_mask: str = ''
|
| 141 |
+
):
|
| 142 |
+
self.byte_encoder = bytes_to_unicode()
|
| 143 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 144 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 145 |
+
merges = merges[1:49152-256-2+1]
|
| 146 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 147 |
+
vocab = list(bytes_to_unicode().values())
|
| 148 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 149 |
+
for merge in merges:
|
| 150 |
+
vocab.append(''.join(merge))
|
| 151 |
+
special_tokens = ['<start_of_text>', '<end_of_text>']
|
| 152 |
+
if additional_special_tokens:
|
| 153 |
+
special_tokens += additional_special_tokens
|
| 154 |
+
vocab.extend(special_tokens)
|
| 155 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 156 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 157 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 158 |
+
self.cache = {t:t for t in special_tokens}
|
| 159 |
+
special = "|".join(special_tokens)
|
| 160 |
+
self.pat = re.compile(
|
| 161 |
+
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
| 162 |
+
re.IGNORECASE,
|
| 163 |
+
)
|
| 164 |
+
self.vocab_size = len(self.encoder)
|
| 165 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
| 166 |
+
self.sot_token_id = self.all_special_ids[0]
|
| 167 |
+
self.eot_token_id = self.all_special_ids[1]
|
| 168 |
+
self.context_length = context_length
|
| 169 |
+
self.clean_fn = get_clean_fn(clean)
|
| 170 |
+
self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None
|
| 171 |
+
|
| 172 |
+
def bpe(self, token):
|
| 173 |
+
if token in self.cache:
|
| 174 |
+
return self.cache[token]
|
| 175 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 176 |
+
pairs = get_pairs(word)
|
| 177 |
+
|
| 178 |
+
if not pairs:
|
| 179 |
+
return token+'</w>'
|
| 180 |
+
|
| 181 |
+
while True:
|
| 182 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 183 |
+
if bigram not in self.bpe_ranks:
|
| 184 |
+
break
|
| 185 |
+
first, second = bigram
|
| 186 |
+
new_word = []
|
| 187 |
+
i = 0
|
| 188 |
+
while i < len(word):
|
| 189 |
+
try:
|
| 190 |
+
j = word.index(first, i)
|
| 191 |
+
new_word.extend(word[i:j])
|
| 192 |
+
i = j
|
| 193 |
+
except Exception:
|
| 194 |
+
new_word.extend(word[i:])
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 198 |
+
new_word.append(first+second)
|
| 199 |
+
i += 2
|
| 200 |
+
else:
|
| 201 |
+
new_word.append(word[i])
|
| 202 |
+
i += 1
|
| 203 |
+
new_word = tuple(new_word)
|
| 204 |
+
word = new_word
|
| 205 |
+
if len(word) == 1:
|
| 206 |
+
break
|
| 207 |
+
else:
|
| 208 |
+
pairs = get_pairs(word)
|
| 209 |
+
word = ' '.join(word)
|
| 210 |
+
self.cache[token] = word
|
| 211 |
+
return word
|
| 212 |
+
|
| 213 |
+
def encode(self, text):
|
| 214 |
+
bpe_tokens = []
|
| 215 |
+
text = self.clean_fn(text)
|
| 216 |
+
for token in re.findall(self.pat, text):
|
| 217 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 218 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 219 |
+
return bpe_tokens
|
| 220 |
+
|
| 221 |
+
def decode(self, tokens):
|
| 222 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 223 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 224 |
+
return text
|
| 225 |
+
|
| 226 |
+
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:
|
| 227 |
+
""" Returns the tokenized representation of given input string(s)
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
texts : Union[str, List[str]]
|
| 232 |
+
An input string or a list of input strings to tokenize
|
| 233 |
+
context_length : int
|
| 234 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 235 |
+
|
| 236 |
+
Returns
|
| 237 |
+
-------
|
| 238 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 239 |
+
"""
|
| 240 |
+
if isinstance(texts, str):
|
| 241 |
+
texts = [texts]
|
| 242 |
+
|
| 243 |
+
context_length = context_length or self.context_length
|
| 244 |
+
assert context_length, 'Please set a valid context length'
|
| 245 |
+
|
| 246 |
+
if self.reduction_fn is not None:
|
| 247 |
+
# use reduction strategy for tokenize if set, otherwise default to truncation below
|
| 248 |
+
return self.reduction_fn(
|
| 249 |
+
texts,
|
| 250 |
+
context_length=context_length,
|
| 251 |
+
sot_token_id=self.sot_token_id,
|
| 252 |
+
eot_token_id=self.eot_token_id,
|
| 253 |
+
encode_fn=self.encode,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]
|
| 257 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 258 |
+
|
| 259 |
+
for i, tokens in enumerate(all_tokens):
|
| 260 |
+
if len(tokens) > context_length:
|
| 261 |
+
tokens = tokens[:context_length] # Truncate
|
| 262 |
+
tokens[-1] = self.eot_token_id
|
| 263 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 264 |
+
|
| 265 |
+
return result
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
_tokenizer = SimpleTokenizer()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def decode(output_ids: torch.Tensor):
|
| 272 |
+
output_ids = output_ids.cpu().numpy()
|
| 273 |
+
return _tokenizer.decode(output_ids)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:
|
| 277 |
+
return _tokenizer(texts, context_length=context_length)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def random_mask_tokenize(
|
| 281 |
+
texts: Union[str, List[str]],
|
| 282 |
+
context_length: int,
|
| 283 |
+
sot_token_id: int,
|
| 284 |
+
eot_token_id: int,
|
| 285 |
+
encode_fn: Callable,
|
| 286 |
+
shuffle: bool = False,
|
| 287 |
+
):
|
| 288 |
+
all_tokens = [encode_fn(text) for text in texts]
|
| 289 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 290 |
+
|
| 291 |
+
for i, tokens in enumerate(all_tokens):
|
| 292 |
+
tokens = torch.tensor(tokens)
|
| 293 |
+
num_tokens = len(tokens)
|
| 294 |
+
if num_tokens > context_length - 2: # 2 for sot and eot token
|
| 295 |
+
num_keep = context_length - 2
|
| 296 |
+
indices = torch.randperm(len(tokens))
|
| 297 |
+
indices = indices[:num_keep]
|
| 298 |
+
if not shuffle:
|
| 299 |
+
indices = indices.msort()
|
| 300 |
+
tokens = tokens[indices]
|
| 301 |
+
num_tokens = num_keep
|
| 302 |
+
result[i, 0] = sot_token_id
|
| 303 |
+
result[i, 1:num_tokens + 1] = tokens
|
| 304 |
+
result[i, num_tokens + 1] = eot_token_id
|
| 305 |
+
|
| 306 |
+
return result
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def simple_mask_tokenize(
|
| 310 |
+
texts: Union[str, List[str]],
|
| 311 |
+
context_length: int,
|
| 312 |
+
sot_token_id: int,
|
| 313 |
+
eot_token_id: int,
|
| 314 |
+
encode_fn: Callable,
|
| 315 |
+
):
|
| 316 |
+
all_tokens = [encode_fn(text) for text in texts]
|
| 317 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 318 |
+
|
| 319 |
+
for i, tokens in enumerate(all_tokens):
|
| 320 |
+
num_tokens = len(tokens)
|
| 321 |
+
if num_tokens > context_length - 2: # 2 for sot and eot token
|
| 322 |
+
num_keep = context_length - 2
|
| 323 |
+
start_index = random.randint(0, num_tokens - num_keep) # high is incl
|
| 324 |
+
tokens = tokens[start_index: start_index + num_keep]
|
| 325 |
+
tokens = [sot_token_id] + tokens + [eot_token_id]
|
| 326 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 327 |
+
|
| 328 |
+
return result
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def syntax_mask_tokenize(
|
| 332 |
+
texts: Union[str, List[str]],
|
| 333 |
+
context_length: int,
|
| 334 |
+
sot_token_id: int,
|
| 335 |
+
eot_token_id: int,
|
| 336 |
+
encode_fn: Callable,
|
| 337 |
+
) -> torch.LongTensor:
|
| 338 |
+
""" Returns the tokenized representation of given input string(s).
|
| 339 |
+
Apply syntax masking before tokenize.
|
| 340 |
+
"""
|
| 341 |
+
import nltk
|
| 342 |
+
global _nltk_init
|
| 343 |
+
if not _nltk_init:
|
| 344 |
+
# run them for the first time
|
| 345 |
+
nltk.download('punkt')
|
| 346 |
+
nltk.download('averaged_perceptron_tagger')
|
| 347 |
+
_nltk_init = True
|
| 348 |
+
|
| 349 |
+
def get_order(x):
|
| 350 |
+
if x.startswith('NN'):
|
| 351 |
+
return 1
|
| 352 |
+
elif x.startswith('JJ'):
|
| 353 |
+
return 2
|
| 354 |
+
elif x.startswith('VB'):
|
| 355 |
+
return 3
|
| 356 |
+
else:
|
| 357 |
+
return 4
|
| 358 |
+
|
| 359 |
+
# syntax masking
|
| 360 |
+
new_texts = []
|
| 361 |
+
for text in texts:
|
| 362 |
+
list_tokens = nltk.tokenize.word_tokenize(text)
|
| 363 |
+
pos_tags = nltk.pos_tag(list_tokens)
|
| 364 |
+
# sample the words by get_order method
|
| 365 |
+
order_list = [get_order(tag) for _, tag in pos_tags]
|
| 366 |
+
sorted_ids = np.argsort(np.array(order_list))
|
| 367 |
+
sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens
|
| 368 |
+
sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens
|
| 369 |
+
|
| 370 |
+
new_text = ''
|
| 371 |
+
for token in sampled_tokens:
|
| 372 |
+
new_text = new_text + str(token) + ' '
|
| 373 |
+
new_text = new_text.strip()
|
| 374 |
+
new_texts.append(new_text)
|
| 375 |
+
texts = new_texts
|
| 376 |
+
|
| 377 |
+
all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]
|
| 378 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 379 |
+
|
| 380 |
+
for i, tokens in enumerate(all_tokens):
|
| 381 |
+
# still need first truncate because some words produces two tokens
|
| 382 |
+
if len(tokens) > context_length:
|
| 383 |
+
tokens = tokens[:context_length] # Truncate
|
| 384 |
+
tokens[-1] = eot_token_id
|
| 385 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 386 |
+
|
| 387 |
+
return result
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def get_reduction_mask_fn(type: str):
|
| 391 |
+
""" Choose strategy for dropping (masking) tokens to achieve target context length"""
|
| 392 |
+
assert type in ('simple', 'random', 'shuffle', 'syntax')
|
| 393 |
+
if type == 'simple':
|
| 394 |
+
return simple_mask_tokenize # randomly select block [start:end]
|
| 395 |
+
elif type == 'random':
|
| 396 |
+
return random_mask_tokenize # randomly drop tokens (keep order)
|
| 397 |
+
elif type == 'shuffle':
|
| 398 |
+
return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order)
|
| 399 |
+
elif type == 'syntax':
|
| 400 |
+
return syntax_mask_tokenize # randomly drop prioritized by syntax
|
| 401 |
+
else:
|
| 402 |
+
assert False, F'Unknown type {type}.'
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class HFTokenizer:
|
| 406 |
+
"""HuggingFace tokenizer wrapper with support for custom tokenization modes"""
|
| 407 |
+
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
tokenizer_name: str,
|
| 411 |
+
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
|
| 412 |
+
clean: str = 'whitespace',
|
| 413 |
+
strip_sep_token: bool = False,
|
| 414 |
+
language: Optional[str] = None,
|
| 415 |
+
cache_dir: Optional[str] = None,
|
| 416 |
+
tokenizer_mode: Optional[str] = None, # None, 'clips'
|
| 417 |
+
**kwargs
|
| 418 |
+
):
|
| 419 |
+
self.tokenizer_mode = tokenizer_mode or ''
|
| 420 |
+
self.context_length = context_length
|
| 421 |
+
self.clean_fn = get_clean_fn(clean)
|
| 422 |
+
self.strip_sep_token = strip_sep_token
|
| 423 |
+
|
| 424 |
+
# NOTE: Left as example of loading custom tokenizer from file for experimentation
|
| 425 |
+
# if self.tokenizer_mode == 'bert_clips':
|
| 426 |
+
# self.special_tokens = {
|
| 427 |
+
# "bos_token": 1,
|
| 428 |
+
# "eos_token": 2,
|
| 429 |
+
# "cls_token": 101,
|
| 430 |
+
# "pad_token": 0
|
| 431 |
+
# }
|
| 432 |
+
#
|
| 433 |
+
# # For BERT CLIPS mode with vocab file
|
| 434 |
+
# from tokenizers import BertWordPieceTokenizer
|
| 435 |
+
# if tokenizer_name.startswith('hf-hub:'):
|
| 436 |
+
# from huggingface_hub import hf_hub_download
|
| 437 |
+
# # Format: hf-hub:repo_id/filename
|
| 438 |
+
# repo_url = tokenizer_name[7:]
|
| 439 |
+
# parts = repo_url.split('/')
|
| 440 |
+
# filename = parts[-1]
|
| 441 |
+
# repo_id = '/'.join(parts[:-1])
|
| 442 |
+
# vocab_file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
|
| 443 |
+
# self.tokenizer = BertWordPieceTokenizer(lowercase=True)
|
| 444 |
+
# self.tokenizer = self.tokenizer.from_file(vocab_file)
|
| 445 |
+
# else:
|
| 446 |
+
# # Assume tokenizer_name is a local path to a vocab file
|
| 447 |
+
# self.tokenizer = BertWordPieceTokenizer(lowercase=True)
|
| 448 |
+
# self.tokenizer = self.tokenizer.from_file(tokenizer_name)
|
| 449 |
+
|
| 450 |
+
# Standard HuggingFace tokenizer initialization
|
| 451 |
+
from transformers import AutoTokenizer
|
| 452 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 453 |
+
tokenizer_name,
|
| 454 |
+
cache_dir=cache_dir,
|
| 455 |
+
**kwargs
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Set language function if available
|
| 459 |
+
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
|
| 460 |
+
if callable(set_lang_fn):
|
| 461 |
+
self.set_lang_fn = set_lang_fn
|
| 462 |
+
if language is not None:
|
| 463 |
+
self.set_language(language)
|
| 464 |
+
|
| 465 |
+
def save_pretrained(self, dest):
|
| 466 |
+
self.tokenizer.save_pretrained(dest)
|
| 467 |
+
|
| 468 |
+
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
|
| 469 |
+
# same cleaning as for default tokenizer, except lowercasing
|
| 470 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
| 471 |
+
if isinstance(texts, str):
|
| 472 |
+
texts = [texts]
|
| 473 |
+
|
| 474 |
+
context_length = context_length or self.context_length
|
| 475 |
+
assert context_length, 'Please set a valid context length in class init or call.'
|
| 476 |
+
|
| 477 |
+
texts = [self.clean_fn(text) for text in texts]
|
| 478 |
+
|
| 479 |
+
# Handle different tokenization modes
|
| 480 |
+
if self.tokenizer_mode == 'clips':
|
| 481 |
+
return self._clips_tokenize(texts, context_length)
|
| 482 |
+
else:
|
| 483 |
+
# Standard tokenization
|
| 484 |
+
input_ids = self.tokenizer.batch_encode_plus(
|
| 485 |
+
texts,
|
| 486 |
+
return_tensors='pt',
|
| 487 |
+
max_length=context_length,
|
| 488 |
+
padding='max_length',
|
| 489 |
+
truncation=True,
|
| 490 |
+
).input_ids
|
| 491 |
+
|
| 492 |
+
if self.strip_sep_token:
|
| 493 |
+
input_ids = torch.where(
|
| 494 |
+
input_ids == self.tokenizer.sep_token_id,
|
| 495 |
+
torch.zeros_like(input_ids),
|
| 496 |
+
input_ids,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
return input_ids
|
| 500 |
+
|
| 501 |
+
def set_language(self, src_lang):
|
| 502 |
+
if hasattr(self, 'set_lang_fn'):
|
| 503 |
+
self.set_lang_fn(src_lang)
|
| 504 |
+
else:
|
| 505 |
+
warnings.warn('Cannot set language for the tokenizer.')
|
| 506 |
+
|
| 507 |
+
def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
|
| 508 |
+
"""Use standard HF tokenizer but apply custom post-processing"""
|
| 509 |
+
# Use standard tokenizer without special tokens - we'll add our own
|
| 510 |
+
encoded_outputs = self.tokenizer.batch_encode_plus(
|
| 511 |
+
texts,
|
| 512 |
+
add_special_tokens=False,
|
| 513 |
+
padding=False,
|
| 514 |
+
truncation=False,
|
| 515 |
+
return_tensors=None
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
encoded = []
|
| 519 |
+
for tokens in encoded_outputs["input_ids"]:
|
| 520 |
+
tokens = tokens[:context_length - 3] # Leave room for special tokens
|
| 521 |
+
tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
| 522 |
+
encoded.append(tokens)
|
| 523 |
+
|
| 524 |
+
# Create result tensor and handle padding + class token
|
| 525 |
+
result = torch.zeros(len(encoded), context_length, dtype=torch.long)
|
| 526 |
+
for i, tokens in enumerate(encoded):
|
| 527 |
+
padded_tokens = self._pad_and_add_class_token(
|
| 528 |
+
tokens,
|
| 529 |
+
max_length=context_length,
|
| 530 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 531 |
+
cls_token_id=self.tokenizer.cls_token_id,
|
| 532 |
+
)
|
| 533 |
+
result[i, :len(padded_tokens)] = torch.tensor(padded_tokens)
|
| 534 |
+
|
| 535 |
+
return result
|
| 536 |
+
|
| 537 |
+
def _pad_and_add_class_token(
|
| 538 |
+
self,
|
| 539 |
+
tokens: List[int],
|
| 540 |
+
max_length: int,
|
| 541 |
+
pad_token_id: int = 0,
|
| 542 |
+
cls_token_id: int = 101,
|
| 543 |
+
) -> List[int]:
|
| 544 |
+
""" Add padding with class token at the end """
|
| 545 |
+
if len(tokens) > max_length - 1:
|
| 546 |
+
tokens = tokens[:max_length - 1]
|
| 547 |
+
|
| 548 |
+
# Add padding to reach max_length-1
|
| 549 |
+
if len(tokens) < max_length - 1:
|
| 550 |
+
tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
|
| 551 |
+
|
| 552 |
+
# Add class token at the end
|
| 553 |
+
tokens = tokens + [cls_token_id]
|
| 554 |
+
return tokens
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class SigLipTokenizer:
|
| 558 |
+
"""HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
|
| 559 |
+
|
| 560 |
+
NOTE: this is not needed in normal library use, but is used to import new sentencepiece tokenizers
|
| 561 |
+
into OpenCLIP. Leaving code here in case future models use new tokenizers.
|
| 562 |
+
"""
|
| 563 |
+
VOCAB_FILES = {
|
| 564 |
+
# english, vocab_size=32_000
|
| 565 |
+
"c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model",
|
| 566 |
+
# used in multilingual models (mT5, PaLI), vocab_size=250_000
|
| 567 |
+
"mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
|
| 568 |
+
# used in SigLIP2 models, vocab_size=256000
|
| 569 |
+
"gemma": "http://storage.googleapis.com/big_vision/gemma_tokenizer.model",
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
def __init__(
|
| 573 |
+
self,
|
| 574 |
+
tokenizer_name: str,
|
| 575 |
+
context_length: Optional[int] = 64,
|
| 576 |
+
):
|
| 577 |
+
if 'gemma' in tokenizer_name:
|
| 578 |
+
from transformers import GemmaTokenizerFast
|
| 579 |
+
tokenizer_cls = partial(
|
| 580 |
+
GemmaTokenizerFast, padding_side='right', add_bos_token=False, add_eos_token=True)
|
| 581 |
+
else:
|
| 582 |
+
from transformers import T5TokenizerFast
|
| 583 |
+
tokenizer_cls = partial(T5TokenizerFast, extra_ids=0)
|
| 584 |
+
|
| 585 |
+
if tokenizer_name in self.VOCAB_FILES:
|
| 586 |
+
# FIXME temporary hack?
|
| 587 |
+
import tempfile
|
| 588 |
+
import fsspec
|
| 589 |
+
vocab_file = self.VOCAB_FILES[tokenizer_name]
|
| 590 |
+
with tempfile.NamedTemporaryFile('wb') as dst:
|
| 591 |
+
with fsspec.open(vocab_file, 'rb') as src:
|
| 592 |
+
dst.write(src.read())
|
| 593 |
+
self.tokenizer = tokenizer_cls(dst.name, legacy=False)
|
| 594 |
+
else:
|
| 595 |
+
self.tokenizer = tokenizer_cls(tokenizer_name, legacy=False)
|
| 596 |
+
|
| 597 |
+
self.tokenizer.pad_token_id = 0 if 'gemma' in tokenizer_name else 1
|
| 598 |
+
self.tokenizer.eos_token_id = 1
|
| 599 |
+
self.context_length = context_length
|
| 600 |
+
|
| 601 |
+
def save_pretrained(self, dest):
|
| 602 |
+
self.tokenizer.save_pretrained(dest)
|
| 603 |
+
|
| 604 |
+
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
|
| 605 |
+
# same cleaning as for default tokenizer, except lowercasing
|
| 606 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
| 607 |
+
if isinstance(texts, str):
|
| 608 |
+
texts = [texts]
|
| 609 |
+
|
| 610 |
+
context_length = context_length or self.context_length
|
| 611 |
+
assert context_length, 'Please set a valid context length in class init or call.'
|
| 612 |
+
|
| 613 |
+
texts = [canonicalize_text(basic_clean(text)) for text in texts]
|
| 614 |
+
output = self.tokenizer(
|
| 615 |
+
texts,
|
| 616 |
+
return_tensors='pt',
|
| 617 |
+
max_length=context_length,
|
| 618 |
+
padding='max_length',
|
| 619 |
+
truncation=True,
|
| 620 |
+
)
|
| 621 |
+
return output.input_ids
|
src/open_clip/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .factory import create_model, load_checkpoint, get_tokenizer, get_input_dtype
|
| 2 |
+
from .tokenizer import SimpleTokenizer
|
src/open_clip/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (319 Bytes). View file
|
|
|
src/open_clip/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (330 Bytes). View file
|
|
|
src/open_clip/__pycache__/biosignals_coca_model.cpython-310.pyc
ADDED
|
Binary file (44.3 kB). View file
|
|
|
src/open_clip/__pycache__/biosignals_coca_model.cpython-313.pyc
ADDED
|
Binary file (70.3 kB). View file
|
|
|
src/open_clip/__pycache__/coca_model.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
src/open_clip/__pycache__/coca_model.cpython-313.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
src/open_clip/__pycache__/factory.cpython-310.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
src/open_clip/__pycache__/factory.cpython-313.pyc
ADDED
|
Binary file (5.05 kB). View file
|
|
|
src/open_clip/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (24.5 kB). View file
|
|
|
src/open_clip/__pycache__/model.cpython-313.pyc
ADDED
|
Binary file (42.6 kB). View file
|
|
|
src/open_clip/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
src/open_clip/__pycache__/tokenizer.cpython-313.pyc
ADDED
|
Binary file (28.6 kB). View file
|
|
|
src/open_clip/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (44.1 kB). View file
|
|
|
src/open_clip/__pycache__/transformer.cpython-313.pyc
ADDED
|
Binary file (79.7 kB). View file
|
|
|
src/open_clip/biosignals_coca_model.py
ADDED
|
@@ -0,0 +1,1807 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Biosignals-Text CoCa Model
|
| 3 |
+
|
| 4 |
+
Adapted from the original CoCa model to work with biosignals (time series) data
|
| 5 |
+
instead of images. This model is designed for biosignals-text contrastive learning.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, List, Optional, Union, Tuple
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
import numpy as np
|
| 13 |
+
import math
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
|
| 16 |
+
from .transformer import (
|
| 17 |
+
LayerNormFp32,
|
| 18 |
+
LayerNorm,
|
| 19 |
+
QuickGELU,
|
| 20 |
+
MultimodalTransformer,
|
| 21 |
+
ConcatMultimodalTransformer,
|
| 22 |
+
)
|
| 23 |
+
from .model import CLIPTextCfg, _build_text_tower
|
| 24 |
+
from .coca_model import MultimodalCfg, _build_text_decoder_tower, _token_to_tensor
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from transformers.generation.beam_search import BeamSearchScorer
|
| 28 |
+
from transformers.generation.logits_process import (
|
| 29 |
+
LogitsProcessorList,
|
| 30 |
+
TopPLogitsWarper,
|
| 31 |
+
TopKLogitsWarper,
|
| 32 |
+
RepetitionPenaltyLogitsProcessor,
|
| 33 |
+
MinLengthLogitsProcessor,
|
| 34 |
+
)
|
| 35 |
+
from transformers.generation.stopping_criteria import (
|
| 36 |
+
MaxLengthCriteria,
|
| 37 |
+
EosTokenCriteria,
|
| 38 |
+
StoppingCriteriaList,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
GENERATION_TYPES = {
|
| 42 |
+
"top_k": TopKLogitsWarper,
|
| 43 |
+
"top_p": TopPLogitsWarper,
|
| 44 |
+
"beam_search": "beam_search"
|
| 45 |
+
}
|
| 46 |
+
_has_transformers = True
|
| 47 |
+
except ImportError as e:
|
| 48 |
+
GENERATION_TYPES = {
|
| 49 |
+
"top_k": None,
|
| 50 |
+
"top_p": None,
|
| 51 |
+
"beam_search": "beam_search"
|
| 52 |
+
}
|
| 53 |
+
_has_transformers = False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ============================================================================
|
| 57 |
+
# Pure Transformer Architecture Components (from PureTransformerMAE)
|
| 58 |
+
# ============================================================================
|
| 59 |
+
|
| 60 |
+
class RotaryEmbedding(nn.Module):
|
| 61 |
+
"""Rotary Position Embedding (RoPE)"""
|
| 62 |
+
def __init__(self, dim: int, theta: float = 10000.0, learned_freq: bool = False):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.dim = dim
|
| 65 |
+
self.theta = theta
|
| 66 |
+
self.learned_freq = learned_freq
|
| 67 |
+
|
| 68 |
+
if learned_freq:
|
| 69 |
+
# Learnable frequencies for channel attention
|
| 70 |
+
self.freqs = nn.Parameter(torch.randn(dim // 2) * 0.02)
|
| 71 |
+
else:
|
| 72 |
+
# Fixed frequencies for temporal attention
|
| 73 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 74 |
+
self.register_buffer('freqs', freqs)
|
| 75 |
+
|
| 76 |
+
def rotate_queries_or_keys(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None):
|
| 77 |
+
"""
|
| 78 |
+
Apply rotary embeddings to queries or keys
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
x: (batch_size, num_heads, seq_len, head_dim)
|
| 82 |
+
position_ids: (seq_len,) or (batch_size, seq_len) - position indices
|
| 83 |
+
Returns:
|
| 84 |
+
Rotated tensor of same shape
|
| 85 |
+
"""
|
| 86 |
+
batch_size, num_heads, seq_len, head_dim = x.shape
|
| 87 |
+
assert head_dim == self.dim, f"head_dim {head_dim} != self.dim {self.dim}"
|
| 88 |
+
|
| 89 |
+
# Generate position indices if not provided
|
| 90 |
+
if position_ids is None:
|
| 91 |
+
position_ids = torch.arange(seq_len, device=x.device, dtype=torch.float)
|
| 92 |
+
elif position_ids.ndim == 2:
|
| 93 |
+
# If 2D, take the first batch (assuming all batches have same pattern)
|
| 94 |
+
position_ids = position_ids[0].float()
|
| 95 |
+
else:
|
| 96 |
+
position_ids = position_ids.float()
|
| 97 |
+
|
| 98 |
+
# Compute angles: position_ids * freqs
|
| 99 |
+
# position_ids: (seq_len,), freqs: (dim // 2,)
|
| 100 |
+
# angles: (seq_len, dim // 2)
|
| 101 |
+
angles = torch.einsum('s,d->sd', position_ids, self.freqs)
|
| 102 |
+
|
| 103 |
+
# Duplicate for cos and sin
|
| 104 |
+
# cos/sin: (seq_len, dim)
|
| 105 |
+
cos = torch.cos(angles).repeat_interleave(2, dim=-1)
|
| 106 |
+
sin = torch.sin(angles).repeat_interleave(2, dim=-1)
|
| 107 |
+
|
| 108 |
+
# Reshape for broadcasting: (1, 1, seq_len, dim)
|
| 109 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 110 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 111 |
+
|
| 112 |
+
# Apply rotation
|
| 113 |
+
# Split x into even and odd dimensions
|
| 114 |
+
x1 = x[..., 0::2] # Even dimensions
|
| 115 |
+
x2 = x[..., 1::2] # Odd dimensions
|
| 116 |
+
|
| 117 |
+
# Apply rotation: [x1, x2] @ [[cos, -sin], [sin, cos]]
|
| 118 |
+
x_rotated = torch.empty_like(x)
|
| 119 |
+
x_rotated[..., 0::2] = x1 * cos[..., 0::2] - x2 * sin[..., 0::2]
|
| 120 |
+
x_rotated[..., 1::2] = x1 * sin[..., 1::2] + x2 * cos[..., 1::2]
|
| 121 |
+
|
| 122 |
+
return x_rotated
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class RMSNorm(nn.Module):
|
| 126 |
+
"""Root Mean Square Layer Normalization"""
|
| 127 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.eps = eps
|
| 130 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 131 |
+
|
| 132 |
+
def _norm(self, x):
|
| 133 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
output = self._norm(x.float()).type_as(x)
|
| 137 |
+
return output * self.weight
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class SwiGLU(nn.Module):
|
| 141 |
+
"""SwiGLU activation function: SiLU(x * W1) * (x * W2)"""
|
| 142 |
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = False):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.w1 = nn.Linear(dim_in, dim_out, bias=bias)
|
| 145 |
+
self.w2 = nn.Linear(dim_in, dim_out, bias=bias)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
return F.silu(self.w1(x)) * self.w2(x)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MLP(nn.Module):
|
| 152 |
+
"""MLP with configurable activation and normalization"""
|
| 153 |
+
def __init__(self,
|
| 154 |
+
dim: int,
|
| 155 |
+
hidden_dim: int,
|
| 156 |
+
dropout: float = 0.0,
|
| 157 |
+
activation: str = "swiglu", # "swiglu", "gelu", "relu"
|
| 158 |
+
bias: bool = False):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.activation = activation
|
| 161 |
+
|
| 162 |
+
if activation == "swiglu":
|
| 163 |
+
# SwiGLU requires different structure: two parallel linear layers
|
| 164 |
+
self.gate_proj = SwiGLU(dim, hidden_dim, bias=bias)
|
| 165 |
+
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
|
| 166 |
+
else:
|
| 167 |
+
# Standard MLP structure
|
| 168 |
+
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
| 169 |
+
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
|
| 170 |
+
|
| 171 |
+
if activation == "gelu":
|
| 172 |
+
self.act_fn = nn.GELU()
|
| 173 |
+
elif activation == "relu":
|
| 174 |
+
self.act_fn = nn.ReLU()
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 177 |
+
|
| 178 |
+
self.dropout = nn.Dropout(dropout)
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
if self.activation == "swiglu":
|
| 182 |
+
x = self.gate_proj(x)
|
| 183 |
+
x = self.dropout(x)
|
| 184 |
+
x = self.down_proj(x)
|
| 185 |
+
else:
|
| 186 |
+
x = self.up_proj(x)
|
| 187 |
+
x = self.act_fn(x)
|
| 188 |
+
x = self.dropout(x)
|
| 189 |
+
x = self.down_proj(x)
|
| 190 |
+
|
| 191 |
+
return self.dropout(x)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class ChannelPatching(nn.Module):
|
| 195 |
+
"""Patching layer that operates independently on each channel"""
|
| 196 |
+
def __init__(self,
|
| 197 |
+
patch_size: int = 32,
|
| 198 |
+
conv_embed_dim: int = 256,
|
| 199 |
+
num_channels: int = 21):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.patch_size = patch_size
|
| 202 |
+
self.conv_embed_dim = conv_embed_dim
|
| 203 |
+
self.num_channels = num_channels
|
| 204 |
+
|
| 205 |
+
# Single conv layer applied to all channels (kernel_size=patch_size, stride=patch_size)
|
| 206 |
+
self.conv_patching = nn.Conv1d(
|
| 207 |
+
in_channels=1,
|
| 208 |
+
out_channels=conv_embed_dim,
|
| 209 |
+
kernel_size=patch_size,
|
| 210 |
+
stride=patch_size,
|
| 211 |
+
padding=0 # No padding for clean non-overlapping patches
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
"""
|
| 216 |
+
Args:
|
| 217 |
+
x: (batch_size, num_channels, signal_length) - multi-channel signal
|
| 218 |
+
Returns:
|
| 219 |
+
(batch_size, num_channels, num_patches, conv_embed_dim) - patched representations
|
| 220 |
+
"""
|
| 221 |
+
batch_size, num_channels, seq_len = x.shape
|
| 222 |
+
|
| 223 |
+
# Reshape to process all channels independently: (batch_size * num_channels, 1, seq_len)
|
| 224 |
+
x_reshaped = x.reshape(batch_size * num_channels, 1, seq_len)
|
| 225 |
+
|
| 226 |
+
# Apply conv patching to all channels
|
| 227 |
+
patched = self.conv_patching(x_reshaped) # (batch_size * num_channels, conv_embed_dim, num_patches)
|
| 228 |
+
|
| 229 |
+
# Reshape back to separate batch and channel dimensions
|
| 230 |
+
_, conv_embed_dim, num_patches = patched.shape
|
| 231 |
+
patched = patched.reshape(batch_size, num_channels, conv_embed_dim, num_patches)
|
| 232 |
+
|
| 233 |
+
# Transpose to get (batch_size, num_channels, num_patches, conv_embed_dim)
|
| 234 |
+
patched = patched.transpose(2, 3)
|
| 235 |
+
|
| 236 |
+
return patched
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class DualRoPEAttention(nn.Module):
|
| 240 |
+
"""Multi-head attention with separate RoPE for temporal and learnable RoPE for channels"""
|
| 241 |
+
def __init__(self,
|
| 242 |
+
embed_dim: int = 256,
|
| 243 |
+
num_heads: int = 8,
|
| 244 |
+
dropout: float = 0.1,
|
| 245 |
+
attention_type: str = "temporal", # "temporal" or "channel"
|
| 246 |
+
num_channels: int = 21,
|
| 247 |
+
shared_channel_rope: Optional[nn.Module] = None):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.embed_dim = embed_dim
|
| 250 |
+
self.num_heads = num_heads
|
| 251 |
+
self.head_dim = embed_dim // num_heads
|
| 252 |
+
self.attention_type = attention_type
|
| 253 |
+
|
| 254 |
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 255 |
+
|
| 256 |
+
# Linear projections
|
| 257 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 258 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 259 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 260 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 261 |
+
|
| 262 |
+
# RoPE embeddings - different for temporal vs channel
|
| 263 |
+
if attention_type == "temporal":
|
| 264 |
+
# Standard RoPE for temporal attention
|
| 265 |
+
self.rotary_emb = RotaryEmbedding(
|
| 266 |
+
dim=self.head_dim,
|
| 267 |
+
theta=10000,
|
| 268 |
+
learned_freq=False
|
| 269 |
+
)
|
| 270 |
+
elif attention_type == "channel":
|
| 271 |
+
# Use shared learnable RoPE for channel attention if provided
|
| 272 |
+
if shared_channel_rope is not None:
|
| 273 |
+
self.rotary_emb = shared_channel_rope
|
| 274 |
+
else:
|
| 275 |
+
# Fallback to creating own RoPE
|
| 276 |
+
self.rotary_emb = RotaryEmbedding(
|
| 277 |
+
dim=self.head_dim,
|
| 278 |
+
theta=10000,
|
| 279 |
+
learned_freq=True # Learnable frequencies for channels
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
raise ValueError(f"Unknown attention_type: {attention_type}")
|
| 283 |
+
|
| 284 |
+
self.dropout = nn.Dropout(dropout)
|
| 285 |
+
self.scale = self.head_dim ** -0.5
|
| 286 |
+
|
| 287 |
+
def forward(self, x, position_ids=None):
|
| 288 |
+
"""
|
| 289 |
+
Args:
|
| 290 |
+
x: (batch_size, seq_len, embed_dim)
|
| 291 |
+
position_ids: (batch_size, seq_len) or (seq_len,) - custom position indices for RoPE
|
| 292 |
+
Returns:
|
| 293 |
+
(batch_size, seq_len, embed_dim)
|
| 294 |
+
"""
|
| 295 |
+
batch_size, seq_len, embed_dim = x.shape
|
| 296 |
+
|
| 297 |
+
# Linear projections
|
| 298 |
+
q = self.q_proj(x)
|
| 299 |
+
k = self.k_proj(x)
|
| 300 |
+
v = self.v_proj(x)
|
| 301 |
+
|
| 302 |
+
# Reshape for multi-head attention
|
| 303 |
+
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 304 |
+
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 305 |
+
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 306 |
+
|
| 307 |
+
# Apply RoPE
|
| 308 |
+
q = self.rotary_emb.rotate_queries_or_keys(q, position_ids=position_ids)
|
| 309 |
+
k = self.rotary_emb.rotate_queries_or_keys(k, position_ids=position_ids)
|
| 310 |
+
|
| 311 |
+
# Scaled dot-product attention
|
| 312 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
| 313 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 314 |
+
attn_weights = self.dropout(attn_weights)
|
| 315 |
+
|
| 316 |
+
# Apply attention to values
|
| 317 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 318 |
+
|
| 319 |
+
# Reshape and project output
|
| 320 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
|
| 321 |
+
output = self.out_proj(attn_output)
|
| 322 |
+
|
| 323 |
+
return output
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class DualTransformerBlock(nn.Module):
|
| 327 |
+
"""Biosignal transformer block with channel and temporal attention using dual RoPE"""
|
| 328 |
+
def __init__(self,
|
| 329 |
+
embed_dim: int = 256,
|
| 330 |
+
num_heads: int = 8,
|
| 331 |
+
num_temporal_layers: int = 2,
|
| 332 |
+
dropout: float = 0.1,
|
| 333 |
+
mlp_ratio: float = 4.0,
|
| 334 |
+
num_channels: int = 21,
|
| 335 |
+
activation: str = "swiglu",
|
| 336 |
+
norm_type: str = "rmsnorm",
|
| 337 |
+
mlp_bias: bool = False,
|
| 338 |
+
shared_channel_rope: Optional[nn.Module] = None):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.embed_dim = embed_dim
|
| 341 |
+
self.num_temporal_layers = num_temporal_layers
|
| 342 |
+
|
| 343 |
+
# Helper function to create normalization layer
|
| 344 |
+
def create_norm(dim):
|
| 345 |
+
if norm_type == "rmsnorm":
|
| 346 |
+
return RMSNorm(dim)
|
| 347 |
+
elif norm_type == "layernorm":
|
| 348 |
+
return nn.LayerNorm(dim)
|
| 349 |
+
else:
|
| 350 |
+
raise ValueError(f"Unknown norm_type: {norm_type}")
|
| 351 |
+
|
| 352 |
+
# Channel-wise attention with shared learnable RoPE
|
| 353 |
+
self.channel_attention = DualRoPEAttention(
|
| 354 |
+
embed_dim, num_heads, dropout,
|
| 355 |
+
attention_type="channel", num_channels=num_channels,
|
| 356 |
+
shared_channel_rope=shared_channel_rope
|
| 357 |
+
)
|
| 358 |
+
self.channel_norm = create_norm(embed_dim)
|
| 359 |
+
|
| 360 |
+
# Temporal attention layers with standard RoPE
|
| 361 |
+
self.temporal_attention_layers = nn.ModuleList([
|
| 362 |
+
DualRoPEAttention(embed_dim, num_heads, dropout, attention_type="temporal")
|
| 363 |
+
for _ in range(num_temporal_layers)
|
| 364 |
+
])
|
| 365 |
+
self.temporal_norms = nn.ModuleList([
|
| 366 |
+
create_norm(embed_dim)
|
| 367 |
+
for _ in range(num_temporal_layers)
|
| 368 |
+
])
|
| 369 |
+
|
| 370 |
+
# MLP layers
|
| 371 |
+
mlp_hidden_dim = int(embed_dim * mlp_ratio)
|
| 372 |
+
self.channel_mlp = MLP(
|
| 373 |
+
dim=embed_dim,
|
| 374 |
+
hidden_dim=mlp_hidden_dim,
|
| 375 |
+
dropout=dropout,
|
| 376 |
+
activation=activation,
|
| 377 |
+
bias=mlp_bias
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
self.temporal_mlps = nn.ModuleList([
|
| 381 |
+
MLP(
|
| 382 |
+
dim=embed_dim,
|
| 383 |
+
hidden_dim=mlp_hidden_dim,
|
| 384 |
+
dropout=dropout,
|
| 385 |
+
activation=activation,
|
| 386 |
+
bias=mlp_bias
|
| 387 |
+
) for _ in range(num_temporal_layers)
|
| 388 |
+
])
|
| 389 |
+
|
| 390 |
+
self.channel_mlp_norm = create_norm(embed_dim)
|
| 391 |
+
self.temporal_mlp_norms = nn.ModuleList([
|
| 392 |
+
create_norm(embed_dim)
|
| 393 |
+
for _ in range(num_temporal_layers)
|
| 394 |
+
])
|
| 395 |
+
|
| 396 |
+
def forward(self, x, temporal_position_ids=None):
|
| 397 |
+
"""
|
| 398 |
+
Args:
|
| 399 |
+
x: (batch_size, num_channels, num_patches, embed_dim)
|
| 400 |
+
temporal_position_ids: (batch_size, num_patches) or (num_patches,) - position indices for temporal RoPE
|
| 401 |
+
Returns:
|
| 402 |
+
(batch_size, num_channels, num_patches, embed_dim)
|
| 403 |
+
"""
|
| 404 |
+
batch_size, num_channels, num_patches, embed_dim = x.shape
|
| 405 |
+
|
| 406 |
+
# 1. Channel-wise attention on each patch independently
|
| 407 |
+
x_for_channel_attn = x.permute(0, 2, 1, 3).contiguous().reshape(batch_size * num_patches, num_channels, embed_dim)
|
| 408 |
+
|
| 409 |
+
# Apply channel attention with learnable RoPE
|
| 410 |
+
channel_attn_out = self.channel_attention(x_for_channel_attn)
|
| 411 |
+
|
| 412 |
+
# Residual connection and layer norm
|
| 413 |
+
x_for_channel_attn = self.channel_norm(x_for_channel_attn + channel_attn_out)
|
| 414 |
+
|
| 415 |
+
# MLP
|
| 416 |
+
channel_mlp_out = self.channel_mlp(x_for_channel_attn)
|
| 417 |
+
x_for_channel_attn = self.channel_mlp_norm(x_for_channel_attn + channel_mlp_out)
|
| 418 |
+
|
| 419 |
+
# Reshape back
|
| 420 |
+
x = x_for_channel_attn.reshape(batch_size, num_patches, num_channels, embed_dim).permute(0, 2, 1, 3)
|
| 421 |
+
|
| 422 |
+
# 2. Temporal attention on patches for each channel
|
| 423 |
+
x_for_temporal_attn = x.reshape(batch_size * num_channels, num_patches, embed_dim)
|
| 424 |
+
|
| 425 |
+
# Prepare temporal position IDs
|
| 426 |
+
if temporal_position_ids is not None:
|
| 427 |
+
if temporal_position_ids.ndim == 2:
|
| 428 |
+
temporal_pos_ids_expanded = temporal_position_ids[0]
|
| 429 |
+
else:
|
| 430 |
+
temporal_pos_ids_expanded = temporal_position_ids
|
| 431 |
+
else:
|
| 432 |
+
temporal_pos_ids_expanded = None
|
| 433 |
+
|
| 434 |
+
# Apply multiple temporal attention layers
|
| 435 |
+
for i in range(self.num_temporal_layers):
|
| 436 |
+
temporal_attn_out = self.temporal_attention_layers[i](x_for_temporal_attn, position_ids=temporal_pos_ids_expanded)
|
| 437 |
+
x_for_temporal_attn = self.temporal_norms[i](x_for_temporal_attn + temporal_attn_out)
|
| 438 |
+
|
| 439 |
+
temporal_mlp_out = self.temporal_mlps[i](x_for_temporal_attn)
|
| 440 |
+
x_for_temporal_attn = self.temporal_mlp_norms[i](x_for_temporal_attn + temporal_mlp_out)
|
| 441 |
+
|
| 442 |
+
# Reshape back
|
| 443 |
+
x = x_for_temporal_attn.reshape(batch_size, num_channels, num_patches, embed_dim)
|
| 444 |
+
|
| 445 |
+
return x
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# ============================================================================
|
| 449 |
+
# End of Pure Transformer Architecture Components
|
| 450 |
+
# ============================================================================
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def _build_signal_tower(
|
| 454 |
+
embed_dim: int,
|
| 455 |
+
signal_cfg,
|
| 456 |
+
output_tokens: bool = False,
|
| 457 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 458 |
+
):
|
| 459 |
+
"""Build a biosignals encoder tower
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
embed_dim: Output embedding dimension
|
| 463 |
+
signal_cfg: BiosignalsCfg or dict with configuration
|
| 464 |
+
output_tokens: Whether to output tokens for multimodal decoder
|
| 465 |
+
cast_dtype: Optional dtype for casting
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
Biosignals encoder (either BiosignalsEncoder or PureTransformerBiosignalsEncoder)
|
| 469 |
+
"""
|
| 470 |
+
if isinstance(signal_cfg, dict):
|
| 471 |
+
signal_cfg = BiosignalsCfg(**signal_cfg)
|
| 472 |
+
|
| 473 |
+
import logging
|
| 474 |
+
architecture = getattr(signal_cfg, 'architecture', 'conv_transformer')
|
| 475 |
+
logging.info(f"Building biosignals encoder with architecture: {architecture}")
|
| 476 |
+
|
| 477 |
+
if architecture == "pure_transformer":
|
| 478 |
+
signal_encoder = PureTransformerBiosignalsEncoder(
|
| 479 |
+
biosignals_cfg=signal_cfg,
|
| 480 |
+
embed_dim=embed_dim,
|
| 481 |
+
output_tokens=output_tokens,
|
| 482 |
+
cast_dtype=cast_dtype
|
| 483 |
+
)
|
| 484 |
+
logging.info(f"Pure Transformer architecture:")
|
| 485 |
+
logging.info(f" Patch size: {signal_cfg.patch_size}")
|
| 486 |
+
logging.info(f" Conv embed dim: {signal_cfg.conv_embed_dim}")
|
| 487 |
+
logging.info(f" Transformer blocks: {signal_cfg.transformer_layers}")
|
| 488 |
+
logging.info(f" Temporal layers per block: {signal_cfg.num_temporal_layers}")
|
| 489 |
+
logging.info(f" Activation: {signal_cfg.activation}")
|
| 490 |
+
logging.info(f" Norm type: {signal_cfg.norm_type}")
|
| 491 |
+
logging.info(f" Share channel RoPE: {signal_cfg.share_channel_rope}")
|
| 492 |
+
elif architecture == "conv_transformer":
|
| 493 |
+
signal_encoder = BiosignalsEncoder(
|
| 494 |
+
biosignals_cfg=signal_cfg,
|
| 495 |
+
embed_dim=embed_dim,
|
| 496 |
+
output_tokens=output_tokens,
|
| 497 |
+
cast_dtype=cast_dtype
|
| 498 |
+
)
|
| 499 |
+
logging.info(f"Conv-Transformer architecture:")
|
| 500 |
+
logging.info(f" Conv layers: {signal_cfg.conv_layers}")
|
| 501 |
+
logging.info(f" Kernel sizes: {signal_cfg.kernel_sizes}")
|
| 502 |
+
logging.info(f" Strides: {signal_cfg.strides}")
|
| 503 |
+
logging.info(f" Transformer layers: {signal_cfg.transformer_layers}")
|
| 504 |
+
else:
|
| 505 |
+
raise ValueError(f"Unknown architecture: {architecture}. Must be 'conv_transformer' or 'pure_transformer'")
|
| 506 |
+
|
| 507 |
+
return signal_encoder
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def _build_text_decoder_tower_v2(
|
| 511 |
+
embed_dim,
|
| 512 |
+
multimodal_cfg,
|
| 513 |
+
quick_gelu: bool = False,
|
| 514 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 515 |
+
decoder_type: str = "cross_attention",
|
| 516 |
+
prefix_len: int = 0,
|
| 517 |
+
):
|
| 518 |
+
"""Build text decoder tower with support for different decoder types.
|
| 519 |
+
|
| 520 |
+
Args:
|
| 521 |
+
embed_dim: Embedding dimension
|
| 522 |
+
multimodal_cfg: MultimodalCfg config
|
| 523 |
+
quick_gelu: Whether to use QuickGELU
|
| 524 |
+
cast_dtype: Optional dtype for casting
|
| 525 |
+
decoder_type: "cross_attention" or "concat"
|
| 526 |
+
- "cross_attention": Uses separate cross-attention layers (default CoCa)
|
| 527 |
+
- "concat": Concatenates image/biosignals and text tokens
|
| 528 |
+
prefix_len: Number of prefix tokens (condition embeddings) prepended to text
|
| 529 |
+
Used to pre-build prefix-causal attention mask
|
| 530 |
+
"""
|
| 531 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
| 532 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 533 |
+
norm_layer = (
|
| 534 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
if decoder_type == "cross_attention":
|
| 538 |
+
decoder = MultimodalTransformer(
|
| 539 |
+
context_length=multimodal_cfg.context_length,
|
| 540 |
+
width=multimodal_cfg.width,
|
| 541 |
+
heads=multimodal_cfg.heads,
|
| 542 |
+
layers=multimodal_cfg.layers,
|
| 543 |
+
mlp_ratio=multimodal_cfg.mlp_ratio,
|
| 544 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
| 545 |
+
output_dim=embed_dim,
|
| 546 |
+
act_layer=act_layer,
|
| 547 |
+
norm_layer=norm_layer,
|
| 548 |
+
prefix_len=prefix_len,
|
| 549 |
+
)
|
| 550 |
+
elif decoder_type == "concat":
|
| 551 |
+
decoder = ConcatMultimodalTransformer(
|
| 552 |
+
context_length=multimodal_cfg.context_length,
|
| 553 |
+
width=multimodal_cfg.width,
|
| 554 |
+
heads=multimodal_cfg.heads,
|
| 555 |
+
layers=multimodal_cfg.layers,
|
| 556 |
+
mlp_ratio=multimodal_cfg.mlp_ratio,
|
| 557 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
| 558 |
+
output_dim=embed_dim,
|
| 559 |
+
act_layer=act_layer,
|
| 560 |
+
norm_layer=norm_layer,
|
| 561 |
+
prefix_len=prefix_len,
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
raise ValueError(f"Unknown decoder_type: {decoder_type}. Must be 'cross_attention' or 'concat'")
|
| 565 |
+
|
| 566 |
+
return decoder
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
@dataclass
|
| 570 |
+
class BiosignalsCfg:
|
| 571 |
+
"""Configuration for biosignals encoder"""
|
| 572 |
+
input_channels: int = 12 # Number of input channels (e.g., 12-lead ECG)
|
| 573 |
+
signal_length: int = 1000 # Length of input time series
|
| 574 |
+
sampling_rate: int = 500 # Sampling rate in Hz
|
| 575 |
+
|
| 576 |
+
# Architecture selection
|
| 577 |
+
architecture: str = "conv_transformer" # "conv_transformer" or "pure_transformer"
|
| 578 |
+
|
| 579 |
+
# Architecture parameters for conv_transformer
|
| 580 |
+
conv_layers: List[int] = None # Conv layer dimensions
|
| 581 |
+
kernel_sizes: List[int] = None # Kernel sizes for conv layers
|
| 582 |
+
strides: List[int] = None # Strides for conv layers
|
| 583 |
+
|
| 584 |
+
# Architecture parameters for pure_transformer
|
| 585 |
+
patch_size: int = 32 # Patch size for pure_transformer
|
| 586 |
+
conv_embed_dim: int = 256 # Conv embedding dimension for pure_transformer
|
| 587 |
+
num_temporal_layers: int = 2 # Number of temporal attention layers per block
|
| 588 |
+
activation: str = "swiglu" # "swiglu", "gelu", "relu" (for pure_transformer)
|
| 589 |
+
norm_type: str = "rmsnorm" # "rmsnorm", "layernorm" (for pure_transformer)
|
| 590 |
+
mlp_bias: bool = False # Whether to use bias in MLP layers (for pure_transformer)
|
| 591 |
+
share_channel_rope: bool = True # Share channel RoPE across blocks (for pure_transformer)
|
| 592 |
+
decoder_tokens: int = 32 # Number of decoder tokens for dual-axis transformer (pure_transformer)
|
| 593 |
+
|
| 594 |
+
# Transformer parameters (shared)
|
| 595 |
+
transformer_layers: int = 6 # Number of transformer layers/blocks
|
| 596 |
+
transformer_width: int = 768 # Transformer width
|
| 597 |
+
transformer_heads: int = 12 # Number of attention heads
|
| 598 |
+
mlp_ratio: float = 4.0 # MLP expansion ratio
|
| 599 |
+
|
| 600 |
+
# Pooling and output
|
| 601 |
+
pool_type: str = 'attn' # 'avg', 'max', 'cls', 'attn'
|
| 602 |
+
dropout: float = 0.1
|
| 603 |
+
|
| 604 |
+
def __post_init__(self):
|
| 605 |
+
if self.architecture == "conv_transformer":
|
| 606 |
+
if self.conv_layers is None:
|
| 607 |
+
# Default conv layers for processing time series
|
| 608 |
+
self.conv_layers = [64, 128, 256, 512]
|
| 609 |
+
if self.kernel_sizes is None:
|
| 610 |
+
# Default kernel sizes
|
| 611 |
+
self.kernel_sizes = [7, 5, 3, 3]
|
| 612 |
+
if self.strides is None:
|
| 613 |
+
# Default strides
|
| 614 |
+
self.strides = [2, 2, 2, 2]
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class BaseBiosignalsEncoder(nn.Module):
|
| 618 |
+
"""
|
| 619 |
+
Base class for biosignals encoders that handles common pooling and projection logic.
|
| 620 |
+
Child classes should implement _encode() to return features before pooling.
|
| 621 |
+
"""
|
| 622 |
+
|
| 623 |
+
def __init__(
|
| 624 |
+
self,
|
| 625 |
+
biosignals_cfg: BiosignalsCfg,
|
| 626 |
+
embed_dim: int,
|
| 627 |
+
output_tokens: bool,
|
| 628 |
+
transformer_width: int,
|
| 629 |
+
cast_dtype: Optional[torch.dtype] = None
|
| 630 |
+
):
|
| 631 |
+
super().__init__()
|
| 632 |
+
self.biosignals_cfg = biosignals_cfg
|
| 633 |
+
self.embed_dim = embed_dim
|
| 634 |
+
self.output_tokens = output_tokens
|
| 635 |
+
self.transformer_width = transformer_width
|
| 636 |
+
self.pool_type = biosignals_cfg.pool_type
|
| 637 |
+
|
| 638 |
+
# Projection to output embedding dimension
|
| 639 |
+
self.proj_to_embed = nn.Linear(transformer_width, embed_dim)
|
| 640 |
+
|
| 641 |
+
# Attention pooling if needed
|
| 642 |
+
if self.pool_type == 'attn':
|
| 643 |
+
self.attn_pool = nn.MultiheadAttention(
|
| 644 |
+
transformer_width,
|
| 645 |
+
biosignals_cfg.transformer_heads,
|
| 646 |
+
batch_first=True
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
def _pool_features(self, x: torch.Tensor, has_cls_token: bool) -> torch.Tensor:
|
| 650 |
+
"""
|
| 651 |
+
Pool features using the configured pooling method.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
x: Features of shape (batch_size, seq_len, width)
|
| 655 |
+
has_cls_token: Whether the sequence includes a CLS token at the last position
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
pooled: Pooled features of shape (batch_size, width)
|
| 659 |
+
"""
|
| 660 |
+
if self.pool_type == 'cls':
|
| 661 |
+
# Use class token (last position)
|
| 662 |
+
pooled = x[:, -1]
|
| 663 |
+
elif self.pool_type == 'avg':
|
| 664 |
+
# Average pooling over sequence
|
| 665 |
+
if has_cls_token:
|
| 666 |
+
pooled = x[:, :-1].mean(dim=1)
|
| 667 |
+
else:
|
| 668 |
+
pooled = x.mean(dim=1)
|
| 669 |
+
elif self.pool_type == 'max':
|
| 670 |
+
# Max pooling over sequence
|
| 671 |
+
if has_cls_token:
|
| 672 |
+
pooled = x[:, :-1].max(dim=1)[0]
|
| 673 |
+
else:
|
| 674 |
+
pooled = x.max(dim=1)[0]
|
| 675 |
+
elif self.pool_type == 'attn':
|
| 676 |
+
# Attention pooling using cls token as query
|
| 677 |
+
query = x[:, -1:] # CLS token as query
|
| 678 |
+
# CLS attends to content tokens
|
| 679 |
+
pooled, _ = self.attn_pool(query, x[:, :-1], x[:, :-1])
|
| 680 |
+
pooled = pooled.squeeze(1)
|
| 681 |
+
else:
|
| 682 |
+
raise ValueError(f"Unknown pool_type: {self.pool_type}")
|
| 683 |
+
|
| 684 |
+
return pooled
|
| 685 |
+
|
| 686 |
+
def _encode(self, biosignals: torch.Tensor) -> Tuple[torch.Tensor, bool]:
|
| 687 |
+
"""
|
| 688 |
+
Encode biosignals to features. Must be implemented by child classes.
|
| 689 |
+
|
| 690 |
+
Args:
|
| 691 |
+
biosignals: Input biosignals tensor
|
| 692 |
+
|
| 693 |
+
Returns:
|
| 694 |
+
features: Encoded features of shape (batch_size, seq_len, transformer_width)
|
| 695 |
+
has_cls_token: Whether the sequence includes a CLS token at the last position
|
| 696 |
+
"""
|
| 697 |
+
raise NotImplementedError("Child classes must implement _encode()")
|
| 698 |
+
|
| 699 |
+
def forward(self, biosignals: torch.Tensor):
|
| 700 |
+
"""
|
| 701 |
+
Forward pass with encoding, pooling, and projection.
|
| 702 |
+
|
| 703 |
+
Args:
|
| 704 |
+
biosignals: Input biosignals tensor
|
| 705 |
+
|
| 706 |
+
Returns:
|
| 707 |
+
embedding: Global embedding (batch_size, embed_dim)
|
| 708 |
+
tokens_for_decoder: Optional tokens for decoder (batch_size, seq_len, transformer_width)
|
| 709 |
+
"""
|
| 710 |
+
# Encode to features
|
| 711 |
+
features, has_cls_token = self._encode(biosignals)
|
| 712 |
+
|
| 713 |
+
# Pool features
|
| 714 |
+
pooled = self._pool_features(features, has_cls_token)
|
| 715 |
+
|
| 716 |
+
# Project to final embedding dimension
|
| 717 |
+
embedding = self.proj_to_embed(pooled)
|
| 718 |
+
|
| 719 |
+
if self.output_tokens:
|
| 720 |
+
# Return tokens for multimodal decoder
|
| 721 |
+
if has_cls_token:
|
| 722 |
+
# Exclude CLS token from tokens for decoder
|
| 723 |
+
tokens_for_decoder = features[:, :-1]
|
| 724 |
+
else:
|
| 725 |
+
tokens_for_decoder = features
|
| 726 |
+
return embedding, tokens_for_decoder
|
| 727 |
+
else:
|
| 728 |
+
return embedding
|
| 729 |
+
|
| 730 |
+
def set_grad_checkpointing(self, enable=True):
|
| 731 |
+
# For compatibility with other models
|
| 732 |
+
pass
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
class Conv1dBlock(nn.Module):
|
| 736 |
+
"""1D Convolutional block with normalization and activation"""
|
| 737 |
+
|
| 738 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 739 |
+
norm_layer=nn.BatchNorm1d, act_layer=nn.ReLU):
|
| 740 |
+
super().__init__()
|
| 741 |
+
self.conv = nn.Conv1d(
|
| 742 |
+
in_channels, out_channels, kernel_size,
|
| 743 |
+
stride=stride, padding=kernel_size//2
|
| 744 |
+
)
|
| 745 |
+
self.norm = norm_layer(out_channels)
|
| 746 |
+
self.act = act_layer()
|
| 747 |
+
self.dropout = nn.Dropout(0.1)
|
| 748 |
+
|
| 749 |
+
def forward(self, x):
|
| 750 |
+
x = self.conv(x)
|
| 751 |
+
x = self.norm(x)
|
| 752 |
+
x = self.act(x)
|
| 753 |
+
x = self.dropout(x)
|
| 754 |
+
return x
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
class BiosignalsEncoder(BaseBiosignalsEncoder):
|
| 758 |
+
"""
|
| 759 |
+
Biosignals encoder that converts time series data to embeddings.
|
| 760 |
+
Uses a combination of 1D convolutions and transformers.
|
| 761 |
+
"""
|
| 762 |
+
|
| 763 |
+
def __init__(
|
| 764 |
+
self,
|
| 765 |
+
biosignals_cfg: BiosignalsCfg,
|
| 766 |
+
embed_dim: int = 512,
|
| 767 |
+
output_tokens: bool = False,
|
| 768 |
+
cast_dtype: Optional[torch.dtype] = None
|
| 769 |
+
):
|
| 770 |
+
# Initialize base class with common pooling/projection logic
|
| 771 |
+
super().__init__(
|
| 772 |
+
biosignals_cfg=biosignals_cfg,
|
| 773 |
+
embed_dim=embed_dim,
|
| 774 |
+
output_tokens=output_tokens,
|
| 775 |
+
transformer_width=biosignals_cfg.transformer_width,
|
| 776 |
+
cast_dtype=cast_dtype
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
# Convolutional feature extraction
|
| 780 |
+
conv_layers = []
|
| 781 |
+
in_channels = biosignals_cfg.input_channels
|
| 782 |
+
|
| 783 |
+
for i, (out_channels, kernel_size, stride) in enumerate(
|
| 784 |
+
zip(biosignals_cfg.conv_layers, biosignals_cfg.kernel_sizes, biosignals_cfg.strides)
|
| 785 |
+
):
|
| 786 |
+
conv_layers.append(
|
| 787 |
+
Conv1dBlock(in_channels, out_channels, kernel_size, stride)
|
| 788 |
+
)
|
| 789 |
+
in_channels = out_channels
|
| 790 |
+
|
| 791 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
| 792 |
+
|
| 793 |
+
# Calculate the length after convolutions with padding - we'll use a dummy forward pass
|
| 794 |
+
# to get the exact dimensions
|
| 795 |
+
with torch.no_grad():
|
| 796 |
+
dummy_input = torch.randn(1, biosignals_cfg.input_channels, biosignals_cfg.signal_length)
|
| 797 |
+
dummy_output = self.conv_layers(dummy_input)
|
| 798 |
+
conv_output_length = dummy_output.shape[2]
|
| 799 |
+
|
| 800 |
+
self.conv_output_length = conv_output_length
|
| 801 |
+
self.conv_output_dim = biosignals_cfg.conv_layers[-1]
|
| 802 |
+
|
| 803 |
+
# Projection to transformer dimension
|
| 804 |
+
self.proj_conv_to_transformer = nn.Linear(
|
| 805 |
+
self.conv_output_dim, biosignals_cfg.transformer_width
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
# Positional embeddings for sequence positions (excluding CLS token)
|
| 809 |
+
# CLS token gets no positional embedding as it represents global context
|
| 810 |
+
self.pos_embed = nn.Parameter(
|
| 811 |
+
torch.randn(1, conv_output_length, biosignals_cfg.transformer_width)
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
# Add a class token for global representation (only used for 'cls' and 'attn' pooling)
|
| 815 |
+
self.cls_token = nn.Parameter(
|
| 816 |
+
torch.randn(1, 1, biosignals_cfg.transformer_width)
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# Transformer layers
|
| 820 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 821 |
+
act_layer = QuickGELU
|
| 822 |
+
|
| 823 |
+
self.transformer_layers = nn.ModuleList([
|
| 824 |
+
TransformerBlock(
|
| 825 |
+
biosignals_cfg.transformer_width,
|
| 826 |
+
biosignals_cfg.transformer_heads,
|
| 827 |
+
biosignals_cfg.mlp_ratio,
|
| 828 |
+
act_layer=act_layer,
|
| 829 |
+
norm_layer=norm_layer,
|
| 830 |
+
dropout=biosignals_cfg.dropout
|
| 831 |
+
)
|
| 832 |
+
for _ in range(biosignals_cfg.transformer_layers)
|
| 833 |
+
])
|
| 834 |
+
|
| 835 |
+
# Final layer norm
|
| 836 |
+
self.ln_final = norm_layer(biosignals_cfg.transformer_width)
|
| 837 |
+
|
| 838 |
+
def _encode(self, biosignals):
|
| 839 |
+
"""
|
| 840 |
+
Encode biosignals to features before pooling.
|
| 841 |
+
|
| 842 |
+
Args:
|
| 843 |
+
biosignals: Tensor of shape (batch_size, channels, signal_length)
|
| 844 |
+
Returns:
|
| 845 |
+
features: Encoded features of shape (batch_size, seq_len, transformer_width)
|
| 846 |
+
has_cls_token: Whether the sequence includes a CLS token at the last position
|
| 847 |
+
"""
|
| 848 |
+
batch_size = biosignals.shape[0]
|
| 849 |
+
|
| 850 |
+
# Apply convolutional layers
|
| 851 |
+
x = self.conv_layers(biosignals) # (batch_size, conv_dim, conv_length)
|
| 852 |
+
|
| 853 |
+
# Transpose to (batch_size, conv_length, conv_dim)
|
| 854 |
+
x = x.transpose(1, 2)
|
| 855 |
+
|
| 856 |
+
# Project to transformer dimension
|
| 857 |
+
x = self.proj_conv_to_transformer(x) # (batch_size, conv_length, transformer_width)
|
| 858 |
+
|
| 859 |
+
# Add positional embeddings
|
| 860 |
+
x = x + self.pos_embed
|
| 861 |
+
|
| 862 |
+
# Add class token only if needed for pooling
|
| 863 |
+
# For consistency with causal text encoder, append CLS token (not prepend)
|
| 864 |
+
if self.pool_type in ['cls', 'attn']:
|
| 865 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 866 |
+
x = torch.cat([x, cls_tokens], dim=1) # (batch_size, conv_length + 1, transformer_width)
|
| 867 |
+
has_cls_token = True
|
| 868 |
+
else:
|
| 869 |
+
has_cls_token = False
|
| 870 |
+
|
| 871 |
+
# Apply transformer layers
|
| 872 |
+
for layer in self.transformer_layers:
|
| 873 |
+
x = layer(x)
|
| 874 |
+
|
| 875 |
+
# Apply final layer norm
|
| 876 |
+
x = self.ln_final(x)
|
| 877 |
+
|
| 878 |
+
return x, has_cls_token
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
class TransformerBlock(nn.Module):
|
| 882 |
+
"""Transformer block with self-attention and MLP"""
|
| 883 |
+
|
| 884 |
+
def __init__(
|
| 885 |
+
self,
|
| 886 |
+
width: int,
|
| 887 |
+
heads: int,
|
| 888 |
+
mlp_ratio: float = 4.0,
|
| 889 |
+
act_layer=QuickGELU,
|
| 890 |
+
norm_layer=LayerNorm,
|
| 891 |
+
dropout: float = 0.1
|
| 892 |
+
):
|
| 893 |
+
super().__init__()
|
| 894 |
+
self.attention = nn.MultiheadAttention(width, heads, dropout=dropout, batch_first=True)
|
| 895 |
+
self.ln_1 = norm_layer(width)
|
| 896 |
+
self.mlp = nn.Sequential(
|
| 897 |
+
nn.Linear(width, int(width * mlp_ratio)),
|
| 898 |
+
act_layer(),
|
| 899 |
+
nn.Dropout(dropout),
|
| 900 |
+
nn.Linear(int(width * mlp_ratio), width),
|
| 901 |
+
nn.Dropout(dropout)
|
| 902 |
+
)
|
| 903 |
+
self.ln_2 = norm_layer(width)
|
| 904 |
+
|
| 905 |
+
def forward(self, x):
|
| 906 |
+
# Self-attention
|
| 907 |
+
attn_out, _ = self.attention(x, x, x)
|
| 908 |
+
x = x + attn_out
|
| 909 |
+
x = self.ln_1(x)
|
| 910 |
+
|
| 911 |
+
# MLP
|
| 912 |
+
mlp_out = self.mlp(x)
|
| 913 |
+
x = x + mlp_out
|
| 914 |
+
x = self.ln_2(x)
|
| 915 |
+
|
| 916 |
+
return x
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
class AttnPooler(nn.Module):
|
| 920 |
+
"""
|
| 921 |
+
CoCa-style attentional pooler.
|
| 922 |
+
A small multi-head attention layer with n_query learned queries (Q),
|
| 923 |
+
and the encoder sequence as both K and V. This lets us:
|
| 924 |
+
- n_query = 1 => global embedding for contrastive loss
|
| 925 |
+
- n_query = N => compressed token set for decoder cross-attention
|
| 926 |
+
Ref: CoCa uses task-specific attentional pooling with nquery=1 for contrastive
|
| 927 |
+
and nquery=256 for generative objectives. [oai_citation:2‡Medium](https://medium.com/%40arithmancylabs/coca-contrastive-captioners-are-image-textfoundation-models-324022377630?utm_source=chatgpt.com)
|
| 928 |
+
"""
|
| 929 |
+
def __init__(self, dim: int, num_heads: int, n_query: int):
|
| 930 |
+
super().__init__()
|
| 931 |
+
self.n_query = n_query
|
| 932 |
+
self.query_tokens = nn.Parameter(torch.randn(1, n_query, dim) * 0.02)
|
| 933 |
+
self.attn = nn.MultiheadAttention(
|
| 934 |
+
embed_dim=dim,
|
| 935 |
+
num_heads=num_heads,
|
| 936 |
+
batch_first=True
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
def forward(self, x_seq: torch.Tensor) -> torch.Tensor:
|
| 940 |
+
"""
|
| 941 |
+
x_seq: (B, L, D)
|
| 942 |
+
returns:
|
| 943 |
+
pooled: (B, n_query, D)
|
| 944 |
+
"""
|
| 945 |
+
B = x_seq.size(0)
|
| 946 |
+
q = self.query_tokens.expand(B, -1, -1) # (B, n_query, D)
|
| 947 |
+
pooled, _ = self.attn(q, x_seq, x_seq) # pooled attends over all tokens
|
| 948 |
+
return pooled # (B, n_query, D)
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
class PureTransformerBiosignalsEncoder(BaseBiosignalsEncoder):
|
| 952 |
+
"""
|
| 953 |
+
Pure Transformer encoder for biosignals with channel+temporal attention.
|
| 954 |
+
|
| 955 |
+
Updated to use CoCa-style task-specific attentional pooling:
|
| 956 |
+
- contrastive_pooler (n_query=1) → 1 global token for contrastive / CLS
|
| 957 |
+
- decoder_pooler (n_query=N_dec) → small set of summary tokens for text decoder
|
| 958 |
+
|
| 959 |
+
We still:
|
| 960 |
+
1. Patch each channel independently
|
| 961 |
+
2. Alternate channel-attn and temporal-attn in DualTransformerBlocks (factorized attention)
|
| 962 |
+
3. Keep (B, C, T, D) internally (cheap attention along channel or time separately)
|
| 963 |
+
4. Flatten to (B, C*T, D) only at the end
|
| 964 |
+
5. Run two poolers:
|
| 965 |
+
- 1-query pooler -> global token
|
| 966 |
+
- multi-query pooler -> decoder tokens
|
| 967 |
+
6. Append the 1-query pooled token to the end of x_seq so BaseBiosignalsEncoder
|
| 968 |
+
can keep using pool_type='cls' or 'attn' the same way.
|
| 969 |
+
7. Save the multi-query pooled tokens so, when output_tokens=True, we can hand
|
| 970 |
+
them to the text decoder instead of the full ~C*T sequence.
|
| 971 |
+
|
| 972 |
+
This mirrors CoCa's "task-specific attentional pooling," where the same encoder
|
| 973 |
+
supports both contrastive global alignment and caption-style generation with
|
| 974 |
+
minimal extra cost. [oai_citation:3‡Medium](https://medium.com/%40arithmancylabs/coca-contrastive-captioners-are-image-textfoundation-models-324022377630?utm_source=chatgpt.com)
|
| 975 |
+
"""
|
| 976 |
+
|
| 977 |
+
def __init__(
|
| 978 |
+
self,
|
| 979 |
+
biosignals_cfg: BiosignalsCfg,
|
| 980 |
+
embed_dim: int = 512,
|
| 981 |
+
output_tokens: bool = False,
|
| 982 |
+
cast_dtype: Optional[torch.dtype] = None
|
| 983 |
+
):
|
| 984 |
+
super().__init__(
|
| 985 |
+
biosignals_cfg=biosignals_cfg,
|
| 986 |
+
embed_dim=embed_dim,
|
| 987 |
+
output_tokens=output_tokens,
|
| 988 |
+
transformer_width=biosignals_cfg.transformer_width,
|
| 989 |
+
cast_dtype=cast_dtype
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
# --- Sanity checks for RoPE dimensions ---
|
| 993 |
+
assert biosignals_cfg.transformer_width % biosignals_cfg.transformer_heads == 0, (
|
| 994 |
+
f"transformer_width ({biosignals_cfg.transformer_width}) must be divisible by "
|
| 995 |
+
f"transformer_heads ({biosignals_cfg.transformer_heads})"
|
| 996 |
+
)
|
| 997 |
+
head_dim = biosignals_cfg.transformer_width // biosignals_cfg.transformer_heads
|
| 998 |
+
assert head_dim % 2 == 0, (
|
| 999 |
+
f"head_dim ({head_dim}) must be even for RoPE. "
|
| 1000 |
+
f"Got transformer_width={biosignals_cfg.transformer_width}, "
|
| 1001 |
+
f"transformer_heads={biosignals_cfg.transformer_heads}"
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
# 1. Channel patching (Conv1d tokenizer per channel)
|
| 1005 |
+
self.patching = ChannelPatching(
|
| 1006 |
+
patch_size=biosignals_cfg.patch_size,
|
| 1007 |
+
conv_embed_dim=biosignals_cfg.conv_embed_dim,
|
| 1008 |
+
num_channels=biosignals_cfg.input_channels
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
# number of temporal patches per channel
|
| 1012 |
+
self.num_patches = biosignals_cfg.signal_length // biosignals_cfg.patch_size
|
| 1013 |
+
|
| 1014 |
+
# 2. Project patch embeddings to transformer_width
|
| 1015 |
+
self.embed_projection = nn.Linear(
|
| 1016 |
+
biosignals_cfg.conv_embed_dim,
|
| 1017 |
+
biosignals_cfg.transformer_width
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
# 2a. Channel ID embedding (categorical channel identity)
|
| 1021 |
+
self.channel_id_embed = nn.Embedding(
|
| 1022 |
+
num_embeddings=biosignals_cfg.input_channels,
|
| 1023 |
+
embedding_dim=biosignals_cfg.transformer_width,
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
# 3. Shared learnable RoPE for channel attention (optional)
|
| 1027 |
+
if biosignals_cfg.share_channel_rope:
|
| 1028 |
+
shared_head_dim = biosignals_cfg.transformer_width // biosignals_cfg.transformer_heads
|
| 1029 |
+
self.shared_channel_rope = RotaryEmbedding(
|
| 1030 |
+
dim=shared_head_dim,
|
| 1031 |
+
theta=10000,
|
| 1032 |
+
learned_freq=True # learnable for channel axis
|
| 1033 |
+
)
|
| 1034 |
+
else:
|
| 1035 |
+
self.shared_channel_rope = None
|
| 1036 |
+
|
| 1037 |
+
# 4. Dual-axis Transformer blocks (channel attention + temporal attention)
|
| 1038 |
+
self.transformer_blocks = nn.ModuleList([
|
| 1039 |
+
DualTransformerBlock(
|
| 1040 |
+
embed_dim=biosignals_cfg.transformer_width,
|
| 1041 |
+
num_heads=biosignals_cfg.transformer_heads,
|
| 1042 |
+
num_temporal_layers=biosignals_cfg.num_temporal_layers,
|
| 1043 |
+
dropout=biosignals_cfg.dropout,
|
| 1044 |
+
mlp_ratio=biosignals_cfg.mlp_ratio,
|
| 1045 |
+
num_channels=biosignals_cfg.input_channels,
|
| 1046 |
+
activation=biosignals_cfg.activation,
|
| 1047 |
+
norm_type=biosignals_cfg.norm_type,
|
| 1048 |
+
mlp_bias=biosignals_cfg.mlp_bias,
|
| 1049 |
+
shared_channel_rope=self.shared_channel_rope if biosignals_cfg.share_channel_rope else None
|
| 1050 |
+
) for _ in range(biosignals_cfg.transformer_layers)
|
| 1051 |
+
])
|
| 1052 |
+
|
| 1053 |
+
# 5. Final norm
|
| 1054 |
+
norm_layer = (
|
| 1055 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 1056 |
+
)
|
| 1057 |
+
if biosignals_cfg.norm_type == "rmsnorm":
|
| 1058 |
+
self.ln_final = RMSNorm(biosignals_cfg.transformer_width)
|
| 1059 |
+
else:
|
| 1060 |
+
self.ln_final = norm_layer(biosignals_cfg.transformer_width)
|
| 1061 |
+
|
| 1062 |
+
# 6. CoCa-style attentional poolers
|
| 1063 |
+
# - contrastive_pooler: n_query = 1 for global CLS token (contrastive head)
|
| 1064 |
+
# - decoder_pooler: n_query = decoder_tokens (e.g. 32) for compressed memory
|
| 1065 |
+
#
|
| 1066 |
+
# We'll add a new config field on BiosignalsCfg: decoder_tokens (int, default 32).
|
| 1067 |
+
n_decoder_tokens = getattr(biosignals_cfg, "decoder_tokens", 32)
|
| 1068 |
+
|
| 1069 |
+
self.contrastive_pooler = AttnPooler(
|
| 1070 |
+
dim=biosignals_cfg.transformer_width,
|
| 1071 |
+
num_heads=biosignals_cfg.transformer_heads,
|
| 1072 |
+
n_query=1
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
self.decoder_pooler = AttnPooler(
|
| 1076 |
+
dim=biosignals_cfg.transformer_width,
|
| 1077 |
+
num_heads=biosignals_cfg.transformer_heads,
|
| 1078 |
+
n_query=n_decoder_tokens
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
def _encode(self, biosignals: torch.Tensor):
|
| 1083 |
+
"""
|
| 1084 |
+
Returns:
|
| 1085 |
+
features: (B, N_dec + 1, D)
|
| 1086 |
+
first N_dec tokens = pooled decoder tokens
|
| 1087 |
+
last token = global pooled token (contrastive CLS)
|
| 1088 |
+
has_cls_token: True
|
| 1089 |
+
"""
|
| 1090 |
+
B = biosignals.shape[0]
|
| 1091 |
+
device = biosignals.device
|
| 1092 |
+
|
| 1093 |
+
# 1. Patch per channel -> (B, C, T, conv_dim)
|
| 1094 |
+
x = self.patching(biosignals)
|
| 1095 |
+
|
| 1096 |
+
# 2. Project to model dim -> (B, C, T, D)
|
| 1097 |
+
x = self.embed_projection(x)
|
| 1098 |
+
|
| 1099 |
+
# 2a. Add channel ID embedding
|
| 1100 |
+
_, C, T, D = x.shape
|
| 1101 |
+
channel_ids = torch.arange(C, device=device) # (C,)
|
| 1102 |
+
channel_bias = self.channel_id_embed(channel_ids) # (C, D)
|
| 1103 |
+
channel_bias = channel_bias.view(1, C, 1, D).expand(B, C, T, D)
|
| 1104 |
+
x = x + channel_bias
|
| 1105 |
+
|
| 1106 |
+
# 3. Temporal RoPE positions
|
| 1107 |
+
pos_ids = torch.arange(self.num_patches, device=device) # (T,)
|
| 1108 |
+
|
| 1109 |
+
# 4. Dual-axis transformer blocks (channel-attn + temporal-attn)
|
| 1110 |
+
for block in self.transformer_blocks:
|
| 1111 |
+
x = block(x, temporal_position_ids=pos_ids) # stays (B, C, T, D)
|
| 1112 |
+
|
| 1113 |
+
# 5. Final norm
|
| 1114 |
+
x = self.ln_final(x) # (B, C, T, D)
|
| 1115 |
+
|
| 1116 |
+
# 6. Flatten channels×time to a sequence for pooling (not for decoder!)
|
| 1117 |
+
x_seq = x.reshape(B, C * T, D) # (B, L, D) with L = C*T
|
| 1118 |
+
|
| 1119 |
+
# 7. Task-specific attentional pooling (CoCa-style)
|
| 1120 |
+
# contrastive_pooler: n_query=1 -> global_token (B,1,D)
|
| 1121 |
+
# decoder_pooler: n_query=Nd -> dec_tokens (B,Nd,D)
|
| 1122 |
+
global_token = self.contrastive_pooler(x_seq) # (B, 1, D)
|
| 1123 |
+
dec_tokens = self.decoder_pooler(x_seq) # (B, N_dec, D)
|
| 1124 |
+
|
| 1125 |
+
# 8. Build final feature sequence:
|
| 1126 |
+
# [decoder tokens..., global token] so that:
|
| 1127 |
+
# - features[:, :-1] = dec_tokens (for decoder cross-attn)
|
| 1128 |
+
# - features[:, -1] = global_token (for contrastive / CLS pooling)
|
| 1129 |
+
features = torch.cat([dec_tokens, global_token], dim=1) # (B, N_dec+1, D)
|
| 1130 |
+
|
| 1131 |
+
has_cls_token = True
|
| 1132 |
+
return features, has_cls_token
|
| 1133 |
+
|
| 1134 |
+
|
| 1135 |
+
class SignalReconstructionDecoder(nn.Module):
|
| 1136 |
+
"""
|
| 1137 |
+
Lightweight transformer decoder for signal reconstruction.
|
| 1138 |
+
Uses 2-3 transformer encoder layers + final MLP to reconstruct biosignals.
|
| 1139 |
+
Note: Uses TransformerEncoder (self-attention only) since we don't need cross-attention.
|
| 1140 |
+
"""
|
| 1141 |
+
|
| 1142 |
+
def __init__(
|
| 1143 |
+
self,
|
| 1144 |
+
input_dim: int = 768,
|
| 1145 |
+
num_layers: int = 2,
|
| 1146 |
+
num_heads: int = 4, # Reduced from 8 for efficiency
|
| 1147 |
+
output_channels: int = 10,
|
| 1148 |
+
output_length: int = 1920,
|
| 1149 |
+
):
|
| 1150 |
+
super().__init__()
|
| 1151 |
+
|
| 1152 |
+
# Transformer encoder layers (self-attention + FFN)
|
| 1153 |
+
# Using 2x feedforward (instead of 4x) for lighter decoder
|
| 1154 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 1155 |
+
d_model=input_dim,
|
| 1156 |
+
nhead=num_heads,
|
| 1157 |
+
dim_feedforward=input_dim * 2, # 1536 for input_dim=768
|
| 1158 |
+
batch_first=True,
|
| 1159 |
+
norm_first=True,
|
| 1160 |
+
)
|
| 1161 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
|
| 1162 |
+
|
| 1163 |
+
# Final MLP to project to signal space
|
| 1164 |
+
# Reduced intermediate dimension for efficiency
|
| 1165 |
+
self.to_signal = nn.Sequential(
|
| 1166 |
+
nn.Linear(input_dim, input_dim // 2),
|
| 1167 |
+
nn.ReLU(),
|
| 1168 |
+
nn.Linear(input_dim // 2, output_channels * output_length),
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
self.output_channels = output_channels
|
| 1172 |
+
self.output_length = output_length
|
| 1173 |
+
|
| 1174 |
+
def forward(self, encoder_features):
|
| 1175 |
+
"""
|
| 1176 |
+
Args:
|
| 1177 |
+
encoder_features: (B, seq_len, input_dim) - unprojected encoder features
|
| 1178 |
+
Returns:
|
| 1179 |
+
reconstructed: (B, output_channels, output_length)
|
| 1180 |
+
"""
|
| 1181 |
+
B = encoder_features.shape[0]
|
| 1182 |
+
|
| 1183 |
+
# Self-attention on encoder features
|
| 1184 |
+
decoded = self.transformer(encoder_features) # (B, seq_len, dim)
|
| 1185 |
+
|
| 1186 |
+
# Global average pooling
|
| 1187 |
+
pooled = decoded.mean(dim=1) # (B, dim)
|
| 1188 |
+
|
| 1189 |
+
# Project to signal space
|
| 1190 |
+
signal_flat = self.to_signal(pooled) # (B, output_channels * output_length)
|
| 1191 |
+
|
| 1192 |
+
# Reshape to signal format
|
| 1193 |
+
signal = signal_flat.reshape(B, self.output_channels, self.output_length)
|
| 1194 |
+
|
| 1195 |
+
return signal
|
| 1196 |
+
|
| 1197 |
+
|
| 1198 |
+
class BiosignalsCoCa(nn.Module):
|
| 1199 |
+
"""
|
| 1200 |
+
CoCa model adapted for biosignals-text contrastive learning.
|
| 1201 |
+
Replaces the vision tower with a biosignals encoder.
|
| 1202 |
+
|
| 1203 |
+
Supports two decoder types:
|
| 1204 |
+
- "cross_attention": Separate cross-attention between text and biosignals (default CoCa)
|
| 1205 |
+
- "concat": Concatenate biosignals and text tokens with prefix-causal masking
|
| 1206 |
+
"""
|
| 1207 |
+
|
| 1208 |
+
def __init__(
|
| 1209 |
+
self,
|
| 1210 |
+
embed_dim,
|
| 1211 |
+
multimodal_cfg: MultimodalCfg,
|
| 1212 |
+
text_cfg: CLIPTextCfg,
|
| 1213 |
+
biosignals_cfg: BiosignalsCfg,
|
| 1214 |
+
quick_gelu: bool = False,
|
| 1215 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
| 1216 |
+
init_logit_bias: Optional[float] = None,
|
| 1217 |
+
nonscalar_logit_scale: bool = False,
|
| 1218 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 1219 |
+
pad_id: int = 0,
|
| 1220 |
+
decoder_type: str = "cross_attention",
|
| 1221 |
+
num_caption_channels: int = 12, # Number of channel/modality embeddings (22 for channels, 4 for modalities)
|
| 1222 |
+
prefix_len: int = 0,
|
| 1223 |
+
use_signal_decoder: bool = False, # NEW: Enable signal reconstruction
|
| 1224 |
+
):
|
| 1225 |
+
super().__init__()
|
| 1226 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
| 1227 |
+
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
| 1228 |
+
biosignals_cfg = BiosignalsCfg(**biosignals_cfg) if isinstance(biosignals_cfg, dict) else biosignals_cfg
|
| 1229 |
+
|
| 1230 |
+
self.decoder_type = decoder_type
|
| 1231 |
+
self.num_channels = num_caption_channels
|
| 1232 |
+
self.use_signal_decoder = use_signal_decoder
|
| 1233 |
+
|
| 1234 |
+
# Debug logging for channel configuration
|
| 1235 |
+
import logging
|
| 1236 |
+
logging.info(f"BiosignalsCoCa initialized with num_caption_channels={num_caption_channels}, prefix_len={prefix_len}")
|
| 1237 |
+
if use_signal_decoder:
|
| 1238 |
+
logging.info(f"Signal reconstruction decoder enabled")
|
| 1239 |
+
|
| 1240 |
+
self.text = _build_text_tower(
|
| 1241 |
+
embed_dim=embed_dim,
|
| 1242 |
+
text_cfg=text_cfg,
|
| 1243 |
+
quick_gelu=quick_gelu,
|
| 1244 |
+
cast_dtype=cast_dtype,
|
| 1245 |
+
)
|
| 1246 |
+
|
| 1247 |
+
vocab_size = (
|
| 1248 |
+
self.text.vocab_size # for hf models
|
| 1249 |
+
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
| 1250 |
+
else text_cfg.vocab_size
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
# Replace visual tower with biosignals tower
|
| 1254 |
+
self.biosignals = _build_signal_tower(
|
| 1255 |
+
embed_dim=embed_dim,
|
| 1256 |
+
signal_cfg=biosignals_cfg,
|
| 1257 |
+
output_tokens=True, # Need tokens for multimodal decoder
|
| 1258 |
+
cast_dtype=cast_dtype,
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
self.text_decoder = _build_text_decoder_tower_v2(
|
| 1262 |
+
vocab_size,
|
| 1263 |
+
multimodal_cfg=multimodal_cfg,
|
| 1264 |
+
quick_gelu=quick_gelu,
|
| 1265 |
+
cast_dtype=cast_dtype,
|
| 1266 |
+
decoder_type=decoder_type,
|
| 1267 |
+
prefix_len=prefix_len,
|
| 1268 |
+
)
|
| 1269 |
+
|
| 1270 |
+
lshape = [1] if nonscalar_logit_scale else []
|
| 1271 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
| 1272 |
+
if init_logit_bias is not None:
|
| 1273 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
| 1274 |
+
else:
|
| 1275 |
+
self.logit_bias = None
|
| 1276 |
+
self.pad_id = pad_id
|
| 1277 |
+
|
| 1278 |
+
self.context_length = multimodal_cfg.context_length
|
| 1279 |
+
|
| 1280 |
+
# Learnable channel/modality embeddings
|
| 1281 |
+
# num_caption_channels will be 23 for individual channel mode or 5 for modality mode
|
| 1282 |
+
# Dimension should match the decoder width (multimodal_cfg.width for text decoder input)
|
| 1283 |
+
self.channel_embeddings = nn.Parameter(
|
| 1284 |
+
torch.randn(num_caption_channels, multimodal_cfg.width) * 0.02
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
# Learnable padding embedding for -1 positions
|
| 1288 |
+
# This learns to be "neutral" or ignored during training (similar to [PAD] tokens)
|
| 1289 |
+
self.padding_embedding = nn.Parameter(
|
| 1290 |
+
torch.randn(multimodal_cfg.width) * 0.02
|
| 1291 |
+
)
|
| 1292 |
+
|
| 1293 |
+
self.decoder_width = multimodal_cfg.width
|
| 1294 |
+
|
| 1295 |
+
# Optional signal reconstruction decoder
|
| 1296 |
+
if use_signal_decoder:
|
| 1297 |
+
self.signal_decoder = SignalReconstructionDecoder(
|
| 1298 |
+
input_dim=biosignals_cfg.transformer_width,
|
| 1299 |
+
num_layers=2, # Lightweight: 2 transformer layers
|
| 1300 |
+
num_heads=biosignals_cfg.transformer_heads,
|
| 1301 |
+
output_channels=biosignals_cfg.input_channels,
|
| 1302 |
+
output_length=biosignals_cfg.signal_length,
|
| 1303 |
+
)
|
| 1304 |
+
|
| 1305 |
+
@torch.jit.ignore
|
| 1306 |
+
def set_grad_checkpointing(self, enable: bool = True):
|
| 1307 |
+
self.biosignals.set_grad_checkpointing(enable)
|
| 1308 |
+
self.text.set_grad_checkpointing(enable)
|
| 1309 |
+
self.text_decoder.set_grad_checkpointing(enable)
|
| 1310 |
+
|
| 1311 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
| 1312 |
+
"""Lock the text encoder, optionally leaving the last N layers unlocked.
|
| 1313 |
+
|
| 1314 |
+
Args:
|
| 1315 |
+
unlocked_layers: Number of layers to leave unlocked (from the end)
|
| 1316 |
+
freeze_layer_norm: Whether to freeze LayerNorm parameters in locked layers
|
| 1317 |
+
"""
|
| 1318 |
+
if hasattr(self.text, 'lock'):
|
| 1319 |
+
# For HFTextEncoder (Pythia, etc.)
|
| 1320 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
| 1321 |
+
|
| 1322 |
+
# IMPORTANT: Unfreeze newly added token embeddings (e.g., <pad>, <coca_cls>)
|
| 1323 |
+
# These were randomly initialized and need to be trained
|
| 1324 |
+
if hasattr(self.text, 'original_vocab_size'):
|
| 1325 |
+
import logging
|
| 1326 |
+
embedding_module = self.text.transformer.get_input_embeddings()
|
| 1327 |
+
original_size = self.text.original_vocab_size
|
| 1328 |
+
current_size = embedding_module.weight.shape[0]
|
| 1329 |
+
|
| 1330 |
+
if current_size > original_size:
|
| 1331 |
+
# Enable gradients for the embedding layer
|
| 1332 |
+
embedding_module.weight.requires_grad = True
|
| 1333 |
+
|
| 1334 |
+
# Store metadata for optimizer configuration (zero weight decay)
|
| 1335 |
+
self.text._new_token_start_idx = original_size
|
| 1336 |
+
|
| 1337 |
+
# Get actual embedding size (may be padded for Tensor Cores)
|
| 1338 |
+
actual_embedding_size = embedding_module.weight.shape[0]
|
| 1339 |
+
new_vocab_size = self.text.vocab_size # Actual number of tokens (not padded)
|
| 1340 |
+
|
| 1341 |
+
# Register parameter-level hook to mask frozen token gradients
|
| 1342 |
+
# IMPORTANT: This is registered BEFORE DDP wrapping to ensure it persists
|
| 1343 |
+
def _zero_grad_frozen_tokens(grad):
|
| 1344 |
+
"""Zero out gradients for old (frozen) tokens and padding, keep only new tokens."""
|
| 1345 |
+
if grad is not None:
|
| 1346 |
+
# Zero out pretrained tokens [0:original_size]
|
| 1347 |
+
grad[:original_size] = 0
|
| 1348 |
+
# Zero out padding tokens [new_vocab_size:actual_embedding_size]
|
| 1349 |
+
if actual_embedding_size > new_vocab_size:
|
| 1350 |
+
grad[new_vocab_size:] = 0
|
| 1351 |
+
return grad
|
| 1352 |
+
|
| 1353 |
+
embedding_module.weight.register_hook(_zero_grad_frozen_tokens)
|
| 1354 |
+
|
| 1355 |
+
num_new_tokens = new_vocab_size - original_size
|
| 1356 |
+
num_padding_tokens = actual_embedding_size - new_vocab_size
|
| 1357 |
+
logging.info(f"Embedding layer configuration:")
|
| 1358 |
+
logging.info(f" Trainable new tokens: {num_new_tokens} (indices {original_size}:{new_vocab_size})")
|
| 1359 |
+
logging.info(f" Frozen pretrained tokens: {original_size} (indices 0:{original_size})")
|
| 1360 |
+
if num_padding_tokens > 0:
|
| 1361 |
+
logging.info(f" Frozen padding tokens: {num_padding_tokens} (indices {new_vocab_size}:{actual_embedding_size})")
|
| 1362 |
+
logging.info(f" Total embedding size: {actual_embedding_size}")
|
| 1363 |
+
logging.info(f"Registered gradient masking hook before DDP wrapping")
|
| 1364 |
+
logging.info(f"NOTE: Optimizer uses weight_decay=0 for embedding layer")
|
| 1365 |
+
else:
|
| 1366 |
+
# For standard TextTransformer
|
| 1367 |
+
assert False, "BiosignalsCoCa does not support locking standard TextTransformer"
|
| 1368 |
+
from .transformer import lock_text_tower
|
| 1369 |
+
lock_text_tower(self, unlocked_layers)
|
| 1370 |
+
|
| 1371 |
+
def _encode_biosignals(self, biosignals, normalize: bool = True):
|
| 1372 |
+
biosignals_latent, tokens_embs = self.biosignals(biosignals)
|
| 1373 |
+
biosignals_latent = F.normalize(biosignals_latent, dim=-1) if normalize else biosignals_latent
|
| 1374 |
+
return biosignals_latent, tokens_embs
|
| 1375 |
+
|
| 1376 |
+
def _encode_text(self, text, normalize: bool = True):
|
| 1377 |
+
text_latent, token_emb = self.text(text)
|
| 1378 |
+
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
| 1379 |
+
return text_latent, token_emb
|
| 1380 |
+
|
| 1381 |
+
def encode_image(self, biosignals, normalize: bool = True):
|
| 1382 |
+
biosignals_latent, _ = self._encode_biosignals(biosignals, normalize=normalize)
|
| 1383 |
+
return biosignals_latent
|
| 1384 |
+
|
| 1385 |
+
def encode_text(self, text, normalize: bool = True):
|
| 1386 |
+
text_latent, _ = self._encode_text(text, normalize=normalize)
|
| 1387 |
+
return text_latent
|
| 1388 |
+
|
| 1389 |
+
def _get_channel_condition_embs(self, channel_indices: torch.Tensor) -> torch.Tensor:
|
| 1390 |
+
"""Convert channel/modality indices to embeddings with learnable padding.
|
| 1391 |
+
|
| 1392 |
+
Args:
|
| 1393 |
+
channel_indices: (batch_size, prefix_len) tensor of indices
|
| 1394 |
+
- Individual mode: indices into 23 channel embeddings (22 channels + 1 stage_event)
|
| 1395 |
+
- Modality mode: indices into 5 modality embeddings (4 modalities + 1 stage_event)
|
| 1396 |
+
- Padded with -1 for variable length (uses learnable padding_embedding for -1)
|
| 1397 |
+
|
| 1398 |
+
Returns:
|
| 1399 |
+
condition_embs: (batch_size, prefix_len, decoder_width)
|
| 1400 |
+
Embeddings for all positions. -1 positions use learnable padding_embedding
|
| 1401 |
+
that learns to be neutral/ignored during training.
|
| 1402 |
+
"""
|
| 1403 |
+
batch_size, prefix_len = channel_indices.shape
|
| 1404 |
+
|
| 1405 |
+
# Create output tensor
|
| 1406 |
+
condition_embs = torch.zeros(batch_size, prefix_len, self.decoder_width,
|
| 1407 |
+
dtype=self.channel_embeddings.dtype,
|
| 1408 |
+
device=self.channel_embeddings.device)
|
| 1409 |
+
|
| 1410 |
+
# Create mask for valid (non-padding) indices
|
| 1411 |
+
valid_mask = channel_indices >= 0 # (batch_size, prefix_len)
|
| 1412 |
+
padding_mask = channel_indices == -1 # (batch_size, prefix_len)
|
| 1413 |
+
|
| 1414 |
+
# Gather channel embeddings for valid indices
|
| 1415 |
+
# Clamp to 0 for safe indexing (will be overwritten by padding where needed)
|
| 1416 |
+
indices_safe = channel_indices.clamp(min=0)
|
| 1417 |
+
|
| 1418 |
+
# Expand embeddings for batching
|
| 1419 |
+
expanded_embeddings = self.channel_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
|
| 1420 |
+
|
| 1421 |
+
# Gather embeddings
|
| 1422 |
+
indices_expanded = indices_safe.unsqueeze(-1).expand(-1, -1, self.decoder_width)
|
| 1423 |
+
gathered_embs = torch.gather(expanded_embeddings, 1, indices_expanded)
|
| 1424 |
+
|
| 1425 |
+
# Fill in valid positions with gathered embeddings
|
| 1426 |
+
condition_embs[valid_mask] = gathered_embs[valid_mask]
|
| 1427 |
+
|
| 1428 |
+
# Fill in padding positions with learnable padding embedding
|
| 1429 |
+
if padding_mask.any():
|
| 1430 |
+
# Broadcast padding_embedding to all padding positions
|
| 1431 |
+
condition_embs[padding_mask] = self.padding_embedding
|
| 1432 |
+
|
| 1433 |
+
return condition_embs
|
| 1434 |
+
|
| 1435 |
+
def forward(
|
| 1436 |
+
self,
|
| 1437 |
+
biosignals,
|
| 1438 |
+
text: Optional[torch.Tensor] = None,
|
| 1439 |
+
biosignals_latent: Optional[torch.Tensor] = None,
|
| 1440 |
+
biosignals_embs: Optional[torch.Tensor] = None,
|
| 1441 |
+
|
| 1442 |
+
channel_indices: Optional[torch.Tensor] = None,
|
| 1443 |
+
output_labels: bool = True,
|
| 1444 |
+
):
|
| 1445 |
+
"""Forward pass for BiosignalsCoCa model.
|
| 1446 |
+
|
| 1447 |
+
Args:
|
| 1448 |
+
biosignals: Input biosignals tensor
|
| 1449 |
+
text: Optional text token ids
|
| 1450 |
+
biosignals_latent: Optional pre-computed biosignals latent features
|
| 1451 |
+
biosignals_embs: Optional pre-computed biosignals token embeddings
|
| 1452 |
+
|
| 1453 |
+
channel_indices: Optional (batch_size, num_selected_channels) tensor of channel indices
|
| 1454 |
+
Used to select channel-specific condition embeddings. If provided, overrides condition_embs.
|
| 1455 |
+
output_labels: Whether to output labels for loss computation
|
| 1456 |
+
"""
|
| 1457 |
+
if biosignals_latent is None or biosignals_embs is None:
|
| 1458 |
+
biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals)
|
| 1459 |
+
|
| 1460 |
+
if text is None:
|
| 1461 |
+
return {"image_features": biosignals_latent, "image_embs": biosignals_embs}
|
| 1462 |
+
|
| 1463 |
+
text_latent, token_embs = self._encode_text(text)
|
| 1464 |
+
|
| 1465 |
+
# FIXME this isn't an ideal solution, would like to improve -RW
|
| 1466 |
+
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
|
| 1467 |
+
if output_labels:
|
| 1468 |
+
# align text_embs and thus logits with labels for teacher-forcing caption loss
|
| 1469 |
+
token_embs = token_embs[:, :-1]
|
| 1470 |
+
|
| 1471 |
+
# Convert channel indices to condition embeddings if provided
|
| 1472 |
+
if channel_indices is not None:
|
| 1473 |
+
condition_embs = self._get_channel_condition_embs(channel_indices)
|
| 1474 |
+
else:
|
| 1475 |
+
condition_embs = None
|
| 1476 |
+
|
| 1477 |
+
logits = self.text_decoder(biosignals_embs, token_embs, condition_embs=condition_embs)
|
| 1478 |
+
out_dict = {
|
| 1479 |
+
"image_features": biosignals_latent,
|
| 1480 |
+
"text_features": text_latent,
|
| 1481 |
+
"logits": logits,
|
| 1482 |
+
"logit_scale": self.logit_scale.exp()
|
| 1483 |
+
}
|
| 1484 |
+
if labels is not None:
|
| 1485 |
+
out_dict["labels"] = labels
|
| 1486 |
+
if self.logit_bias is not None:
|
| 1487 |
+
out_dict["logit_bias"] = self.logit_bias
|
| 1488 |
+
|
| 1489 |
+
# Optional signal reconstruction
|
| 1490 |
+
if self.use_signal_decoder:
|
| 1491 |
+
reconstructed_signal = self.signal_decoder(biosignals_embs)
|
| 1492 |
+
out_dict["reconstructed_signal"] = reconstructed_signal
|
| 1493 |
+
out_dict["original_signal"] = biosignals
|
| 1494 |
+
|
| 1495 |
+
return out_dict
|
| 1496 |
+
|
| 1497 |
+
def generate(
|
| 1498 |
+
self,
|
| 1499 |
+
biosignals,
|
| 1500 |
+
text=None,
|
| 1501 |
+
seq_len=30,
|
| 1502 |
+
max_seq_len=256,
|
| 1503 |
+
temperature=1.,
|
| 1504 |
+
generation_type="beam_search",
|
| 1505 |
+
top_p=0.1,
|
| 1506 |
+
top_k=1,
|
| 1507 |
+
pad_token_id=None,
|
| 1508 |
+
eos_token_id=None,
|
| 1509 |
+
sot_token_id=None,
|
| 1510 |
+
num_beams=6,
|
| 1511 |
+
num_beam_groups=3,
|
| 1512 |
+
min_seq_len=5,
|
| 1513 |
+
stopping_criteria=None,
|
| 1514 |
+
repetition_penalty=1.0,
|
| 1515 |
+
fixed_output_length=False,
|
| 1516 |
+
condition_embs=None,
|
| 1517 |
+
channel_indices=None,
|
| 1518 |
+
):
|
| 1519 |
+
# taking many ideas and components from HuggingFace GenerationMixin
|
| 1520 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
| 1521 |
+
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
| 1522 |
+
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
| 1523 |
+
device = biosignals.device
|
| 1524 |
+
|
| 1525 |
+
# Note: condition_embs parameter is for backward compatibility
|
| 1526 |
+
# We pass channel_indices directly to forward(), which handles the conversion internally
|
| 1527 |
+
|
| 1528 |
+
with torch.no_grad():
|
| 1529 |
+
sot_token_id = _token_to_tensor(sot_token_id, device=device)
|
| 1530 |
+
eos_token_id = _token_to_tensor(eos_token_id, device=device)
|
| 1531 |
+
pad_token_id = pad_token_id
|
| 1532 |
+
logit_processor = LogitsProcessorList(
|
| 1533 |
+
[
|
| 1534 |
+
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
| 1535 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
| 1536 |
+
]
|
| 1537 |
+
)
|
| 1538 |
+
|
| 1539 |
+
if stopping_criteria is None:
|
| 1540 |
+
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
| 1541 |
+
stopping_criteria = StoppingCriteriaList(stopping_criteria)
|
| 1542 |
+
|
| 1543 |
+
if generation_type == "beam_search":
|
| 1544 |
+
output = self._generate_beamsearch(
|
| 1545 |
+
biosignals_inputs=biosignals,
|
| 1546 |
+
pad_token_id=pad_token_id,
|
| 1547 |
+
eos_token_id=eos_token_id,
|
| 1548 |
+
sot_token_id=sot_token_id,
|
| 1549 |
+
num_beams=num_beams,
|
| 1550 |
+
num_beam_groups=num_beam_groups,
|
| 1551 |
+
min_seq_len=min_seq_len,
|
| 1552 |
+
stopping_criteria=stopping_criteria,
|
| 1553 |
+
logit_processor=logit_processor,
|
| 1554 |
+
channel_indices=channel_indices,
|
| 1555 |
+
)
|
| 1556 |
+
if fixed_output_length and output.shape[1] < seq_len:
|
| 1557 |
+
pad_len = seq_len - output.shape[1]
|
| 1558 |
+
return torch.cat((
|
| 1559 |
+
output,
|
| 1560 |
+
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
|
| 1561 |
+
),
|
| 1562 |
+
dim=1
|
| 1563 |
+
)
|
| 1564 |
+
return output
|
| 1565 |
+
|
| 1566 |
+
elif generation_type == "top_p":
|
| 1567 |
+
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
| 1568 |
+
elif generation_type == "top_k":
|
| 1569 |
+
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
| 1570 |
+
else:
|
| 1571 |
+
raise ValueError(
|
| 1572 |
+
f"generation_type has to be one of "
|
| 1573 |
+
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
| 1574 |
+
)
|
| 1575 |
+
|
| 1576 |
+
biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals)
|
| 1577 |
+
|
| 1578 |
+
if text is None:
|
| 1579 |
+
text = torch.ones((biosignals.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
| 1580 |
+
|
| 1581 |
+
was_training = self.training
|
| 1582 |
+
num_dims = len(text.shape)
|
| 1583 |
+
|
| 1584 |
+
if num_dims == 1:
|
| 1585 |
+
text = text[None, :]
|
| 1586 |
+
|
| 1587 |
+
self.eval()
|
| 1588 |
+
out = text
|
| 1589 |
+
|
| 1590 |
+
while True:
|
| 1591 |
+
x = out[:, -max_seq_len:]
|
| 1592 |
+
cur_len = x.shape[1]
|
| 1593 |
+
logits = self(
|
| 1594 |
+
biosignals,
|
| 1595 |
+
x,
|
| 1596 |
+
biosignals_latent=biosignals_latent,
|
| 1597 |
+
biosignals_embs=biosignals_embs,
|
| 1598 |
+
channel_indices=channel_indices,
|
| 1599 |
+
output_labels=False,
|
| 1600 |
+
)["logits"][:, -1]
|
| 1601 |
+
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
| 1602 |
+
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
| 1603 |
+
|
| 1604 |
+
if mask.all():
|
| 1605 |
+
if not fixed_output_length:
|
| 1606 |
+
break
|
| 1607 |
+
else:
|
| 1608 |
+
logits = logits[~mask, :]
|
| 1609 |
+
filtered_logits = logit_processor(x[~mask, :], logits)
|
| 1610 |
+
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
| 1611 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| 1612 |
+
|
| 1613 |
+
if (cur_len + 1 == seq_len):
|
| 1614 |
+
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
| 1615 |
+
else:
|
| 1616 |
+
sample[~mask, :] = torch.multinomial(probs, 1)
|
| 1617 |
+
|
| 1618 |
+
out = torch.cat((out, sample), dim=-1)
|
| 1619 |
+
|
| 1620 |
+
cur_len += 1
|
| 1621 |
+
|
| 1622 |
+
if all(stopping_criteria(out, None)):
|
| 1623 |
+
break
|
| 1624 |
+
|
| 1625 |
+
if num_dims == 1:
|
| 1626 |
+
out = out.squeeze(0)
|
| 1627 |
+
|
| 1628 |
+
self.train(was_training)
|
| 1629 |
+
return out
|
| 1630 |
+
|
| 1631 |
+
def _generate_beamsearch(
|
| 1632 |
+
self,
|
| 1633 |
+
biosignals_inputs,
|
| 1634 |
+
pad_token_id=None,
|
| 1635 |
+
eos_token_id=None,
|
| 1636 |
+
sot_token_id=None,
|
| 1637 |
+
num_beams=6,
|
| 1638 |
+
num_beam_groups=3,
|
| 1639 |
+
min_seq_len=5,
|
| 1640 |
+
stopping_criteria=None,
|
| 1641 |
+
logit_processor=None,
|
| 1642 |
+
logit_warper=None,
|
| 1643 |
+
channel_indices=None,
|
| 1644 |
+
):
|
| 1645 |
+
device = biosignals_inputs.device
|
| 1646 |
+
batch_size = biosignals_inputs.shape[0]
|
| 1647 |
+
biosignals_inputs = torch.repeat_interleave(biosignals_inputs, num_beams, dim=0)
|
| 1648 |
+
biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals_inputs)
|
| 1649 |
+
|
| 1650 |
+
# Repeat channel indices for beam search if provided
|
| 1651 |
+
# forward() will convert them to condition embeddings internally
|
| 1652 |
+
if channel_indices is not None:
|
| 1653 |
+
channel_indices = torch.repeat_interleave(channel_indices, num_beams, dim=0)
|
| 1654 |
+
|
| 1655 |
+
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
| 1656 |
+
input_ids = input_ids * sot_token_id
|
| 1657 |
+
beam_scorer = BeamSearchScorer(
|
| 1658 |
+
batch_size=batch_size,
|
| 1659 |
+
num_beams=num_beams,
|
| 1660 |
+
device=device,
|
| 1661 |
+
num_beam_groups=num_beam_groups,
|
| 1662 |
+
)
|
| 1663 |
+
# instantiate logits processors
|
| 1664 |
+
logits_processor = (
|
| 1665 |
+
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
| 1666 |
+
if logit_processor is None
|
| 1667 |
+
else logit_processor
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
num_beams = beam_scorer.num_beams
|
| 1671 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
| 1672 |
+
num_sub_beams = num_beams // num_beam_groups
|
| 1673 |
+
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
|
| 1674 |
+
batch_beam_size, cur_len = input_ids.shape
|
| 1675 |
+
beam_indices = None
|
| 1676 |
+
|
| 1677 |
+
if num_beams * batch_size != batch_beam_size:
|
| 1678 |
+
raise ValueError(
|
| 1679 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
| 1680 |
+
)
|
| 1681 |
+
|
| 1682 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
| 1683 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
| 1684 |
+
# the same group don't produce same tokens everytime.
|
| 1685 |
+
beam_scores[:, ::num_sub_beams] = 0
|
| 1686 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
| 1687 |
+
|
| 1688 |
+
while True:
|
| 1689 |
+
|
| 1690 |
+
# predicted tokens in cur_len step
|
| 1691 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
| 1692 |
+
|
| 1693 |
+
# indices which will form the beams in the next time step
|
| 1694 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
| 1695 |
+
|
| 1696 |
+
# do one decoder step on all beams of all sentences in batch
|
| 1697 |
+
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, biosignals_inputs=biosignals_inputs)
|
| 1698 |
+
outputs = self(
|
| 1699 |
+
model_inputs['biosignals'],
|
| 1700 |
+
model_inputs['text'],
|
| 1701 |
+
biosignals_latent=biosignals_latent,
|
| 1702 |
+
biosignals_embs=biosignals_embs,
|
| 1703 |
+
channel_indices=channel_indices,
|
| 1704 |
+
output_labels=False,
|
| 1705 |
+
)
|
| 1706 |
+
|
| 1707 |
+
for beam_group_idx in range(num_beam_groups):
|
| 1708 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
| 1709 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
| 1710 |
+
group_size = group_end_idx - group_start_idx
|
| 1711 |
+
|
| 1712 |
+
# indices of beams of current group among all sentences in batch
|
| 1713 |
+
batch_group_indices = []
|
| 1714 |
+
|
| 1715 |
+
for batch_idx in range(batch_size):
|
| 1716 |
+
batch_group_indices.extend(
|
| 1717 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
| 1718 |
+
)
|
| 1719 |
+
group_input_ids = input_ids[batch_group_indices]
|
| 1720 |
+
|
| 1721 |
+
# select outputs of beams of currentg group only
|
| 1722 |
+
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
| 1723 |
+
vocab_size = next_token_logits.shape[-1]
|
| 1724 |
+
|
| 1725 |
+
next_token_scores_processed = logits_processor(
|
| 1726 |
+
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
| 1727 |
+
)
|
| 1728 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
| 1729 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 1730 |
+
|
| 1731 |
+
# reshape for beam search
|
| 1732 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
| 1733 |
+
|
| 1734 |
+
next_token_scores, next_tokens = torch.topk(
|
| 1735 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
| 1736 |
+
)
|
| 1737 |
+
|
| 1738 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
| 1739 |
+
next_tokens = next_tokens % vocab_size
|
| 1740 |
+
|
| 1741 |
+
# stateless
|
| 1742 |
+
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 1743 |
+
beam_outputs = beam_scorer.process(
|
| 1744 |
+
group_input_ids,
|
| 1745 |
+
next_token_scores,
|
| 1746 |
+
next_tokens,
|
| 1747 |
+
next_indices,
|
| 1748 |
+
pad_token_id=pad_token_id,
|
| 1749 |
+
eos_token_id=eos_token_id,
|
| 1750 |
+
beam_indices=process_beam_indices,
|
| 1751 |
+
group_index=beam_group_idx,
|
| 1752 |
+
)
|
| 1753 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
| 1754 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
| 1755 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
| 1756 |
+
|
| 1757 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
| 1758 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
| 1759 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
| 1760 |
+
|
| 1761 |
+
# (beam_idx // group_size) -> batch_idx
|
| 1762 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
| 1763 |
+
reordering_indices[batch_group_indices] = (
|
| 1764 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
| 1765 |
+
)
|
| 1766 |
+
|
| 1767 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
| 1768 |
+
|
| 1769 |
+
# increase cur_len
|
| 1770 |
+
cur_len = cur_len + 1
|
| 1771 |
+
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
|
| 1772 |
+
break
|
| 1773 |
+
|
| 1774 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 1775 |
+
sequence_outputs = beam_scorer.finalize(
|
| 1776 |
+
input_ids,
|
| 1777 |
+
beam_scores,
|
| 1778 |
+
next_tokens,
|
| 1779 |
+
next_indices,
|
| 1780 |
+
pad_token_id=pad_token_id,
|
| 1781 |
+
eos_token_id=eos_token_id,
|
| 1782 |
+
max_length=stopping_criteria.max_length,
|
| 1783 |
+
beam_indices=final_beam_indices,
|
| 1784 |
+
)
|
| 1785 |
+
return sequence_outputs['sequences']
|
| 1786 |
+
|
| 1787 |
+
|
| 1788 |
+
def prepare_inputs_for_generation(input_ids, biosignals_inputs, past=None, **kwargs):
|
| 1789 |
+
if past:
|
| 1790 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 1791 |
+
|
| 1792 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 1793 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1794 |
+
|
| 1795 |
+
if attention_mask is not None and position_ids is None:
|
| 1796 |
+
# create position_ids on the fly for batch generation
|
| 1797 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1798 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1799 |
+
else:
|
| 1800 |
+
position_ids = None
|
| 1801 |
+
return {
|
| 1802 |
+
"text": input_ids,
|
| 1803 |
+
"biosignals": biosignals_inputs,
|
| 1804 |
+
"past_key_values": past,
|
| 1805 |
+
"position_ids": position_ids,
|
| 1806 |
+
"attention_mask": attention_mask,
|
| 1807 |
+
}
|
src/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
src/open_clip/coca_model.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
from .transformer import (
|
| 10 |
+
LayerNormFp32,
|
| 11 |
+
LayerNorm,
|
| 12 |
+
QuickGELU,
|
| 13 |
+
MultimodalTransformer,
|
| 14 |
+
)
|
| 15 |
+
from .model import CLIPTextCfg, _build_text_tower
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from transformers import (
|
| 19 |
+
BeamSearchScorer,
|
| 20 |
+
LogitsProcessorList,
|
| 21 |
+
TopPLogitsWarper,
|
| 22 |
+
TopKLogitsWarper,
|
| 23 |
+
RepetitionPenaltyLogitsProcessor,
|
| 24 |
+
MinLengthLogitsProcessor,
|
| 25 |
+
MaxLengthCriteria,
|
| 26 |
+
StopStringCriteria,
|
| 27 |
+
EosTokenCriteria,
|
| 28 |
+
StoppingCriteriaList
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
GENERATION_TYPES = {
|
| 32 |
+
"top_k": TopKLogitsWarper,
|
| 33 |
+
"top_p": TopPLogitsWarper,
|
| 34 |
+
"beam_search": "beam_search"
|
| 35 |
+
}
|
| 36 |
+
_has_transformers = True
|
| 37 |
+
except ImportError as e:
|
| 38 |
+
GENERATION_TYPES = {
|
| 39 |
+
"top_k": None,
|
| 40 |
+
"top_p": None,
|
| 41 |
+
"beam_search": "beam_search"
|
| 42 |
+
}
|
| 43 |
+
_has_transformers = False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class MultimodalCfg(CLIPTextCfg):
|
| 48 |
+
mlp_ratio: int = 4
|
| 49 |
+
dim_head: int = 64
|
| 50 |
+
heads: int = 8
|
| 51 |
+
n_queries: int = 256
|
| 52 |
+
attn_pooler_heads: int = 8
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _build_text_decoder_tower(
|
| 56 |
+
embed_dim,
|
| 57 |
+
multimodal_cfg,
|
| 58 |
+
quick_gelu: bool = False,
|
| 59 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 60 |
+
):
|
| 61 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
| 62 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 63 |
+
norm_layer = (
|
| 64 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
decoder = MultimodalTransformer(
|
| 68 |
+
context_length=multimodal_cfg.context_length,
|
| 69 |
+
width=multimodal_cfg.width,
|
| 70 |
+
heads=multimodal_cfg.heads,
|
| 71 |
+
layers=multimodal_cfg.layers,
|
| 72 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
| 73 |
+
output_dim=embed_dim,
|
| 74 |
+
act_layer=act_layer,
|
| 75 |
+
norm_layer=norm_layer,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return decoder
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
|
| 82 |
+
if not isinstance(token_id, torch.Tensor):
|
| 83 |
+
if isinstance(token_id, int):
|
| 84 |
+
token_id = [token_id]
|
| 85 |
+
token_id = torch.tensor(token_id, device=device)
|
| 86 |
+
return token_id
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class CoCa(nn.Module):
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
embed_dim,
|
| 93 |
+
multimodal_cfg: MultimodalCfg,
|
| 94 |
+
text_cfg: CLIPTextCfg,
|
| 95 |
+
vision_cfg=None,
|
| 96 |
+
quick_gelu: bool = False,
|
| 97 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
| 98 |
+
init_logit_bias: Optional[float] = None,
|
| 99 |
+
nonscalar_logit_scale: bool = False,
|
| 100 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 101 |
+
pad_id: int = 0,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
| 105 |
+
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
| 106 |
+
|
| 107 |
+
self.text = _build_text_tower(
|
| 108 |
+
embed_dim=embed_dim,
|
| 109 |
+
text_cfg=text_cfg,
|
| 110 |
+
quick_gelu=quick_gelu,
|
| 111 |
+
cast_dtype=cast_dtype,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
vocab_size = (
|
| 115 |
+
self.text.vocab_size
|
| 116 |
+
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
| 117 |
+
else text_cfg.vocab_size
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if vision_cfg is not None:
|
| 121 |
+
from .model import CLIPVisionCfg, _build_vision_tower
|
| 122 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
| 123 |
+
self.visual = _build_vision_tower(
|
| 124 |
+
embed_dim=embed_dim,
|
| 125 |
+
vision_cfg=vision_cfg,
|
| 126 |
+
quick_gelu=quick_gelu,
|
| 127 |
+
cast_dtype=cast_dtype,
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
self.visual = None
|
| 131 |
+
|
| 132 |
+
self.text_decoder = _build_text_decoder_tower(
|
| 133 |
+
vocab_size,
|
| 134 |
+
multimodal_cfg=multimodal_cfg,
|
| 135 |
+
quick_gelu=quick_gelu,
|
| 136 |
+
cast_dtype=cast_dtype,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
lshape = [1] if nonscalar_logit_scale else []
|
| 140 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
| 141 |
+
if init_logit_bias is not None:
|
| 142 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
| 143 |
+
else:
|
| 144 |
+
self.logit_bias = None
|
| 145 |
+
self.pad_id = pad_id
|
| 146 |
+
|
| 147 |
+
self.context_length = multimodal_cfg.context_length
|
| 148 |
+
|
| 149 |
+
@torch.jit.ignore
|
| 150 |
+
def set_grad_checkpointing(self, enable: bool = True):
|
| 151 |
+
self.visual.set_grad_checkpointing(enable)
|
| 152 |
+
self.text.set_grad_checkpointing(enable)
|
| 153 |
+
self.text_decoder.set_grad_checkpointing(enable)
|
| 154 |
+
|
| 155 |
+
def _encode_image(self, images, normalize: bool = True):
|
| 156 |
+
image_latent, tokens_embs = self.visual(images)
|
| 157 |
+
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
| 158 |
+
return image_latent, tokens_embs
|
| 159 |
+
|
| 160 |
+
def _encode_text(self, text, normalize: bool = True):
|
| 161 |
+
text_latent, token_emb = self.text(text)
|
| 162 |
+
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
| 163 |
+
return text_latent, token_emb
|
| 164 |
+
|
| 165 |
+
def encode_image(self, images, normalize: bool = True):
|
| 166 |
+
image_latent, _ = self._encode_image(images, normalize=normalize)
|
| 167 |
+
return image_latent
|
| 168 |
+
|
| 169 |
+
def encode_text(self, text, normalize: bool = True):
|
| 170 |
+
text_latent, _ = self._encode_text(text, normalize=normalize)
|
| 171 |
+
return text_latent
|
| 172 |
+
|
| 173 |
+
def forward_intermediates(
|
| 174 |
+
self,
|
| 175 |
+
image: Optional[torch.Tensor] = None,
|
| 176 |
+
text: Optional[torch.Tensor] = None,
|
| 177 |
+
image_indices: Optional[Union[int, List[int]]] = None,
|
| 178 |
+
text_indices: Optional[Union[int, List[int]]] = None,
|
| 179 |
+
stop_early: bool = False,
|
| 180 |
+
normalize: bool = True,
|
| 181 |
+
normalize_intermediates: bool = False,
|
| 182 |
+
intermediates_only: bool = False,
|
| 183 |
+
image_output_fmt: str = 'NCHW',
|
| 184 |
+
image_output_extra_tokens: bool = False,
|
| 185 |
+
text_output_fmt: str = 'NLC',
|
| 186 |
+
text_output_extra_tokens: bool = False,
|
| 187 |
+
output_logits: bool = False,
|
| 188 |
+
output_logit_scale_bias: bool = False,
|
| 189 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
| 190 |
+
""" Forward features that returns intermediates.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
image: Input image tensor
|
| 194 |
+
text: Input text tensor
|
| 195 |
+
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
|
| 196 |
+
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 197 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 198 |
+
normalize: L2 Normalize final image and text features (if present)
|
| 199 |
+
normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
|
| 200 |
+
intermediates_only: Only return intermediate features, do not return final features
|
| 201 |
+
image_output_fmt: Shape of intermediate image feature outputs
|
| 202 |
+
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
| 203 |
+
text_output_fmt: Shape of intermediate text feature outputs
|
| 204 |
+
text_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
| 205 |
+
output_logits: Include logits in output
|
| 206 |
+
output_logit_scale_bias: Include the logit scale bias in the output
|
| 207 |
+
Returns:
|
| 208 |
+
|
| 209 |
+
"""
|
| 210 |
+
output = {}
|
| 211 |
+
if intermediates_only:
|
| 212 |
+
# intermediates only disables final feature normalization, and include logits
|
| 213 |
+
normalize = False
|
| 214 |
+
output_logits = False
|
| 215 |
+
if output_logits:
|
| 216 |
+
assert False, 'FIXME, needs implementing'
|
| 217 |
+
|
| 218 |
+
if image is not None:
|
| 219 |
+
image_output = self.visual.forward_intermediates(
|
| 220 |
+
image,
|
| 221 |
+
indices=image_indices,
|
| 222 |
+
stop_early=stop_early,
|
| 223 |
+
normalize_intermediates=normalize_intermediates,
|
| 224 |
+
intermediates_only=intermediates_only,
|
| 225 |
+
output_fmt=image_output_fmt,
|
| 226 |
+
output_extra_tokens=image_output_extra_tokens,
|
| 227 |
+
)
|
| 228 |
+
if normalize and "image_features" in image_output:
|
| 229 |
+
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
|
| 230 |
+
output.update(image_output)
|
| 231 |
+
|
| 232 |
+
if text is not None:
|
| 233 |
+
text_output = self.text.forward_intermediates(
|
| 234 |
+
text,
|
| 235 |
+
indices=text_indices,
|
| 236 |
+
stop_early=stop_early,
|
| 237 |
+
normalize_intermediates=normalize_intermediates,
|
| 238 |
+
intermediates_only=intermediates_only,
|
| 239 |
+
output_fmt=text_output_fmt,
|
| 240 |
+
output_extra_tokens=text_output_extra_tokens,
|
| 241 |
+
)
|
| 242 |
+
if normalize and "text_features" in text_output:
|
| 243 |
+
text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
|
| 244 |
+
output.update(text_output)
|
| 245 |
+
|
| 246 |
+
# FIXME text decoder
|
| 247 |
+
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
|
| 248 |
+
if output_logit_scale_bias:
|
| 249 |
+
output["logit_scale"] = logit_scale_exp
|
| 250 |
+
if self.logit_bias is not None:
|
| 251 |
+
output['logit_bias'] = self.logit_bias
|
| 252 |
+
|
| 253 |
+
return output
|
| 254 |
+
|
| 255 |
+
def forward(
|
| 256 |
+
self,
|
| 257 |
+
image,
|
| 258 |
+
text: Optional[torch.Tensor] = None,
|
| 259 |
+
image_latent: Optional[torch.Tensor] = None,
|
| 260 |
+
image_embs: Optional[torch.Tensor] = None,
|
| 261 |
+
output_labels: bool = True,
|
| 262 |
+
):
|
| 263 |
+
if image_latent is None or image_embs is None:
|
| 264 |
+
image_latent, image_embs = self._encode_image(image)
|
| 265 |
+
|
| 266 |
+
if text is None:
|
| 267 |
+
return {"image_features": image_latent, "image_embs": image_embs}
|
| 268 |
+
|
| 269 |
+
text_latent, token_embs = self._encode_text(text)
|
| 270 |
+
|
| 271 |
+
# FIXME this isn't an ideal solution, would like to improve -RW
|
| 272 |
+
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
|
| 273 |
+
if output_labels:
|
| 274 |
+
# align text_embs and thus logits with labels for teacher-forcing caption loss
|
| 275 |
+
token_embs = token_embs[:, :-1]
|
| 276 |
+
|
| 277 |
+
logits = self.text_decoder(image_embs, token_embs)
|
| 278 |
+
out_dict = {
|
| 279 |
+
"image_features": image_latent,
|
| 280 |
+
"text_features": text_latent,
|
| 281 |
+
"logits": logits,
|
| 282 |
+
"logit_scale": self.logit_scale.exp()
|
| 283 |
+
}
|
| 284 |
+
if labels is not None:
|
| 285 |
+
out_dict["labels"] = labels
|
| 286 |
+
if self.logit_bias is not None:
|
| 287 |
+
out_dict["logit_bias"] = self.logit_bias
|
| 288 |
+
return out_dict
|
| 289 |
+
|
| 290 |
+
def generate(
|
| 291 |
+
self,
|
| 292 |
+
image,
|
| 293 |
+
text=None,
|
| 294 |
+
seq_len=30,
|
| 295 |
+
max_seq_len=77,
|
| 296 |
+
temperature=1.,
|
| 297 |
+
generation_type="beam_search",
|
| 298 |
+
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
| 299 |
+
top_k=1, # keeps the top_k most probable tokens
|
| 300 |
+
pad_token_id=None,
|
| 301 |
+
eos_token_id=None,
|
| 302 |
+
sot_token_id=None,
|
| 303 |
+
num_beams=6,
|
| 304 |
+
num_beam_groups=3,
|
| 305 |
+
min_seq_len=5,
|
| 306 |
+
stopping_criteria=None,
|
| 307 |
+
repetition_penalty=1.0,
|
| 308 |
+
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
| 309 |
+
):
|
| 310 |
+
# taking many ideas and components from HuggingFace GenerationMixin
|
| 311 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
| 312 |
+
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
| 313 |
+
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
| 314 |
+
device = image.device
|
| 315 |
+
|
| 316 |
+
with torch.no_grad():
|
| 317 |
+
sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
|
| 318 |
+
eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
|
| 319 |
+
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
| 320 |
+
logit_processor = LogitsProcessorList(
|
| 321 |
+
[
|
| 322 |
+
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
| 323 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
| 324 |
+
]
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if stopping_criteria is None:
|
| 328 |
+
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
| 329 |
+
stopping_criteria = StoppingCriteriaList(stopping_criteria)
|
| 330 |
+
|
| 331 |
+
if generation_type == "beam_search":
|
| 332 |
+
output = self._generate_beamsearch(
|
| 333 |
+
image_inputs=image,
|
| 334 |
+
pad_token_id=pad_token_id,
|
| 335 |
+
eos_token_id=eos_token_id,
|
| 336 |
+
sot_token_id=sot_token_id,
|
| 337 |
+
num_beams=num_beams,
|
| 338 |
+
num_beam_groups=num_beam_groups,
|
| 339 |
+
min_seq_len=min_seq_len,
|
| 340 |
+
stopping_criteria=stopping_criteria,
|
| 341 |
+
logit_processor=logit_processor,
|
| 342 |
+
)
|
| 343 |
+
if fixed_output_length and output.shape[1] < seq_len:
|
| 344 |
+
pad_len = seq_len - output.shape[1]
|
| 345 |
+
return torch.cat((
|
| 346 |
+
output,
|
| 347 |
+
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
|
| 348 |
+
),
|
| 349 |
+
dim=1
|
| 350 |
+
)
|
| 351 |
+
return output
|
| 352 |
+
|
| 353 |
+
elif generation_type == "top_p":
|
| 354 |
+
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
| 355 |
+
elif generation_type == "top_k":
|
| 356 |
+
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
| 357 |
+
else:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
f"generation_type has to be one of "
|
| 360 |
+
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
image_latent, image_embs = self._encode_image(image)
|
| 364 |
+
|
| 365 |
+
if text is None:
|
| 366 |
+
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
| 367 |
+
|
| 368 |
+
was_training = self.training
|
| 369 |
+
num_dims = len(text.shape)
|
| 370 |
+
|
| 371 |
+
if num_dims == 1:
|
| 372 |
+
text = text[None, :]
|
| 373 |
+
|
| 374 |
+
self.eval()
|
| 375 |
+
out = text
|
| 376 |
+
|
| 377 |
+
while True:
|
| 378 |
+
x = out[:, -max_seq_len:]
|
| 379 |
+
cur_len = x.shape[1]
|
| 380 |
+
logits = self(
|
| 381 |
+
image,
|
| 382 |
+
x,
|
| 383 |
+
image_latent=image_latent,
|
| 384 |
+
image_embs=image_embs,
|
| 385 |
+
output_labels=False,
|
| 386 |
+
)["logits"][:, -1]
|
| 387 |
+
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
| 388 |
+
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
| 389 |
+
|
| 390 |
+
if mask.all():
|
| 391 |
+
if not fixed_output_length:
|
| 392 |
+
break
|
| 393 |
+
else:
|
| 394 |
+
logits = logits[~mask, :]
|
| 395 |
+
filtered_logits = logit_processor(x[~mask, :], logits)
|
| 396 |
+
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
| 397 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| 398 |
+
|
| 399 |
+
if (cur_len + 1 == seq_len):
|
| 400 |
+
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
| 401 |
+
else:
|
| 402 |
+
sample[~mask, :] = torch.multinomial(probs, 1)
|
| 403 |
+
|
| 404 |
+
out = torch.cat((out, sample), dim=-1)
|
| 405 |
+
|
| 406 |
+
cur_len += 1
|
| 407 |
+
|
| 408 |
+
if all(stopping_criteria(out, None)):
|
| 409 |
+
break
|
| 410 |
+
|
| 411 |
+
if num_dims == 1:
|
| 412 |
+
out = out.squeeze(0)
|
| 413 |
+
|
| 414 |
+
self.train(was_training)
|
| 415 |
+
return out
|
| 416 |
+
|
| 417 |
+
def _generate_beamsearch(
|
| 418 |
+
self,
|
| 419 |
+
image_inputs,
|
| 420 |
+
pad_token_id=None,
|
| 421 |
+
eos_token_id=None,
|
| 422 |
+
sot_token_id=None,
|
| 423 |
+
num_beams=6,
|
| 424 |
+
num_beam_groups=3,
|
| 425 |
+
min_seq_len=5,
|
| 426 |
+
stopping_criteria=None,
|
| 427 |
+
logit_processor=None,
|
| 428 |
+
logit_warper=None,
|
| 429 |
+
):
|
| 430 |
+
device = image_inputs.device
|
| 431 |
+
batch_size = image_inputs.shape[0]
|
| 432 |
+
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
| 433 |
+
image_latent, image_embs = self._encode_image(image_inputs)
|
| 434 |
+
|
| 435 |
+
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
| 436 |
+
input_ids = input_ids * sot_token_id
|
| 437 |
+
beam_scorer = BeamSearchScorer(
|
| 438 |
+
batch_size=batch_size,
|
| 439 |
+
num_beams=num_beams,
|
| 440 |
+
device=device,
|
| 441 |
+
num_beam_groups=num_beam_groups,
|
| 442 |
+
)
|
| 443 |
+
# instantiate logits processors
|
| 444 |
+
logits_processor = (
|
| 445 |
+
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
| 446 |
+
if logit_processor is None
|
| 447 |
+
else logit_processor
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
num_beams = beam_scorer.num_beams
|
| 451 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
| 452 |
+
num_sub_beams = num_beams // num_beam_groups
|
| 453 |
+
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
|
| 454 |
+
batch_beam_size, cur_len = input_ids.shape
|
| 455 |
+
beam_indices = None
|
| 456 |
+
|
| 457 |
+
if num_beams * batch_size != batch_beam_size:
|
| 458 |
+
raise ValueError(
|
| 459 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
| 463 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
| 464 |
+
# the same group don't produce same tokens everytime.
|
| 465 |
+
beam_scores[:, ::num_sub_beams] = 0
|
| 466 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
| 467 |
+
|
| 468 |
+
while True:
|
| 469 |
+
|
| 470 |
+
# predicted tokens in cur_len step
|
| 471 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
| 472 |
+
|
| 473 |
+
# indices which will form the beams in the next time step
|
| 474 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
| 475 |
+
|
| 476 |
+
# do one decoder step on all beams of all sentences in batch
|
| 477 |
+
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
| 478 |
+
outputs = self(
|
| 479 |
+
model_inputs['images'],
|
| 480 |
+
model_inputs['text'],
|
| 481 |
+
image_latent=image_latent,
|
| 482 |
+
image_embs=image_embs,
|
| 483 |
+
output_labels=False,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
for beam_group_idx in range(num_beam_groups):
|
| 487 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
| 488 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
| 489 |
+
group_size = group_end_idx - group_start_idx
|
| 490 |
+
|
| 491 |
+
# indices of beams of current group among all sentences in batch
|
| 492 |
+
batch_group_indices = []
|
| 493 |
+
|
| 494 |
+
for batch_idx in range(batch_size):
|
| 495 |
+
batch_group_indices.extend(
|
| 496 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
| 497 |
+
)
|
| 498 |
+
group_input_ids = input_ids[batch_group_indices]
|
| 499 |
+
|
| 500 |
+
# select outputs of beams of currentg group only
|
| 501 |
+
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
| 502 |
+
vocab_size = next_token_logits.shape[-1]
|
| 503 |
+
|
| 504 |
+
next_token_scores_processed = logits_processor(
|
| 505 |
+
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
| 506 |
+
)
|
| 507 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
| 508 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 509 |
+
|
| 510 |
+
# reshape for beam search
|
| 511 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
| 512 |
+
|
| 513 |
+
next_token_scores, next_tokens = torch.topk(
|
| 514 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
| 518 |
+
next_tokens = next_tokens % vocab_size
|
| 519 |
+
|
| 520 |
+
# stateless
|
| 521 |
+
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 522 |
+
beam_outputs = beam_scorer.process(
|
| 523 |
+
group_input_ids,
|
| 524 |
+
next_token_scores,
|
| 525 |
+
next_tokens,
|
| 526 |
+
next_indices,
|
| 527 |
+
pad_token_id=pad_token_id,
|
| 528 |
+
eos_token_id=eos_token_id,
|
| 529 |
+
beam_indices=process_beam_indices,
|
| 530 |
+
group_index=beam_group_idx,
|
| 531 |
+
)
|
| 532 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
| 533 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
| 534 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
| 535 |
+
|
| 536 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
| 537 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
| 538 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
| 539 |
+
|
| 540 |
+
# (beam_idx // group_size) -> batch_idx
|
| 541 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
| 542 |
+
reordering_indices[batch_group_indices] = (
|
| 543 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
| 547 |
+
|
| 548 |
+
# increase cur_len
|
| 549 |
+
cur_len = cur_len + 1
|
| 550 |
+
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
|
| 551 |
+
break
|
| 552 |
+
|
| 553 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 554 |
+
sequence_outputs = beam_scorer.finalize(
|
| 555 |
+
input_ids,
|
| 556 |
+
beam_scores,
|
| 557 |
+
next_tokens,
|
| 558 |
+
next_indices,
|
| 559 |
+
pad_token_id=pad_token_id,
|
| 560 |
+
eos_token_id=eos_token_id,
|
| 561 |
+
max_length=stopping_criteria.max_length,
|
| 562 |
+
beam_indices=final_beam_indices,
|
| 563 |
+
)
|
| 564 |
+
return sequence_outputs['sequences']
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
| 568 |
+
if past:
|
| 569 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 570 |
+
|
| 571 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 572 |
+
position_ids = kwargs.get("position_ids", None)
|
| 573 |
+
|
| 574 |
+
if attention_mask is not None and position_ids is None:
|
| 575 |
+
# create position_ids on the fly for batch generation
|
| 576 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 577 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 578 |
+
else:
|
| 579 |
+
position_ids = None
|
| 580 |
+
return {
|
| 581 |
+
"text": input_ids,
|
| 582 |
+
"images": image_inputs,
|
| 583 |
+
"past_key_values": past,
|
| 584 |
+
"position_ids": position_ids,
|
| 585 |
+
"attention_mask": attention_mask,
|
| 586 |
+
}
|
src/open_clip/factory.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .biosignals_coca_model import BiosignalsCoCa
|
| 10 |
+
from .model import get_cast_dtype, convert_weights_to_lp
|
| 11 |
+
from .tokenizer import SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
|
| 12 |
+
|
| 13 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / "model_configs/"]
|
| 14 |
+
_MODEL_CONFIGS = {}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _rescan_model_configs():
|
| 18 |
+
global _MODEL_CONFIGS
|
| 19 |
+
config_files = []
|
| 20 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
| 21 |
+
if config_path.is_dir():
|
| 22 |
+
config_files.extend(config_path.glob("*.json"))
|
| 23 |
+
for cf in config_files:
|
| 24 |
+
with open(cf, "r") as f:
|
| 25 |
+
model_cfg = json.load(f)
|
| 26 |
+
if all(a in model_cfg for a in ("embed_dim", "biosignals_cfg", "text_cfg")):
|
| 27 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
_rescan_model_configs()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_model_config(model_name: str):
|
| 34 |
+
return deepcopy(_MODEL_CONFIGS.get(model_name))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_model(
|
| 38 |
+
model_name: str,
|
| 39 |
+
precision: str = "fp32",
|
| 40 |
+
device: Union[str, torch.device] = "cpu",
|
| 41 |
+
**model_kwargs,
|
| 42 |
+
) -> BiosignalsCoCa:
|
| 43 |
+
if isinstance(device, str):
|
| 44 |
+
device = torch.device(device)
|
| 45 |
+
|
| 46 |
+
model_cfg = get_model_config(model_name)
|
| 47 |
+
if model_cfg is None:
|
| 48 |
+
raise RuntimeError(f"Model config for '{model_name}' not found. Available: {list(_MODEL_CONFIGS.keys())}")
|
| 49 |
+
|
| 50 |
+
model_cfg.pop("custom_text", None)
|
| 51 |
+
model_cfg.update(model_kwargs)
|
| 52 |
+
|
| 53 |
+
cast_dtype = get_cast_dtype(precision)
|
| 54 |
+
model = BiosignalsCoCa(**model_cfg, cast_dtype=cast_dtype)
|
| 55 |
+
|
| 56 |
+
if precision in ("fp16", "bf16"):
|
| 57 |
+
dtype = torch.float16 if "fp16" in precision else torch.bfloat16
|
| 58 |
+
model.to(device=device)
|
| 59 |
+
convert_weights_to_lp(model, dtype=dtype)
|
| 60 |
+
elif precision in ("pure_fp16", "pure_bf16"):
|
| 61 |
+
dtype = torch.float16 if "fp16" in precision else torch.bfloat16
|
| 62 |
+
model.to(device=device, dtype=dtype)
|
| 63 |
+
else:
|
| 64 |
+
model.to(device=device)
|
| 65 |
+
|
| 66 |
+
model.output_dict = True
|
| 67 |
+
return model
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_checkpoint(model, checkpoint_path: str, device="cpu"):
|
| 71 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 72 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 73 |
+
if next(iter(state_dict)).startswith("module."):
|
| 74 |
+
state_dict = {k[len("module."):]: v for k, v in state_dict.items()}
|
| 75 |
+
incompatible = model.load_state_dict(state_dict, strict=False)
|
| 76 |
+
return incompatible
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_tokenizer(model_name: str = "", context_length: Optional[int] = None, **kwargs):
|
| 80 |
+
config = get_model_config(model_name) or {}
|
| 81 |
+
text_cfg = config.get("text_cfg", {})
|
| 82 |
+
if context_length is None:
|
| 83 |
+
context_length = text_cfg.get("context_length", DEFAULT_CONTEXT_LENGTH)
|
| 84 |
+
return SimpleTokenizer(context_length=context_length, **kwargs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_input_dtype(precision: str):
|
| 88 |
+
input_dtype = None
|
| 89 |
+
if precision in ("bf16", "pure_bf16"):
|
| 90 |
+
input_dtype = torch.bfloat16
|
| 91 |
+
elif precision in ("fp16", "pure_fp16"):
|
| 92 |
+
input_dtype = torch.float16
|
| 93 |
+
return input_dtype
|
src/open_clip/model.py
ADDED
|
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" CLIP Model
|
| 2 |
+
|
| 3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
| 4 |
+
"""
|
| 5 |
+
import copy
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.utils.checkpoint import checkpoint
|
| 16 |
+
from functools import partial
|
| 17 |
+
|
| 18 |
+
from .transformer import (
|
| 19 |
+
LayerNormFp32,
|
| 20 |
+
LayerNorm,
|
| 21 |
+
QuickGELU,
|
| 22 |
+
Attention,
|
| 23 |
+
VisionTransformer,
|
| 24 |
+
TextTransformer,
|
| 25 |
+
text_global_pool,
|
| 26 |
+
lock_text_tower,
|
| 27 |
+
to_2tuple,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class CLIPVisionCfg:
|
| 33 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
| 34 |
+
width: int = 768
|
| 35 |
+
head_width: int = 64
|
| 36 |
+
mlp_ratio: float = 4.0
|
| 37 |
+
patch_size: int = 16
|
| 38 |
+
image_size: Union[Tuple[int, int], int] = 224
|
| 39 |
+
|
| 40 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
| 41 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
| 42 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
|
| 43 |
+
attn_pooler_queries: int = 256 # n_queries for attentional pooler
|
| 44 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
| 45 |
+
no_ln_pre: bool = False # disable pre transformer LayerNorm
|
| 46 |
+
pos_embed_type: str = 'learnable'
|
| 47 |
+
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
|
| 48 |
+
pool_type: str = 'tok'
|
| 49 |
+
output_tokens: bool = False
|
| 50 |
+
act_kwargs: Optional[dict] = None
|
| 51 |
+
norm_kwargs: Optional[dict] = None
|
| 52 |
+
|
| 53 |
+
# Custom attention block settings
|
| 54 |
+
block_type: Optional[str] = None # attention block type ('default', 'custom'), auto-selects 'custom' if any below features enabled
|
| 55 |
+
qk_norm: bool = False # apply layer norm to q and k in attention
|
| 56 |
+
scaled_cosine_attn: bool = False # use scaled cosine attention
|
| 57 |
+
scale_heads: bool = False # learnable head-specific scale applied to attention logits
|
| 58 |
+
scale_attn_inner: bool = False # apply layer norm on attention context, before output projection
|
| 59 |
+
scale_attn: bool = False # apply layer norm after full attention block
|
| 60 |
+
scale_fc: bool = False # apply layer norm in MLP block
|
| 61 |
+
|
| 62 |
+
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
|
| 63 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
| 64 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
| 65 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
| 66 |
+
timm_proj_bias: bool = False # enable bias final projection
|
| 67 |
+
timm_drop: float = 0. # head dropout
|
| 68 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class CLIPTextCfg:
|
| 73 |
+
context_length: int = 77
|
| 74 |
+
vocab_size: int = 49408
|
| 75 |
+
hf_tokenizer_name: Optional[str] = None
|
| 76 |
+
tokenizer_mode: Optional[str] = None
|
| 77 |
+
tokenizer_kwargs: Optional[dict] = None
|
| 78 |
+
|
| 79 |
+
width: int = 512
|
| 80 |
+
heads: int = 8
|
| 81 |
+
layers: int = 12
|
| 82 |
+
mlp_ratio: float = 4.0
|
| 83 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
| 84 |
+
embed_cls: bool = False
|
| 85 |
+
pad_id: int = 0
|
| 86 |
+
eos_id: int = 2 # only used for when pool_type == 'eos', must match tokenizer eos
|
| 87 |
+
no_causal_mask: bool = False # disable causal masking
|
| 88 |
+
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
|
| 89 |
+
pool_type: str = 'argmax'
|
| 90 |
+
proj_bias: bool = False
|
| 91 |
+
proj_type: str = 'linear' # control final text projection, 'none' forces no projection
|
| 92 |
+
output_tokens: bool = False
|
| 93 |
+
act_kwargs: dict = None
|
| 94 |
+
norm_kwargs: dict = None
|
| 95 |
+
|
| 96 |
+
# Custom attention block settings
|
| 97 |
+
block_type: Optional[str] = None # attention block type ('default', 'custom'), auto-selects 'custom' if any custom features enabled
|
| 98 |
+
qk_norm: bool = False # apply layer norm to q and k in attention
|
| 99 |
+
scaled_cosine_attn: bool = False # use scaled cosine attention
|
| 100 |
+
scale_heads: bool = False # learnable head-specific scale applied to attention logits
|
| 101 |
+
scale_attn_inner: bool = False # apply layer norm on attention context, before output projection
|
| 102 |
+
scale_attn: bool = False # apply layer norm after full attention block
|
| 103 |
+
scale_fc: bool = False # apply layer norm in MLP block
|
| 104 |
+
|
| 105 |
+
# HuggingFace specific text tower config
|
| 106 |
+
hf_model_name: Optional[str] = None
|
| 107 |
+
hf_model_pretrained: bool = True
|
| 108 |
+
hf_proj_type: str = 'mlp'
|
| 109 |
+
hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
|
| 110 |
+
special_tokens_to_add: Optional[dict] = None # special tokens to add to tokenizer (e.g., for Pythia)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_cast_dtype(precision: str):
|
| 114 |
+
cast_dtype = None
|
| 115 |
+
if precision == 'bf16':
|
| 116 |
+
cast_dtype = torch.bfloat16
|
| 117 |
+
elif precision == 'fp16':
|
| 118 |
+
cast_dtype = torch.float16
|
| 119 |
+
return cast_dtype
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_input_dtype(precision: str):
|
| 123 |
+
input_dtype = None
|
| 124 |
+
if precision in ('bf16', 'pure_bf16'):
|
| 125 |
+
input_dtype = torch.bfloat16
|
| 126 |
+
elif precision in ('fp16', 'pure_fp16'):
|
| 127 |
+
input_dtype = torch.float16
|
| 128 |
+
return input_dtype
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _build_vision_tower(
|
| 132 |
+
embed_dim: int,
|
| 133 |
+
vision_cfg: CLIPVisionCfg,
|
| 134 |
+
quick_gelu: bool = False,
|
| 135 |
+
cast_dtype: Optional[torch.dtype] = None
|
| 136 |
+
):
|
| 137 |
+
if isinstance(vision_cfg, dict):
|
| 138 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
| 139 |
+
|
| 140 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
| 141 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
| 142 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
| 143 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 144 |
+
|
| 145 |
+
if vision_cfg.timm_model_name:
|
| 146 |
+
from .timm_model import TimmModel
|
| 147 |
+
visual = TimmModel(
|
| 148 |
+
vision_cfg.timm_model_name,
|
| 149 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
| 150 |
+
pool=vision_cfg.timm_pool,
|
| 151 |
+
proj=vision_cfg.timm_proj,
|
| 152 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
| 153 |
+
drop=vision_cfg.timm_drop,
|
| 154 |
+
drop_path=vision_cfg.timm_drop_path,
|
| 155 |
+
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
|
| 156 |
+
embed_dim=embed_dim,
|
| 157 |
+
image_size=vision_cfg.image_size,
|
| 158 |
+
)
|
| 159 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
| 160 |
+
from .modified_resnet import ModifiedResNet
|
| 161 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
| 162 |
+
visual = ModifiedResNet(
|
| 163 |
+
layers=vision_cfg.layers,
|
| 164 |
+
output_dim=embed_dim,
|
| 165 |
+
heads=vision_heads,
|
| 166 |
+
image_size=vision_cfg.image_size,
|
| 167 |
+
width=vision_cfg.width,
|
| 168 |
+
)
|
| 169 |
+
else:
|
| 170 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
| 171 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 172 |
+
if vision_cfg.norm_kwargs:
|
| 173 |
+
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
|
| 174 |
+
if vision_cfg.act_kwargs is not None:
|
| 175 |
+
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
|
| 176 |
+
|
| 177 |
+
visual = VisionTransformer(
|
| 178 |
+
image_size=vision_cfg.image_size,
|
| 179 |
+
patch_size=vision_cfg.patch_size,
|
| 180 |
+
width=vision_cfg.width,
|
| 181 |
+
layers=vision_cfg.layers,
|
| 182 |
+
heads=vision_heads,
|
| 183 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
| 184 |
+
ls_init_value=vision_cfg.ls_init_value,
|
| 185 |
+
patch_dropout=vision_cfg.patch_dropout,
|
| 186 |
+
attentional_pool=vision_cfg.attentional_pool,
|
| 187 |
+
attn_pooler_queries=vision_cfg.attn_pooler_queries,
|
| 188 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
| 189 |
+
pos_embed_type=vision_cfg.pos_embed_type,
|
| 190 |
+
no_ln_pre=vision_cfg.no_ln_pre,
|
| 191 |
+
final_ln_after_pool=vision_cfg.final_ln_after_pool,
|
| 192 |
+
pool_type=vision_cfg.pool_type,
|
| 193 |
+
output_tokens=vision_cfg.output_tokens,
|
| 194 |
+
output_dim=embed_dim,
|
| 195 |
+
act_layer=act_layer,
|
| 196 |
+
norm_layer=norm_layer,
|
| 197 |
+
block_type=vision_cfg.block_type,
|
| 198 |
+
qk_norm=vision_cfg.qk_norm,
|
| 199 |
+
scaled_cosine_attn=vision_cfg.scaled_cosine_attn,
|
| 200 |
+
scale_heads=vision_cfg.scale_heads,
|
| 201 |
+
scale_attn_inner=vision_cfg.scale_attn_inner,
|
| 202 |
+
scale_attn=vision_cfg.scale_attn,
|
| 203 |
+
scale_fc=vision_cfg.scale_fc,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
return visual
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _build_text_tower(
|
| 213 |
+
embed_dim: int,
|
| 214 |
+
text_cfg: CLIPTextCfg,
|
| 215 |
+
quick_gelu: bool = False,
|
| 216 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 217 |
+
):
|
| 218 |
+
if isinstance(text_cfg, dict):
|
| 219 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
| 220 |
+
|
| 221 |
+
if text_cfg.hf_model_name:
|
| 222 |
+
from .hf_model import HFTextEncoder
|
| 223 |
+
text = HFTextEncoder(
|
| 224 |
+
text_cfg.hf_model_name,
|
| 225 |
+
output_dim=embed_dim,
|
| 226 |
+
proj_type=text_cfg.hf_proj_type,
|
| 227 |
+
pooler_type=text_cfg.hf_pooler_type,
|
| 228 |
+
pretrained=text_cfg.hf_model_pretrained,
|
| 229 |
+
output_tokens=text_cfg.output_tokens,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Handle special tokens if configured (e.g., for Pythia)
|
| 233 |
+
special_tokens_cfg = getattr(text_cfg, 'special_tokens_to_add', None)
|
| 234 |
+
if special_tokens_cfg:
|
| 235 |
+
from transformers import AutoTokenizer
|
| 236 |
+
import logging
|
| 237 |
+
|
| 238 |
+
# Load tokenizer from local cache only (ensures consistency with get_tokenizer())
|
| 239 |
+
# get_tokenizer() is called first and downloads/caches, we just reuse that exact version
|
| 240 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 241 |
+
text_cfg.hf_model_name,
|
| 242 |
+
local_files_only=True
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Store original vocab size before adding new tokens
|
| 246 |
+
# This is needed to unfreeze new token embeddings after locking
|
| 247 |
+
original_vocab_size = len(tokenizer)
|
| 248 |
+
text.original_vocab_size = original_vocab_size
|
| 249 |
+
|
| 250 |
+
tokenizer.add_special_tokens(special_tokens_cfg)
|
| 251 |
+
|
| 252 |
+
# Resize model embeddings to accommodate new tokens
|
| 253 |
+
# pad_to_multiple_of=64 ensures optimal Tensor Core performance for embedding lookups
|
| 254 |
+
new_vocab_size = len(tokenizer)
|
| 255 |
+
text.transformer.resize_token_embeddings(new_vocab_size, pad_to_multiple_of=64)
|
| 256 |
+
|
| 257 |
+
# Store token IDs for use in forward pass
|
| 258 |
+
if 'additional_special_tokens' in special_tokens_cfg:
|
| 259 |
+
for token in special_tokens_cfg['additional_special_tokens']:
|
| 260 |
+
if token == '<coca_cls>':
|
| 261 |
+
text.coca_cls_token_id = tokenizer.convert_tokens_to_ids(token)
|
| 262 |
+
|
| 263 |
+
if 'pad_token' in special_tokens_cfg:
|
| 264 |
+
text.config.pad_token_id = tokenizer.pad_token_id
|
| 265 |
+
text.pad_token_id = tokenizer.pad_token_id
|
| 266 |
+
|
| 267 |
+
text.config.vocab_size = new_vocab_size
|
| 268 |
+
text.vocab_size = new_vocab_size
|
| 269 |
+
|
| 270 |
+
logging.info(f"Added special tokens to {text_cfg.hf_model_name}:")
|
| 271 |
+
logging.info(f" Original vocab size: {original_vocab_size}")
|
| 272 |
+
logging.info(f" New vocab size: {new_vocab_size}")
|
| 273 |
+
logging.info(f" Added {new_vocab_size - original_vocab_size} new tokens")
|
| 274 |
+
if text.coca_cls_token_id is not None:
|
| 275 |
+
logging.info(f" CoCa CLS token ID: {text.coca_cls_token_id}")
|
| 276 |
+
if text.pad_token_id is not None:
|
| 277 |
+
logging.info(f" Pad token ID: {text.pad_token_id}")
|
| 278 |
+
else:
|
| 279 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 280 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 281 |
+
if text_cfg.norm_kwargs:
|
| 282 |
+
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
|
| 283 |
+
if text_cfg.act_kwargs is not None:
|
| 284 |
+
act_layer = partial(act_layer, **text_cfg.act_kwargs)
|
| 285 |
+
|
| 286 |
+
text = TextTransformer(
|
| 287 |
+
context_length=text_cfg.context_length,
|
| 288 |
+
vocab_size=text_cfg.vocab_size,
|
| 289 |
+
width=text_cfg.width,
|
| 290 |
+
heads=text_cfg.heads,
|
| 291 |
+
layers=text_cfg.layers,
|
| 292 |
+
mlp_ratio=text_cfg.mlp_ratio,
|
| 293 |
+
ls_init_value=text_cfg.ls_init_value,
|
| 294 |
+
output_dim=embed_dim,
|
| 295 |
+
embed_cls=text_cfg.embed_cls,
|
| 296 |
+
no_causal_mask=text_cfg.no_causal_mask,
|
| 297 |
+
pad_id=text_cfg.pad_id,
|
| 298 |
+
eos_id=text_cfg.eos_id,
|
| 299 |
+
pool_type=text_cfg.pool_type,
|
| 300 |
+
proj_type=text_cfg.proj_type,
|
| 301 |
+
proj_bias=text_cfg.proj_bias,
|
| 302 |
+
output_tokens=text_cfg.output_tokens,
|
| 303 |
+
act_layer=act_layer,
|
| 304 |
+
norm_layer=norm_layer,
|
| 305 |
+
block_type=text_cfg.block_type,
|
| 306 |
+
qk_norm=text_cfg.qk_norm,
|
| 307 |
+
scaled_cosine_attn=text_cfg.scaled_cosine_attn,
|
| 308 |
+
scale_heads=text_cfg.scale_heads,
|
| 309 |
+
scale_attn_inner=text_cfg.scale_attn_inner,
|
| 310 |
+
scale_attn=text_cfg.scale_attn,
|
| 311 |
+
scale_fc=text_cfg.scale_fc,
|
| 312 |
+
)
|
| 313 |
+
return text
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class CLIP(nn.Module):
|
| 317 |
+
output_dict: torch.jit.Final[bool]
|
| 318 |
+
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
embed_dim: int,
|
| 322 |
+
vision_cfg: CLIPVisionCfg,
|
| 323 |
+
text_cfg: CLIPTextCfg,
|
| 324 |
+
quick_gelu: bool = False,
|
| 325 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
| 326 |
+
init_logit_bias: Optional[float] = None,
|
| 327 |
+
nonscalar_logit_scale: bool = False,
|
| 328 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 329 |
+
output_dict: bool = False,
|
| 330 |
+
):
|
| 331 |
+
super().__init__()
|
| 332 |
+
self.output_dict = output_dict
|
| 333 |
+
|
| 334 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
| 335 |
+
|
| 336 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
| 337 |
+
self.transformer = text.transformer
|
| 338 |
+
self.context_length = text.context_length
|
| 339 |
+
self.vocab_size = text.vocab_size
|
| 340 |
+
self.token_embedding = text.token_embedding
|
| 341 |
+
self.positional_embedding = text.positional_embedding
|
| 342 |
+
self.ln_final = text.ln_final
|
| 343 |
+
self.text_projection = text.text_projection
|
| 344 |
+
self.text_pool_type = text.pool_type
|
| 345 |
+
self.text_eos_id = text.eos_id
|
| 346 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
| 347 |
+
|
| 348 |
+
lshape = [1] if nonscalar_logit_scale else []
|
| 349 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
| 350 |
+
if init_logit_bias is not None:
|
| 351 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
| 352 |
+
else:
|
| 353 |
+
self.logit_bias = None
|
| 354 |
+
|
| 355 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
| 356 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
| 357 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
| 358 |
+
|
| 359 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
| 360 |
+
assert freeze_layer_norm, 'Unfreezing LayerNorm is not supported. LayerNorm treated like other weights.'
|
| 361 |
+
lock_text_tower(self, unlocked_layers)
|
| 362 |
+
|
| 363 |
+
@torch.jit.ignore
|
| 364 |
+
def set_grad_checkpointing(self, enable=True):
|
| 365 |
+
self.visual.set_grad_checkpointing(enable)
|
| 366 |
+
self.transformer.grad_checkpointing = enable
|
| 367 |
+
|
| 368 |
+
@torch.jit.ignore
|
| 369 |
+
def no_weight_decay(self):
|
| 370 |
+
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
|
| 371 |
+
no_wd = {'positional_embedding'}
|
| 372 |
+
if hasattr(self.visual, 'no_weight_decay'):
|
| 373 |
+
for n in self.visual.no_weight_decay():
|
| 374 |
+
no_wd.add('visual.' + n)
|
| 375 |
+
return no_wd
|
| 376 |
+
|
| 377 |
+
def encode_image(self, image, normalize: bool = False):
|
| 378 |
+
features = self.visual(image)
|
| 379 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 380 |
+
|
| 381 |
+
def encode_text(self, text, normalize: bool = False):
|
| 382 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
| 383 |
+
|
| 384 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
| 385 |
+
|
| 386 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
| 387 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
| 388 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
| 389 |
+
x = text_global_pool(x, text, self.text_pool_type, eos_token_id=getattr(self, "text_eos_id", None))
|
| 390 |
+
if self.text_projection is not None:
|
| 391 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 392 |
+
x = self.text_projection(x)
|
| 393 |
+
else:
|
| 394 |
+
x = x @ self.text_projection
|
| 395 |
+
|
| 396 |
+
return F.normalize(x, dim=-1) if normalize else x
|
| 397 |
+
|
| 398 |
+
def get_logits(self, image, text):
|
| 399 |
+
image_features = self.encode_image(image, normalize=True)
|
| 400 |
+
text_features = self.encode_text(text, normalize=True)
|
| 401 |
+
image_logits = self.logit_scale.exp() * image_features @ text_features.T
|
| 402 |
+
if self.logit_bias is not None:
|
| 403 |
+
image_logits += self.logit_bias
|
| 404 |
+
text_logits = image_logits.T
|
| 405 |
+
return image_logits, text_logits
|
| 406 |
+
|
| 407 |
+
def forward_intermediates(
|
| 408 |
+
self,
|
| 409 |
+
image: Optional[torch.Tensor] = None,
|
| 410 |
+
text: Optional[torch.Tensor] = None,
|
| 411 |
+
image_indices: Optional[Union[int, List[int]]] = None,
|
| 412 |
+
text_indices: Optional[Union[int, List[int]]] = None,
|
| 413 |
+
stop_early: bool = False,
|
| 414 |
+
normalize: bool = True,
|
| 415 |
+
normalize_intermediates: bool = False,
|
| 416 |
+
intermediates_only: bool = False,
|
| 417 |
+
image_output_fmt: str = 'NCHW',
|
| 418 |
+
image_output_extra_tokens: bool = False,
|
| 419 |
+
text_output_fmt: str = 'NLC',
|
| 420 |
+
text_output_extra_tokens: bool = False,
|
| 421 |
+
output_logits: bool = False,
|
| 422 |
+
output_logit_scale_bias: bool = False,
|
| 423 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
| 424 |
+
""" Forward features that returns intermediates.
|
| 425 |
+
|
| 426 |
+
Args:
|
| 427 |
+
image: Input image tensor
|
| 428 |
+
text: Input text tensor
|
| 429 |
+
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
|
| 430 |
+
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 431 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 432 |
+
normalize_intermediates: Apply final norm layer to all intermediates
|
| 433 |
+
normalize: L2 Normalize final features
|
| 434 |
+
intermediates_only: Only return intermediate features, do not return final features
|
| 435 |
+
image_output_fmt: Shape of intermediate image feature outputs
|
| 436 |
+
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
| 437 |
+
text_output_fmt: Shape of intermediate text feature outputs (ignored for this model)
|
| 438 |
+
text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model)
|
| 439 |
+
output_logits: Include logits in output
|
| 440 |
+
output_logit_scale_bias: Include the logit scale bias in the output
|
| 441 |
+
Returns:
|
| 442 |
+
|
| 443 |
+
"""
|
| 444 |
+
output = {}
|
| 445 |
+
if intermediates_only:
|
| 446 |
+
# intermediates only disables final feature normalization, and include logits
|
| 447 |
+
normalize = False
|
| 448 |
+
output_logits = False
|
| 449 |
+
if output_logits:
|
| 450 |
+
assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
|
| 451 |
+
|
| 452 |
+
if image is not None:
|
| 453 |
+
image_output = self.visual.forward_intermediates(
|
| 454 |
+
image,
|
| 455 |
+
indices=image_indices,
|
| 456 |
+
stop_early=stop_early,
|
| 457 |
+
normalize_intermediates=normalize_intermediates,
|
| 458 |
+
intermediates_only=intermediates_only,
|
| 459 |
+
output_fmt=image_output_fmt,
|
| 460 |
+
output_extra_tokens=image_output_extra_tokens,
|
| 461 |
+
)
|
| 462 |
+
if normalize and "image_features" in image_output:
|
| 463 |
+
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
|
| 464 |
+
output.update(image_output)
|
| 465 |
+
|
| 466 |
+
if text is not None:
|
| 467 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
| 468 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
| 469 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
| 470 |
+
x, intermediates = self.transformer.forward_intermediates(
|
| 471 |
+
x,
|
| 472 |
+
attn_mask=self.attn_mask,
|
| 473 |
+
indices=text_indices
|
| 474 |
+
)
|
| 475 |
+
if normalize_intermediates:
|
| 476 |
+
intermediates = [self.ln_final(xi) for xi in intermediates]
|
| 477 |
+
|
| 478 |
+
# NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens
|
| 479 |
+
output["text_intermediates"] = intermediates
|
| 480 |
+
|
| 481 |
+
if not intermediates_only:
|
| 482 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
| 483 |
+
x = text_global_pool(x, text, self.text_pool_type, eos_token_id=getattr(self, "text_eos_id", None))
|
| 484 |
+
if self.text_projection is not None:
|
| 485 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 486 |
+
x = self.text_projection(x)
|
| 487 |
+
else:
|
| 488 |
+
x = x @ self.text_projection
|
| 489 |
+
if normalize:
|
| 490 |
+
x = F.normalize(x, dim=-1)
|
| 491 |
+
output["text_features"] = x
|
| 492 |
+
|
| 493 |
+
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
|
| 494 |
+
|
| 495 |
+
if output_logits:
|
| 496 |
+
image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
|
| 497 |
+
if self.logit_bias is not None:
|
| 498 |
+
image_logits += self.logit_bias
|
| 499 |
+
text_logits = image_logits.T
|
| 500 |
+
output["image_logits"] = image_logits
|
| 501 |
+
output["text_logits"] = text_logits
|
| 502 |
+
|
| 503 |
+
if output_logit_scale_bias:
|
| 504 |
+
output["logit_scale"] = logit_scale_exp
|
| 505 |
+
if self.logit_bias is not None:
|
| 506 |
+
output['logit_bias'] = self.logit_bias
|
| 507 |
+
|
| 508 |
+
return output
|
| 509 |
+
|
| 510 |
+
def forward(
|
| 511 |
+
self,
|
| 512 |
+
image: Optional[torch.Tensor] = None,
|
| 513 |
+
text: Optional[torch.Tensor] = None,
|
| 514 |
+
):
|
| 515 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
| 516 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
| 517 |
+
|
| 518 |
+
if self.output_dict:
|
| 519 |
+
out_dict = {
|
| 520 |
+
"image_features": image_features,
|
| 521 |
+
"text_features": text_features,
|
| 522 |
+
"logit_scale": self.logit_scale.exp()
|
| 523 |
+
}
|
| 524 |
+
if self.logit_bias is not None:
|
| 525 |
+
out_dict['logit_bias'] = self.logit_bias
|
| 526 |
+
return out_dict
|
| 527 |
+
|
| 528 |
+
if self.logit_bias is not None:
|
| 529 |
+
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
|
| 530 |
+
return image_features, text_features, self.logit_scale.exp()
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class CustomTextCLIP(nn.Module):
|
| 534 |
+
output_dict: torch.jit.Final[bool]
|
| 535 |
+
|
| 536 |
+
def __init__(
|
| 537 |
+
self,
|
| 538 |
+
embed_dim: int,
|
| 539 |
+
vision_cfg: CLIPVisionCfg,
|
| 540 |
+
text_cfg: CLIPTextCfg,
|
| 541 |
+
quick_gelu: bool = False,
|
| 542 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
| 543 |
+
init_logit_bias: Optional[float] = None,
|
| 544 |
+
nonscalar_logit_scale: bool = False,
|
| 545 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 546 |
+
output_dict: bool = False,
|
| 547 |
+
):
|
| 548 |
+
super().__init__()
|
| 549 |
+
self.output_dict = output_dict
|
| 550 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
| 551 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
| 552 |
+
self.context_length = self.text.context_length
|
| 553 |
+
self.vocab_size = self.text.vocab_size
|
| 554 |
+
|
| 555 |
+
lshape = [1] if nonscalar_logit_scale else []
|
| 556 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
| 557 |
+
if init_logit_bias is not None:
|
| 558 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
| 559 |
+
else:
|
| 560 |
+
self.logit_bias = None
|
| 561 |
+
|
| 562 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
| 563 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
| 564 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
| 565 |
+
|
| 566 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
| 567 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
| 568 |
+
|
| 569 |
+
@torch.jit.ignore
|
| 570 |
+
def set_grad_checkpointing(self, enable=True):
|
| 571 |
+
self.visual.set_grad_checkpointing(enable)
|
| 572 |
+
self.text.set_grad_checkpointing(enable)
|
| 573 |
+
|
| 574 |
+
@torch.jit.ignore
|
| 575 |
+
def no_weight_decay(self):
|
| 576 |
+
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
|
| 577 |
+
no_wd = set()
|
| 578 |
+
if hasattr(self.visual, 'no_weight_decay'):
|
| 579 |
+
for n in self.visual.no_weight_decay():
|
| 580 |
+
no_wd.add('visual.' + n)
|
| 581 |
+
if hasattr(self.text, 'no_weight_decay'):
|
| 582 |
+
for n in self.text.no_weight_decay():
|
| 583 |
+
no_wd.add('text.' + n)
|
| 584 |
+
return no_wd
|
| 585 |
+
|
| 586 |
+
def encode_image(self, image, normalize: bool = False):
|
| 587 |
+
features = self.visual(image)
|
| 588 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 589 |
+
|
| 590 |
+
def encode_text(self, text, normalize: bool = False):
|
| 591 |
+
features = self.text(text)
|
| 592 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 593 |
+
|
| 594 |
+
def get_logits(self, image, text):
|
| 595 |
+
image_features = self.encode_image(image, normalize=True)
|
| 596 |
+
text_features = self.encode_text(text, normalize=True)
|
| 597 |
+
image_logits = self.logit_scale.exp() * image_features @ text_features.T
|
| 598 |
+
if self.logit_bias is not None:
|
| 599 |
+
image_logits += self.logit_bias
|
| 600 |
+
text_logits = image_logits.T
|
| 601 |
+
return image_logits, text_logits
|
| 602 |
+
|
| 603 |
+
def forward_intermediates(
|
| 604 |
+
self,
|
| 605 |
+
image: Optional[torch.Tensor] = None,
|
| 606 |
+
text: Optional[torch.Tensor] = None,
|
| 607 |
+
image_indices: Optional[Union[int, List[int]]] = None,
|
| 608 |
+
text_indices: Optional[Union[int, List[int]]] = None,
|
| 609 |
+
stop_early: bool = False,
|
| 610 |
+
normalize: bool = True,
|
| 611 |
+
normalize_intermediates: bool = False,
|
| 612 |
+
intermediates_only: bool = False,
|
| 613 |
+
image_output_fmt: str = 'NCHW',
|
| 614 |
+
image_output_extra_tokens: bool = False,
|
| 615 |
+
text_output_fmt: str = 'NLC',
|
| 616 |
+
text_output_extra_tokens: bool = False,
|
| 617 |
+
output_logits: bool = False,
|
| 618 |
+
output_logit_scale_bias: bool = False,
|
| 619 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
| 620 |
+
""" Forward features that returns intermediates.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
image: Input image tensor
|
| 624 |
+
text: Input text tensor
|
| 625 |
+
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
|
| 626 |
+
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 627 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 628 |
+
normalize: L2 Normalize final image and text features (if present)
|
| 629 |
+
normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
|
| 630 |
+
intermediates_only: Only return intermediate features, do not return final features
|
| 631 |
+
image_output_fmt: Shape of intermediate image feature outputs
|
| 632 |
+
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
| 633 |
+
text_output_fmt: Shape of intermediate text feature outputs
|
| 634 |
+
text_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
| 635 |
+
output_logits: Include logits in output
|
| 636 |
+
output_logit_scale_bias: Include the logit scale bias in the output
|
| 637 |
+
Returns:
|
| 638 |
+
|
| 639 |
+
"""
|
| 640 |
+
output = {}
|
| 641 |
+
if intermediates_only:
|
| 642 |
+
# intermediates only disables final feature normalization, and include logits
|
| 643 |
+
normalize = False
|
| 644 |
+
output_logits = False
|
| 645 |
+
if output_logits:
|
| 646 |
+
assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
|
| 647 |
+
|
| 648 |
+
if image is not None:
|
| 649 |
+
image_output = self.visual.forward_intermediates(
|
| 650 |
+
image,
|
| 651 |
+
indices=image_indices,
|
| 652 |
+
stop_early=stop_early,
|
| 653 |
+
normalize_intermediates=normalize_intermediates,
|
| 654 |
+
intermediates_only=intermediates_only,
|
| 655 |
+
output_fmt=image_output_fmt,
|
| 656 |
+
output_extra_tokens=image_output_extra_tokens,
|
| 657 |
+
)
|
| 658 |
+
if normalize and "image_features" in image_output:
|
| 659 |
+
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
|
| 660 |
+
output.update(image_output)
|
| 661 |
+
|
| 662 |
+
if text is not None:
|
| 663 |
+
text_output = self.text.forward_intermediates(
|
| 664 |
+
text,
|
| 665 |
+
indices=text_indices,
|
| 666 |
+
stop_early=stop_early,
|
| 667 |
+
normalize_intermediates=normalize_intermediates,
|
| 668 |
+
intermediates_only=intermediates_only,
|
| 669 |
+
output_fmt=text_output_fmt,
|
| 670 |
+
output_extra_tokens=text_output_extra_tokens,
|
| 671 |
+
)
|
| 672 |
+
if normalize and "text_features" in text_output:
|
| 673 |
+
text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
|
| 674 |
+
output.update(text_output)
|
| 675 |
+
|
| 676 |
+
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
|
| 677 |
+
|
| 678 |
+
if output_logits:
|
| 679 |
+
image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
|
| 680 |
+
if self.logit_bias is not None:
|
| 681 |
+
image_logits += self.logit_bias
|
| 682 |
+
text_logits = image_logits.T
|
| 683 |
+
output["image_logits"] = image_logits
|
| 684 |
+
output["text_logits"] = text_logits
|
| 685 |
+
|
| 686 |
+
if output_logit_scale_bias:
|
| 687 |
+
output["logit_scale"] = logit_scale_exp
|
| 688 |
+
if self.logit_bias is not None:
|
| 689 |
+
output['logit_bias'] = self.logit_bias
|
| 690 |
+
|
| 691 |
+
return output
|
| 692 |
+
|
| 693 |
+
def forward(
|
| 694 |
+
self,
|
| 695 |
+
image: Optional[torch.Tensor] = None,
|
| 696 |
+
text: Optional[torch.Tensor] = None,
|
| 697 |
+
):
|
| 698 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
| 699 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
| 700 |
+
|
| 701 |
+
if self.output_dict:
|
| 702 |
+
out_dict = {
|
| 703 |
+
"image_features": image_features,
|
| 704 |
+
"text_features": text_features,
|
| 705 |
+
"logit_scale": self.logit_scale.exp()
|
| 706 |
+
}
|
| 707 |
+
if self.logit_bias is not None:
|
| 708 |
+
out_dict['logit_bias'] = self.logit_bias
|
| 709 |
+
return out_dict
|
| 710 |
+
|
| 711 |
+
if self.logit_bias is not None:
|
| 712 |
+
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
|
| 713 |
+
return image_features, text_features, self.logit_scale.exp()
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
| 717 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
| 718 |
+
|
| 719 |
+
def _convert_weights(l):
|
| 720 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 721 |
+
l.weight.data = l.weight.data.to(dtype)
|
| 722 |
+
if l.bias is not None:
|
| 723 |
+
l.bias.data = l.bias.data.to(dtype)
|
| 724 |
+
|
| 725 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
| 726 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 727 |
+
tensor = getattr(l, attr, None)
|
| 728 |
+
if tensor is not None:
|
| 729 |
+
tensor.data = tensor.data.to(dtype)
|
| 730 |
+
|
| 731 |
+
if isinstance(l, (CLIP, TextTransformer)):
|
| 732 |
+
# convert text nn.Parameter projections
|
| 733 |
+
attr = getattr(l, "text_projection", None)
|
| 734 |
+
if attr is not None:
|
| 735 |
+
attr.data = attr.data.to(dtype)
|
| 736 |
+
|
| 737 |
+
if isinstance(l, VisionTransformer):
|
| 738 |
+
# convert vision nn.Parameter projections
|
| 739 |
+
attr = getattr(l, "proj", None)
|
| 740 |
+
if attr is not None:
|
| 741 |
+
attr.data = attr.data.to(dtype)
|
| 742 |
+
|
| 743 |
+
model.apply(_convert_weights)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
# used to maintain checkpoint compatibility
|
| 750 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
| 751 |
+
if 'text_projection' in state_dict:
|
| 752 |
+
# old format state_dict, move text tower -> .text
|
| 753 |
+
new_state_dict = {}
|
| 754 |
+
for k, v in state_dict.items():
|
| 755 |
+
if any(k.startswith(p) for p in (
|
| 756 |
+
'text_projection',
|
| 757 |
+
'positional_embedding',
|
| 758 |
+
'token_embedding',
|
| 759 |
+
'transformer',
|
| 760 |
+
'ln_final',
|
| 761 |
+
)):
|
| 762 |
+
k = 'text.' + k
|
| 763 |
+
new_state_dict[k] = v
|
| 764 |
+
return new_state_dict
|
| 765 |
+
return state_dict
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def build_model_from_openai_state_dict(
|
| 769 |
+
state_dict: dict,
|
| 770 |
+
quick_gelu=True,
|
| 771 |
+
cast_dtype=torch.float16,
|
| 772 |
+
):
|
| 773 |
+
vit = "visual.proj" in state_dict
|
| 774 |
+
|
| 775 |
+
if vit:
|
| 776 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 777 |
+
vision_layers = len(
|
| 778 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 779 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 780 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 781 |
+
image_size = vision_patch_size * grid_size
|
| 782 |
+
else:
|
| 783 |
+
counts: list = [
|
| 784 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 785 |
+
vision_layers = tuple(counts)
|
| 786 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 787 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 788 |
+
vision_patch_size = None
|
| 789 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 790 |
+
image_size = output_width * 32
|
| 791 |
+
|
| 792 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 793 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 794 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 795 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 796 |
+
transformer_heads = transformer_width // 64
|
| 797 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
| 798 |
+
|
| 799 |
+
vision_cfg = CLIPVisionCfg(
|
| 800 |
+
layers=vision_layers,
|
| 801 |
+
width=vision_width,
|
| 802 |
+
patch_size=vision_patch_size,
|
| 803 |
+
image_size=image_size,
|
| 804 |
+
)
|
| 805 |
+
text_cfg = CLIPTextCfg(
|
| 806 |
+
context_length=context_length,
|
| 807 |
+
vocab_size=vocab_size,
|
| 808 |
+
width=transformer_width,
|
| 809 |
+
heads=transformer_heads,
|
| 810 |
+
layers=transformer_layers,
|
| 811 |
+
)
|
| 812 |
+
model = CLIP(
|
| 813 |
+
embed_dim,
|
| 814 |
+
vision_cfg=vision_cfg,
|
| 815 |
+
text_cfg=text_cfg,
|
| 816 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
| 817 |
+
cast_dtype=cast_dtype,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 821 |
+
state_dict.pop(key, None)
|
| 822 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
| 823 |
+
model.load_state_dict(state_dict)
|
| 824 |
+
return model.eval()
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
| 828 |
+
model.eval()
|
| 829 |
+
image_size = model.visual.image_size
|
| 830 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
| 831 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
| 832 |
+
model = torch.jit.trace_module(
|
| 833 |
+
model,
|
| 834 |
+
inputs=dict(
|
| 835 |
+
forward=(example_images, example_text),
|
| 836 |
+
encode_text=(example_text,),
|
| 837 |
+
encode_image=(example_images,)
|
| 838 |
+
))
|
| 839 |
+
model.visual.image_size = image_size
|
| 840 |
+
return model
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
| 844 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
| 845 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
| 846 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
| 847 |
+
return
|
| 848 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
| 849 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
| 850 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
| 851 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
| 852 |
+
return
|
| 853 |
+
|
| 854 |
+
if extra_tokens:
|
| 855 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
| 856 |
+
else:
|
| 857 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
| 858 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
| 859 |
+
|
| 860 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
| 861 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
| 862 |
+
pos_emb_img = F.interpolate(
|
| 863 |
+
pos_emb_img,
|
| 864 |
+
size=grid_size,
|
| 865 |
+
mode=interpolation,
|
| 866 |
+
antialias=antialias,
|
| 867 |
+
align_corners=False,
|
| 868 |
+
)
|
| 869 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
| 870 |
+
if pos_emb_tok is not None:
|
| 871 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
| 872 |
+
else:
|
| 873 |
+
new_pos_embed = pos_emb_img
|
| 874 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
|
| 878 |
+
pos_embed_key = 'positional_embedding' if 'positional_embedding' in state_dict else 'text.positional_embedding'
|
| 879 |
+
old_pos_embed = state_dict.get(pos_embed_key, None)
|
| 880 |
+
if old_pos_embed is None:
|
| 881 |
+
return
|
| 882 |
+
# FIXME add support for text cls_token
|
| 883 |
+
model_pos_embed = getattr(model, 'positional_embedding', None)
|
| 884 |
+
if model_pos_embed is None:
|
| 885 |
+
model_pos_embed = getattr(model.text, 'positional_embedding', None)
|
| 886 |
+
|
| 887 |
+
old_num_pos = old_pos_embed.shape[0]
|
| 888 |
+
old_width = old_pos_embed.shape[1]
|
| 889 |
+
num_pos = model_pos_embed.shape[0]
|
| 890 |
+
width = model_pos_embed.shape[1]
|
| 891 |
+
assert old_width == width, 'text pos_embed width changed!'
|
| 892 |
+
if old_num_pos == num_pos:
|
| 893 |
+
return
|
| 894 |
+
|
| 895 |
+
logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
|
| 896 |
+
old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
|
| 897 |
+
old_pos_embed = F.interpolate(
|
| 898 |
+
old_pos_embed,
|
| 899 |
+
size=num_pos,
|
| 900 |
+
mode=interpolation,
|
| 901 |
+
antialias=antialias,
|
| 902 |
+
align_corners=False,
|
| 903 |
+
)
|
| 904 |
+
old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
|
| 905 |
+
new_pos_embed = old_pos_embed
|
| 906 |
+
|
| 907 |
+
state_dict[pos_embed_key] = new_pos_embed
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
def get_model_preprocess_cfg(model):
|
| 911 |
+
module = getattr(model, 'visual', model)
|
| 912 |
+
preprocess_cfg = getattr(module, 'preprocess_cfg', {})
|
| 913 |
+
if not preprocess_cfg:
|
| 914 |
+
# use separate legacy attributes if preprocess_cfg dict not found
|
| 915 |
+
size = getattr(module, 'image_size')
|
| 916 |
+
if size is not None:
|
| 917 |
+
preprocess_cfg['size'] = size
|
| 918 |
+
mean = getattr(module, 'image_mean', None)
|
| 919 |
+
if mean is not None:
|
| 920 |
+
preprocess_cfg['mean'] = mean
|
| 921 |
+
std = getattr(module, 'image_std', None)
|
| 922 |
+
if std is not None:
|
| 923 |
+
preprocess_cfg['std'] = std
|
| 924 |
+
return preprocess_cfg
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
|
| 928 |
+
module = getattr(model, 'visual', model)
|
| 929 |
+
module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
|
| 930 |
+
module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
|
| 931 |
+
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def get_model_tokenize_cfg(model):
|
| 935 |
+
module = getattr(model, 'text', model)
|
| 936 |
+
cfg = {}
|
| 937 |
+
context_length = getattr(module, 'context_length', None)
|
| 938 |
+
if context_length is not None:
|
| 939 |
+
cfg['context_length'] = context_length
|
| 940 |
+
vocab_size = getattr(module, 'vocab_size', None)
|
| 941 |
+
if vocab_size is not None:
|
| 942 |
+
cfg['vocab_size'] = vocab_size
|
| 943 |
+
return cfg
|
src/open_clip/model_configs/sleep_coca_base_dualtransformer.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 512,
|
| 3 |
+
"multimodal_cfg": {
|
| 4 |
+
"width": 768,
|
| 5 |
+
"context_length": 256,
|
| 6 |
+
"mlp_ratio": 4,
|
| 7 |
+
"layers": 12,
|
| 8 |
+
"heads": 12
|
| 9 |
+
},
|
| 10 |
+
"biosignals_cfg": {
|
| 11 |
+
"architecture": "pure_transformer",
|
| 12 |
+
"input_channels": 10,
|
| 13 |
+
"signal_length": 1920,
|
| 14 |
+
"sampling_rate": 64,
|
| 15 |
+
"patch_size": 16,
|
| 16 |
+
"conv_embed_dim": 256,
|
| 17 |
+
"num_temporal_layers": 1,
|
| 18 |
+
"activation": "swiglu",
|
| 19 |
+
"norm_type": "rmsnorm",
|
| 20 |
+
"mlp_bias": false,
|
| 21 |
+
"share_channel_rope": true,
|
| 22 |
+
"transformer_layers": 3,
|
| 23 |
+
"transformer_width": 768,
|
| 24 |
+
"transformer_heads": 12,
|
| 25 |
+
"mlp_ratio": 3.0,
|
| 26 |
+
"pool_type": "attn",
|
| 27 |
+
"dropout": 0.1,
|
| 28 |
+
"decoder_tokens": 32
|
| 29 |
+
},
|
| 30 |
+
"text_cfg": {
|
| 31 |
+
"context_length": 256,
|
| 32 |
+
"vocab_size": 49408,
|
| 33 |
+
"layers": 12,
|
| 34 |
+
"heads": 12,
|
| 35 |
+
"width": 768,
|
| 36 |
+
"embed_cls": true,
|
| 37 |
+
"output_tokens": true
|
| 38 |
+
},
|
| 39 |
+
"custom_text": true,
|
| 40 |
+
"prefix_len": 1,
|
| 41 |
+
"num_caption_channels": 12,
|
| 42 |
+
"decoder_type": "cross_attention"
|
| 43 |
+
}
|
| 44 |
+
|
src/open_clip/tokenizer.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" CLIP tokenizer
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
| 4 |
+
"""
|
| 5 |
+
import gzip
|
| 6 |
+
import html
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import string
|
| 10 |
+
from functools import lru_cache, partial
|
| 11 |
+
from typing import Callable, List, Optional, Union, Dict
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import ftfy
|
| 15 |
+
import numpy as np
|
| 16 |
+
import regex as re
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
# https://stackoverflow.com/q/62691279
|
| 20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
+
_nltk_init = False
|
| 22 |
+
|
| 23 |
+
DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@lru_cache()
|
| 27 |
+
def default_bpe():
|
| 28 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@lru_cache()
|
| 32 |
+
def bytes_to_unicode():
|
| 33 |
+
"""
|
| 34 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 35 |
+
The reversible bpe codes work on unicode strings.
|
| 36 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 37 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 38 |
+
This is a significant percentage of your normal, say, 32K bpe vocab.
|
| 39 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 40 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 41 |
+
"""
|
| 42 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 43 |
+
cs = bs[:]
|
| 44 |
+
n = 0
|
| 45 |
+
for b in range(2**8):
|
| 46 |
+
if b not in bs:
|
| 47 |
+
bs.append(b)
|
| 48 |
+
cs.append(2**8+n)
|
| 49 |
+
n += 1
|
| 50 |
+
cs = [chr(n) for n in cs]
|
| 51 |
+
return dict(zip(bs, cs))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_pairs(word):
|
| 55 |
+
"""Return set of symbol pairs in a word.
|
| 56 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 57 |
+
"""
|
| 58 |
+
pairs = set()
|
| 59 |
+
prev_char = word[0]
|
| 60 |
+
for char in word[1:]:
|
| 61 |
+
pairs.add((prev_char, char))
|
| 62 |
+
prev_char = char
|
| 63 |
+
return pairs
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def basic_clean(text):
|
| 67 |
+
text = ftfy.fix_text(text)
|
| 68 |
+
text = html.unescape(html.unescape(text))
|
| 69 |
+
return text.strip()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def whitespace_clean(text):
|
| 73 |
+
text = " ".join(text.split())
|
| 74 |
+
text = text.strip()
|
| 75 |
+
return text
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _clean_canonicalize(x):
|
| 79 |
+
# basic, remove whitespace, remove punctuation, lower case
|
| 80 |
+
return canonicalize_text(basic_clean(x))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _clean_lower(x):
|
| 84 |
+
# basic, remove whitespace, lower case
|
| 85 |
+
return whitespace_clean(basic_clean(x)).lower()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _clean_whitespace(x):
|
| 89 |
+
# basic, remove whitespace
|
| 90 |
+
return whitespace_clean(basic_clean(x))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_clean_fn(type: str):
|
| 94 |
+
if type == 'canonicalize':
|
| 95 |
+
return _clean_canonicalize
|
| 96 |
+
elif type == 'lower':
|
| 97 |
+
return _clean_lower
|
| 98 |
+
elif type == 'whitespace':
|
| 99 |
+
return _clean_whitespace
|
| 100 |
+
else:
|
| 101 |
+
assert False, f"Invalid clean function ({type})."
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def canonicalize_text(
|
| 105 |
+
text,
|
| 106 |
+
*,
|
| 107 |
+
keep_punctuation_exact_string=None,
|
| 108 |
+
trans_punctuation: dict = str.maketrans("", "", string.punctuation),
|
| 109 |
+
):
|
| 110 |
+
"""Returns canonicalized `text` (lowercase and punctuation removed).
|
| 111 |
+
|
| 112 |
+
From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
text: string to be canonicalized.
|
| 116 |
+
keep_punctuation_exact_string: If provided, then this exact string kept.
|
| 117 |
+
For example providing '{}' will keep any occurrences of '{}' (but will
|
| 118 |
+
still remove '{' and '}' that appear separately).
|
| 119 |
+
"""
|
| 120 |
+
text = text.replace("_", " ")
|
| 121 |
+
if keep_punctuation_exact_string:
|
| 122 |
+
text = keep_punctuation_exact_string.join(
|
| 123 |
+
part.translate(trans_punctuation)
|
| 124 |
+
for part in text.split(keep_punctuation_exact_string)
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
text = text.translate(trans_punctuation)
|
| 128 |
+
text = text.lower()
|
| 129 |
+
text = " ".join(text.split())
|
| 130 |
+
return text.strip()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class SimpleTokenizer(object):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
bpe_path: str = default_bpe(),
|
| 137 |
+
additional_special_tokens: Optional[List[str]] = None,
|
| 138 |
+
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
|
| 139 |
+
clean: str = 'lower',
|
| 140 |
+
reduction_mask: str = ''
|
| 141 |
+
):
|
| 142 |
+
self.byte_encoder = bytes_to_unicode()
|
| 143 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 144 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 145 |
+
merges = merges[1:49152-256-2+1]
|
| 146 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 147 |
+
vocab = list(bytes_to_unicode().values())
|
| 148 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 149 |
+
for merge in merges:
|
| 150 |
+
vocab.append(''.join(merge))
|
| 151 |
+
special_tokens = ['<start_of_text>', '<end_of_text>']
|
| 152 |
+
if additional_special_tokens:
|
| 153 |
+
special_tokens += additional_special_tokens
|
| 154 |
+
vocab.extend(special_tokens)
|
| 155 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 156 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 157 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 158 |
+
self.cache = {t:t for t in special_tokens}
|
| 159 |
+
special = "|".join(special_tokens)
|
| 160 |
+
self.pat = re.compile(
|
| 161 |
+
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
| 162 |
+
re.IGNORECASE,
|
| 163 |
+
)
|
| 164 |
+
self.vocab_size = len(self.encoder)
|
| 165 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
| 166 |
+
self.sot_token_id = self.all_special_ids[0]
|
| 167 |
+
self.eot_token_id = self.all_special_ids[1]
|
| 168 |
+
self.context_length = context_length
|
| 169 |
+
self.clean_fn = get_clean_fn(clean)
|
| 170 |
+
self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None
|
| 171 |
+
|
| 172 |
+
def bpe(self, token):
|
| 173 |
+
if token in self.cache:
|
| 174 |
+
return self.cache[token]
|
| 175 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 176 |
+
pairs = get_pairs(word)
|
| 177 |
+
|
| 178 |
+
if not pairs:
|
| 179 |
+
return token+'</w>'
|
| 180 |
+
|
| 181 |
+
while True:
|
| 182 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 183 |
+
if bigram not in self.bpe_ranks:
|
| 184 |
+
break
|
| 185 |
+
first, second = bigram
|
| 186 |
+
new_word = []
|
| 187 |
+
i = 0
|
| 188 |
+
while i < len(word):
|
| 189 |
+
try:
|
| 190 |
+
j = word.index(first, i)
|
| 191 |
+
new_word.extend(word[i:j])
|
| 192 |
+
i = j
|
| 193 |
+
except Exception:
|
| 194 |
+
new_word.extend(word[i:])
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 198 |
+
new_word.append(first+second)
|
| 199 |
+
i += 2
|
| 200 |
+
else:
|
| 201 |
+
new_word.append(word[i])
|
| 202 |
+
i += 1
|
| 203 |
+
new_word = tuple(new_word)
|
| 204 |
+
word = new_word
|
| 205 |
+
if len(word) == 1:
|
| 206 |
+
break
|
| 207 |
+
else:
|
| 208 |
+
pairs = get_pairs(word)
|
| 209 |
+
word = ' '.join(word)
|
| 210 |
+
self.cache[token] = word
|
| 211 |
+
return word
|
| 212 |
+
|
| 213 |
+
def encode(self, text):
|
| 214 |
+
bpe_tokens = []
|
| 215 |
+
text = self.clean_fn(text)
|
| 216 |
+
for token in re.findall(self.pat, text):
|
| 217 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 218 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 219 |
+
return bpe_tokens
|
| 220 |
+
|
| 221 |
+
def decode(self, tokens):
|
| 222 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 223 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 224 |
+
return text
|
| 225 |
+
|
| 226 |
+
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:
|
| 227 |
+
""" Returns the tokenized representation of given input string(s)
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
texts : Union[str, List[str]]
|
| 232 |
+
An input string or a list of input strings to tokenize
|
| 233 |
+
context_length : int
|
| 234 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 235 |
+
|
| 236 |
+
Returns
|
| 237 |
+
-------
|
| 238 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 239 |
+
"""
|
| 240 |
+
if isinstance(texts, str):
|
| 241 |
+
texts = [texts]
|
| 242 |
+
|
| 243 |
+
context_length = context_length or self.context_length
|
| 244 |
+
assert context_length, 'Please set a valid context length'
|
| 245 |
+
|
| 246 |
+
if self.reduction_fn is not None:
|
| 247 |
+
# use reduction strategy for tokenize if set, otherwise default to truncation below
|
| 248 |
+
return self.reduction_fn(
|
| 249 |
+
texts,
|
| 250 |
+
context_length=context_length,
|
| 251 |
+
sot_token_id=self.sot_token_id,
|
| 252 |
+
eot_token_id=self.eot_token_id,
|
| 253 |
+
encode_fn=self.encode,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]
|
| 257 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 258 |
+
|
| 259 |
+
for i, tokens in enumerate(all_tokens):
|
| 260 |
+
if len(tokens) > context_length:
|
| 261 |
+
tokens = tokens[:context_length] # Truncate
|
| 262 |
+
tokens[-1] = self.eot_token_id
|
| 263 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 264 |
+
|
| 265 |
+
return result
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
_tokenizer = SimpleTokenizer()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def decode(output_ids: torch.Tensor):
|
| 272 |
+
output_ids = output_ids.cpu().numpy()
|
| 273 |
+
return _tokenizer.decode(output_ids)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:
|
| 277 |
+
return _tokenizer(texts, context_length=context_length)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def random_mask_tokenize(
|
| 281 |
+
texts: Union[str, List[str]],
|
| 282 |
+
context_length: int,
|
| 283 |
+
sot_token_id: int,
|
| 284 |
+
eot_token_id: int,
|
| 285 |
+
encode_fn: Callable,
|
| 286 |
+
shuffle: bool = False,
|
| 287 |
+
):
|
| 288 |
+
all_tokens = [encode_fn(text) for text in texts]
|
| 289 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 290 |
+
|
| 291 |
+
for i, tokens in enumerate(all_tokens):
|
| 292 |
+
tokens = torch.tensor(tokens)
|
| 293 |
+
num_tokens = len(tokens)
|
| 294 |
+
if num_tokens > context_length - 2: # 2 for sot and eot token
|
| 295 |
+
num_keep = context_length - 2
|
| 296 |
+
indices = torch.randperm(len(tokens))
|
| 297 |
+
indices = indices[:num_keep]
|
| 298 |
+
if not shuffle:
|
| 299 |
+
indices = indices.msort()
|
| 300 |
+
tokens = tokens[indices]
|
| 301 |
+
num_tokens = num_keep
|
| 302 |
+
result[i, 0] = sot_token_id
|
| 303 |
+
result[i, 1:num_tokens + 1] = tokens
|
| 304 |
+
result[i, num_tokens + 1] = eot_token_id
|
| 305 |
+
|
| 306 |
+
return result
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def simple_mask_tokenize(
|
| 310 |
+
texts: Union[str, List[str]],
|
| 311 |
+
context_length: int,
|
| 312 |
+
sot_token_id: int,
|
| 313 |
+
eot_token_id: int,
|
| 314 |
+
encode_fn: Callable,
|
| 315 |
+
):
|
| 316 |
+
all_tokens = [encode_fn(text) for text in texts]
|
| 317 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 318 |
+
|
| 319 |
+
for i, tokens in enumerate(all_tokens):
|
| 320 |
+
num_tokens = len(tokens)
|
| 321 |
+
if num_tokens > context_length - 2: # 2 for sot and eot token
|
| 322 |
+
num_keep = context_length - 2
|
| 323 |
+
start_index = random.randint(0, num_tokens - num_keep) # high is incl
|
| 324 |
+
tokens = tokens[start_index: start_index + num_keep]
|
| 325 |
+
tokens = [sot_token_id] + tokens + [eot_token_id]
|
| 326 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 327 |
+
|
| 328 |
+
return result
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def syntax_mask_tokenize(
|
| 332 |
+
texts: Union[str, List[str]],
|
| 333 |
+
context_length: int,
|
| 334 |
+
sot_token_id: int,
|
| 335 |
+
eot_token_id: int,
|
| 336 |
+
encode_fn: Callable,
|
| 337 |
+
) -> torch.LongTensor:
|
| 338 |
+
""" Returns the tokenized representation of given input string(s).
|
| 339 |
+
Apply syntax masking before tokenize.
|
| 340 |
+
"""
|
| 341 |
+
import nltk
|
| 342 |
+
global _nltk_init
|
| 343 |
+
if not _nltk_init:
|
| 344 |
+
# run them for the first time
|
| 345 |
+
nltk.download('punkt')
|
| 346 |
+
nltk.download('averaged_perceptron_tagger')
|
| 347 |
+
_nltk_init = True
|
| 348 |
+
|
| 349 |
+
def get_order(x):
|
| 350 |
+
if x.startswith('NN'):
|
| 351 |
+
return 1
|
| 352 |
+
elif x.startswith('JJ'):
|
| 353 |
+
return 2
|
| 354 |
+
elif x.startswith('VB'):
|
| 355 |
+
return 3
|
| 356 |
+
else:
|
| 357 |
+
return 4
|
| 358 |
+
|
| 359 |
+
# syntax masking
|
| 360 |
+
new_texts = []
|
| 361 |
+
for text in texts:
|
| 362 |
+
list_tokens = nltk.tokenize.word_tokenize(text)
|
| 363 |
+
pos_tags = nltk.pos_tag(list_tokens)
|
| 364 |
+
# sample the words by get_order method
|
| 365 |
+
order_list = [get_order(tag) for _, tag in pos_tags]
|
| 366 |
+
sorted_ids = np.argsort(np.array(order_list))
|
| 367 |
+
sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens
|
| 368 |
+
sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens
|
| 369 |
+
|
| 370 |
+
new_text = ''
|
| 371 |
+
for token in sampled_tokens:
|
| 372 |
+
new_text = new_text + str(token) + ' '
|
| 373 |
+
new_text = new_text.strip()
|
| 374 |
+
new_texts.append(new_text)
|
| 375 |
+
texts = new_texts
|
| 376 |
+
|
| 377 |
+
all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]
|
| 378 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 379 |
+
|
| 380 |
+
for i, tokens in enumerate(all_tokens):
|
| 381 |
+
# still need first truncate because some words produces two tokens
|
| 382 |
+
if len(tokens) > context_length:
|
| 383 |
+
tokens = tokens[:context_length] # Truncate
|
| 384 |
+
tokens[-1] = eot_token_id
|
| 385 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 386 |
+
|
| 387 |
+
return result
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def get_reduction_mask_fn(type: str):
|
| 391 |
+
""" Choose strategy for dropping (masking) tokens to achieve target context length"""
|
| 392 |
+
assert type in ('simple', 'random', 'shuffle', 'syntax')
|
| 393 |
+
if type == 'simple':
|
| 394 |
+
return simple_mask_tokenize # randomly select block [start:end]
|
| 395 |
+
elif type == 'random':
|
| 396 |
+
return random_mask_tokenize # randomly drop tokens (keep order)
|
| 397 |
+
elif type == 'shuffle':
|
| 398 |
+
return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order)
|
| 399 |
+
elif type == 'syntax':
|
| 400 |
+
return syntax_mask_tokenize # randomly drop prioritized by syntax
|
| 401 |
+
else:
|
| 402 |
+
assert False, F'Unknown type {type}.'
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class HFTokenizer:
|
| 406 |
+
"""HuggingFace tokenizer wrapper with support for custom tokenization modes"""
|
| 407 |
+
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
tokenizer_name: str,
|
| 411 |
+
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
|
| 412 |
+
clean: str = 'whitespace',
|
| 413 |
+
strip_sep_token: bool = False,
|
| 414 |
+
language: Optional[str] = None,
|
| 415 |
+
cache_dir: Optional[str] = None,
|
| 416 |
+
tokenizer_mode: Optional[str] = None, # None, 'clips'
|
| 417 |
+
**kwargs
|
| 418 |
+
):
|
| 419 |
+
self.tokenizer_mode = tokenizer_mode or ''
|
| 420 |
+
self.context_length = context_length
|
| 421 |
+
self.clean_fn = get_clean_fn(clean)
|
| 422 |
+
self.strip_sep_token = strip_sep_token
|
| 423 |
+
|
| 424 |
+
# NOTE: Left as example of loading custom tokenizer from file for experimentation
|
| 425 |
+
# if self.tokenizer_mode == 'bert_clips':
|
| 426 |
+
# self.special_tokens = {
|
| 427 |
+
# "bos_token": 1,
|
| 428 |
+
# "eos_token": 2,
|
| 429 |
+
# "cls_token": 101,
|
| 430 |
+
# "pad_token": 0
|
| 431 |
+
# }
|
| 432 |
+
#
|
| 433 |
+
# # For BERT CLIPS mode with vocab file
|
| 434 |
+
# from tokenizers import BertWordPieceTokenizer
|
| 435 |
+
# if tokenizer_name.startswith('hf-hub:'):
|
| 436 |
+
# from huggingface_hub import hf_hub_download
|
| 437 |
+
# # Format: hf-hub:repo_id/filename
|
| 438 |
+
# repo_url = tokenizer_name[7:]
|
| 439 |
+
# parts = repo_url.split('/')
|
| 440 |
+
# filename = parts[-1]
|
| 441 |
+
# repo_id = '/'.join(parts[:-1])
|
| 442 |
+
# vocab_file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
|
| 443 |
+
# self.tokenizer = BertWordPieceTokenizer(lowercase=True)
|
| 444 |
+
# self.tokenizer = self.tokenizer.from_file(vocab_file)
|
| 445 |
+
# else:
|
| 446 |
+
# # Assume tokenizer_name is a local path to a vocab file
|
| 447 |
+
# self.tokenizer = BertWordPieceTokenizer(lowercase=True)
|
| 448 |
+
# self.tokenizer = self.tokenizer.from_file(tokenizer_name)
|
| 449 |
+
|
| 450 |
+
# Standard HuggingFace tokenizer initialization
|
| 451 |
+
from transformers import AutoTokenizer
|
| 452 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 453 |
+
tokenizer_name,
|
| 454 |
+
cache_dir=cache_dir,
|
| 455 |
+
**kwargs
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Set language function if available
|
| 459 |
+
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
|
| 460 |
+
if callable(set_lang_fn):
|
| 461 |
+
self.set_lang_fn = set_lang_fn
|
| 462 |
+
if language is not None:
|
| 463 |
+
self.set_language(language)
|
| 464 |
+
|
| 465 |
+
def save_pretrained(self, dest):
|
| 466 |
+
self.tokenizer.save_pretrained(dest)
|
| 467 |
+
|
| 468 |
+
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
|
| 469 |
+
# same cleaning as for default tokenizer, except lowercasing
|
| 470 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
| 471 |
+
if isinstance(texts, str):
|
| 472 |
+
texts = [texts]
|
| 473 |
+
|
| 474 |
+
context_length = context_length or self.context_length
|
| 475 |
+
assert context_length, 'Please set a valid context length in class init or call.'
|
| 476 |
+
|
| 477 |
+
texts = [self.clean_fn(text) for text in texts]
|
| 478 |
+
|
| 479 |
+
# Handle different tokenization modes
|
| 480 |
+
if self.tokenizer_mode == 'clips':
|
| 481 |
+
return self._clips_tokenize(texts, context_length)
|
| 482 |
+
else:
|
| 483 |
+
# Standard tokenization
|
| 484 |
+
input_ids = self.tokenizer.batch_encode_plus(
|
| 485 |
+
texts,
|
| 486 |
+
return_tensors='pt',
|
| 487 |
+
max_length=context_length,
|
| 488 |
+
padding='max_length',
|
| 489 |
+
truncation=True,
|
| 490 |
+
).input_ids
|
| 491 |
+
|
| 492 |
+
if self.strip_sep_token:
|
| 493 |
+
input_ids = torch.where(
|
| 494 |
+
input_ids == self.tokenizer.sep_token_id,
|
| 495 |
+
torch.zeros_like(input_ids),
|
| 496 |
+
input_ids,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
return input_ids
|
| 500 |
+
|
| 501 |
+
def set_language(self, src_lang):
|
| 502 |
+
if hasattr(self, 'set_lang_fn'):
|
| 503 |
+
self.set_lang_fn(src_lang)
|
| 504 |
+
else:
|
| 505 |
+
warnings.warn('Cannot set language for the tokenizer.')
|
| 506 |
+
|
| 507 |
+
def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
|
| 508 |
+
"""Use standard HF tokenizer but apply custom post-processing"""
|
| 509 |
+
# Use standard tokenizer without special tokens - we'll add our own
|
| 510 |
+
encoded_outputs = self.tokenizer.batch_encode_plus(
|
| 511 |
+
texts,
|
| 512 |
+
add_special_tokens=False,
|
| 513 |
+
padding=False,
|
| 514 |
+
truncation=False,
|
| 515 |
+
return_tensors=None
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
encoded = []
|
| 519 |
+
for tokens in encoded_outputs["input_ids"]:
|
| 520 |
+
tokens = tokens[:context_length - 3] # Leave room for special tokens
|
| 521 |
+
tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
| 522 |
+
encoded.append(tokens)
|
| 523 |
+
|
| 524 |
+
# Create result tensor and handle padding + class token
|
| 525 |
+
result = torch.zeros(len(encoded), context_length, dtype=torch.long)
|
| 526 |
+
for i, tokens in enumerate(encoded):
|
| 527 |
+
padded_tokens = self._pad_and_add_class_token(
|
| 528 |
+
tokens,
|
| 529 |
+
max_length=context_length,
|
| 530 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 531 |
+
cls_token_id=self.tokenizer.cls_token_id,
|
| 532 |
+
)
|
| 533 |
+
result[i, :len(padded_tokens)] = torch.tensor(padded_tokens)
|
| 534 |
+
|
| 535 |
+
return result
|
| 536 |
+
|
| 537 |
+
def _pad_and_add_class_token(
|
| 538 |
+
self,
|
| 539 |
+
tokens: List[int],
|
| 540 |
+
max_length: int,
|
| 541 |
+
pad_token_id: int = 0,
|
| 542 |
+
cls_token_id: int = 101,
|
| 543 |
+
) -> List[int]:
|
| 544 |
+
""" Add padding with class token at the end """
|
| 545 |
+
if len(tokens) > max_length - 1:
|
| 546 |
+
tokens = tokens[:max_length - 1]
|
| 547 |
+
|
| 548 |
+
# Add padding to reach max_length-1
|
| 549 |
+
if len(tokens) < max_length - 1:
|
| 550 |
+
tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
|
| 551 |
+
|
| 552 |
+
# Add class token at the end
|
| 553 |
+
tokens = tokens + [cls_token_id]
|
| 554 |
+
return tokens
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class SigLipTokenizer:
|
| 558 |
+
"""HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
|
| 559 |
+
|
| 560 |
+
NOTE: this is not needed in normal library use, but is used to import new sentencepiece tokenizers
|
| 561 |
+
into OpenCLIP. Leaving code here in case future models use new tokenizers.
|
| 562 |
+
"""
|
| 563 |
+
VOCAB_FILES = {
|
| 564 |
+
# english, vocab_size=32_000
|
| 565 |
+
"c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model",
|
| 566 |
+
# used in multilingual models (mT5, PaLI), vocab_size=250_000
|
| 567 |
+
"mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
|
| 568 |
+
# used in SigLIP2 models, vocab_size=256000
|
| 569 |
+
"gemma": "http://storage.googleapis.com/big_vision/gemma_tokenizer.model",
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
def __init__(
|
| 573 |
+
self,
|
| 574 |
+
tokenizer_name: str,
|
| 575 |
+
context_length: Optional[int] = 64,
|
| 576 |
+
):
|
| 577 |
+
if 'gemma' in tokenizer_name:
|
| 578 |
+
from transformers import GemmaTokenizerFast
|
| 579 |
+
tokenizer_cls = partial(
|
| 580 |
+
GemmaTokenizerFast, padding_side='right', add_bos_token=False, add_eos_token=True)
|
| 581 |
+
else:
|
| 582 |
+
from transformers import T5TokenizerFast
|
| 583 |
+
tokenizer_cls = partial(T5TokenizerFast, extra_ids=0)
|
| 584 |
+
|
| 585 |
+
if tokenizer_name in self.VOCAB_FILES:
|
| 586 |
+
# FIXME temporary hack?
|
| 587 |
+
import tempfile
|
| 588 |
+
import fsspec
|
| 589 |
+
vocab_file = self.VOCAB_FILES[tokenizer_name]
|
| 590 |
+
with tempfile.NamedTemporaryFile('wb') as dst:
|
| 591 |
+
with fsspec.open(vocab_file, 'rb') as src:
|
| 592 |
+
dst.write(src.read())
|
| 593 |
+
self.tokenizer = tokenizer_cls(dst.name, legacy=False)
|
| 594 |
+
else:
|
| 595 |
+
self.tokenizer = tokenizer_cls(tokenizer_name, legacy=False)
|
| 596 |
+
|
| 597 |
+
self.tokenizer.pad_token_id = 0 if 'gemma' in tokenizer_name else 1
|
| 598 |
+
self.tokenizer.eos_token_id = 1
|
| 599 |
+
self.context_length = context_length
|
| 600 |
+
|
| 601 |
+
def save_pretrained(self, dest):
|
| 602 |
+
self.tokenizer.save_pretrained(dest)
|
| 603 |
+
|
| 604 |
+
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
|
| 605 |
+
# same cleaning as for default tokenizer, except lowercasing
|
| 606 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
| 607 |
+
if isinstance(texts, str):
|
| 608 |
+
texts = [texts]
|
| 609 |
+
|
| 610 |
+
context_length = context_length or self.context_length
|
| 611 |
+
assert context_length, 'Please set a valid context length in class init or call.'
|
| 612 |
+
|
| 613 |
+
texts = [canonicalize_text(basic_clean(text)) for text in texts]
|
| 614 |
+
output = self.tokenizer(
|
| 615 |
+
texts,
|
| 616 |
+
return_tensors='pt',
|
| 617 |
+
max_length=context_length,
|
| 618 |
+
padding='max_length',
|
| 619 |
+
truncation=True,
|
| 620 |
+
)
|
| 621 |
+
return output.input_ids
|
src/open_clip/transformer.py
ADDED
|
@@ -0,0 +1,1823 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from torch.utils.checkpoint import checkpoint
|
| 9 |
+
|
| 10 |
+
import warnings
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def to_2tuple(x):
|
| 15 |
+
if isinstance(x, (tuple, list)):
|
| 16 |
+
return x
|
| 17 |
+
return (x, x)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def feature_take_indices(num_blocks, indices):
|
| 21 |
+
take_indices = [i if i >= 0 else num_blocks + i for i in indices]
|
| 22 |
+
max_index = max(take_indices)
|
| 23 |
+
return take_indices, max_index
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 27 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 28 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 29 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 30 |
+
grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
|
| 31 |
+
pos_embed = _get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 32 |
+
if cls_token:
|
| 33 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 34 |
+
return pos_embed
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 38 |
+
assert embed_dim % 2 == 0
|
| 39 |
+
emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 40 |
+
emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 41 |
+
return np.concatenate([emb_h, emb_w], axis=1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 45 |
+
assert embed_dim % 2 == 0
|
| 46 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 47 |
+
omega /= embed_dim / 2.
|
| 48 |
+
omega = 1. / 10000**omega
|
| 49 |
+
pos = pos.reshape(-1)
|
| 50 |
+
out = np.einsum('m,d->md', pos, omega)
|
| 51 |
+
return np.concatenate([np.sin(out), np.cos(out)], axis=1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class LayerNormFp32(nn.LayerNorm):
|
| 55 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
| 56 |
+
|
| 57 |
+
def forward(self, x: torch.Tensor):
|
| 58 |
+
orig_type = x.dtype
|
| 59 |
+
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
| 60 |
+
return x.to(orig_type)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class LayerNorm(nn.LayerNorm):
|
| 64 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor):
|
| 67 |
+
orig_type = x.dtype
|
| 68 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 69 |
+
return x.to(orig_type)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class QuickGELU(nn.Module):
|
| 73 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
| 74 |
+
def forward(self, x: torch.Tensor):
|
| 75 |
+
return x * torch.sigmoid(1.702 * x)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class LayerScale(nn.Module):
|
| 79 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.inplace = inplace
|
| 82 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class PatchDropout(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
https://arxiv.org/abs/2212.00794
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
prob: float = 0.5,
|
| 96 |
+
exclude_first_token: bool = True
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
assert 0 <= prob < 1.
|
| 100 |
+
self.prob = prob
|
| 101 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
if not self.training or self.prob == 0.:
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
if self.exclude_first_token:
|
| 108 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
| 109 |
+
else:
|
| 110 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
| 111 |
+
|
| 112 |
+
batch = x.size()[0]
|
| 113 |
+
num_tokens = x.size()[1]
|
| 114 |
+
|
| 115 |
+
batch_indices = torch.arange(batch)
|
| 116 |
+
batch_indices = batch_indices[..., None]
|
| 117 |
+
|
| 118 |
+
keep_prob = 1 - self.prob
|
| 119 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
| 120 |
+
|
| 121 |
+
rand = torch.randn(batch, num_tokens)
|
| 122 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
| 123 |
+
|
| 124 |
+
x = x[batch_indices, patch_indices_keep]
|
| 125 |
+
|
| 126 |
+
if self.exclude_first_token:
|
| 127 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 128 |
+
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class Attention(nn.Module):
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
dim: int,
|
| 136 |
+
num_heads: int = 8,
|
| 137 |
+
qkv_bias: bool = True,
|
| 138 |
+
qk_norm: bool = False,
|
| 139 |
+
scaled_cosine: bool = False,
|
| 140 |
+
scale_heads: bool = False,
|
| 141 |
+
inner_norm: bool = False,
|
| 142 |
+
logit_scale_max: float = math.log(1. / 0.01),
|
| 143 |
+
norm_layer: Type[nn.Module] = LayerNormFp32,
|
| 144 |
+
attn_drop: float = 0.,
|
| 145 |
+
proj_drop: float = 0.
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
assert not (scaled_cosine and qk_norm), "Cannot activate both scaled cosine and QK normalization"
|
| 149 |
+
self.scaled_cosine = scaled_cosine
|
| 150 |
+
self.scale_heads = scale_heads
|
| 151 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 152 |
+
self.num_heads = num_heads
|
| 153 |
+
self.head_dim = dim // num_heads
|
| 154 |
+
self.scale = self.head_dim ** -0.5
|
| 155 |
+
self.logit_scale_max = logit_scale_max
|
| 156 |
+
self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')
|
| 157 |
+
|
| 158 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
| 159 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
| 160 |
+
if qkv_bias:
|
| 161 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
| 162 |
+
else:
|
| 163 |
+
self.in_proj_bias = None
|
| 164 |
+
|
| 165 |
+
# QK normalization (with LN) from https://arxiv.org/abs/2106.04560 and related to other QK Norm ideas
|
| 166 |
+
if qk_norm:
|
| 167 |
+
self.ln_q = norm_layer(self.head_dim)
|
| 168 |
+
self.ln_k = norm_layer(self.head_dim)
|
| 169 |
+
else:
|
| 170 |
+
self.ln_q = nn.Identity()
|
| 171 |
+
self.ln_k = nn.Identity()
|
| 172 |
+
|
| 173 |
+
# Scaled cosine attention (from Swin Transformer V2, https://arxiv.org/abs/2111.09883)
|
| 174 |
+
if self.scaled_cosine:
|
| 175 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
| 176 |
+
else:
|
| 177 |
+
self.logit_scale = None
|
| 178 |
+
|
| 179 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 180 |
+
|
| 181 |
+
# Per-head attention logit scaling (from NormFormer, https://arxiv.org/abs/2110.09456)
|
| 182 |
+
if self.scale_heads:
|
| 183 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
| 184 |
+
else:
|
| 185 |
+
self.head_scale = None
|
| 186 |
+
|
| 187 |
+
# Normalization of attention logits, before final projection.
|
| 188 |
+
# Origin likely Sub-LN in (Foundation Transformers, https://arxiv.org/abs/2210.06423)
|
| 189 |
+
if inner_norm:
|
| 190 |
+
self.ln_inner = norm_layer(dim)
|
| 191 |
+
else:
|
| 192 |
+
self.ln_inner = nn.Identity()
|
| 193 |
+
|
| 194 |
+
self.out_proj = nn.Linear(dim, dim)
|
| 195 |
+
self.out_drop = nn.Dropout(proj_drop)
|
| 196 |
+
|
| 197 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
| 198 |
+
N, L, C = x.shape
|
| 199 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
| 200 |
+
q = q.reshape(N, L, self.num_heads, -1).transpose(1, 2)
|
| 201 |
+
k = k.reshape(N, L, self.num_heads, -1).transpose(1, 2)
|
| 202 |
+
v = v.reshape(N, L, self.num_heads, -1).transpose(1, 2)
|
| 203 |
+
|
| 204 |
+
if attn_mask is not None:
|
| 205 |
+
if attn_mask.ndim == 3:
|
| 206 |
+
# this module works with (L, L), or (N, num_heads, L, L) masks
|
| 207 |
+
attn_mask = attn_mask.reshape(N, self.num_heads, L, L)
|
| 208 |
+
if attn_mask.dtype == torch.bool:
|
| 209 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 210 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 211 |
+
attn_mask = new_attn_mask
|
| 212 |
+
else:
|
| 213 |
+
attn_mask = attn_mask.to(dtype=q.dtype)
|
| 214 |
+
|
| 215 |
+
if self.logit_scale is not None:
|
| 216 |
+
attn = torch.bmm(
|
| 217 |
+
F.normalize(q, dim=-1),
|
| 218 |
+
F.normalize(k, dim=-1).transpose(-1, -2)
|
| 219 |
+
)
|
| 220 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
| 221 |
+
attn = attn * logit_scale
|
| 222 |
+
if attn_mask is not None:
|
| 223 |
+
attn = attn + attn_mask
|
| 224 |
+
attn = attn.softmax(dim=-1)
|
| 225 |
+
attn = self.attn_drop(attn)
|
| 226 |
+
x = torch.bmm(attn, v)
|
| 227 |
+
else:
|
| 228 |
+
q = self.ln_q(q)
|
| 229 |
+
k = self.ln_k(k)
|
| 230 |
+
if self.use_fsdpa:
|
| 231 |
+
x = F.scaled_dot_product_attention(
|
| 232 |
+
q, k, v,
|
| 233 |
+
attn_mask=attn_mask,
|
| 234 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
q = q * self.scale
|
| 238 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
| 239 |
+
if attn_mask is not None:
|
| 240 |
+
attn += attn_mask
|
| 241 |
+
attn = attn.softmax(dim=-1)
|
| 242 |
+
attn = self.attn_drop(attn)
|
| 243 |
+
x = torch.bmm(attn, v)
|
| 244 |
+
|
| 245 |
+
# N, num_heads, L, head_dim
|
| 246 |
+
if self.head_scale is not None:
|
| 247 |
+
x = x * self.head_scale
|
| 248 |
+
x = x.transpose(1, 2).reshape(N, L, C)
|
| 249 |
+
x = self.ln_inner(x)
|
| 250 |
+
x = self.out_proj(x)
|
| 251 |
+
x = self.out_drop(x)
|
| 252 |
+
return x
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class AttentionalPooler(nn.Module):
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
d_model: int,
|
| 259 |
+
context_dim: int,
|
| 260 |
+
n_head: int = 8,
|
| 261 |
+
n_queries: int = 256,
|
| 262 |
+
norm_layer: Callable = LayerNorm,
|
| 263 |
+
):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
| 266 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
|
| 267 |
+
self.ln_q = norm_layer(d_model)
|
| 268 |
+
self.ln_k = norm_layer(context_dim)
|
| 269 |
+
|
| 270 |
+
def forward(self, x: torch.Tensor):
|
| 271 |
+
N = x.shape[0]
|
| 272 |
+
x = self.ln_k(x)
|
| 273 |
+
q = self.ln_q(self.query)
|
| 274 |
+
out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
|
| 275 |
+
return out
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class ResidualAttentionBlock(nn.Module):
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
d_model: int,
|
| 282 |
+
n_head: int,
|
| 283 |
+
mlp_ratio: float = 4.0,
|
| 284 |
+
ls_init_value: float = None,
|
| 285 |
+
act_layer: Callable = nn.GELU,
|
| 286 |
+
norm_layer: Callable = LayerNorm,
|
| 287 |
+
is_cross_attention: bool = False,
|
| 288 |
+
batch_first: bool = True,
|
| 289 |
+
):
|
| 290 |
+
super().__init__()
|
| 291 |
+
|
| 292 |
+
self.ln_1 = norm_layer(d_model)
|
| 293 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
|
| 294 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 295 |
+
if is_cross_attention:
|
| 296 |
+
self.ln_1_kv = norm_layer(d_model)
|
| 297 |
+
|
| 298 |
+
self.ln_2 = norm_layer(d_model)
|
| 299 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 300 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 301 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 302 |
+
("gelu", act_layer()),
|
| 303 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
| 304 |
+
]))
|
| 305 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 306 |
+
|
| 307 |
+
def get_weight_dtype(self) -> torch.dtype:
|
| 308 |
+
if hasattr(self.mlp.c_fc, 'int8_original_dtype'):
|
| 309 |
+
return self.mlp.c_fc.int8_original_dtype
|
| 310 |
+
return self.mlp.c_fc.weight.dtype
|
| 311 |
+
|
| 312 |
+
def attention(
|
| 313 |
+
self,
|
| 314 |
+
q_x: torch.Tensor,
|
| 315 |
+
k_x: Optional[torch.Tensor] = None,
|
| 316 |
+
v_x: Optional[torch.Tensor] = None,
|
| 317 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 318 |
+
):
|
| 319 |
+
k_x = k_x if k_x is not None else q_x
|
| 320 |
+
v_x = v_x if v_x is not None else q_x
|
| 321 |
+
|
| 322 |
+
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
| 323 |
+
return self.attn(
|
| 324 |
+
q_x, k_x, v_x,
|
| 325 |
+
need_weights=False,
|
| 326 |
+
attn_mask=attn_mask
|
| 327 |
+
)[0]
|
| 328 |
+
|
| 329 |
+
def forward(
|
| 330 |
+
self,
|
| 331 |
+
q_x: torch.Tensor,
|
| 332 |
+
k_x: Optional[torch.Tensor] = None,
|
| 333 |
+
v_x: Optional[torch.Tensor] = None,
|
| 334 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 335 |
+
):
|
| 336 |
+
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
| 337 |
+
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
| 338 |
+
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
|
| 339 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
| 340 |
+
return x
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class CustomResidualAttentionBlock(nn.Module):
|
| 344 |
+
def __init__(
|
| 345 |
+
self,
|
| 346 |
+
d_model: int,
|
| 347 |
+
n_head: int,
|
| 348 |
+
mlp_ratio: float = 4.0,
|
| 349 |
+
ls_init_value: float = None,
|
| 350 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 351 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 352 |
+
qk_norm: bool = False,
|
| 353 |
+
scale_cosine_attn: bool = False,
|
| 354 |
+
scale_heads: bool = False,
|
| 355 |
+
scale_attn_inner: bool = False,
|
| 356 |
+
scale_attn: bool = False,
|
| 357 |
+
scale_fc: bool = False,
|
| 358 |
+
batch_first: bool = True,
|
| 359 |
+
):
|
| 360 |
+
super().__init__()
|
| 361 |
+
assert batch_first, 'batch_first must be True for CustomResidualAttentionBlock'
|
| 362 |
+
|
| 363 |
+
self.ln_1 = norm_layer(d_model)
|
| 364 |
+
self.attn = Attention(
|
| 365 |
+
d_model,
|
| 366 |
+
n_head,
|
| 367 |
+
qk_norm=qk_norm,
|
| 368 |
+
scaled_cosine=scale_cosine_attn,
|
| 369 |
+
scale_heads=scale_heads,
|
| 370 |
+
inner_norm=scale_attn_inner,
|
| 371 |
+
norm_layer=norm_layer,
|
| 372 |
+
)
|
| 373 |
+
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
| 374 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 375 |
+
|
| 376 |
+
self.ln_2 = norm_layer(d_model)
|
| 377 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 378 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 379 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 380 |
+
("gelu", act_layer()),
|
| 381 |
+
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), # from NormFormer / Foundation Transformers
|
| 382 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
| 383 |
+
]))
|
| 384 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 385 |
+
|
| 386 |
+
def get_weight_dtype(self) -> torch.dtype:
|
| 387 |
+
if hasattr(self.mlp.c_fc, 'int8_original_dtype'):
|
| 388 |
+
return self.mlp.c_fc.int8_original_dtype
|
| 389 |
+
return self.mlp.c_fc.weight.dtype
|
| 390 |
+
|
| 391 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 392 |
+
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
|
| 393 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
| 394 |
+
return x
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class CustomTransformer(nn.Module):
|
| 398 |
+
""" A custom transformer that can use different block types. """
|
| 399 |
+
def __init__(
|
| 400 |
+
self,
|
| 401 |
+
width: int,
|
| 402 |
+
layers: int,
|
| 403 |
+
heads: int,
|
| 404 |
+
mlp_ratio: float = 4.0,
|
| 405 |
+
ls_init_value: float = None,
|
| 406 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 407 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 408 |
+
batch_first: bool = True,
|
| 409 |
+
block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',
|
| 410 |
+
):
|
| 411 |
+
super().__init__()
|
| 412 |
+
self.width = width
|
| 413 |
+
self.layers = layers
|
| 414 |
+
self.batch_first = batch_first # run transformer stack in batch first (N, L, D)
|
| 415 |
+
self.grad_checkpointing = False
|
| 416 |
+
|
| 417 |
+
if isinstance(block_types, str):
|
| 418 |
+
block_types = [block_types] * layers
|
| 419 |
+
assert len(block_types) == layers
|
| 420 |
+
|
| 421 |
+
def _create_block(bt: str):
|
| 422 |
+
if bt == 'CustomResidualAttentionBlock':
|
| 423 |
+
return CustomResidualAttentionBlock(
|
| 424 |
+
width,
|
| 425 |
+
heads,
|
| 426 |
+
mlp_ratio=mlp_ratio,
|
| 427 |
+
ls_init_value=ls_init_value,
|
| 428 |
+
act_layer=act_layer,
|
| 429 |
+
norm_layer=norm_layer,
|
| 430 |
+
batch_first=batch_first,
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
assert False
|
| 434 |
+
|
| 435 |
+
self.resblocks = nn.ModuleList([
|
| 436 |
+
_create_block(bt)
|
| 437 |
+
for bt in block_types
|
| 438 |
+
])
|
| 439 |
+
|
| 440 |
+
def get_cast_dtype(self) -> torch.dtype:
|
| 441 |
+
return self.resblocks[0].get_weight_dtype()
|
| 442 |
+
|
| 443 |
+
def forward_intermediates(
|
| 444 |
+
self,
|
| 445 |
+
x: torch.Tensor,
|
| 446 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 447 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 448 |
+
stop_early: bool = False,
|
| 449 |
+
):
|
| 450 |
+
take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
|
| 451 |
+
|
| 452 |
+
if not self.batch_first:
|
| 453 |
+
x = x.transpose(0, 1).contiguous() # NLD -> LND
|
| 454 |
+
|
| 455 |
+
intermediates = []
|
| 456 |
+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
| 457 |
+
blocks = self.resblocks
|
| 458 |
+
else:
|
| 459 |
+
blocks = self.resblocks[:max_index + 1]
|
| 460 |
+
for i, blk in enumerate(blocks):
|
| 461 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 462 |
+
x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)
|
| 463 |
+
else:
|
| 464 |
+
x = blk(x, attn_mask=attn_mask)
|
| 465 |
+
|
| 466 |
+
if i in take_indices:
|
| 467 |
+
intermediates.append(x.transpose(0, 1) if not self.batch_first else x)
|
| 468 |
+
|
| 469 |
+
if not self.batch_first:
|
| 470 |
+
x = x.transpose(0, 1) # LND -> NLD
|
| 471 |
+
|
| 472 |
+
return x, intermediates
|
| 473 |
+
|
| 474 |
+
def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):
|
| 475 |
+
""" Prune layers not required for specified intermediates.
|
| 476 |
+
"""
|
| 477 |
+
take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
|
| 478 |
+
self.resblocks = self.resblocks[:max_index + 1] # truncate blocks
|
| 479 |
+
return take_indices
|
| 480 |
+
|
| 481 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 482 |
+
if not self.batch_first:
|
| 483 |
+
x = x.transpose(0, 1) # NLD -> LND
|
| 484 |
+
|
| 485 |
+
for r in self.resblocks:
|
| 486 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 487 |
+
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
| 488 |
+
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
|
| 489 |
+
else:
|
| 490 |
+
x = r(x, attn_mask=attn_mask)
|
| 491 |
+
|
| 492 |
+
if not self.batch_first:
|
| 493 |
+
x = x.transpose(0, 1) # NLD -> LND
|
| 494 |
+
return x
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class Transformer(nn.Module):
|
| 498 |
+
def __init__(
|
| 499 |
+
self,
|
| 500 |
+
width: int,
|
| 501 |
+
layers: int,
|
| 502 |
+
heads: int,
|
| 503 |
+
mlp_ratio: float = 4.0,
|
| 504 |
+
ls_init_value: float = None,
|
| 505 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 506 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 507 |
+
batch_first: bool = True,
|
| 508 |
+
block_type: Optional[str] = None,
|
| 509 |
+
qk_norm: bool = False,
|
| 510 |
+
scaled_cosine_attn: bool = False,
|
| 511 |
+
scale_heads: bool = False,
|
| 512 |
+
scale_attn_inner: bool = False,
|
| 513 |
+
scale_attn: bool = False,
|
| 514 |
+
scale_fc: bool = False,
|
| 515 |
+
):
|
| 516 |
+
super().__init__()
|
| 517 |
+
self.width = width
|
| 518 |
+
self.layers = layers
|
| 519 |
+
self.batch_first = batch_first
|
| 520 |
+
self.grad_checkpointing = False
|
| 521 |
+
|
| 522 |
+
# Auto-select custom block if any custom features are enabled
|
| 523 |
+
if block_type is None:
|
| 524 |
+
if any([qk_norm, scaled_cosine_attn, scale_heads, scale_attn_inner, scale_attn, scale_fc]):
|
| 525 |
+
block_type = 'custom'
|
| 526 |
+
else:
|
| 527 |
+
block_type = 'default'
|
| 528 |
+
|
| 529 |
+
if block_type == 'custom':
|
| 530 |
+
self.resblocks = nn.ModuleList([
|
| 531 |
+
CustomResidualAttentionBlock(
|
| 532 |
+
width,
|
| 533 |
+
heads,
|
| 534 |
+
mlp_ratio,
|
| 535 |
+
ls_init_value=ls_init_value,
|
| 536 |
+
act_layer=act_layer,
|
| 537 |
+
norm_layer=norm_layer,
|
| 538 |
+
qk_norm=qk_norm,
|
| 539 |
+
scale_cosine_attn=scaled_cosine_attn,
|
| 540 |
+
scale_heads=scale_heads,
|
| 541 |
+
scale_attn_inner=scale_attn_inner,
|
| 542 |
+
scale_attn=scale_attn,
|
| 543 |
+
scale_fc=scale_fc,
|
| 544 |
+
batch_first=batch_first,
|
| 545 |
+
)
|
| 546 |
+
for _ in range(layers)
|
| 547 |
+
])
|
| 548 |
+
else:
|
| 549 |
+
self.resblocks = nn.ModuleList([
|
| 550 |
+
ResidualAttentionBlock(
|
| 551 |
+
width,
|
| 552 |
+
heads,
|
| 553 |
+
mlp_ratio,
|
| 554 |
+
ls_init_value=ls_init_value,
|
| 555 |
+
act_layer=act_layer,
|
| 556 |
+
norm_layer=norm_layer,
|
| 557 |
+
batch_first=batch_first,
|
| 558 |
+
)
|
| 559 |
+
for _ in range(layers)
|
| 560 |
+
])
|
| 561 |
+
|
| 562 |
+
def get_cast_dtype(self) -> torch.dtype:
|
| 563 |
+
return self.resblocks[0].get_weight_dtype()
|
| 564 |
+
|
| 565 |
+
def forward_intermediates(
|
| 566 |
+
self,
|
| 567 |
+
x: torch.Tensor,
|
| 568 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 569 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 570 |
+
stop_early: bool = False,
|
| 571 |
+
):
|
| 572 |
+
take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
|
| 573 |
+
|
| 574 |
+
if not self.batch_first:
|
| 575 |
+
x = x.transpose(0, 1).contiguous() # NLD -> LND
|
| 576 |
+
|
| 577 |
+
intermediates = []
|
| 578 |
+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
| 579 |
+
blocks = self.resblocks
|
| 580 |
+
else:
|
| 581 |
+
blocks = self.resblocks[:max_index + 1]
|
| 582 |
+
for i, blk in enumerate(blocks):
|
| 583 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 584 |
+
x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)
|
| 585 |
+
else:
|
| 586 |
+
x = blk(x, attn_mask=attn_mask)
|
| 587 |
+
|
| 588 |
+
if i in take_indices:
|
| 589 |
+
intermediates.append(x.transpose(0, 1) if not self.batch_first else x)
|
| 590 |
+
|
| 591 |
+
if not self.batch_first:
|
| 592 |
+
x = x.transpose(0, 1) # LND -> NLD
|
| 593 |
+
|
| 594 |
+
return x, intermediates
|
| 595 |
+
|
| 596 |
+
def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):
|
| 597 |
+
""" Prune layers not required for specified intermediates.
|
| 598 |
+
"""
|
| 599 |
+
take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
|
| 600 |
+
self.resblocks = self.resblocks[:max_index + 1] # truncate blocks
|
| 601 |
+
return take_indices
|
| 602 |
+
|
| 603 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 604 |
+
if not self.batch_first:
|
| 605 |
+
x = x.transpose(0, 1).contiguous() # NLD -> LND
|
| 606 |
+
|
| 607 |
+
for r in self.resblocks:
|
| 608 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 609 |
+
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
| 610 |
+
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
|
| 611 |
+
else:
|
| 612 |
+
x = r(x, attn_mask=attn_mask)
|
| 613 |
+
|
| 614 |
+
if not self.batch_first:
|
| 615 |
+
x = x.transpose(0, 1) # LND -> NLD
|
| 616 |
+
return x
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def _expand_token(token, batch_size: int):
|
| 620 |
+
return token.view(1, 1, -1).expand(batch_size, -1, -1)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
class VisionTransformer(nn.Module):
|
| 624 |
+
output_tokens: torch.jit.Final[bool]
|
| 625 |
+
|
| 626 |
+
def __init__(
|
| 627 |
+
self,
|
| 628 |
+
image_size: int,
|
| 629 |
+
patch_size: int,
|
| 630 |
+
width: int,
|
| 631 |
+
layers: int,
|
| 632 |
+
heads: int,
|
| 633 |
+
mlp_ratio: float,
|
| 634 |
+
ls_init_value: float = None,
|
| 635 |
+
attentional_pool: bool = False,
|
| 636 |
+
attn_pooler_queries: int = 256,
|
| 637 |
+
attn_pooler_heads: int = 8,
|
| 638 |
+
output_dim: int = 512,
|
| 639 |
+
patch_dropout: float = 0.,
|
| 640 |
+
no_ln_pre: bool = False,
|
| 641 |
+
pos_embed_type: str = 'learnable',
|
| 642 |
+
pool_type: str = 'tok',
|
| 643 |
+
final_ln_after_pool: bool = False,
|
| 644 |
+
act_layer: Callable = nn.GELU,
|
| 645 |
+
norm_layer: Callable = LayerNorm,
|
| 646 |
+
output_tokens: bool = False,
|
| 647 |
+
block_type: Optional[str] = None,
|
| 648 |
+
qk_norm: bool = False,
|
| 649 |
+
scaled_cosine_attn: bool = False,
|
| 650 |
+
scale_heads: bool = False,
|
| 651 |
+
scale_attn_inner: bool = False,
|
| 652 |
+
scale_attn: bool = False,
|
| 653 |
+
scale_fc: bool = False,
|
| 654 |
+
):
|
| 655 |
+
super().__init__()
|
| 656 |
+
assert pool_type in ('tok', 'avg', 'none')
|
| 657 |
+
self.output_tokens = output_tokens
|
| 658 |
+
image_height, image_width = self.image_size = to_2tuple(image_size)
|
| 659 |
+
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
|
| 660 |
+
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
| 661 |
+
self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
|
| 662 |
+
self.output_dim = output_dim
|
| 663 |
+
|
| 664 |
+
self.conv1 = nn.Conv2d(
|
| 665 |
+
in_channels=3,
|
| 666 |
+
out_channels=width,
|
| 667 |
+
kernel_size=patch_size,
|
| 668 |
+
stride=patch_size,
|
| 669 |
+
bias=False,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# class embeddings and positional embeddings
|
| 673 |
+
scale = width ** -0.5
|
| 674 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 675 |
+
if pos_embed_type == 'learnable':
|
| 676 |
+
self.positional_embedding = nn.Parameter(
|
| 677 |
+
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
| 678 |
+
elif pos_embed_type == 'sin_cos_2d':
|
| 679 |
+
# fixed sin-cos embedding
|
| 680 |
+
assert self.grid_size[0] == self.grid_size[1],\
|
| 681 |
+
'currently sin cos 2d pos embedding only supports square input'
|
| 682 |
+
self.positional_embedding = nn.Parameter(
|
| 683 |
+
torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)
|
| 684 |
+
pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)
|
| 685 |
+
self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())
|
| 686 |
+
else:
|
| 687 |
+
raise ValueError
|
| 688 |
+
|
| 689 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
| 690 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
| 691 |
+
|
| 692 |
+
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
|
| 693 |
+
self.transformer = Transformer(
|
| 694 |
+
width,
|
| 695 |
+
layers,
|
| 696 |
+
heads,
|
| 697 |
+
mlp_ratio,
|
| 698 |
+
ls_init_value=ls_init_value,
|
| 699 |
+
act_layer=act_layer,
|
| 700 |
+
norm_layer=norm_layer,
|
| 701 |
+
block_type=block_type,
|
| 702 |
+
qk_norm=qk_norm,
|
| 703 |
+
scaled_cosine_attn=scaled_cosine_attn,
|
| 704 |
+
scale_heads=scale_heads,
|
| 705 |
+
scale_attn_inner=scale_attn_inner,
|
| 706 |
+
scale_attn=scale_attn,
|
| 707 |
+
scale_fc=scale_fc,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
if attentional_pool:
|
| 711 |
+
if isinstance(attentional_pool, str):
|
| 712 |
+
self.attn_pool_type = attentional_pool
|
| 713 |
+
self.pool_type = 'none'
|
| 714 |
+
if attentional_pool in ('parallel', 'cascade'):
|
| 715 |
+
self.attn_pool = AttentionalPooler(
|
| 716 |
+
output_dim,
|
| 717 |
+
width,
|
| 718 |
+
n_head=attn_pooler_heads,
|
| 719 |
+
n_queries=attn_pooler_queries,
|
| 720 |
+
)
|
| 721 |
+
self.attn_pool_contrastive = AttentionalPooler(
|
| 722 |
+
output_dim,
|
| 723 |
+
width,
|
| 724 |
+
n_head=attn_pooler_heads,
|
| 725 |
+
n_queries=1,
|
| 726 |
+
)
|
| 727 |
+
else:
|
| 728 |
+
assert False
|
| 729 |
+
else:
|
| 730 |
+
self.attn_pool_type = ''
|
| 731 |
+
self.pool_type = pool_type
|
| 732 |
+
self.attn_pool = AttentionalPooler(
|
| 733 |
+
output_dim,
|
| 734 |
+
width,
|
| 735 |
+
n_head=attn_pooler_heads,
|
| 736 |
+
n_queries=attn_pooler_queries,
|
| 737 |
+
)
|
| 738 |
+
self.attn_pool_contrastive = None
|
| 739 |
+
pool_dim = output_dim
|
| 740 |
+
else:
|
| 741 |
+
self.attn_pool = None
|
| 742 |
+
pool_dim = width
|
| 743 |
+
self.pool_type = pool_type
|
| 744 |
+
|
| 745 |
+
self.ln_post = norm_layer(pool_dim)
|
| 746 |
+
self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
|
| 747 |
+
|
| 748 |
+
self.init_parameters()
|
| 749 |
+
|
| 750 |
+
def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
|
| 751 |
+
for param in self.parameters():
|
| 752 |
+
param.requires_grad = False
|
| 753 |
+
|
| 754 |
+
if unlocked_groups != 0:
|
| 755 |
+
groups = [
|
| 756 |
+
[
|
| 757 |
+
self.conv1,
|
| 758 |
+
self.class_embedding,
|
| 759 |
+
self.positional_embedding,
|
| 760 |
+
self.ln_pre,
|
| 761 |
+
],
|
| 762 |
+
*self.transformer.resblocks[:-1],
|
| 763 |
+
[
|
| 764 |
+
self.transformer.resblocks[-1],
|
| 765 |
+
self.ln_post,
|
| 766 |
+
],
|
| 767 |
+
self.proj,
|
| 768 |
+
]
|
| 769 |
+
|
| 770 |
+
def _unlock(x):
|
| 771 |
+
if isinstance(x, Sequence):
|
| 772 |
+
for g in x:
|
| 773 |
+
_unlock(g)
|
| 774 |
+
else:
|
| 775 |
+
if isinstance(x, torch.nn.Parameter):
|
| 776 |
+
x.requires_grad = True
|
| 777 |
+
else:
|
| 778 |
+
for p in x.parameters():
|
| 779 |
+
p.requires_grad = True
|
| 780 |
+
|
| 781 |
+
_unlock(groups[-unlocked_groups:])
|
| 782 |
+
|
| 783 |
+
def init_parameters(self):
|
| 784 |
+
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
|
| 785 |
+
# TODO experiment if default PyTorch init, below, or alternate init is best.
|
| 786 |
+
|
| 787 |
+
# nn.init.normal_(self.class_embedding, std=self.scale)
|
| 788 |
+
# nn.init.normal_(self.positional_embedding, std=self.scale)
|
| 789 |
+
#
|
| 790 |
+
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 791 |
+
# attn_std = self.transformer.width ** -0.5
|
| 792 |
+
# fc_std = (2 * self.transformer.width) ** -0.5
|
| 793 |
+
# for block in self.transformer.resblocks:
|
| 794 |
+
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 795 |
+
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 796 |
+
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 797 |
+
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 798 |
+
#
|
| 799 |
+
# if self.text_projection is not None:
|
| 800 |
+
# nn.init.normal_(self.text_projection, std=self.scale)
|
| 801 |
+
pass
|
| 802 |
+
|
| 803 |
+
@torch.jit.ignore
|
| 804 |
+
def set_grad_checkpointing(self, enable: bool = True):
|
| 805 |
+
self.transformer.grad_checkpointing = enable
|
| 806 |
+
|
| 807 |
+
@torch.jit.ignore
|
| 808 |
+
def no_weight_decay(self):
|
| 809 |
+
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
|
| 810 |
+
no_wd = {'positional_embedding', 'class_embedding'}
|
| 811 |
+
return no_wd
|
| 812 |
+
|
| 813 |
+
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 814 |
+
if self.pool_type == 'avg':
|
| 815 |
+
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
|
| 816 |
+
elif self.pool_type == 'tok':
|
| 817 |
+
pooled, tokens = x[:, 0], x[:, 1:]
|
| 818 |
+
else:
|
| 819 |
+
pooled = tokens = x
|
| 820 |
+
|
| 821 |
+
return pooled, tokens
|
| 822 |
+
|
| 823 |
+
def _embeds(self, x:torch.Tensor) -> torch.Tensor:
|
| 824 |
+
x = self.conv1(x) # shape = [*, dim, grid, grid]
|
| 825 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 826 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 827 |
+
|
| 828 |
+
# class embeddings and positional embeddings
|
| 829 |
+
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
|
| 830 |
+
# shape = [*, grid ** 2 + 1, width]
|
| 831 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 832 |
+
|
| 833 |
+
# patch dropout (if active)
|
| 834 |
+
x = self.patch_dropout(x)
|
| 835 |
+
|
| 836 |
+
# apply norm before transformer
|
| 837 |
+
x = self.ln_pre(x)
|
| 838 |
+
return x
|
| 839 |
+
|
| 840 |
+
def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 841 |
+
if self.attn_pool is not None:
|
| 842 |
+
if self.attn_pool_contrastive is not None:
|
| 843 |
+
# This is untested, WIP pooling that should match paper
|
| 844 |
+
x = self.ln_post(x) # TBD LN first or separate one after each pool?
|
| 845 |
+
tokens = self.attn_pool(x)
|
| 846 |
+
if self.attn_pool_type == 'parallel':
|
| 847 |
+
pooled = self.attn_pool_contrastive(x)
|
| 848 |
+
else:
|
| 849 |
+
assert self.attn_pool_type == 'cascade'
|
| 850 |
+
pooled = self.attn_pool_contrastive(tokens)
|
| 851 |
+
else:
|
| 852 |
+
# this is the original OpenCLIP CoCa setup, does not match paper
|
| 853 |
+
x = self.attn_pool(x)
|
| 854 |
+
x = self.ln_post(x)
|
| 855 |
+
pooled, tokens = self._global_pool(x)
|
| 856 |
+
elif self.final_ln_after_pool:
|
| 857 |
+
pooled, tokens = self._global_pool(x)
|
| 858 |
+
pooled = self.ln_post(pooled)
|
| 859 |
+
else:
|
| 860 |
+
x = self.ln_post(x)
|
| 861 |
+
pooled, tokens = self._global_pool(x)
|
| 862 |
+
|
| 863 |
+
return pooled, tokens
|
| 864 |
+
|
| 865 |
+
def forward_intermediates(
|
| 866 |
+
self,
|
| 867 |
+
x: torch.Tensor,
|
| 868 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 869 |
+
stop_early: bool = False,
|
| 870 |
+
normalize_intermediates: bool = False,
|
| 871 |
+
intermediates_only: bool = False,
|
| 872 |
+
output_fmt: str = 'NCHW',
|
| 873 |
+
output_extra_tokens: bool = False,
|
| 874 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
| 875 |
+
""" Forward features that returns intermediates.
|
| 876 |
+
|
| 877 |
+
Args:
|
| 878 |
+
x: Input image tensor
|
| 879 |
+
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 880 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 881 |
+
intermediates_only: Only return intermediate features
|
| 882 |
+
normalize_intermediates: Apply final norm layer to all intermediates
|
| 883 |
+
output_fmt: Shape of intermediate feature outputs
|
| 884 |
+
output_extra_tokens: Return both extra prefix class tokens
|
| 885 |
+
Returns:
|
| 886 |
+
|
| 887 |
+
"""
|
| 888 |
+
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
| 889 |
+
reshape = output_fmt == 'NCHW'
|
| 890 |
+
|
| 891 |
+
# forward pass
|
| 892 |
+
B, _, height, width = x.shape
|
| 893 |
+
x = self._embeds(x)
|
| 894 |
+
x, intermediates = self.transformer.forward_intermediates(
|
| 895 |
+
x,
|
| 896 |
+
indices=indices,
|
| 897 |
+
stop_early=stop_early,
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
# process intermediates
|
| 901 |
+
if normalize_intermediates:
|
| 902 |
+
# apply final norm to all intermediates
|
| 903 |
+
intermediates = [self.ln_post(xi) for xi in intermediates]
|
| 904 |
+
num_prefix_tokens = 1 # one class token that's always there (as of now)
|
| 905 |
+
if num_prefix_tokens:
|
| 906 |
+
# split prefix (e.g. class, distill) and spatial feature tokens
|
| 907 |
+
prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates]
|
| 908 |
+
intermediates = [y[:, num_prefix_tokens:] for y in intermediates]
|
| 909 |
+
else:
|
| 910 |
+
prefix_tokens = None
|
| 911 |
+
if reshape:
|
| 912 |
+
# reshape to BCHW output format
|
| 913 |
+
H, W = height // self.patch_size[0], width // self.patch_size[1]
|
| 914 |
+
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
| 915 |
+
|
| 916 |
+
output = {'image_intermediates': intermediates}
|
| 917 |
+
if prefix_tokens is not None and output_extra_tokens:
|
| 918 |
+
output['image_intermediates_prefix'] = prefix_tokens
|
| 919 |
+
|
| 920 |
+
if intermediates_only:
|
| 921 |
+
return output
|
| 922 |
+
|
| 923 |
+
pooled, _ = self._pool(x)
|
| 924 |
+
|
| 925 |
+
if self.proj is not None:
|
| 926 |
+
pooled = pooled @ self.proj
|
| 927 |
+
|
| 928 |
+
output['image_features'] = pooled
|
| 929 |
+
|
| 930 |
+
return output
|
| 931 |
+
|
| 932 |
+
def prune_intermediate_layers(
|
| 933 |
+
self,
|
| 934 |
+
indices: Union[int, List[int]] = 1,
|
| 935 |
+
prune_norm: bool = False,
|
| 936 |
+
prune_head: bool = True,
|
| 937 |
+
):
|
| 938 |
+
""" Prune layers not required for specified intermediates.
|
| 939 |
+
"""
|
| 940 |
+
take_indices = self.transformer.prune_intermediate_layers(indices)
|
| 941 |
+
if prune_norm:
|
| 942 |
+
self.ln_post = nn.Identity()
|
| 943 |
+
if prune_head:
|
| 944 |
+
self.proj = None
|
| 945 |
+
return take_indices
|
| 946 |
+
|
| 947 |
+
def forward(self, x: torch.Tensor):
|
| 948 |
+
x = self._embeds(x)
|
| 949 |
+
x = self.transformer(x)
|
| 950 |
+
pooled, tokens = self._pool(x)
|
| 951 |
+
|
| 952 |
+
if self.proj is not None:
|
| 953 |
+
pooled = pooled @ self.proj
|
| 954 |
+
|
| 955 |
+
if self.output_tokens:
|
| 956 |
+
return pooled, tokens
|
| 957 |
+
|
| 958 |
+
return pooled
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def text_global_pool(
|
| 962 |
+
x: torch.Tensor,
|
| 963 |
+
text: Optional[torch.Tensor] = None,
|
| 964 |
+
pool_type: str = 'argmax',
|
| 965 |
+
eos_token_id: Optional[int] = None,
|
| 966 |
+
) -> torch.Tensor:
|
| 967 |
+
if pool_type == 'first':
|
| 968 |
+
pooled = x[:, 0]
|
| 969 |
+
elif pool_type == 'last':
|
| 970 |
+
pooled = x[:, -1]
|
| 971 |
+
elif pool_type == 'argmax':
|
| 972 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 973 |
+
assert text is not None
|
| 974 |
+
pooled = x[torch.arange(x.shape[0], device=x.device), text.argmax(dim=-1)]
|
| 975 |
+
elif pool_type == 'eos':
|
| 976 |
+
# take features from tokenizer specific eos
|
| 977 |
+
assert text is not None
|
| 978 |
+
assert eos_token_id is not None
|
| 979 |
+
idx = (text == eos_token_id).int().argmax(dim=-1)
|
| 980 |
+
pooled = x[torch.arange(x.shape[0], device=x.device), idx]
|
| 981 |
+
else:
|
| 982 |
+
pooled = x
|
| 983 |
+
|
| 984 |
+
return pooled
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
class TextTransformer(nn.Module):
|
| 988 |
+
output_tokens: torch.jit.Final[bool]
|
| 989 |
+
|
| 990 |
+
def __init__(
|
| 991 |
+
self,
|
| 992 |
+
context_length: int = 77,
|
| 993 |
+
vocab_size: int = 49408,
|
| 994 |
+
width: int = 512,
|
| 995 |
+
heads: int = 8,
|
| 996 |
+
layers: int = 12,
|
| 997 |
+
mlp_ratio: float = 4.0,
|
| 998 |
+
ls_init_value: float = None,
|
| 999 |
+
output_dim: Optional[int] = 512,
|
| 1000 |
+
embed_cls: bool = False,
|
| 1001 |
+
no_causal_mask: bool = False,
|
| 1002 |
+
use_pad_mask: bool = False,
|
| 1003 |
+
correct_cls_mask: bool = False,
|
| 1004 |
+
pad_id: int = 0,
|
| 1005 |
+
eos_id: int = 2,
|
| 1006 |
+
pool_type: str = 'argmax',
|
| 1007 |
+
proj_type: str = 'linear',
|
| 1008 |
+
proj_bias: bool = False,
|
| 1009 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 1010 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 1011 |
+
output_tokens: bool = False,
|
| 1012 |
+
block_type: Optional[str] = None,
|
| 1013 |
+
qk_norm: bool = False,
|
| 1014 |
+
scaled_cosine_attn: bool = False,
|
| 1015 |
+
scale_heads: bool = False,
|
| 1016 |
+
scale_attn_inner: bool = False,
|
| 1017 |
+
scale_attn: bool = False,
|
| 1018 |
+
scale_fc: bool = False,
|
| 1019 |
+
):
|
| 1020 |
+
super().__init__()
|
| 1021 |
+
assert pool_type in ('first', 'last', 'argmax', 'eos', 'none')
|
| 1022 |
+
self.output_tokens = output_tokens
|
| 1023 |
+
self.num_pos = self.context_length = context_length
|
| 1024 |
+
self.vocab_size = vocab_size
|
| 1025 |
+
self.width = width
|
| 1026 |
+
self.output_dim = output_dim
|
| 1027 |
+
self.heads = heads
|
| 1028 |
+
self.pad_id = pad_id
|
| 1029 |
+
self.eos_id = eos_id
|
| 1030 |
+
self.pool_type = pool_type
|
| 1031 |
+
self.use_pad_mask = use_pad_mask and no_causal_mask # only use in bi‑dir mode
|
| 1032 |
+
self.correct_cls_mask = correct_cls_mask # use the correct cls mask for CoCa (original is wrong)
|
| 1033 |
+
|
| 1034 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
| 1035 |
+
if embed_cls:
|
| 1036 |
+
self.cls_emb = nn.Parameter(torch.empty(width))
|
| 1037 |
+
self.num_pos += 1
|
| 1038 |
+
else:
|
| 1039 |
+
self.cls_emb = None
|
| 1040 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
| 1041 |
+
self.transformer = Transformer(
|
| 1042 |
+
width=width,
|
| 1043 |
+
layers=layers,
|
| 1044 |
+
heads=heads,
|
| 1045 |
+
mlp_ratio=mlp_ratio,
|
| 1046 |
+
ls_init_value=ls_init_value,
|
| 1047 |
+
act_layer=act_layer,
|
| 1048 |
+
norm_layer=norm_layer,
|
| 1049 |
+
block_type=block_type,
|
| 1050 |
+
qk_norm=qk_norm,
|
| 1051 |
+
scaled_cosine_attn=scaled_cosine_attn,
|
| 1052 |
+
scale_heads=scale_heads,
|
| 1053 |
+
scale_attn_inner=scale_attn_inner,
|
| 1054 |
+
scale_attn=scale_attn,
|
| 1055 |
+
scale_fc=scale_fc,
|
| 1056 |
+
)
|
| 1057 |
+
self.ln_final = norm_layer(width)
|
| 1058 |
+
|
| 1059 |
+
if no_causal_mask:
|
| 1060 |
+
self.attn_mask = None # bi‑directional
|
| 1061 |
+
else:
|
| 1062 |
+
self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)
|
| 1063 |
+
|
| 1064 |
+
if proj_type == 'none' or not output_dim:
|
| 1065 |
+
self.text_projection = None
|
| 1066 |
+
else:
|
| 1067 |
+
if proj_bias:
|
| 1068 |
+
self.text_projection = nn.Linear(width, output_dim)
|
| 1069 |
+
else:
|
| 1070 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
| 1071 |
+
|
| 1072 |
+
self.init_parameters()
|
| 1073 |
+
|
| 1074 |
+
def init_parameters(self):
|
| 1075 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 1076 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 1077 |
+
if self.cls_emb is not None:
|
| 1078 |
+
nn.init.normal_(self.cls_emb, std=0.01)
|
| 1079 |
+
|
| 1080 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 1081 |
+
attn_std = self.transformer.width ** -0.5
|
| 1082 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 1083 |
+
for block in self.transformer.resblocks:
|
| 1084 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 1085 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 1086 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 1087 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 1088 |
+
|
| 1089 |
+
if self.text_projection is not None:
|
| 1090 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 1091 |
+
nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)
|
| 1092 |
+
if self.text_projection.bias is not None:
|
| 1093 |
+
nn.init.zeros_(self.text_projection.bias)
|
| 1094 |
+
else:
|
| 1095 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 1096 |
+
|
| 1097 |
+
@torch.jit.ignore
|
| 1098 |
+
def set_grad_checkpointing(self, enable=True):
|
| 1099 |
+
self.transformer.grad_checkpointing = enable
|
| 1100 |
+
|
| 1101 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
| 1102 |
+
"""
|
| 1103 |
+
Lock the text transformer layers, optionally leaving some layers unlocked.
|
| 1104 |
+
|
| 1105 |
+
Args:
|
| 1106 |
+
unlocked_layers: Number of layers to leave unlocked (from the end).
|
| 1107 |
+
freeze_layer_norm: LayerNorm freeze (only for API compatibility, not functional)
|
| 1108 |
+
"""
|
| 1109 |
+
assert freeze_layer_norm, 'Unfreezing LayerNorm is not supported. LayerNorm treated like other weights.'
|
| 1110 |
+
lock_text_tower(self, unlocked_layers)
|
| 1111 |
+
|
| 1112 |
+
@torch.jit.ignore
|
| 1113 |
+
def no_weight_decay(self):
|
| 1114 |
+
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
|
| 1115 |
+
no_wd = {'positional_embedding'}
|
| 1116 |
+
if self.cls_emb is not None:
|
| 1117 |
+
no_wd.add('cls_emb')
|
| 1118 |
+
return no_wd
|
| 1119 |
+
|
| 1120 |
+
def build_causal_mask(self):
|
| 1121 |
+
# lazily create causal attention mask, with full attention between the tokens
|
| 1122 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 1123 |
+
mask = torch.empty(self.num_pos, self.num_pos)
|
| 1124 |
+
mask.fill_(float("-inf"))
|
| 1125 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 1126 |
+
return mask
|
| 1127 |
+
|
| 1128 |
+
def _build_additive_mask(
|
| 1129 |
+
self,
|
| 1130 |
+
text: torch.Tensor, # [B, L] – original text ids without CLS yet
|
| 1131 |
+
seq_len: int, # L (+1 if CLS added)
|
| 1132 |
+
dtype: torch.dtype,
|
| 1133 |
+
) -> torch.Tensor:
|
| 1134 |
+
"""
|
| 1135 |
+
Returns an additive (-inf) mask of shape [B*heads, seq_len, seq_len] that
|
| 1136 |
+
simultaneously masks padding tokens and (optionally) the CLS token.
|
| 1137 |
+
"""
|
| 1138 |
+
valid = text != self.pad_id # [B, L] (True = keep)
|
| 1139 |
+
|
| 1140 |
+
if self.cls_emb is not None:
|
| 1141 |
+
cls_valid = valid.new_ones(valid.size(0), 1) # [B, 1]
|
| 1142 |
+
# cls mask pos at end if correct or front for incorrect legacy mode in existing CoCa weights
|
| 1143 |
+
valid = torch.cat([valid, cls_valid] if self.correct_cls_mask else [cls_valid, valid], 1)
|
| 1144 |
+
|
| 1145 |
+
# broadcast over query dimension
|
| 1146 |
+
key_mask = valid.unsqueeze(1).expand(-1, seq_len, -1) # [B, Q, K]
|
| 1147 |
+
additive = torch.zeros_like(key_mask, dtype=dtype)
|
| 1148 |
+
additive.masked_fill_(~key_mask, float("-inf"))
|
| 1149 |
+
additive = additive.repeat_interleave(self.heads, 0) # [B*H, Q, K]
|
| 1150 |
+
return additive
|
| 1151 |
+
|
| 1152 |
+
def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 1153 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
| 1154 |
+
B, seq_len = text.shape
|
| 1155 |
+
|
| 1156 |
+
x = self.token_embedding(text).to(cast_dtype)
|
| 1157 |
+
|
| 1158 |
+
# Optional class token (always appended ala CoCa)
|
| 1159 |
+
if self.cls_emb is not None:
|
| 1160 |
+
x = torch.cat([x, _expand_token(self.cls_emb, x.size(0))], 1)
|
| 1161 |
+
seq_len += 1
|
| 1162 |
+
|
| 1163 |
+
attn_mask = self.attn_mask # Base causal mask (if any)
|
| 1164 |
+
|
| 1165 |
+
# Class + padding additive mask
|
| 1166 |
+
if self.use_pad_mask or self.cls_emb is not None:
|
| 1167 |
+
add_mask = self._build_additive_mask(text, seq_len, x.dtype)
|
| 1168 |
+
if attn_mask is not None:
|
| 1169 |
+
# Slice the causal mask to match current sequence length
|
| 1170 |
+
attn_mask = attn_mask[:seq_len, :seq_len].unsqueeze(0) + add_mask
|
| 1171 |
+
else:
|
| 1172 |
+
attn_mask = add_mask
|
| 1173 |
+
|
| 1174 |
+
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
|
| 1175 |
+
return x, attn_mask
|
| 1176 |
+
|
| 1177 |
+
def forward_intermediates(
|
| 1178 |
+
self,
|
| 1179 |
+
text: torch.Tensor,
|
| 1180 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 1181 |
+
stop_early: bool = False,
|
| 1182 |
+
normalize_intermediates: bool = False,
|
| 1183 |
+
intermediates_only: bool = False,
|
| 1184 |
+
output_fmt: str = 'NCHW',
|
| 1185 |
+
output_extra_tokens: bool = False,
|
| 1186 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
| 1187 |
+
""" Forward features that returns intermediates.
|
| 1188 |
+
|
| 1189 |
+
Args:
|
| 1190 |
+
text: Input text ids
|
| 1191 |
+
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 1192 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 1193 |
+
normalize_intermediates: Apply norm layer to all intermediates
|
| 1194 |
+
intermediates_only: Only return intermediate features
|
| 1195 |
+
output_fmt: Shape of intermediate feature outputs
|
| 1196 |
+
output_extra_tokens: Return both prefix and intermediate tokens
|
| 1197 |
+
Returns:
|
| 1198 |
+
|
| 1199 |
+
"""
|
| 1200 |
+
assert output_fmt in ('NLC',), 'Output format must be NLC.'
|
| 1201 |
+
# forward pass
|
| 1202 |
+
x, attn_mask = self._embeds(text)
|
| 1203 |
+
x, intermediates = self.transformer.forward_intermediates(
|
| 1204 |
+
x,
|
| 1205 |
+
attn_mask=attn_mask,
|
| 1206 |
+
indices=indices,
|
| 1207 |
+
stop_early=stop_early,
|
| 1208 |
+
)
|
| 1209 |
+
|
| 1210 |
+
# process intermediates
|
| 1211 |
+
if normalize_intermediates:
|
| 1212 |
+
# apply final norm to all intermediates
|
| 1213 |
+
intermediates = [self.ln_final(xi) for xi in intermediates]
|
| 1214 |
+
|
| 1215 |
+
output = {}
|
| 1216 |
+
|
| 1217 |
+
if self.cls_emb is not None:
|
| 1218 |
+
seq_intermediates = [xi[:, :-1] for xi in intermediates] # separate concat'd class token from sequence
|
| 1219 |
+
if output_extra_tokens:
|
| 1220 |
+
# return suffix class tokens separately
|
| 1221 |
+
cls_intermediates = [xi[:, -1:] for xi in intermediates]
|
| 1222 |
+
output['text_intermediates_suffix'] = cls_intermediates
|
| 1223 |
+
intermediates = seq_intermediates
|
| 1224 |
+
output['text_intermediates'] = intermediates
|
| 1225 |
+
|
| 1226 |
+
if intermediates_only:
|
| 1227 |
+
return output
|
| 1228 |
+
|
| 1229 |
+
if self.cls_emb is not None:
|
| 1230 |
+
# presence of appended cls embed (CoCa) overrides pool_type, always take last token
|
| 1231 |
+
pooled = text_global_pool(x, pool_type='last')
|
| 1232 |
+
pooled = self.ln_final(pooled) # final LN applied after pooling in this case
|
| 1233 |
+
else:
|
| 1234 |
+
x = self.ln_final(x)
|
| 1235 |
+
pooled = text_global_pool(x, text, pool_type=self.pool_type, eos_token_id=getattr(self, "eos_id", None))
|
| 1236 |
+
|
| 1237 |
+
if self.text_projection is not None:
|
| 1238 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 1239 |
+
pooled = self.text_projection(pooled)
|
| 1240 |
+
else:
|
| 1241 |
+
pooled = pooled @ self.text_projection
|
| 1242 |
+
|
| 1243 |
+
output['text_features'] = pooled
|
| 1244 |
+
|
| 1245 |
+
return output
|
| 1246 |
+
|
| 1247 |
+
def prune_intermediate_layers(
|
| 1248 |
+
self,
|
| 1249 |
+
indices: Union[int, List[int]] = 1,
|
| 1250 |
+
prune_norm: bool = False,
|
| 1251 |
+
prune_head: bool = True,
|
| 1252 |
+
):
|
| 1253 |
+
""" Prune layers not required for specified intermediates.
|
| 1254 |
+
"""
|
| 1255 |
+
take_indices = self.transformer.prune_intermediate_layers(indices)
|
| 1256 |
+
if prune_norm:
|
| 1257 |
+
self.ln_final = nn.Identity()
|
| 1258 |
+
if prune_head:
|
| 1259 |
+
self.text_projection = None
|
| 1260 |
+
return take_indices
|
| 1261 |
+
|
| 1262 |
+
def forward(self, text):
|
| 1263 |
+
x, attn_mask = self._embeds(text)
|
| 1264 |
+
|
| 1265 |
+
x = self.transformer(x, attn_mask=attn_mask)
|
| 1266 |
+
|
| 1267 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 1268 |
+
if self.cls_emb is not None:
|
| 1269 |
+
# presence of appended cls embed (CoCa) overrides pool_type, always take last token
|
| 1270 |
+
pooled = text_global_pool(x, pool_type='last')
|
| 1271 |
+
pooled = self.ln_final(pooled) # final LN applied after pooling in this case
|
| 1272 |
+
tokens = x[:, :-1]
|
| 1273 |
+
else:
|
| 1274 |
+
x = self.ln_final(x)
|
| 1275 |
+
pooled = text_global_pool(x, text, pool_type=self.pool_type, eos_token_id=getattr(self, "eos_id", None))
|
| 1276 |
+
tokens = x
|
| 1277 |
+
|
| 1278 |
+
if self.text_projection is not None:
|
| 1279 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 1280 |
+
pooled = self.text_projection(pooled)
|
| 1281 |
+
else:
|
| 1282 |
+
pooled = pooled @ self.text_projection
|
| 1283 |
+
|
| 1284 |
+
if self.output_tokens:
|
| 1285 |
+
return pooled, tokens
|
| 1286 |
+
|
| 1287 |
+
return pooled
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
class MultimodalTransformer(Transformer):
|
| 1291 |
+
"""Cross-attention based multimodal decoder.
|
| 1292 |
+
|
| 1293 |
+
Text and image/biosignals embeddings are kept separate.
|
| 1294 |
+
Each layer has:
|
| 1295 |
+
1. Self-attention on text (causal)
|
| 1296 |
+
2. Cross-attention from text to image/biosignals
|
| 1297 |
+
"""
|
| 1298 |
+
def __init__(
|
| 1299 |
+
self,
|
| 1300 |
+
width: int,
|
| 1301 |
+
layers: int,
|
| 1302 |
+
heads: int,
|
| 1303 |
+
context_length: int = 77,
|
| 1304 |
+
mlp_ratio: float = 4.0,
|
| 1305 |
+
ls_init_value: float = None,
|
| 1306 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 1307 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 1308 |
+
output_dim: int = 512,
|
| 1309 |
+
batch_first: bool = True,
|
| 1310 |
+
prefix_len: int = 0,
|
| 1311 |
+
):
|
| 1312 |
+
super().__init__(
|
| 1313 |
+
width=width,
|
| 1314 |
+
layers=layers,
|
| 1315 |
+
heads=heads,
|
| 1316 |
+
mlp_ratio=mlp_ratio,
|
| 1317 |
+
ls_init_value=ls_init_value,
|
| 1318 |
+
act_layer=act_layer,
|
| 1319 |
+
norm_layer=norm_layer,
|
| 1320 |
+
batch_first=batch_first,
|
| 1321 |
+
)
|
| 1322 |
+
self.context_length = context_length
|
| 1323 |
+
self.cross_attn = nn.ModuleList([
|
| 1324 |
+
ResidualAttentionBlock(
|
| 1325 |
+
width,
|
| 1326 |
+
heads,
|
| 1327 |
+
mlp_ratio,
|
| 1328 |
+
ls_init_value=ls_init_value,
|
| 1329 |
+
act_layer=act_layer,
|
| 1330 |
+
norm_layer=norm_layer,
|
| 1331 |
+
is_cross_attention=True,
|
| 1332 |
+
batch_first=batch_first,
|
| 1333 |
+
)
|
| 1334 |
+
for _ in range(layers)
|
| 1335 |
+
])
|
| 1336 |
+
|
| 1337 |
+
# Register attention masks based on prefix configuration
|
| 1338 |
+
self.prefix_len = prefix_len
|
| 1339 |
+
if prefix_len > 0:
|
| 1340 |
+
# Pre-build prefix-causal mask for condition tokens + text
|
| 1341 |
+
prefix_causal_mask = self.build_prefix_causal_mask(prefix_len, context_length)
|
| 1342 |
+
self.register_buffer('prefix_causal_mask', prefix_causal_mask, persistent=False)
|
| 1343 |
+
else:
|
| 1344 |
+
# Only register standard causal mask when not using prefix tokens
|
| 1345 |
+
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
| 1346 |
+
|
| 1347 |
+
self.ln_final = norm_layer(width)
|
| 1348 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
| 1349 |
+
|
| 1350 |
+
self.init_parameters()
|
| 1351 |
+
|
| 1352 |
+
def init_parameters(self):
|
| 1353 |
+
proj_std = (self.width ** -0.5) * ((2 * self.layers) ** -0.5)
|
| 1354 |
+
attn_std = self.width ** -0.5
|
| 1355 |
+
fc_std = (2 * self.width) ** -0.5
|
| 1356 |
+
for block in self.resblocks:
|
| 1357 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 1358 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 1359 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 1360 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 1361 |
+
for block in self.cross_attn:
|
| 1362 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 1363 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 1364 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 1365 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 1366 |
+
|
| 1367 |
+
if self.text_projection is not None:
|
| 1368 |
+
nn.init.normal_(self.text_projection, std=self.width ** -0.5)
|
| 1369 |
+
|
| 1370 |
+
def build_attention_mask(self):
|
| 1371 |
+
# lazily create causal attention mask, with full attention between the tokens
|
| 1372 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 1373 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 1374 |
+
mask.fill_(float("-inf"))
|
| 1375 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 1376 |
+
return mask
|
| 1377 |
+
|
| 1378 |
+
# def build_prefix_causal_mask(self, prefix_len: int, text_len: int):
|
| 1379 |
+
# """Build a prefix-causal attention mask for condition tokens + text.
|
| 1380 |
+
|
| 1381 |
+
# Args:
|
| 1382 |
+
# prefix_len: Length of prefix (condition tokens)
|
| 1383 |
+
# These tokens receive full bidirectional attention among themselves.
|
| 1384 |
+
# text_len: Length of text sequence
|
| 1385 |
+
# These tokens receive causal attention.
|
| 1386 |
+
|
| 1387 |
+
# Returns:
|
| 1388 |
+
# Additive mask of shape (prefix_len + text_len, prefix_len + text_len)
|
| 1389 |
+
# Where -inf = cannot attend, 0 = can attend
|
| 1390 |
+
|
| 1391 |
+
# Attention pattern:
|
| 1392 |
+
# - Prefix tokens ↔ Prefix tokens: Full bidirectional (can attend)
|
| 1393 |
+
# - Text tokens → Prefix tokens: Full attention (can attend)
|
| 1394 |
+
# - Text tokens → Text tokens: Causal attention (only previous tokens)
|
| 1395 |
+
# - Prefix tokens → Text tokens: Cannot attend (masked)
|
| 1396 |
+
# """
|
| 1397 |
+
# total_len = prefix_len + text_len
|
| 1398 |
+
# mask = torch.zeros(total_len, total_len)
|
| 1399 |
+
|
| 1400 |
+
# # Prefix tokens can attend to all prefix tokens (bidirectional)
|
| 1401 |
+
# # mask[:prefix_len, :prefix_len] remains 0 (can attend)
|
| 1402 |
+
|
| 1403 |
+
# # Prefix tokens cannot attend to text tokens
|
| 1404 |
+
# mask[:prefix_len, prefix_len:] = float("-inf")
|
| 1405 |
+
|
| 1406 |
+
# # Text tokens can attend to all prefix tokens
|
| 1407 |
+
# # mask[prefix_len:, :prefix_len] remains 0 (can attend)
|
| 1408 |
+
|
| 1409 |
+
# # Text tokens attend to previous text tokens only (causal)
|
| 1410 |
+
# text_causal_mask = torch.triu(torch.ones(text_len, text_len), diagonal=1) * float("-inf")
|
| 1411 |
+
# mask[prefix_len:, prefix_len:] = text_causal_mask
|
| 1412 |
+
|
| 1413 |
+
# return mask
|
| 1414 |
+
|
| 1415 |
+
def build_prefix_causal_mask(self, prefix_len: int, text_len: int):
|
| 1416 |
+
"""Additive mask; 0 = attend, NEG = block (fp32 for stability)."""
|
| 1417 |
+
total_len = prefix_len + text_len
|
| 1418 |
+
# fp32 on CPU; we'll .to(device) later without changing dtype
|
| 1419 |
+
mask = torch.zeros(total_len, total_len, dtype=torch.float32)
|
| 1420 |
+
|
| 1421 |
+
# large finite negative (safer than -inf for fp16/bf16 kernels)
|
| 1422 |
+
NEG = -torch.finfo(mask.dtype).max
|
| 1423 |
+
|
| 1424 |
+
# Prefix → Text: block
|
| 1425 |
+
mask[:prefix_len, prefix_len:] = NEG
|
| 1426 |
+
|
| 1427 |
+
# Text → Text: causal (block future). Use masked_fill, not 0 * -inf.
|
| 1428 |
+
tri = torch.triu(torch.ones(text_len, text_len, dtype=torch.bool), diagonal=1)
|
| 1429 |
+
mask[prefix_len:, prefix_len:].masked_fill_(tri, NEG)
|
| 1430 |
+
return mask
|
| 1431 |
+
|
| 1432 |
+
def forward_intermediates(
|
| 1433 |
+
self,
|
| 1434 |
+
x: torch.Tensor,
|
| 1435 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 1436 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 1437 |
+
stop_early: bool = False,
|
| 1438 |
+
):
|
| 1439 |
+
assert False, "Not currently implemented for MultimodalTransformer w/ xattn"
|
| 1440 |
+
|
| 1441 |
+
def forward(self, image_embs, text_embs, condition_embs=None):
|
| 1442 |
+
"""Forward pass with cross-attention between text and image.
|
| 1443 |
+
|
| 1444 |
+
Args:
|
| 1445 |
+
image_embs: (batch_size, num_image_tokens, width)
|
| 1446 |
+
text_embs: (batch_size, num_text_tokens, width)
|
| 1447 |
+
condition_embs: Optional (batch_size, num_condition_tokens, width)
|
| 1448 |
+
Additional conditioning tokens that will be prepended to text.
|
| 1449 |
+
These tokens get full bidirectional attention among themselves,
|
| 1450 |
+
then cross-attend to image embeddings.
|
| 1451 |
+
|
| 1452 |
+
Returns:
|
| 1453 |
+
Text decoder outputs: (batch_size, num_text_tokens, output_dim)
|
| 1454 |
+
Note: Only text token outputs are returned (condition token outputs are excluded)
|
| 1455 |
+
"""
|
| 1456 |
+
# Determine text length before prepending conditions
|
| 1457 |
+
original_text_len = text_embs.shape[1]
|
| 1458 |
+
assert original_text_len <= self.context_length, "original_text_len must be less than or equal to context_length"
|
| 1459 |
+
|
| 1460 |
+
# Prepend condition tokens to text if provided
|
| 1461 |
+
if condition_embs is not None:
|
| 1462 |
+
condition_len = condition_embs.shape[1]
|
| 1463 |
+
|
| 1464 |
+
# Safety check: condition_len must not exceed the pre-configured prefix_len
|
| 1465 |
+
assert condition_len <= self.prefix_len, \
|
| 1466 |
+
f"condition_len ({condition_len}) exceeds prefix_len ({self.prefix_len})"
|
| 1467 |
+
|
| 1468 |
+
text_embs = torch.cat([condition_embs, text_embs], dim=1) # (batch, cond_len + text_len, width)
|
| 1469 |
+
else:
|
| 1470 |
+
condition_len = 0
|
| 1471 |
+
|
| 1472 |
+
# Get attention mask based on prefix configuration
|
| 1473 |
+
if self.prefix_len > 0:
|
| 1474 |
+
# Slice the pre-built prefix-causal mask based on actual condition_len
|
| 1475 |
+
# The mask is built for (prefix_len + context_length)
|
| 1476 |
+
# When condition_len < prefix_len, we slice from offset to get the right structure
|
| 1477 |
+
offset = self.prefix_len - condition_len # How many prefix positions to skip
|
| 1478 |
+
seq_len = condition_len + original_text_len # Total sequence length
|
| 1479 |
+
attn_mask = self.prefix_causal_mask[offset:offset+seq_len, offset:offset+seq_len].to(device=text_embs.device)
|
| 1480 |
+
else:
|
| 1481 |
+
# Use standard causal mask when prefix_len == 0
|
| 1482 |
+
seq_len = original_text_len
|
| 1483 |
+
attn_mask = self.attn_mask[:seq_len, :seq_len].to(device=text_embs.device)
|
| 1484 |
+
|
| 1485 |
+
if not self.batch_first:
|
| 1486 |
+
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
|
| 1487 |
+
text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
|
| 1488 |
+
|
| 1489 |
+
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
|
| 1490 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 1491 |
+
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
| 1492 |
+
text_embs = checkpoint(
|
| 1493 |
+
resblock, text_embs, None, None, attn_mask, use_reentrant=False)
|
| 1494 |
+
text_embs = checkpoint(
|
| 1495 |
+
cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False)
|
| 1496 |
+
else:
|
| 1497 |
+
text_embs = resblock(text_embs, attn_mask=attn_mask)
|
| 1498 |
+
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
|
| 1499 |
+
|
| 1500 |
+
if not self.batch_first:
|
| 1501 |
+
text_embs = text_embs.permute(1, 0, 2) # LND -> NLD
|
| 1502 |
+
|
| 1503 |
+
out = self.ln_final(text_embs)
|
| 1504 |
+
if self.text_projection is not None:
|
| 1505 |
+
out = out @ self.text_projection
|
| 1506 |
+
|
| 1507 |
+
# Extract only the text portion (skip condition tokens if present)
|
| 1508 |
+
if condition_len > 0:
|
| 1509 |
+
out = out[:, condition_len:, :] # (batch, text_len, output_dim)
|
| 1510 |
+
|
| 1511 |
+
return out
|
| 1512 |
+
|
| 1513 |
+
@torch.jit.ignore
|
| 1514 |
+
def set_grad_checkpointing(self, enable=True):
|
| 1515 |
+
self.grad_checkpointing = enable
|
| 1516 |
+
|
| 1517 |
+
|
| 1518 |
+
class ConcatMultimodalTransformer(Transformer):
|
| 1519 |
+
"""Concatenation-based multimodal decoder.
|
| 1520 |
+
|
| 1521 |
+
Concatenates [condition_tokens (optional), image/biosignals_tokens, text_tokens] into a single sequence.
|
| 1522 |
+
Uses unified self-attention with a prefix-causal mask that allows:
|
| 1523 |
+
- Condition tokens attend to all condition + image tokens (full bidirectional)
|
| 1524 |
+
- Image/biosignals tokens attend to all condition + image tokens (full bidirectional)
|
| 1525 |
+
- Text tokens attend to all condition + image tokens (full attention to prefix)
|
| 1526 |
+
- Text tokens attend to all previous text tokens (causal)
|
| 1527 |
+
|
| 1528 |
+
This enables flexible conditioning where any prefix tokens (condition + image) get full
|
| 1529 |
+
bidirectional attention, while text maintains causal generation properties.
|
| 1530 |
+
"""
|
| 1531 |
+
def __init__(
|
| 1532 |
+
self,
|
| 1533 |
+
width: int,
|
| 1534 |
+
layers: int,
|
| 1535 |
+
heads: int,
|
| 1536 |
+
context_length: int = 77,
|
| 1537 |
+
mlp_ratio: float = 4.0,
|
| 1538 |
+
ls_init_value: float = None,
|
| 1539 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 1540 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 1541 |
+
output_dim: int = 512,
|
| 1542 |
+
batch_first: bool = True,
|
| 1543 |
+
prefix_len: int = 0,
|
| 1544 |
+
):
|
| 1545 |
+
super().__init__(
|
| 1546 |
+
width=width,
|
| 1547 |
+
layers=layers,
|
| 1548 |
+
heads=heads,
|
| 1549 |
+
mlp_ratio=mlp_ratio,
|
| 1550 |
+
ls_init_value=ls_init_value,
|
| 1551 |
+
act_layer=act_layer,
|
| 1552 |
+
norm_layer=norm_layer,
|
| 1553 |
+
batch_first=batch_first,
|
| 1554 |
+
)
|
| 1555 |
+
self.context_length = context_length
|
| 1556 |
+
self.condition_prefix_len = prefix_len # Number of condition tokens (0, 1, or N)
|
| 1557 |
+
|
| 1558 |
+
# Pre-register an empty buffer for the attention mask
|
| 1559 |
+
# Will be populated on first forward pass when image token count is known
|
| 1560 |
+
self.register_buffer('_cached_attn_mask', torch.empty(0), persistent=False)
|
| 1561 |
+
self._cached_prefix_len = None # Track the prefix length used to build the cache
|
| 1562 |
+
|
| 1563 |
+
# No cross-attention layers needed - uses self-attention only
|
| 1564 |
+
self.ln_final = norm_layer(width)
|
| 1565 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
| 1566 |
+
|
| 1567 |
+
# self.init_parameters()
|
| 1568 |
+
|
| 1569 |
+
def init_parameters(self):
|
| 1570 |
+
proj_std = (self.width ** -0.5) * ((2 * self.layers) ** -0.5)
|
| 1571 |
+
attn_std = self.width ** -0.5
|
| 1572 |
+
fc_std = (2 * self.width) ** -0.5
|
| 1573 |
+
for block in self.resblocks:
|
| 1574 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 1575 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 1576 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 1577 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 1578 |
+
|
| 1579 |
+
if self.text_projection is not None:
|
| 1580 |
+
nn.init.normal_(self.text_projection, std=self.width ** -0.5)
|
| 1581 |
+
|
| 1582 |
+
# def build_prefix_causal_mask(self, prefix_len: int, text_len: int):
|
| 1583 |
+
# """Build a prefix-causal attention mask.
|
| 1584 |
+
|
| 1585 |
+
# Args:
|
| 1586 |
+
# prefix_len: Length of the prefix (condition + image/biosignals tokens)
|
| 1587 |
+
# All prefix tokens receive full bidirectional attention among themselves.
|
| 1588 |
+
# text_len: Length of text sequence
|
| 1589 |
+
|
| 1590 |
+
# Returns:
|
| 1591 |
+
# Additive mask of shape (prefix_len + text_len, prefix_len + text_len)
|
| 1592 |
+
# Where -inf = cannot attend, 0 = can attend
|
| 1593 |
+
|
| 1594 |
+
# Attention pattern:
|
| 1595 |
+
# - Prefix tokens ↔ Prefix tokens: Full bidirectional (can attend)
|
| 1596 |
+
# - Text tokens → Prefix tokens: Full attention (can attend)
|
| 1597 |
+
# - Text tokens → Text tokens: Causal attention (only previous tokens)
|
| 1598 |
+
# - Prefix tokens → Text tokens: Cannot attend (masked)
|
| 1599 |
+
# """
|
| 1600 |
+
# total_len = prefix_len + text_len
|
| 1601 |
+
# # Start with a float mask of zeros (all positions can attend)
|
| 1602 |
+
# mask = torch.zeros(total_len, total_len, dtype=torch.float32)
|
| 1603 |
+
|
| 1604 |
+
# # Prefix tokens can attend to all prefix tokens (bidirectional)
|
| 1605 |
+
# # mask[:prefix_len, :prefix_len] remains 0 (can attend)
|
| 1606 |
+
|
| 1607 |
+
# # Prefix tokens CANNOT attend to text tokens (CRITICAL FIX)
|
| 1608 |
+
# mask[:prefix_len, prefix_len:] = float("-inf")
|
| 1609 |
+
|
| 1610 |
+
# # Text tokens can attend to all prefix tokens
|
| 1611 |
+
# # mask[prefix_len:, :prefix_len] remains 0 (can attend)
|
| 1612 |
+
|
| 1613 |
+
# # Text tokens attend to previous text tokens only (causal)
|
| 1614 |
+
# text_causal_mask = torch.triu(torch.ones(text_len, text_len), diagonal=1) * float("-inf")
|
| 1615 |
+
# mask[prefix_len:, prefix_len:] = text_causal_mask
|
| 1616 |
+
|
| 1617 |
+
# return mask
|
| 1618 |
+
|
| 1619 |
+
def build_prefix_causal_mask(self, prefix_len: int, text_len: int):
|
| 1620 |
+
"""Additive mask; 0 = attend, NEG = block (fp32 for stability)."""
|
| 1621 |
+
total_len = prefix_len + text_len
|
| 1622 |
+
# build in fp32; move to GPU later with .to(device=...) but DON'T cast dtype
|
| 1623 |
+
mask = torch.zeros(total_len, total_len, dtype=torch.float32)
|
| 1624 |
+
|
| 1625 |
+
# large finite negative (safer than -inf with fp16/bf16 + fused kernels)
|
| 1626 |
+
NEG = -torch.finfo(mask.dtype).max
|
| 1627 |
+
|
| 1628 |
+
# Prefix → Text: block
|
| 1629 |
+
mask[:prefix_len, prefix_len:] = NEG
|
| 1630 |
+
|
| 1631 |
+
# Text → Text: causal (block future). Use masked_fill, not multiply by -inf.
|
| 1632 |
+
tri = torch.triu(torch.ones(text_len, text_len, dtype=torch.bool), diagonal=1)
|
| 1633 |
+
mask[prefix_len:, prefix_len:].masked_fill_(tri, NEG)
|
| 1634 |
+
return mask
|
| 1635 |
+
|
| 1636 |
+
def forward_intermediates(
|
| 1637 |
+
self,
|
| 1638 |
+
x: torch.Tensor,
|
| 1639 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 1640 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 1641 |
+
stop_early: bool = False,
|
| 1642 |
+
):
|
| 1643 |
+
assert False, "Not currently implemented for ConcatMultimodalTransformer"
|
| 1644 |
+
|
| 1645 |
+
def forward(self, image_embs, text_embs, condition_embs=None):
|
| 1646 |
+
"""Forward pass with concatenated embeddings.
|
| 1647 |
+
|
| 1648 |
+
Args:
|
| 1649 |
+
image_embs: (batch_size, num_image_tokens, width)
|
| 1650 |
+
text_embs: (batch_size, num_text_tokens, width)
|
| 1651 |
+
condition_embs: Optional (batch_size, num_condition_tokens, width)
|
| 1652 |
+
Additional conditioning tokens that will be prepended before image tokens.
|
| 1653 |
+
These tokens receive full bidirectional attention like image tokens.
|
| 1654 |
+
|
| 1655 |
+
Returns:
|
| 1656 |
+
Text decoder outputs: (batch_size, num_text_tokens, output_dim)
|
| 1657 |
+
"""
|
| 1658 |
+
batch_size = text_embs.shape[0]
|
| 1659 |
+
text_len = text_embs.shape[1]
|
| 1660 |
+
|
| 1661 |
+
# Guard: text length must not exceed context length
|
| 1662 |
+
assert text_len <= self.context_length, \
|
| 1663 |
+
f"text_len ({text_len}) must be <= context_length ({self.context_length})"
|
| 1664 |
+
|
| 1665 |
+
# Build prefix: [condition_tokens (optional), image_tokens]
|
| 1666 |
+
# All prefix tokens get full bidirectional attention
|
| 1667 |
+
if condition_embs is not None:
|
| 1668 |
+
condition_len = condition_embs.shape[1]
|
| 1669 |
+
|
| 1670 |
+
# Safety check: condition_len must not exceed the pre-configured condition_prefix_len
|
| 1671 |
+
assert condition_len <= self.condition_prefix_len, \
|
| 1672 |
+
f"condition_len ({condition_len}) exceeds condition_prefix_len ({self.condition_prefix_len})"
|
| 1673 |
+
|
| 1674 |
+
prefix = torch.cat([condition_embs, image_embs], dim=1) # (batch, cond_len + img_len, width)
|
| 1675 |
+
else:
|
| 1676 |
+
condition_len = 0
|
| 1677 |
+
prefix = image_embs
|
| 1678 |
+
|
| 1679 |
+
prefix_len = prefix.shape[1] # Total prefix length (condition + image tokens)
|
| 1680 |
+
|
| 1681 |
+
# Concatenate prefix and text embeddings
|
| 1682 |
+
x = torch.cat([prefix, text_embs], dim=1) # (batch, prefix_len + text_len, width)
|
| 1683 |
+
|
| 1684 |
+
if not self.batch_first:
|
| 1685 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 1686 |
+
|
| 1687 |
+
# Build or retrieve cached prefix-causal attention mask
|
| 1688 |
+
# Dynamically rebuilds when prefix_len changes (handles variable condition_len or image_len)
|
| 1689 |
+
if self._cached_prefix_len != prefix_len or self._cached_attn_mask.numel() == 0:
|
| 1690 |
+
# Build mask for max possible text length (context_length)
|
| 1691 |
+
mask = self.build_prefix_causal_mask(prefix_len, self.context_length)
|
| 1692 |
+
|
| 1693 |
+
# Directly update the buffer (already registered in __init__)
|
| 1694 |
+
self._cached_attn_mask = mask
|
| 1695 |
+
self._cached_prefix_len = prefix_len
|
| 1696 |
+
|
| 1697 |
+
# Slice cached mask to actual sequence length
|
| 1698 |
+
seq_len = prefix_len + text_len
|
| 1699 |
+
attn_mask = self._cached_attn_mask[:seq_len, :seq_len].to(device=x.device)
|
| 1700 |
+
|
| 1701 |
+
# Apply transformer layers with unified self-attention
|
| 1702 |
+
for resblock in self.resblocks:
|
| 1703 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 1704 |
+
x = checkpoint(resblock, x, None, None, attn_mask, use_reentrant=False)
|
| 1705 |
+
else:
|
| 1706 |
+
x = resblock(x, attn_mask=attn_mask)
|
| 1707 |
+
|
| 1708 |
+
if not self.batch_first:
|
| 1709 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 1710 |
+
|
| 1711 |
+
# Apply final layer norm
|
| 1712 |
+
x = self.ln_final(x)
|
| 1713 |
+
|
| 1714 |
+
# Extract only the text portion (skip image prefix)
|
| 1715 |
+
text_output = x[:, prefix_len:, :] # (batch, text_len, width)
|
| 1716 |
+
|
| 1717 |
+
# Project to output dimension
|
| 1718 |
+
if self.text_projection is not None:
|
| 1719 |
+
text_output = text_output @ self.text_projection
|
| 1720 |
+
|
| 1721 |
+
return text_output
|
| 1722 |
+
|
| 1723 |
+
@torch.jit.ignore
|
| 1724 |
+
def set_grad_checkpointing(self, enable=True):
|
| 1725 |
+
self.grad_checkpointing = enable
|
| 1726 |
+
|
| 1727 |
+
|
| 1728 |
+
def lock_text_tower(
|
| 1729 |
+
model: nn.Module,
|
| 1730 |
+
unlocked_layers: int = 0,
|
| 1731 |
+
):
|
| 1732 |
+
"""
|
| 1733 |
+
Lock text tower layers for CLIP models.
|
| 1734 |
+
|
| 1735 |
+
Works with both model architectures:
|
| 1736 |
+
- CustomTextCLIP where text components are in self.text
|
| 1737 |
+
- Standard CLIP where text components are unpacked as attributes
|
| 1738 |
+
|
| 1739 |
+
Args:
|
| 1740 |
+
model: The CLIP model or TextTransformer module
|
| 1741 |
+
unlocked_layers: Number of layers to leave unlocked (from the end)
|
| 1742 |
+
"""
|
| 1743 |
+
# Determine where to look for text components
|
| 1744 |
+
if hasattr(model, 'text'):
|
| 1745 |
+
# CustomTextCLIP or already a TextTransformer with nested structure
|
| 1746 |
+
text_module = model.text
|
| 1747 |
+
else:
|
| 1748 |
+
# Standard CLIP or direct TextTransformer
|
| 1749 |
+
text_module = model
|
| 1750 |
+
|
| 1751 |
+
# Collect text components
|
| 1752 |
+
text_params = {}
|
| 1753 |
+
text_params['token_embedding'] = getattr(text_module, 'token_embedding', None)
|
| 1754 |
+
text_params['positional_embedding'] = getattr(text_module, 'positional_embedding', None)
|
| 1755 |
+
text_params['cls_emb'] = getattr(text_module, 'cls_emb', None)
|
| 1756 |
+
text_params['transformer'] = getattr(text_module, 'transformer', None)
|
| 1757 |
+
text_params['ln_final'] = getattr(text_module, 'ln_final', None)
|
| 1758 |
+
text_params['text_projection'] = getattr(text_module, 'text_projection', None)
|
| 1759 |
+
|
| 1760 |
+
# Filter out None values
|
| 1761 |
+
text_params = {k: v for k, v in text_params.items() if v is not None}
|
| 1762 |
+
|
| 1763 |
+
# Freeze all text parameters first
|
| 1764 |
+
for module in text_params.values():
|
| 1765 |
+
if isinstance(module, nn.Parameter):
|
| 1766 |
+
module.requires_grad = False
|
| 1767 |
+
elif isinstance(module, nn.Module):
|
| 1768 |
+
for param in module.parameters():
|
| 1769 |
+
param.requires_grad = False
|
| 1770 |
+
|
| 1771 |
+
if unlocked_layers == 0:
|
| 1772 |
+
return
|
| 1773 |
+
|
| 1774 |
+
# Check if we have transformer blocks to work with
|
| 1775 |
+
transformer = text_params['transformer']
|
| 1776 |
+
if not transformer or not hasattr(transformer, 'resblocks'):
|
| 1777 |
+
return
|
| 1778 |
+
|
| 1779 |
+
total_layers = len(transformer.resblocks)
|
| 1780 |
+
if total_layers == 0:
|
| 1781 |
+
return
|
| 1782 |
+
|
| 1783 |
+
# Build groups for selective unlocking
|
| 1784 |
+
groups = []
|
| 1785 |
+
|
| 1786 |
+
# Group 1: Embeddings
|
| 1787 |
+
embedding_group = []
|
| 1788 |
+
for key in ['token_embedding', 'positional_embedding', 'cls_emb']:
|
| 1789 |
+
if key in text_params:
|
| 1790 |
+
embedding_group.append(text_params[key])
|
| 1791 |
+
if embedding_group:
|
| 1792 |
+
groups.append(embedding_group)
|
| 1793 |
+
|
| 1794 |
+
# Group 2-N: Individual transformer blocks (except last)
|
| 1795 |
+
if total_layers > 1:
|
| 1796 |
+
for block in transformer.resblocks[:-1]:
|
| 1797 |
+
groups.append([block])
|
| 1798 |
+
|
| 1799 |
+
# Combine last transformer block + final ln as the penultimate group
|
| 1800 |
+
last_block = [transformer.resblocks[-1]]
|
| 1801 |
+
if 'ln_final' in text_params:
|
| 1802 |
+
last_block.append(text_params['ln_final'])
|
| 1803 |
+
groups.append(last_block)
|
| 1804 |
+
|
| 1805 |
+
# The final group is the projection only
|
| 1806 |
+
if 'text_projection' in text_params:
|
| 1807 |
+
groups.append([text_params['text_projection']])
|
| 1808 |
+
|
| 1809 |
+
# Helper function to unlock parameters
|
| 1810 |
+
def _unlock(module):
|
| 1811 |
+
if isinstance(module, Sequence):
|
| 1812 |
+
for m in module:
|
| 1813 |
+
_unlock(m)
|
| 1814 |
+
elif isinstance(module, nn.Parameter):
|
| 1815 |
+
module.requires_grad = True
|
| 1816 |
+
elif isinstance(module, nn.Module):
|
| 1817 |
+
for name, param in module.named_parameters():
|
| 1818 |
+
param.requires_grad = True
|
| 1819 |
+
|
| 1820 |
+
# Unlock the specified number of layer groups from the end
|
| 1821 |
+
num_groups_to_unlock = min(unlocked_layers, len(groups))
|
| 1822 |
+
for group in groups[-num_groups_to_unlock:]:
|
| 1823 |
+
_unlock(group)
|