hemantn commited on
Commit
7f948f6
Β·
1 Parent(s): 1f81e2b

Fix adapter to copy utility files to cached directory when loaded from Hugging Face

Browse files
Files changed (2) hide show
  1. adapter.py +74 -12
  2. test_adapter_fix.py +65 -0
adapter.py CHANGED
@@ -2,12 +2,74 @@ import os
2
  import sys
3
  import torch
4
  import numpy as np
 
5
 
6
  # Get the directory where this adapter.py file is located
7
  current_dir = os.path.dirname(os.path.abspath(__file__))
8
  if current_dir not in sys.path:
9
  sys.path.insert(0, current_dir)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Import utility modules
12
  from restoration import AbRestore
13
  from ablang_encodings import AbEncoding
@@ -131,7 +193,7 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
131
  if fragmented:
132
  # For fragmented sequences, assume they're already in the right format
133
  return seqs, 'HL'
134
-
135
  # For paired sequences, format them as VH|VL
136
  formatted_seqs = []
137
  for seq in seqs:
@@ -151,7 +213,7 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
151
  formatted_seqs.append(seq[0] if seq else "")
152
  else:
153
  formatted_seqs.append(seq)
154
-
155
  return formatted_seqs, 'HL'
156
 
157
  valid_modes = [
@@ -245,34 +307,34 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
245
  formatted_seqs.append('|'.join(s))
246
  else:
247
  formatted_seqs.append(s)
248
-
249
  plls = []
250
  for seq in formatted_seqs:
251
  tokens = self.tokenizer([seq], padding=True, return_tensors='pt')
252
  input_ids = extract_input_ids(tokens, self.used_device)
253
-
254
  with torch.no_grad():
255
  output = self.AbLang(input_ids)
256
  if hasattr(output, 'last_hidden_state'):
257
  logits = output.last_hidden_state
258
  else:
259
  logits = output
260
-
261
  # Get the sequence (remove batch dimension)
262
  logits = logits[0] # [seq_len, vocab_size]
263
  input_ids = input_ids[0] # [seq_len]
264
-
265
  # Exclude all special tokens (pad, mask, etc.)
266
  if isinstance(self.tokenizer.all_special_tokens[0], int):
267
  special_token_ids = set(self.tokenizer.all_special_tokens)
268
  else:
269
  special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
270
  valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device))
271
-
272
  if valid_mask.sum() > 0:
273
  valid_logits = logits[valid_mask]
274
  valid_labels = input_ids[valid_mask]
275
-
276
  # Calculate cross-entropy loss
277
  nll = torch.nn.functional.cross_entropy(
278
  valid_logits,
@@ -282,9 +344,9 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
282
  pll = -nll.item()
283
  else:
284
  pll = 0.0
285
-
286
  plls.append(pll)
287
-
288
  return np.array(plls, dtype=np.float32)
289
 
290
  def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
@@ -306,10 +368,10 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
306
  logits = self._predict_logits(formatted_seqs)
307
  else:
308
  logits = self._predict_logits(formatted_seqs)
309
-
310
  # Apply softmax to get probabilities
311
  probs = logits.softmax(-1).cpu().numpy()
312
-
313
  if align:
314
  return probs
315
  else:
 
2
  import sys
3
  import torch
4
  import numpy as np
5
+ import shutil
6
 
7
  # Get the directory where this adapter.py file is located
8
  current_dir = os.path.dirname(os.path.abspath(__file__))
9
  if current_dir not in sys.path:
10
  sys.path.insert(0, current_dir)
11
 
12
+ # List of utility files that need to be available
13
+ UTILITY_FILES = [
14
+ 'restoration.py',
15
+ 'ablang_encodings.py',
16
+ 'alignment.py',
17
+ 'scores.py',
18
+ 'extra_utils.py',
19
+ 'ablang.py',
20
+ 'encoderblock.py'
21
+ ]
22
+
23
+ def ensure_utility_files_available():
24
+ """
25
+ Ensure all utility files are available in the current directory.
26
+ If any are missing, try to copy them from the repository root.
27
+ """
28
+ missing_files = []
29
+ for file in UTILITY_FILES:
30
+ if not os.path.exists(file):
31
+ missing_files.append(file)
32
+
33
+ if missing_files:
34
+ # Try to find the repository root (where all utility files are)
35
+ # Look for common parent directories that might contain the files
36
+ possible_paths = [
37
+ os.path.join(current_dir, '..'), # Parent directory
38
+ os.path.join(current_dir, '..', '..'), # Grandparent directory
39
+ os.path.join(os.path.expanduser('~'), 'ablang2'), # Home directory
40
+ '/data/hn533621/ablang2', # Known repository location
41
+ ]
42
+
43
+ for path in possible_paths:
44
+ if os.path.exists(path):
45
+ # Check if all missing files exist in this path
46
+ all_found = True
47
+ for file in missing_files:
48
+ if not os.path.exists(os.path.join(path, file)):
49
+ all_found = False
50
+ break
51
+
52
+ if all_found:
53
+ # Copy all missing files
54
+ for file in missing_files:
55
+ src = os.path.join(path, file)
56
+ dst = os.path.join(current_dir, file)
57
+ shutil.copy2(src, dst)
58
+ print(f"βœ… Copied {file} to cached directory")
59
+ return True
60
+
61
+ # If we get here, we couldn't find the files
62
+ raise FileNotFoundError(
63
+ f"Missing utility files: {missing_files}. "
64
+ "These files are required for the adapter to work. "
65
+ "Please ensure the repository is properly set up."
66
+ )
67
+
68
+ return True
69
+
70
+ # Ensure utility files are available before importing
71
+ ensure_utility_files_available()
72
+
73
  # Import utility modules
74
  from restoration import AbRestore
75
  from ablang_encodings import AbEncoding
 
193
  if fragmented:
194
  # For fragmented sequences, assume they're already in the right format
195
  return seqs, 'HL'
196
+
197
  # For paired sequences, format them as VH|VL
198
  formatted_seqs = []
199
  for seq in seqs:
 
213
  formatted_seqs.append(seq[0] if seq else "")
214
  else:
215
  formatted_seqs.append(seq)
216
+
217
  return formatted_seqs, 'HL'
218
 
219
  valid_modes = [
 
307
  formatted_seqs.append('|'.join(s))
308
  else:
309
  formatted_seqs.append(s)
310
+
311
  plls = []
312
  for seq in formatted_seqs:
313
  tokens = self.tokenizer([seq], padding=True, return_tensors='pt')
314
  input_ids = extract_input_ids(tokens, self.used_device)
315
+
316
  with torch.no_grad():
317
  output = self.AbLang(input_ids)
318
  if hasattr(output, 'last_hidden_state'):
319
  logits = output.last_hidden_state
320
  else:
321
  logits = output
322
+
323
  # Get the sequence (remove batch dimension)
324
  logits = logits[0] # [seq_len, vocab_size]
325
  input_ids = input_ids[0] # [seq_len]
326
+
327
  # Exclude all special tokens (pad, mask, etc.)
328
  if isinstance(self.tokenizer.all_special_tokens[0], int):
329
  special_token_ids = set(self.tokenizer.all_special_tokens)
330
  else:
331
  special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
332
  valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device))
333
+
334
  if valid_mask.sum() > 0:
335
  valid_logits = logits[valid_mask]
336
  valid_labels = input_ids[valid_mask]
337
+
338
  # Calculate cross-entropy loss
339
  nll = torch.nn.functional.cross_entropy(
340
  valid_logits,
 
344
  pll = -nll.item()
345
  else:
346
  pll = 0.0
347
+
348
  plls.append(pll)
349
+
350
  return np.array(plls, dtype=np.float32)
351
 
352
  def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
 
368
  logits = self._predict_logits(formatted_seqs)
369
  else:
370
  logits = self._predict_logits(formatted_seqs)
371
+
372
  # Apply softmax to get probabilities
373
  probs = logits.softmax(-1).cpu().numpy()
374
+
375
  if align:
376
  return probs
377
  else:
test_adapter_fix.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+ from transformers import AutoModel, AutoTokenizer
6
+ from transformers.utils import cached_file
7
+
8
+ def test_adapter_from_outside():
9
+ """Test loading the adapter from outside the repository"""
10
+ print("πŸ§ͺ Testing adapter loading from outside repository...")
11
+
12
+ # Clear cache first
13
+ cache_dir = os.path.expanduser("~/.cache/huggingface/hub/models--hemantn--ablang2")
14
+ if os.path.exists(cache_dir):
15
+ import shutil
16
+ shutil.rmtree(cache_dir)
17
+ print("πŸ—‘οΈ Cleared Hugging Face cache")
18
+
19
+ try:
20
+ # Load model and tokenizer
21
+ print("πŸ“₯ Loading model and tokenizer...")
22
+ model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
23
+ tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
24
+
25
+ # Find the cached model directory and import adapter
26
+ adapter_path = cached_file("hemantn/ablang2", "adapter.py")
27
+ cached_model_dir = os.path.dirname(adapter_path)
28
+ sys.path.insert(0, cached_model_dir)
29
+
30
+ print(f"πŸ“ Cached model directory: {cached_model_dir}")
31
+ print(f"πŸ“„ Files in cached directory:")
32
+ for f in os.listdir(cached_model_dir):
33
+ print(f" {f}")
34
+
35
+ # Import and create the adapter
36
+ print("πŸ”§ Importing adapter...")
37
+ from adapter import AbLang2PairedHuggingFaceAdapter
38
+ ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
39
+ print("βœ… Adapter created successfully!")
40
+
41
+ # Test basic functionality
42
+ print("🧬 Testing restore functionality...")
43
+ test_seq = [
44
+ 'EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS',
45
+ 'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
46
+ ]
47
+
48
+ restored = ablang(test_seq, mode='restore')
49
+ print("βœ… Restore functionality working!")
50
+ print(f"πŸ“Š Restored sequences: {len(restored)}")
51
+
52
+ return True
53
+
54
+ except Exception as e:
55
+ print(f"❌ Error: {e}")
56
+ import traceback
57
+ traceback.print_exc()
58
+ return False
59
+
60
+ if __name__ == "__main__":
61
+ success = test_adapter_from_outside()
62
+ if success:
63
+ print("πŸŽ‰ All tests passed! The adapter works from outside the repository.")
64
+ else:
65
+ print("πŸ’₯ Tests failed. The adapter still has issues.")