hemantn commited on
Commit
1304b94
Β·
1 Parent(s): 7f948f6

Restore original adapter code from AbLang2_final_version with file copying mechanism

Browse files
Files changed (2) hide show
  1. adapter.py +48 -24
  2. test_original_compatibility.py +69 -0
adapter.py CHANGED
@@ -1,7 +1,5 @@
1
  import os
2
  import sys
3
- import torch
4
- import numpy as np
5
  import shutil
6
 
7
  # Get the directory where this adapter.py file is located
@@ -12,7 +10,7 @@ if current_dir not in sys.path:
12
  # List of utility files that need to be available
13
  UTILITY_FILES = [
14
  'restoration.py',
15
- 'ablang_encodings.py',
16
  'alignment.py',
17
  'scores.py',
18
  'extra_utils.py',
@@ -29,7 +27,7 @@ def ensure_utility_files_available():
29
  for file in UTILITY_FILES:
30
  if not os.path.exists(file):
31
  missing_files.append(file)
32
-
33
  if missing_files:
34
  # Try to find the repository root (where all utility files are)
35
  # Look for common parent directories that might contain the files
@@ -39,7 +37,7 @@ def ensure_utility_files_available():
39
  os.path.join(os.path.expanduser('~'), 'ablang2'), # Home directory
40
  '/data/hn533621/ablang2', # Known repository location
41
  ]
42
-
43
  for path in possible_paths:
44
  if os.path.exists(path):
45
  # Check if all missing files exist in this path
@@ -48,7 +46,7 @@ def ensure_utility_files_available():
48
  if not os.path.exists(os.path.join(path, file)):
49
  all_found = False
50
  break
51
-
52
  if all_found:
53
  # Copy all missing files
54
  for file in missing_files:
@@ -57,25 +55,51 @@ def ensure_utility_files_available():
57
  shutil.copy2(src, dst)
58
  print(f"βœ… Copied {file} to cached directory")
59
  return True
60
-
61
  # If we get here, we couldn't find the files
62
  raise FileNotFoundError(
63
  f"Missing utility files: {missing_files}. "
64
  "These files are required for the adapter to work. "
65
  "Please ensure the repository is properly set up."
66
  )
67
-
68
  return True
69
 
70
  # Ensure utility files are available before importing
71
  ensure_utility_files_available()
72
 
73
- # Import utility modules
74
- from restoration import AbRestore
75
- from ablang_encodings import AbEncoding
76
- from alignment import AbAlignment
77
- from scores import AbScores
78
- from extra_utils import res_to_seq, res_to_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  class HuggingFaceTokenizerAdapter:
81
  def __init__(self, tokenizer, device):
@@ -307,34 +331,34 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
307
  formatted_seqs.append('|'.join(s))
308
  else:
309
  formatted_seqs.append(s)
310
-
311
  plls = []
312
  for seq in formatted_seqs:
313
  tokens = self.tokenizer([seq], padding=True, return_tensors='pt')
314
  input_ids = extract_input_ids(tokens, self.used_device)
315
-
316
  with torch.no_grad():
317
  output = self.AbLang(input_ids)
318
  if hasattr(output, 'last_hidden_state'):
319
  logits = output.last_hidden_state
320
  else:
321
  logits = output
322
-
323
  # Get the sequence (remove batch dimension)
324
  logits = logits[0] # [seq_len, vocab_size]
325
  input_ids = input_ids[0] # [seq_len]
326
-
327
  # Exclude all special tokens (pad, mask, etc.)
328
  if isinstance(self.tokenizer.all_special_tokens[0], int):
329
  special_token_ids = set(self.tokenizer.all_special_tokens)
330
  else:
331
  special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
332
  valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device))
333
-
334
  if valid_mask.sum() > 0:
335
  valid_logits = logits[valid_mask]
336
  valid_labels = input_ids[valid_mask]
337
-
338
  # Calculate cross-entropy loss
339
  nll = torch.nn.functional.cross_entropy(
340
  valid_logits,
@@ -344,9 +368,9 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
344
  pll = -nll.item()
345
  else:
346
  pll = 0.0
347
-
348
  plls.append(pll)
349
-
350
  return np.array(plls, dtype=np.float32)
351
 
352
  def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
@@ -368,10 +392,10 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
368
  logits = self._predict_logits(formatted_seqs)
369
  else:
370
  logits = self._predict_logits(formatted_seqs)
371
-
372
  # Apply softmax to get probabilities
373
  probs = logits.softmax(-1).cpu().numpy()
374
-
375
  if align:
376
  return probs
377
  else:
 
1
  import os
2
  import sys
 
 
3
  import shutil
4
 
5
  # Get the directory where this adapter.py file is located
 
10
  # List of utility files that need to be available
11
  UTILITY_FILES = [
12
  'restoration.py',
13
+ 'ablang_encodings.py',
14
  'alignment.py',
15
  'scores.py',
16
  'extra_utils.py',
 
27
  for file in UTILITY_FILES:
28
  if not os.path.exists(file):
29
  missing_files.append(file)
30
+
31
  if missing_files:
32
  # Try to find the repository root (where all utility files are)
33
  # Look for common parent directories that might contain the files
 
37
  os.path.join(os.path.expanduser('~'), 'ablang2'), # Home directory
38
  '/data/hn533621/ablang2', # Known repository location
39
  ]
40
+
41
  for path in possible_paths:
42
  if os.path.exists(path):
43
  # Check if all missing files exist in this path
 
46
  if not os.path.exists(os.path.join(path, file)):
47
  all_found = False
48
  break
49
+
50
  if all_found:
51
  # Copy all missing files
52
  for file in missing_files:
 
55
  shutil.copy2(src, dst)
56
  print(f"βœ… Copied {file} to cached directory")
57
  return True
58
+
59
  # If we get here, we couldn't find the files
60
  raise FileNotFoundError(
61
  f"Missing utility files: {missing_files}. "
62
  "These files are required for the adapter to work. "
63
  "Please ensure the repository is properly set up."
64
  )
65
+
66
  return True
67
 
68
  # Ensure utility files are available before importing
69
  ensure_utility_files_available()
70
 
71
+ # Create the ablang2.pretrained_utils package structure
72
+ if not os.path.exists('ablang2'):
73
+ os.makedirs('ablang2', exist_ok=True)
74
+ if not os.path.exists('ablang2/pretrained_utils'):
75
+ os.makedirs('ablang2/pretrained_utils', exist_ok=True)
76
+
77
+ # Create __init__.py files
78
+ with open('ablang2/__init__.py', 'w') as f:
79
+ f.write('# Mock ablang2 package\n')
80
+
81
+ with open('ablang2/pretrained_utils/__init__.py', 'w') as f:
82
+ f.write('# Mock pretrained_utils package\n')
83
+
84
+ # Copy utility files to the package structure
85
+ for file in UTILITY_FILES:
86
+ src = os.path.join(current_dir, file)
87
+ dst = os.path.join(current_dir, 'ablang2', 'pretrained_utils', file)
88
+ if os.path.exists(src) and not os.path.exists(dst):
89
+ shutil.copy2(src, dst)
90
+
91
+ # Also copy encodings.py as encodings.py (original name)
92
+ if os.path.exists('ablang_encodings.py') and not os.path.exists('ablang2/pretrained_utils/encodings.py'):
93
+ shutil.copy2('ablang_encodings.py', 'ablang2/pretrained_utils/encodings.py')
94
+
95
+ # Now import using the original structure
96
+ from ablang2.pretrained_utils.restoration import AbRestore
97
+ from ablang2.pretrained_utils.encodings import AbEncoding
98
+ from ablang2.pretrained_utils.alignment import AbAlignment
99
+ from ablang2.pretrained_utils.scores import AbScores
100
+ import torch
101
+ import numpy as np
102
+ from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list
103
 
104
  class HuggingFaceTokenizerAdapter:
105
  def __init__(self, tokenizer, device):
 
331
  formatted_seqs.append('|'.join(s))
332
  else:
333
  formatted_seqs.append(s)
334
+
335
  plls = []
336
  for seq in formatted_seqs:
337
  tokens = self.tokenizer([seq], padding=True, return_tensors='pt')
338
  input_ids = extract_input_ids(tokens, self.used_device)
339
+
340
  with torch.no_grad():
341
  output = self.AbLang(input_ids)
342
  if hasattr(output, 'last_hidden_state'):
343
  logits = output.last_hidden_state
344
  else:
345
  logits = output
346
+
347
  # Get the sequence (remove batch dimension)
348
  logits = logits[0] # [seq_len, vocab_size]
349
  input_ids = input_ids[0] # [seq_len]
350
+
351
  # Exclude all special tokens (pad, mask, etc.)
352
  if isinstance(self.tokenizer.all_special_tokens[0], int):
353
  special_token_ids = set(self.tokenizer.all_special_tokens)
354
  else:
355
  special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
356
  valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device))
357
+
358
  if valid_mask.sum() > 0:
359
  valid_logits = logits[valid_mask]
360
  valid_labels = input_ids[valid_mask]
361
+
362
  # Calculate cross-entropy loss
363
  nll = torch.nn.functional.cross_entropy(
364
  valid_logits,
 
368
  pll = -nll.item()
369
  else:
370
  pll = 0.0
371
+
372
  plls.append(pll)
373
+
374
  return np.array(plls, dtype=np.float32)
375
 
376
  def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
 
392
  logits = self._predict_logits(formatted_seqs)
393
  else:
394
  logits = self._predict_logits(formatted_seqs)
395
+
396
  # Apply softmax to get probabilities
397
  probs = logits.softmax(-1).cpu().numpy()
398
+
399
  if align:
400
  return probs
401
  else:
test_original_compatibility.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+ from transformers import AutoModel, AutoTokenizer
6
+ from transformers.utils import cached_file
7
+
8
+ def test_original_compatibility():
9
+ """Test that our adapter produces the same results as the original"""
10
+ print("πŸ§ͺ Testing compatibility with original AbLang2_final_version...")
11
+
12
+ try:
13
+ # Load model and tokenizer
14
+ print("πŸ“₯ Loading model and tokenizer...")
15
+ model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
16
+ tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
17
+
18
+ # Find the cached model directory and import adapter
19
+ adapter_path = cached_file("hemantn/ablang2", "adapter.py")
20
+ cached_model_dir = os.path.dirname(adapter_path)
21
+ sys.path.insert(0, cached_model_dir)
22
+
23
+ # Import and create the adapter
24
+ print("πŸ”§ Importing adapter...")
25
+ from adapter import AbLang2PairedHuggingFaceAdapter
26
+ ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
27
+ print("βœ… Adapter created successfully!")
28
+
29
+ # Test with the same sequences as in the notebook
30
+ print("🧬 Testing with notebook sequences...")
31
+ test_seqs = [
32
+ ['EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS', 'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'],
33
+ ['EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS', 'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK']
34
+ ]
35
+
36
+ # Test restore functionality
37
+ print("πŸ”§ Testing restore functionality...")
38
+ restored = ablang(test_seqs, mode='restore')
39
+ print("βœ… Restore functionality working!")
40
+ print(f"πŸ“Š Restored sequences: {len(restored)}")
41
+ for i, seq in enumerate(restored):
42
+ print(f" Sequence {i+1}: {seq[:50]}...")
43
+
44
+ # Test seqcoding functionality
45
+ print("πŸ”§ Testing seqcoding functionality...")
46
+ seqcodings = ablang(test_seqs, mode='seqcoding')
47
+ print("βœ… Seqcoding functionality working!")
48
+ print(f"πŸ“Š Seqcoding shape: {seqcodings.shape}")
49
+
50
+ # Test confidence functionality
51
+ print("πŸ”§ Testing confidence functionality...")
52
+ confidence_scores = ablang(test_seqs, mode='confidence')
53
+ print("βœ… Confidence functionality working!")
54
+ print(f"πŸ“Š Confidence scores: {confidence_scores}")
55
+
56
+ return True
57
+
58
+ except Exception as e:
59
+ print(f"❌ Error: {e}")
60
+ import traceback
61
+ traceback.print_exc()
62
+ return False
63
+
64
+ if __name__ == "__main__":
65
+ success = test_original_compatibility()
66
+ if success:
67
+ print("πŸŽ‰ All tests passed! The adapter is compatible with the original.")
68
+ else:
69
+ print("πŸ’₯ Tests failed. There are compatibility issues.")