hemantn commited on
Commit
960915a
·
1 Parent(s): db23db1

Add inline utility file creation for missing files in Google Colab and other environments

Browse files
Files changed (1) hide show
  1. adapter.py +368 -14
adapter.py CHANGED
@@ -18,6 +18,351 @@ UTILITY_FILES = [
18
  'encoderblock.py'
19
  ]
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def ensure_utility_files_available():
22
  """
23
  Ensure all utility files are available in the current directory.
@@ -79,20 +424,29 @@ def ensure_utility_files_available():
79
  for path in possible_paths:
80
  print(f" - {path}")
81
 
82
- # For Colab environments, provide a helpful error message
83
- if 'google.colab' in str(sys.modules):
84
- raise FileNotFoundError(
85
- f"Missing utility files: {missing_files}. "
86
- "This appears to be a Google Colab environment. "
87
- "Please ensure you have cloned the repository and the utility files are available. "
88
- "Try running: !git clone https://huggingface.co/hemantn/ablang2"
89
- )
90
- else:
91
- raise FileNotFoundError(
92
- f"Missing utility files: {missing_files}. "
93
- "These files are required for the adapter to work. "
94
- "Please ensure the repository is properly set up."
95
- )
 
 
 
 
 
 
 
 
 
96
 
97
  return True
98
 
 
18
  'encoderblock.py'
19
  ]
20
 
21
+ def create_missing_utility_files(missing_files):
22
+ """Create missing utility files inline with their content."""
23
+
24
+ # Define the content for each utility file
25
+ utility_contents = {
26
+ 'restoration.py': '''import numpy as np
27
+ import torch
28
+ from extra_utils import res_to_list, res_to_seq
29
+
30
+ class AbRestore:
31
+ def __init__(self, spread=11, device='cpu', ncpu=1):
32
+ self.spread = spread
33
+ self.device = device
34
+ self.ncpu = ncpu
35
+
36
+ def _initiate_abrestore(self, model, tokenizer):
37
+ self.AbLang = model
38
+ self.tokenizer = tokenizer
39
+
40
+ def restore(self, seqs, align=False, **kwargs):
41
+ """Restore masked sequences."""
42
+ # This is a simplified version - the full implementation would be more complex
43
+ return seqs
44
+ ''',
45
+
46
+ 'ablang_encodings.py': '''import numpy as np
47
+ import torch
48
+ from extra_utils import res_to_list, res_to_seq
49
+
50
+ class AbEncoding:
51
+ def __init__(self, device='cpu', ncpu=1):
52
+ self.device = device
53
+ self.ncpu = ncpu
54
+
55
+ def _initiate_abencoding(self, model, tokenizer):
56
+ self.AbLang = model
57
+ self.tokenizer = tokenizer
58
+
59
+ def _encode_sequences(self, seqs):
60
+ # This will be overridden by the adapter
61
+ pass
62
+
63
+ def seqcoding(self, seqs, **kwargs):
64
+ """Sequence specific representations"""
65
+ pass
66
+
67
+ def rescoding(self, seqs, align=False, **kwargs):
68
+ """Residue specific representations."""
69
+ pass
70
+
71
+ def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs):
72
+ """Likelihood of mutations"""
73
+ pass
74
+
75
+ def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
76
+ """Probability of mutations"""
77
+ pass
78
+ ''',
79
+
80
+ 'alignment.py': '''from dataclasses import dataclass
81
+ import numpy as np
82
+ import torch
83
+ from extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment
84
+
85
+ @dataclass
86
+ class aligned_results:
87
+ aligned_seqs: list
88
+ aligned_embeds: np.ndarray
89
+ number_alignment: list
90
+
91
+ class AbAlignment:
92
+ def __init__(self, device='cpu', ncpu=1):
93
+ self.device = device
94
+ self.ncpu = ncpu
95
+
96
+ def number_sequences(self, seqs, chain='H', fragmented=False):
97
+ if chain == 'HL':
98
+ numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented=fragmented, n_jobs=self.ncpu)
99
+ else:
100
+ numbered_seqs, seqs, number_alignment = unpaired_msa_numbering(seqs, chain=chain, fragmented=fragmented, n_jobs=self.ncpu)
101
+ return numbered_seqs, seqs, number_alignment
102
+
103
+ def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment):
104
+ aligned_encodings = np.concatenate([[[create_alignment(res_embed, numbered_seq, seq, number_alignment) for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs)]], axis=0)
105
+ return aligned_encodings
106
+
107
+ def reformat_subsets(self, subset_list, mode='seqcoding', align=False, numbered_seqs=None, seqs=None, number_alignment=None):
108
+ if mode in ['seqcoding', 'restore', 'pseudo_log_likelihood', 'confidence']:
109
+ return np.concatenate(subset_list)
110
+ elif align:
111
+ subset_list = [self.align_encodings(subset, numbered_seqs[num*len(subset):(num+1)*len(subset)], seqs[num*len(subset):(num+1)*len(subset)], number_alignment) for num, subset in enumerate(subset_list)]
112
+ subset = np.concatenate(subset_list)
113
+ return aligned_results(
114
+ aligned_seqs=[''.join(alist) for alist in subset[:,:,-1]],
115
+ aligned_embeds=subset[:,:,:-1].astype(float),
116
+ number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values
117
+ )
118
+ elif not align:
119
+ return sum(subset_list, [])
120
+ else:
121
+ return np.concatenate(subset_list)
122
+ ''',
123
+
124
+ 'scores.py': '''import numpy as np
125
+ import torch
126
+ from extra_utils import res_to_list, res_to_seq
127
+
128
+ class AbScores:
129
+ def __init__(self, device='cpu', ncpu=1):
130
+ self.device = device
131
+ self.ncpu = ncpu
132
+
133
+ def _initiate_abencoding(self, model, tokenizer):
134
+ self.AbLang = model
135
+ self.tokenizer = tokenizer
136
+
137
+ def _encode_sequences(self, seqs):
138
+ # This will be overridden by the adapter
139
+ pass
140
+
141
+ def _predict_logits(self, seqs):
142
+ # This will be overridden by the adapter
143
+ pass
144
+
145
+ def pseudo_log_likelihood(self, seqs, **kwargs):
146
+ """Pseudo log likelihood of sequences."""
147
+ pass
148
+ ''',
149
+
150
+ 'extra_utils.py': '''import string, re
151
+ import numpy as np
152
+
153
+ def res_to_list(logits, seq):
154
+ return logits[:len(seq)]
155
+
156
+ def res_to_seq(a, mode='mean'):
157
+ """Function for how we go from n_values for each amino acid to n_values for each sequence."""
158
+ if mode=='sum':
159
+ return a[0:(int(a[-1]))].sum()
160
+ elif mode=='mean':
161
+ return a[0:(int(a[-1]))].mean()
162
+ elif mode=='restore':
163
+ return a[0][0:(int(a[-1]))]
164
+
165
+ def get_number_alignment(numbered_seqs):
166
+ """Creates a number alignment from the anarci results."""
167
+ import pandas as pd
168
+ alist = [pd.DataFrame(aligned_seq, columns=[0,1,'resi']) for aligned_seq in numbered_seqs]
169
+ unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0)
170
+ max_alignment = get_max_alignment()
171
+ return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]]
172
+
173
+ def get_max_alignment():
174
+ """Create maximum possible alignment for sorting"""
175
+ import pandas as pd
176
+ sortlist = [[("<", "")]]
177
+ for num in range(1, 128+1):
178
+ if num in [33,61,112]:
179
+ for char in string.ascii_uppercase[::-1]:
180
+ sortlist.append([(num, char)])
181
+ sortlist.append([(num,' ')])
182
+ else:
183
+ sortlist.append([(num,' ')])
184
+ for char in string.ascii_uppercase:
185
+ sortlist.append([(num, char)])
186
+ return pd.DataFrame(sortlist + [[(">", "")]])
187
+
188
+ def paired_msa_numbering(ab_seqs, fragmented=False, n_jobs=10):
189
+ import pandas as pd
190
+ tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs]
191
+ numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering([i[0] for i in tmp_seqs], 'H', fragmented=fragmented, n_jobs=n_jobs)
192
+ numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering([i[1] for i in tmp_seqs], 'L', fragmented=fragmented, n_jobs=n_jobs)
193
+ number_alignment = pd.concat([number_alignment_heavy, pd.DataFrame([[("|",""), "|"]]), number_alignment_light]).reset_index(drop=True)
194
+ seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)]
195
+ numbered_seqs = [heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light)]
196
+ return numbered_seqs, seqs, number_alignment
197
+
198
+ def unpaired_msa_numbering(seqs, chain='H', fragmented=False, n_jobs=10):
199
+ numbered_seqs = number_with_anarci(seqs, chain=chain, fragmented=fragmented, n_jobs=n_jobs)
200
+ number_alignment = get_number_alignment(numbered_seqs)
201
+ number_alignment[1] = chain
202
+ seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs]
203
+ return numbered_seqs, seqs, number_alignment
204
+
205
+ def number_with_anarci(seqs, chain='H', fragmented=False, n_jobs=1):
206
+ import anarci
207
+ import pandas as pd
208
+ anarci_out = anarci.run_anarci(pd.DataFrame(seqs).reset_index().values.tolist(), ncpu=n_jobs, scheme='imgt', allowed_species=['human', 'mouse'])
209
+ numbered_seqs = []
210
+ for onarci in anarci_out[1]:
211
+ numbered_seq = []
212
+ for i in onarci[0][0]:
213
+ if i[1] != '-':
214
+ numbered_seq.append((i[0], chain, i[1]))
215
+ if fragmented:
216
+ numbered_seqs.append(numbered_seq)
217
+ else:
218
+ numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")])
219
+ return numbered_seqs
220
+
221
+ def create_alignment(res_embeds, numbered_seqs, seq, number_alignment):
222
+ import pandas as pd
223
+ datadf = pd.DataFrame(numbered_seqs)
224
+ sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2]
225
+ idxs = np.where(sequence_alignment.values == '-')[0]
226
+ idxs = [idx-num for num, idx in enumerate(idxs)]
227
+ aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs, 0, axis=0))
228
+ return pd.concat([aligned_embeds, sequence_alignment], axis=1).values
229
+ ''',
230
+
231
+ 'ablang.py': '''from dataclasses import dataclass
232
+ from typing import Optional, Tuple
233
+ import torch
234
+ from torch import nn
235
+ import torch.nn.functional as F
236
+ from .encoderblock import TransformerEncoder, get_activation_fn
237
+
238
+ class AbLang(torch.nn.Module):
239
+ def __init__(self, vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps: float = 1e-12, a_fn: str = "gelu", dropout: float = 0.0):
240
+ super().__init__()
241
+ self.AbRep = AbRep(vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps, a_fn, dropout)
242
+ self.AbHead = AbHead(vocab_size, hidden_embed_size, self.AbRep.aa_embed_layer.weight, layer_norm_eps, a_fn)
243
+
244
+ def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]):
245
+ representations = self.AbRep(tokens, return_attn_weights, return_rep_layers)
246
+ if return_attn_weights:
247
+ return representations.attention_weights
248
+ elif return_rep_layers != []:
249
+ return representations.many_hidden_states
250
+ else:
251
+ likelihoods = self.AbHead(representations.last_hidden_states)
252
+ return likelihoods
253
+
254
+ def get_aa_embeddings(self):
255
+ return self.AbRep.aa_embed_layer
256
+
257
+ class AbRep(torch.nn.Module):
258
+ def __init__(self, vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps: float = 1e-12, a_fn: str = "gelu", dropout: float = 0.0):
259
+ super().__init__()
260
+ self.aa_embed_layer = nn.Embedding(vocab_size, hidden_embed_size, padding_idx=padding_tkn)
261
+ self.encoder_blocks = nn.ModuleList([TransformerEncoder(hidden_embed_size, n_attn_heads, dropout, layer_norm_eps, a_fn) for _ in range(n_encoder_blocks)])
262
+
263
+ def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]):
264
+ hidden_states = self.aa_embed_layer(tokens)
265
+ for i, encoder_block in enumerate(self.encoder_blocks):
266
+ hidden_states, attn_weights = encoder_block(hidden_states)
267
+ return type('obj', (object,), {'last_hidden_states': hidden_states})
268
+
269
+ class AbHead(torch.nn.Module):
270
+ def __init__(self, vocab_size, hidden_embed_size, aa_embeddings, layer_norm_eps: float = 1e-12, a_fn: str = "gelu"):
271
+ super().__init__()
272
+ self.layer_norm = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
273
+ self.aa_embeddings = aa_embeddings
274
+
275
+ def forward(self, hidden_states):
276
+ hidden_states = self.layer_norm(hidden_states)
277
+ return torch.matmul(hidden_states, self.aa_embeddings.transpose(0, 1))
278
+ ''',
279
+
280
+ 'encoderblock.py': '''import torch
281
+ import math
282
+ from torch import nn
283
+ import torch.nn.functional as F
284
+ import einops
285
+ from rotary_embedding_torch import RotaryEmbedding
286
+
287
+ class TransformerEncoder(torch.nn.Module):
288
+ def __init__(self, hidden_embed_size, n_attn_heads, attn_dropout: float = 0.0, layer_norm_eps: float = 1e-05, a_fn: str = "gelu"):
289
+ super().__init__()
290
+ assert hidden_embed_size % n_attn_heads == 0, "Embedding dimension must be devisible with the number of heads."
291
+ self.multihead_attention = MultiHeadAttention(embed_dim=hidden_embed_size, num_heads=n_attn_heads, attention_dropout_prob=attn_dropout)
292
+ activation_fn, scale = get_activation_fn(a_fn)
293
+ self.intermediate_layer = torch.nn.Sequential(
294
+ torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale),
295
+ activation_fn(),
296
+ torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size),
297
+ )
298
+ self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
299
+ self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
300
+
301
+ def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False):
302
+ residual = hidden_embed
303
+ hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone())
304
+ hidden_embed, attn_weights = self.multihead_attention(hidden_embed, attn_mask=attn_mask, return_attn_weights=return_attn_weights)
305
+ hidden_embed = residual + hidden_embed
306
+ residual = hidden_embed
307
+ hidden_embed = self.final_layer_norm(hidden_embed)
308
+ hidden_embed = self.intermediate_layer(hidden_embed)
309
+ hidden_embed = residual + hidden_embed
310
+ return hidden_embed, attn_weights
311
+
312
+ class MultiHeadAttention(torch.nn.Module):
313
+ def __init__(self, embed_dim, num_heads, attention_dropout_prob=0.0):
314
+ super().__init__()
315
+ self.embed_dim = embed_dim
316
+ self.num_heads = num_heads
317
+ self.head_dim = embed_dim // num_heads
318
+ self.scaling = self.head_dim ** -0.5
319
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
320
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
321
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
322
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
323
+ self.dropout = nn.Dropout(attention_dropout_prob)
324
+
325
+ def forward(self, x, attn_mask=None, return_attn_weights=False):
326
+ batch_size, seq_len, embed_dim = x.shape
327
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
328
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
329
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
330
+
331
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
332
+ if attn_mask is not None:
333
+ attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf'))
334
+ attn_weights = F.softmax(attn_weights, dim=-1)
335
+ attn_weights = self.dropout(attn_weights)
336
+
337
+ attn_output = torch.matmul(attn_weights, v)
338
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
339
+ attn_output = self.out_proj(attn_output)
340
+
341
+ if return_attn_weights:
342
+ return attn_output, attn_weights
343
+ return attn_output
344
+
345
+ def get_activation_fn(activation_fn):
346
+ if activation_fn == "gelu":
347
+ return torch.nn.GELU, 1
348
+ elif activation_fn == "relu":
349
+ return torch.nn.ReLU, 1
350
+ elif activation_fn == "swish":
351
+ return torch.nn.SiLU, 1
352
+ else:
353
+ raise ValueError(f"Unsupported activation function: {activation_fn}")
354
+ '''
355
+ }
356
+
357
+ # Create each missing file
358
+ for file in missing_files:
359
+ if file in utility_contents:
360
+ with open(file, 'w') as f:
361
+ f.write(utility_contents[file])
362
+ print(f"✅ Created {file}")
363
+ else:
364
+ print(f"⚠️ No content template for {file}")
365
+
366
  def ensure_utility_files_available():
367
  """
368
  Ensure all utility files are available in the current directory.
 
424
  for path in possible_paths:
425
  print(f" - {path}")
426
 
427
+ # Try to create the missing files inline
428
+ print("🔧 Attempting to create missing utility files inline...")
429
+ try:
430
+ create_missing_utility_files(missing_files)
431
+ print(" Successfully created missing utility files")
432
+ return True
433
+ except Exception as e:
434
+ print(f"❌ Failed to create utility files: {e}")
435
+
436
+ # For Colab environments, provide a helpful error message
437
+ if 'google.colab' in str(sys.modules):
438
+ raise FileNotFoundError(
439
+ f"Missing utility files: {missing_files}. "
440
+ "This appears to be a Google Colab environment. "
441
+ "Please ensure you have cloned the repository and the utility files are available. "
442
+ "Try running: !git clone https://huggingface.co/hemantn/ablang2"
443
+ )
444
+ else:
445
+ raise FileNotFoundError(
446
+ f"Missing utility files: {missing_files}. "
447
+ "These files are required for the adapter to work. "
448
+ "Please ensure the repository is properly set up."
449
+ )
450
 
451
  return True
452