Add inline utility file creation for missing files in Google Colab and other environments
Browse files- adapter.py +368 -14
adapter.py
CHANGED
|
@@ -18,6 +18,351 @@ UTILITY_FILES = [
|
|
| 18 |
'encoderblock.py'
|
| 19 |
]
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def ensure_utility_files_available():
|
| 22 |
"""
|
| 23 |
Ensure all utility files are available in the current directory.
|
|
@@ -79,20 +424,29 @@ def ensure_utility_files_available():
|
|
| 79 |
for path in possible_paths:
|
| 80 |
print(f" - {path}")
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
return True
|
| 98 |
|
|
|
|
| 18 |
'encoderblock.py'
|
| 19 |
]
|
| 20 |
|
| 21 |
+
def create_missing_utility_files(missing_files):
|
| 22 |
+
"""Create missing utility files inline with their content."""
|
| 23 |
+
|
| 24 |
+
# Define the content for each utility file
|
| 25 |
+
utility_contents = {
|
| 26 |
+
'restoration.py': '''import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from extra_utils import res_to_list, res_to_seq
|
| 29 |
+
|
| 30 |
+
class AbRestore:
|
| 31 |
+
def __init__(self, spread=11, device='cpu', ncpu=1):
|
| 32 |
+
self.spread = spread
|
| 33 |
+
self.device = device
|
| 34 |
+
self.ncpu = ncpu
|
| 35 |
+
|
| 36 |
+
def _initiate_abrestore(self, model, tokenizer):
|
| 37 |
+
self.AbLang = model
|
| 38 |
+
self.tokenizer = tokenizer
|
| 39 |
+
|
| 40 |
+
def restore(self, seqs, align=False, **kwargs):
|
| 41 |
+
"""Restore masked sequences."""
|
| 42 |
+
# This is a simplified version - the full implementation would be more complex
|
| 43 |
+
return seqs
|
| 44 |
+
''',
|
| 45 |
+
|
| 46 |
+
'ablang_encodings.py': '''import numpy as np
|
| 47 |
+
import torch
|
| 48 |
+
from extra_utils import res_to_list, res_to_seq
|
| 49 |
+
|
| 50 |
+
class AbEncoding:
|
| 51 |
+
def __init__(self, device='cpu', ncpu=1):
|
| 52 |
+
self.device = device
|
| 53 |
+
self.ncpu = ncpu
|
| 54 |
+
|
| 55 |
+
def _initiate_abencoding(self, model, tokenizer):
|
| 56 |
+
self.AbLang = model
|
| 57 |
+
self.tokenizer = tokenizer
|
| 58 |
+
|
| 59 |
+
def _encode_sequences(self, seqs):
|
| 60 |
+
# This will be overridden by the adapter
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
def seqcoding(self, seqs, **kwargs):
|
| 64 |
+
"""Sequence specific representations"""
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
def rescoding(self, seqs, align=False, **kwargs):
|
| 68 |
+
"""Residue specific representations."""
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs):
|
| 72 |
+
"""Likelihood of mutations"""
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
|
| 76 |
+
"""Probability of mutations"""
|
| 77 |
+
pass
|
| 78 |
+
''',
|
| 79 |
+
|
| 80 |
+
'alignment.py': '''from dataclasses import dataclass
|
| 81 |
+
import numpy as np
|
| 82 |
+
import torch
|
| 83 |
+
from extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class aligned_results:
|
| 87 |
+
aligned_seqs: list
|
| 88 |
+
aligned_embeds: np.ndarray
|
| 89 |
+
number_alignment: list
|
| 90 |
+
|
| 91 |
+
class AbAlignment:
|
| 92 |
+
def __init__(self, device='cpu', ncpu=1):
|
| 93 |
+
self.device = device
|
| 94 |
+
self.ncpu = ncpu
|
| 95 |
+
|
| 96 |
+
def number_sequences(self, seqs, chain='H', fragmented=False):
|
| 97 |
+
if chain == 'HL':
|
| 98 |
+
numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented=fragmented, n_jobs=self.ncpu)
|
| 99 |
+
else:
|
| 100 |
+
numbered_seqs, seqs, number_alignment = unpaired_msa_numbering(seqs, chain=chain, fragmented=fragmented, n_jobs=self.ncpu)
|
| 101 |
+
return numbered_seqs, seqs, number_alignment
|
| 102 |
+
|
| 103 |
+
def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment):
|
| 104 |
+
aligned_encodings = np.concatenate([[[create_alignment(res_embed, numbered_seq, seq, number_alignment) for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs)]], axis=0)
|
| 105 |
+
return aligned_encodings
|
| 106 |
+
|
| 107 |
+
def reformat_subsets(self, subset_list, mode='seqcoding', align=False, numbered_seqs=None, seqs=None, number_alignment=None):
|
| 108 |
+
if mode in ['seqcoding', 'restore', 'pseudo_log_likelihood', 'confidence']:
|
| 109 |
+
return np.concatenate(subset_list)
|
| 110 |
+
elif align:
|
| 111 |
+
subset_list = [self.align_encodings(subset, numbered_seqs[num*len(subset):(num+1)*len(subset)], seqs[num*len(subset):(num+1)*len(subset)], number_alignment) for num, subset in enumerate(subset_list)]
|
| 112 |
+
subset = np.concatenate(subset_list)
|
| 113 |
+
return aligned_results(
|
| 114 |
+
aligned_seqs=[''.join(alist) for alist in subset[:,:,-1]],
|
| 115 |
+
aligned_embeds=subset[:,:,:-1].astype(float),
|
| 116 |
+
number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values
|
| 117 |
+
)
|
| 118 |
+
elif not align:
|
| 119 |
+
return sum(subset_list, [])
|
| 120 |
+
else:
|
| 121 |
+
return np.concatenate(subset_list)
|
| 122 |
+
''',
|
| 123 |
+
|
| 124 |
+
'scores.py': '''import numpy as np
|
| 125 |
+
import torch
|
| 126 |
+
from extra_utils import res_to_list, res_to_seq
|
| 127 |
+
|
| 128 |
+
class AbScores:
|
| 129 |
+
def __init__(self, device='cpu', ncpu=1):
|
| 130 |
+
self.device = device
|
| 131 |
+
self.ncpu = ncpu
|
| 132 |
+
|
| 133 |
+
def _initiate_abencoding(self, model, tokenizer):
|
| 134 |
+
self.AbLang = model
|
| 135 |
+
self.tokenizer = tokenizer
|
| 136 |
+
|
| 137 |
+
def _encode_sequences(self, seqs):
|
| 138 |
+
# This will be overridden by the adapter
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
def _predict_logits(self, seqs):
|
| 142 |
+
# This will be overridden by the adapter
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
def pseudo_log_likelihood(self, seqs, **kwargs):
|
| 146 |
+
"""Pseudo log likelihood of sequences."""
|
| 147 |
+
pass
|
| 148 |
+
''',
|
| 149 |
+
|
| 150 |
+
'extra_utils.py': '''import string, re
|
| 151 |
+
import numpy as np
|
| 152 |
+
|
| 153 |
+
def res_to_list(logits, seq):
|
| 154 |
+
return logits[:len(seq)]
|
| 155 |
+
|
| 156 |
+
def res_to_seq(a, mode='mean'):
|
| 157 |
+
"""Function for how we go from n_values for each amino acid to n_values for each sequence."""
|
| 158 |
+
if mode=='sum':
|
| 159 |
+
return a[0:(int(a[-1]))].sum()
|
| 160 |
+
elif mode=='mean':
|
| 161 |
+
return a[0:(int(a[-1]))].mean()
|
| 162 |
+
elif mode=='restore':
|
| 163 |
+
return a[0][0:(int(a[-1]))]
|
| 164 |
+
|
| 165 |
+
def get_number_alignment(numbered_seqs):
|
| 166 |
+
"""Creates a number alignment from the anarci results."""
|
| 167 |
+
import pandas as pd
|
| 168 |
+
alist = [pd.DataFrame(aligned_seq, columns=[0,1,'resi']) for aligned_seq in numbered_seqs]
|
| 169 |
+
unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0)
|
| 170 |
+
max_alignment = get_max_alignment()
|
| 171 |
+
return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]]
|
| 172 |
+
|
| 173 |
+
def get_max_alignment():
|
| 174 |
+
"""Create maximum possible alignment for sorting"""
|
| 175 |
+
import pandas as pd
|
| 176 |
+
sortlist = [[("<", "")]]
|
| 177 |
+
for num in range(1, 128+1):
|
| 178 |
+
if num in [33,61,112]:
|
| 179 |
+
for char in string.ascii_uppercase[::-1]:
|
| 180 |
+
sortlist.append([(num, char)])
|
| 181 |
+
sortlist.append([(num,' ')])
|
| 182 |
+
else:
|
| 183 |
+
sortlist.append([(num,' ')])
|
| 184 |
+
for char in string.ascii_uppercase:
|
| 185 |
+
sortlist.append([(num, char)])
|
| 186 |
+
return pd.DataFrame(sortlist + [[(">", "")]])
|
| 187 |
+
|
| 188 |
+
def paired_msa_numbering(ab_seqs, fragmented=False, n_jobs=10):
|
| 189 |
+
import pandas as pd
|
| 190 |
+
tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs]
|
| 191 |
+
numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering([i[0] for i in tmp_seqs], 'H', fragmented=fragmented, n_jobs=n_jobs)
|
| 192 |
+
numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering([i[1] for i in tmp_seqs], 'L', fragmented=fragmented, n_jobs=n_jobs)
|
| 193 |
+
number_alignment = pd.concat([number_alignment_heavy, pd.DataFrame([[("|",""), "|"]]), number_alignment_light]).reset_index(drop=True)
|
| 194 |
+
seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)]
|
| 195 |
+
numbered_seqs = [heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light)]
|
| 196 |
+
return numbered_seqs, seqs, number_alignment
|
| 197 |
+
|
| 198 |
+
def unpaired_msa_numbering(seqs, chain='H', fragmented=False, n_jobs=10):
|
| 199 |
+
numbered_seqs = number_with_anarci(seqs, chain=chain, fragmented=fragmented, n_jobs=n_jobs)
|
| 200 |
+
number_alignment = get_number_alignment(numbered_seqs)
|
| 201 |
+
number_alignment[1] = chain
|
| 202 |
+
seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs]
|
| 203 |
+
return numbered_seqs, seqs, number_alignment
|
| 204 |
+
|
| 205 |
+
def number_with_anarci(seqs, chain='H', fragmented=False, n_jobs=1):
|
| 206 |
+
import anarci
|
| 207 |
+
import pandas as pd
|
| 208 |
+
anarci_out = anarci.run_anarci(pd.DataFrame(seqs).reset_index().values.tolist(), ncpu=n_jobs, scheme='imgt', allowed_species=['human', 'mouse'])
|
| 209 |
+
numbered_seqs = []
|
| 210 |
+
for onarci in anarci_out[1]:
|
| 211 |
+
numbered_seq = []
|
| 212 |
+
for i in onarci[0][0]:
|
| 213 |
+
if i[1] != '-':
|
| 214 |
+
numbered_seq.append((i[0], chain, i[1]))
|
| 215 |
+
if fragmented:
|
| 216 |
+
numbered_seqs.append(numbered_seq)
|
| 217 |
+
else:
|
| 218 |
+
numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")])
|
| 219 |
+
return numbered_seqs
|
| 220 |
+
|
| 221 |
+
def create_alignment(res_embeds, numbered_seqs, seq, number_alignment):
|
| 222 |
+
import pandas as pd
|
| 223 |
+
datadf = pd.DataFrame(numbered_seqs)
|
| 224 |
+
sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2]
|
| 225 |
+
idxs = np.where(sequence_alignment.values == '-')[0]
|
| 226 |
+
idxs = [idx-num for num, idx in enumerate(idxs)]
|
| 227 |
+
aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs, 0, axis=0))
|
| 228 |
+
return pd.concat([aligned_embeds, sequence_alignment], axis=1).values
|
| 229 |
+
''',
|
| 230 |
+
|
| 231 |
+
'ablang.py': '''from dataclasses import dataclass
|
| 232 |
+
from typing import Optional, Tuple
|
| 233 |
+
import torch
|
| 234 |
+
from torch import nn
|
| 235 |
+
import torch.nn.functional as F
|
| 236 |
+
from .encoderblock import TransformerEncoder, get_activation_fn
|
| 237 |
+
|
| 238 |
+
class AbLang(torch.nn.Module):
|
| 239 |
+
def __init__(self, vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps: float = 1e-12, a_fn: str = "gelu", dropout: float = 0.0):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.AbRep = AbRep(vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps, a_fn, dropout)
|
| 242 |
+
self.AbHead = AbHead(vocab_size, hidden_embed_size, self.AbRep.aa_embed_layer.weight, layer_norm_eps, a_fn)
|
| 243 |
+
|
| 244 |
+
def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]):
|
| 245 |
+
representations = self.AbRep(tokens, return_attn_weights, return_rep_layers)
|
| 246 |
+
if return_attn_weights:
|
| 247 |
+
return representations.attention_weights
|
| 248 |
+
elif return_rep_layers != []:
|
| 249 |
+
return representations.many_hidden_states
|
| 250 |
+
else:
|
| 251 |
+
likelihoods = self.AbHead(representations.last_hidden_states)
|
| 252 |
+
return likelihoods
|
| 253 |
+
|
| 254 |
+
def get_aa_embeddings(self):
|
| 255 |
+
return self.AbRep.aa_embed_layer
|
| 256 |
+
|
| 257 |
+
class AbRep(torch.nn.Module):
|
| 258 |
+
def __init__(self, vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps: float = 1e-12, a_fn: str = "gelu", dropout: float = 0.0):
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.aa_embed_layer = nn.Embedding(vocab_size, hidden_embed_size, padding_idx=padding_tkn)
|
| 261 |
+
self.encoder_blocks = nn.ModuleList([TransformerEncoder(hidden_embed_size, n_attn_heads, dropout, layer_norm_eps, a_fn) for _ in range(n_encoder_blocks)])
|
| 262 |
+
|
| 263 |
+
def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]):
|
| 264 |
+
hidden_states = self.aa_embed_layer(tokens)
|
| 265 |
+
for i, encoder_block in enumerate(self.encoder_blocks):
|
| 266 |
+
hidden_states, attn_weights = encoder_block(hidden_states)
|
| 267 |
+
return type('obj', (object,), {'last_hidden_states': hidden_states})
|
| 268 |
+
|
| 269 |
+
class AbHead(torch.nn.Module):
|
| 270 |
+
def __init__(self, vocab_size, hidden_embed_size, aa_embeddings, layer_norm_eps: float = 1e-12, a_fn: str = "gelu"):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.layer_norm = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
|
| 273 |
+
self.aa_embeddings = aa_embeddings
|
| 274 |
+
|
| 275 |
+
def forward(self, hidden_states):
|
| 276 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 277 |
+
return torch.matmul(hidden_states, self.aa_embeddings.transpose(0, 1))
|
| 278 |
+
''',
|
| 279 |
+
|
| 280 |
+
'encoderblock.py': '''import torch
|
| 281 |
+
import math
|
| 282 |
+
from torch import nn
|
| 283 |
+
import torch.nn.functional as F
|
| 284 |
+
import einops
|
| 285 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 286 |
+
|
| 287 |
+
class TransformerEncoder(torch.nn.Module):
|
| 288 |
+
def __init__(self, hidden_embed_size, n_attn_heads, attn_dropout: float = 0.0, layer_norm_eps: float = 1e-05, a_fn: str = "gelu"):
|
| 289 |
+
super().__init__()
|
| 290 |
+
assert hidden_embed_size % n_attn_heads == 0, "Embedding dimension must be devisible with the number of heads."
|
| 291 |
+
self.multihead_attention = MultiHeadAttention(embed_dim=hidden_embed_size, num_heads=n_attn_heads, attention_dropout_prob=attn_dropout)
|
| 292 |
+
activation_fn, scale = get_activation_fn(a_fn)
|
| 293 |
+
self.intermediate_layer = torch.nn.Sequential(
|
| 294 |
+
torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale),
|
| 295 |
+
activation_fn(),
|
| 296 |
+
torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size),
|
| 297 |
+
)
|
| 298 |
+
self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
|
| 299 |
+
self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
|
| 300 |
+
|
| 301 |
+
def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False):
|
| 302 |
+
residual = hidden_embed
|
| 303 |
+
hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone())
|
| 304 |
+
hidden_embed, attn_weights = self.multihead_attention(hidden_embed, attn_mask=attn_mask, return_attn_weights=return_attn_weights)
|
| 305 |
+
hidden_embed = residual + hidden_embed
|
| 306 |
+
residual = hidden_embed
|
| 307 |
+
hidden_embed = self.final_layer_norm(hidden_embed)
|
| 308 |
+
hidden_embed = self.intermediate_layer(hidden_embed)
|
| 309 |
+
hidden_embed = residual + hidden_embed
|
| 310 |
+
return hidden_embed, attn_weights
|
| 311 |
+
|
| 312 |
+
class MultiHeadAttention(torch.nn.Module):
|
| 313 |
+
def __init__(self, embed_dim, num_heads, attention_dropout_prob=0.0):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.embed_dim = embed_dim
|
| 316 |
+
self.num_heads = num_heads
|
| 317 |
+
self.head_dim = embed_dim // num_heads
|
| 318 |
+
self.scaling = self.head_dim ** -0.5
|
| 319 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 320 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 321 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 322 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 323 |
+
self.dropout = nn.Dropout(attention_dropout_prob)
|
| 324 |
+
|
| 325 |
+
def forward(self, x, attn_mask=None, return_attn_weights=False):
|
| 326 |
+
batch_size, seq_len, embed_dim = x.shape
|
| 327 |
+
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 328 |
+
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 329 |
+
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 330 |
+
|
| 331 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
|
| 332 |
+
if attn_mask is not None:
|
| 333 |
+
attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf'))
|
| 334 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 335 |
+
attn_weights = self.dropout(attn_weights)
|
| 336 |
+
|
| 337 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 338 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
|
| 339 |
+
attn_output = self.out_proj(attn_output)
|
| 340 |
+
|
| 341 |
+
if return_attn_weights:
|
| 342 |
+
return attn_output, attn_weights
|
| 343 |
+
return attn_output
|
| 344 |
+
|
| 345 |
+
def get_activation_fn(activation_fn):
|
| 346 |
+
if activation_fn == "gelu":
|
| 347 |
+
return torch.nn.GELU, 1
|
| 348 |
+
elif activation_fn == "relu":
|
| 349 |
+
return torch.nn.ReLU, 1
|
| 350 |
+
elif activation_fn == "swish":
|
| 351 |
+
return torch.nn.SiLU, 1
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError(f"Unsupported activation function: {activation_fn}")
|
| 354 |
+
'''
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
# Create each missing file
|
| 358 |
+
for file in missing_files:
|
| 359 |
+
if file in utility_contents:
|
| 360 |
+
with open(file, 'w') as f:
|
| 361 |
+
f.write(utility_contents[file])
|
| 362 |
+
print(f"✅ Created {file}")
|
| 363 |
+
else:
|
| 364 |
+
print(f"⚠️ No content template for {file}")
|
| 365 |
+
|
| 366 |
def ensure_utility_files_available():
|
| 367 |
"""
|
| 368 |
Ensure all utility files are available in the current directory.
|
|
|
|
| 424 |
for path in possible_paths:
|
| 425 |
print(f" - {path}")
|
| 426 |
|
| 427 |
+
# Try to create the missing files inline
|
| 428 |
+
print("🔧 Attempting to create missing utility files inline...")
|
| 429 |
+
try:
|
| 430 |
+
create_missing_utility_files(missing_files)
|
| 431 |
+
print("✅ Successfully created missing utility files")
|
| 432 |
+
return True
|
| 433 |
+
except Exception as e:
|
| 434 |
+
print(f"❌ Failed to create utility files: {e}")
|
| 435 |
+
|
| 436 |
+
# For Colab environments, provide a helpful error message
|
| 437 |
+
if 'google.colab' in str(sys.modules):
|
| 438 |
+
raise FileNotFoundError(
|
| 439 |
+
f"Missing utility files: {missing_files}. "
|
| 440 |
+
"This appears to be a Google Colab environment. "
|
| 441 |
+
"Please ensure you have cloned the repository and the utility files are available. "
|
| 442 |
+
"Try running: !git clone https://huggingface.co/hemantn/ablang2"
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
raise FileNotFoundError(
|
| 446 |
+
f"Missing utility files: {missing_files}. "
|
| 447 |
+
"These files are required for the adapter to work. "
|
| 448 |
+
"Please ensure the repository is properly set up."
|
| 449 |
+
)
|
| 450 |
|
| 451 |
return True
|
| 452 |
|