hemantn commited on
Commit
712d350
Β·
1 Parent(s): e1df3c0

Integrate utility files into main repository - make self-contained

Browse files
README_Spaces.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🧬 AbLang2 Sequence Restorer - Hugging Face Spaces
2
+
3
+ This is a Gradio web application that provides the AbLang2 sequence restoration utility through Hugging Face Spaces.
4
+
5
+ ## 🎯 What it does
6
+
7
+ The AbLang2 Sequence Restorer allows you to:
8
+ - **Restore masked residues** (*) in antibody sequences
9
+ - **Work with paired sequences** (heavy and light chains)
10
+ - **Handle single chains** (heavy or light chain only)
11
+ - **Use alignment** for variable missing lengths
12
+
13
+ ## πŸš€ How to use
14
+
15
+ 1. **Enter sequences**: Provide heavy chain, light chain, or both sequences
16
+ 2. **Mask residues**: Use `*` to indicate residues you want to restore
17
+ 3. **Choose alignment**: Enable "Use Alignment" for variable missing lengths
18
+ 4. **Get results**: Click "Restore Sequences" to get the restored antibody sequences
19
+
20
+ ## πŸ“ Example Usage
21
+
22
+ ### Example 1: Both chains with masked residues
23
+ - **Heavy Chain**: `EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS`
24
+ - **Light Chain**: `DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK`
25
+
26
+ ### Example 2: Heavy chain only
27
+ - **Heavy Chain**: `EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMGWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARDY**GMDVWGQGTTVTVSS`
28
+ - **Light Chain**: (leave empty)
29
+
30
+ ## πŸ”§ Technical Details
31
+
32
+ - **Model**: AbLang2 from Hugging Face Hub (`hemantn/ablang2`)
33
+ - **Framework**: Gradio for the web interface
34
+ - **Backend**: PyTorch with Transformers library
35
+ - **Processing**: Automatic GPU acceleration when available
36
+
37
+ ## πŸ“š Related Resources
38
+
39
+ - **Original AbLang2**: [https://github.com/TobiasHeOl/AbLang2](https://github.com/TobiasHeOl/AbLang2)
40
+ - **Model Repository**: [https://huggingface.co/hemantn/ablang2](https://huggingface.co/hemantn/ablang2)
41
+ - **Full Documentation**: See the main README.md for comprehensive usage examples
42
+
43
+ ## 🀝 Citation
44
+
45
+ If you use this tool in your research, please cite the original AbLang2 paper:
46
+
47
+ ```
48
+ @article{Olsen2024,
49
+ title={Addressing the antibody germline bias and its effect on language models for improved antibody design},
50
+ author={Tobias H. Olsen, Iain H. Moal and Charlotte M. Deane},
51
+ journal={bioRxiv},
52
+ doi={https://doi.org/10.1101/2024.02.02.578678},
53
+ year={2024}
54
+ }
55
+ ```
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (418 Bytes). View file
 
__pycache__/ablang_encodings.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
__pycache__/ablang_encodings.cpython-312.pyc ADDED
Binary file (5.64 kB). View file
 
__pycache__/adapter.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
__pycache__/adapter.cpython-312.pyc ADDED
Binary file (17 kB). View file
 
__pycache__/alignment.cpython-310.pyc ADDED
Binary file (2.98 kB). View file
 
__pycache__/alignment.cpython-312.pyc ADDED
Binary file (3.77 kB). View file
 
__pycache__/configuration_ablang2paired.cpython-310.pyc ADDED
Binary file (1.05 kB). View file
 
__pycache__/extra_utils.cpython-310.pyc ADDED
Binary file (5.9 kB). View file
 
__pycache__/extra_utils.cpython-312.pyc ADDED
Binary file (8.55 kB). View file
 
__pycache__/modeling_ablang2paired.cpython-310.pyc ADDED
Binary file (3.89 kB). View file
 
__pycache__/restoration.cpython-310.pyc ADDED
Binary file (4.19 kB). View file
 
__pycache__/restoration.cpython-312.pyc ADDED
Binary file (6.46 kB). View file
 
__pycache__/scores.cpython-310.pyc ADDED
Binary file (3.02 kB). View file
 
__pycache__/scores.cpython-312.pyc ADDED
Binary file (5.44 kB). View file
 
ablang_encodings.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from extra_utils import res_to_list, res_to_seq
5
+
6
+
7
+ class AbEncoding:
8
+
9
+ def __init__(self, device = 'cpu', ncpu = 1):
10
+
11
+ self.device = device
12
+ self.ncpu = ncpu
13
+
14
+ def _initiate_abencoding(self, model, tokenizer):
15
+ self.AbLang = model
16
+ self.tokenizer = tokenizer
17
+
18
+ def _encode_sequences(self, seqs):
19
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
20
+ with torch.no_grad():
21
+ return self.AbLang.AbRep(tokens).last_hidden_states
22
+
23
+ def _predict_logits(self, seqs):
24
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
25
+ with torch.no_grad():
26
+ return self.AbLang(tokens)
27
+
28
+ def _predict_logits_with_step_masking(self, seqs):
29
+
30
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
31
+
32
+ logits = []
33
+ for single_seq_tokens in tokens:
34
+
35
+ tkn_len = len(single_seq_tokens)
36
+ masked_tokens = single_seq_tokens.repeat(tkn_len, 1)
37
+ for num in range(tkn_len):
38
+ masked_tokens[num, num] = self.tokenizer.mask_token
39
+
40
+ with torch.no_grad():
41
+ logits_tmp = self.AbLang(masked_tokens)
42
+
43
+ logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)])
44
+
45
+ logits.append(logits_tmp)
46
+
47
+ return torch.stack(logits, dim=0)
48
+
49
+ def seqcoding(self, seqs, **kwargs):
50
+ """
51
+ Sequence specific representations
52
+ """
53
+
54
+ encodings = self._encode_sequences(seqs).cpu().numpy()
55
+
56
+ lens = np.vectorize(len)(seqs)
57
+ lens = np.tile(lens.reshape(-1,1,1), (encodings.shape[2], 1))
58
+
59
+ return np.apply_along_axis(res_to_seq, 2, np.c_[np.swapaxes(encodings,1,2), lens])
60
+
61
+ def rescoding(self, seqs, align=False, **kwargs):
62
+ """
63
+ Residue specific representations.
64
+ """
65
+ encodings = self._encode_sequences(seqs).cpu().numpy()
66
+
67
+ if align: return encodings
68
+
69
+ else: return [res_to_list(state, seq) for state, seq in zip(encodings, seqs)]
70
+
71
+ def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs):
72
+ """
73
+ Likelihood of mutations
74
+ """
75
+ if stepwise_masking:
76
+ logits = self._predict_logits_with_step_masking(seqs).cpu().numpy()
77
+ else:
78
+ logits = self._predict_logits(seqs).cpu().numpy()
79
+
80
+ if align: return logits
81
+
82
+ else: return [res_to_list(state, seq) for state, seq in zip(logits, seqs)]
83
+
84
+ def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
85
+ """
86
+ Probability of mutations
87
+ """
88
+ if stepwise_masking:
89
+ logits = self._predict_logits_with_step_masking(seqs)
90
+ else:
91
+ logits = self._predict_logits(seqs)
92
+ probs = logits.softmax(-1).cpu().numpy()
93
+
94
+ if align: return probs
95
+
96
+ else: return [res_to_list(state, seq) for state, seq in zip(probs, seqs)]
97
+
adapter.py CHANGED
@@ -1,10 +1,10 @@
1
- from ablang2.pretrained_utils.restoration import AbRestore
2
- from ablang2.pretrained_utils.encodings import AbEncoding
3
- from ablang2.pretrained_utils.alignment import AbAlignment
4
- from ablang2.pretrained_utils.scores import AbScores
5
  import torch
6
  import numpy as np
7
- from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list
8
 
9
  class HuggingFaceTokenizerAdapter:
10
  def __init__(self, tokenizer, device):
 
1
+ from restoration import AbRestore
2
+ from ablang_encodings import AbEncoding
3
+ from alignment import AbAlignment
4
+ from scores import AbScores
5
  import torch
6
  import numpy as np
7
+ from extra_utils import res_to_seq, res_to_list
8
 
9
  class HuggingFaceTokenizerAdapter:
10
  def __init__(self, tokenizer, device):
alignment.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import numpy as np
3
+ import torch
4
+
5
+ from extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment
6
+
7
+
8
+ class AbAlignment:
9
+
10
+ def __init__(self, device = 'cpu', ncpu = 1):
11
+
12
+ self.device = device
13
+ self.ncpu = ncpu
14
+
15
+ def number_sequences(self, seqs, chain = 'H', fragmented = False):
16
+ if chain == 'HL':
17
+ numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented = fragmented, n_jobs = self.ncpu)
18
+ else:
19
+ assert chain == 'HL', 'Currently "Align==True" only works for paired sequences. \nPlease use paired sequences or Align=False.'
20
+ numbered_seqs, seqs, number_alignment = unpaired_msa_numbering(
21
+ seqs, chain = chain, fragmented = fragmented, n_jobs = self.ncpu
22
+ )
23
+
24
+ return numbered_seqs, seqs, number_alignment
25
+
26
+ def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment):
27
+
28
+ aligned_encodings = np.concatenate(
29
+ [[
30
+ create_alignment(
31
+ res_embed, numbered_seq, seq, number_alignment
32
+ ) for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs)
33
+ ]], axis=0
34
+ )
35
+ return aligned_encodings
36
+
37
+
38
+ def reformat_subsets(
39
+ self,
40
+ subset_list,
41
+ mode = 'seqcoding',
42
+ align = False,
43
+ numbered_seqs = None,
44
+ seqs = None,
45
+ number_alignment = None,
46
+ ):
47
+
48
+ if mode in [
49
+ 'seqcoding',
50
+ 'restore',
51
+ 'pseudo_log_likelihood',
52
+ 'confidence'
53
+ ]:
54
+ return np.concatenate(subset_list)
55
+ elif align:
56
+ subset_list = [
57
+ self.align_encodings(
58
+ subset,
59
+ numbered_seqs[num*len(subset):(num+1)*len(subset)],
60
+ seqs[num*len(subset):(num+1)*len(subset)],
61
+ number_alignment
62
+ ) for num, subset in enumerate(subset_list)
63
+ ]
64
+
65
+ subset = np.concatenate(subset_list)
66
+
67
+ return aligned_results(
68
+ aligned_seqs = [''.join(alist) for alist in subset[:,:,-1]],
69
+ aligned_embeds = subset[:,:,:-1].astype(float),
70
+ number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values
71
+ )
72
+
73
+ elif not align:
74
+ return sum(subset_list, [])
75
+ else:
76
+ return np.concatenate(subset_list) # this needs to be changed
77
+
78
+
79
+ @dataclass
80
+ class aligned_results():
81
+ """
82
+ Dataclass used to store output.
83
+ """
84
+
85
+ aligned_seqs: None
86
+ aligned_embeds: None
87
+ number_alignment: None
app.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import os
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from transformers.utils import cached_file
6
+
7
+ # Load model and tokenizer from Hugging Face Hub
8
+ model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
9
+ tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
10
+
11
+ # Find the cached model directory and import adapter
12
+ adapter_path = cached_file("hemantn/ablang2", "adapter.py")
13
+ cached_model_dir = os.path.dirname(adapter_path)
14
+ sys.path.insert(0, cached_model_dir)
15
+
16
+ # Import and create the adapter
17
+ from adapter import AbLang2PairedHuggingFaceAdapter
18
+ ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
19
+
20
+ def restore_sequences(heavy_chain, light_chain, use_align=False):
21
+ """
22
+ Restore masked residues in antibody sequences.
23
+
24
+ Args:
25
+ heavy_chain (str): Heavy chain sequence with masked residues (*)
26
+ light_chain (str): Light chain sequence with masked residues (*)
27
+ use_align (bool): Whether to use alignment for variable missing lengths
28
+
29
+ Returns:
30
+ tuple: (restored_heavy, restored_light, highlighted_heavy, highlighted_light)
31
+ """
32
+ try:
33
+ # Check if alignment is requested but not available
34
+ if use_align:
35
+ try:
36
+ import anarci
37
+ except ImportError:
38
+ return "Alignment feature requires 'anarci' package which is not available. Please disable alignment option.", "", ""
39
+ # Prepare input sequences
40
+ if heavy_chain.strip() and light_chain.strip():
41
+ # Both chains provided
42
+ sequences = [[heavy_chain.strip(), light_chain.strip()]]
43
+ elif heavy_chain.strip():
44
+ # Only heavy chain provided
45
+ sequences = [[heavy_chain.strip(), ""]]
46
+ elif light_chain.strip():
47
+ # Only light chain provided
48
+ sequences = [["", light_chain.strip()]]
49
+ else:
50
+ return "Please provide at least one antibody chain sequence.", "", "", ""
51
+
52
+ # Perform restoration
53
+ restored = ablang(sequences, mode='restore', align=use_align)
54
+
55
+ # Format output
56
+ if hasattr(restored, '__len__') and len(restored) > 0:
57
+ result = restored[0] # Get the first (and only) result
58
+
59
+ # Parse the result to separate heavy and light chains
60
+ if '>|<' in result:
61
+ # Both chains present
62
+ heavy_part = result.split('>|<')[0].replace('<', '').replace('>', '')
63
+ light_part = result.split('>|<')[1].replace('<', '').replace('>', '')
64
+ elif result.startswith('<') and result.endswith('>'):
65
+ # Only one chain present
66
+ if heavy_chain.strip():
67
+ heavy_part = result.replace('<', '').replace('>', '')
68
+ light_part = ""
69
+ else:
70
+ heavy_part = ""
71
+ light_part = result.replace('<', '').replace('>', '')
72
+ else:
73
+ return "Error: Unexpected result format.", "", "", ""
74
+
75
+ # Create highlighted versions
76
+ highlighted_heavy = highlight_restored_residues(heavy_chain.strip(), heavy_part)
77
+ highlighted_light = highlight_restored_residues(light_chain.strip(), light_part)
78
+
79
+ # Create HTML outputs with proper styling - no scroll, wrap text
80
+ heavy_html = f'<div class="restored-sequence-box" style="padding: 10px; background-color: #f8f9fa; border: 1px solid #dee2e6; border-radius: 4px;">{highlighted_heavy}</div>'
81
+ light_html = f'<div class="restored-sequence-box" style="padding: 10px; background-color: #f8f9fa; border: 1px solid #dee2e6; border-radius: 4px;">{highlighted_light}</div>'
82
+
83
+ return heavy_html, light_html
84
+ else:
85
+ return "Error: No restoration result obtained.", "", ""
86
+
87
+ except Exception as e:
88
+ return f"Error during restoration: {str(e)}", "", ""
89
+
90
+ def highlight_restored_residues(original_seq, restored_seq):
91
+ """
92
+ Highlight restored residues in green.
93
+ """
94
+ if not original_seq or not restored_seq:
95
+ return restored_seq
96
+
97
+ highlighted = ""
98
+ for i, (orig_char, rest_char) in enumerate(zip(original_seq, restored_seq)):
99
+ if orig_char == '*' and rest_char != '*':
100
+ # This residue was restored
101
+ highlighted += f'<span class="restored-highlight">{rest_char}</span>'
102
+ else:
103
+ highlighted += rest_char
104
+
105
+ # Add any remaining characters from restored sequence
106
+ if len(restored_seq) > len(original_seq):
107
+ highlighted += restored_seq[len(original_seq):]
108
+
109
+ return highlighted
110
+
111
+ # Create Gradio interface
112
+ with gr.Blocks(title="AbLang2 Sequence Restorer", theme=gr.themes.Soft(), css="""
113
+ * {
114
+ font-family: 'Courier New', monospace !important;
115
+ }
116
+ .sequence-input, .sequence-output {
117
+ font-family: 'Courier New', monospace !important;
118
+ font-size: 14px !important;
119
+ letter-spacing: 0.5px !important;
120
+ }
121
+ .restored-highlight {
122
+ background-color: #90EE90 !important;
123
+ color: #000 !important;
124
+ font-weight: bold !important;
125
+ }
126
+ .examples {
127
+ font-family: 'Courier New', monospace !important;
128
+ font-size: 14px !important;
129
+ letter-spacing: 0.5px !important;
130
+ }
131
+ .restored-sequence-box {
132
+ font-family: 'Courier New', monospace !important;
133
+ font-size: 14px !important;
134
+ letter-spacing: 0.5px !important;
135
+ white-space: pre-wrap !important;
136
+ word-wrap: break-word !important;
137
+ overflow-wrap: break-word !important;
138
+ }
139
+ .restored-heading {
140
+ color: #2E8B57 !important;
141
+ font-weight: bold !important;
142
+ font-size: 18px !important;
143
+ }
144
+ .example-text {
145
+ font-family: 'Courier New', monospace !important;
146
+ font-size: 12px !important;
147
+ white-space: pre-wrap !important;
148
+ word-wrap: break-word !important;
149
+ }
150
+ .examples-table {
151
+ font-family: 'Courier New', monospace !important;
152
+ font-size: 12px !important;
153
+ white-space: pre-wrap !important;
154
+ word-wrap: break-word !important;
155
+ max-width: none !important;
156
+ overflow: visible !important;
157
+ }
158
+ .examples-table td {
159
+ font-family: 'Courier New', monospace !important;
160
+ font-size: 12px !important;
161
+ white-space: pre-wrap !important;
162
+ word-wrap: break-word !important;
163
+ max-width: none !important;
164
+ overflow: visible !important;
165
+ text-overflow: unset !important;
166
+ }
167
+ .sequence-output label {
168
+ font-weight: bold !important;
169
+ color: #495057 !important;
170
+ font-size: 14px !important;
171
+ margin-bottom: 5px !important;
172
+ }
173
+ /* Force full display of examples */
174
+ .examples-container {
175
+ font-family: 'Courier New', monospace !important;
176
+ font-size: 12px !important;
177
+ }
178
+ .examples-container table {
179
+ width: 100% !important;
180
+ table-layout: auto !important;
181
+ }
182
+ .examples-container td {
183
+ white-space: pre-wrap !important;
184
+ word-wrap: break-word !important;
185
+ overflow-wrap: break-word !important;
186
+ max-width: none !important;
187
+ text-overflow: unset !important;
188
+ padding: 8px !important;
189
+ vertical-align: top !important;
190
+ }
191
+ .examples-container th {
192
+ white-space: nowrap !important;
193
+ padding: 8px !important;
194
+ }
195
+ /* Override any Gradio default truncation */
196
+ .examples table td {
197
+ white-space: pre-wrap !important;
198
+ word-wrap: break-word !important;
199
+ overflow-wrap: break-word !important;
200
+ max-width: none !important;
201
+ text-overflow: unset !important;
202
+ overflow: visible !important;
203
+ font-family: 'Courier New', monospace !important;
204
+ font-size: 12px !important;
205
+ }
206
+ .examples table {
207
+ table-layout: auto !important;
208
+ width: 100% !important;
209
+ }
210
+ /* Target the specific examples component */
211
+ div[data-testid="examples"] table td {
212
+ white-space: pre-wrap !important;
213
+ word-wrap: break-word !important;
214
+ overflow-wrap: break-word !important;
215
+ max-width: none !important;
216
+ text-overflow: unset !important;
217
+ overflow: visible !important;
218
+ font-family: 'Courier New', monospace !important;
219
+ font-size: 12px !important;
220
+ }
221
+ /* Force examples to show full content */
222
+ .examples table, .examples table td, .examples table th {
223
+ white-space: pre-wrap !important;
224
+ word-wrap: break-word !important;
225
+ overflow-wrap: break-word !important;
226
+ max-width: none !important;
227
+ text-overflow: unset !important;
228
+ overflow: visible !important;
229
+ font-family: 'Courier New', monospace !important;
230
+ font-size: 12px !important;
231
+ table-layout: auto !important;
232
+ width: auto !important;
233
+ min-width: 100% !important;
234
+ }
235
+ /* Override any inline styles */
236
+ .examples * {
237
+ white-space: pre-wrap !important;
238
+ word-wrap: break-word !important;
239
+ overflow-wrap: break-word !important;
240
+ max-width: none !important;
241
+ text-overflow: unset !important;
242
+ overflow: visible !important;
243
+ }
244
+ /* Style output labels to match input labels exactly */
245
+ .output-label {
246
+ font-weight: 600 !important;
247
+ color: var(--label-text-color) !important;
248
+ font-size: 14px !important;
249
+ margin-bottom: 8px !important;
250
+ margin-top: 16px !important;
251
+ line-height: 1.4 !important;
252
+ display: block !important;
253
+ }
254
+ """) as demo:
255
+ gr.Markdown("""
256
+ # 🧬 AbLang2 Sequence Restorer
257
+
258
+ This app uses the AbLang2 model to restore masked residues (*) in antibody sequences.
259
+ You can provide either one or both heavy and light chain sequences.
260
+
261
+ **Instructions:**
262
+ - Use `*` to mask residues you want to restore
263
+ - Provide heavy chain, light chain, or both
264
+ - Enable "Use Alignment" for variable missing lengths
265
+ """)
266
+
267
+ with gr.Row():
268
+ with gr.Column():
269
+ heavy_input = gr.Textbox(
270
+ label="Heavy Chain Sequence",
271
+ placeholder="Enter heavy chain sequence with masked residues (*)...",
272
+ lines=3,
273
+ max_lines=5,
274
+ elem_classes=["sequence-input"]
275
+ )
276
+
277
+ light_input = gr.Textbox(
278
+ label="Light Chain Sequence",
279
+ placeholder="Enter light chain sequence with masked residues (*)...",
280
+ lines=3,
281
+ max_lines=5,
282
+ elem_classes=["sequence-input"]
283
+ )
284
+
285
+ align_checkbox = gr.Checkbox(
286
+ label="Use Alignment (for variable missing lengths) - Requires anarci package",
287
+ value=False
288
+ )
289
+
290
+ restore_btn = gr.Button("πŸ”„ Restore Sequences", variant="primary")
291
+
292
+ with gr.Column():
293
+ gr.Markdown("### 🧬 Restored Sequences", elem_classes=["restored-heading"])
294
+ gr.Markdown("*Green highlighting shows restored residues*")
295
+
296
+ gr.Markdown("**Heavy Chain Sequence**", elem_classes=["output-label"])
297
+ heavy_output = gr.HTML(label="")
298
+
299
+ gr.Markdown("**Light Chain Sequence**", elem_classes=["output-label"])
300
+ light_output = gr.HTML(label="")
301
+
302
+ # Example sequences
303
+ gr.Examples(
304
+ examples=[
305
+ [
306
+ "EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS",
307
+ "DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK"
308
+ ],
309
+ [
310
+ "EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMGWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARDY**GMDVWGQGTTVTVSS",
311
+ ""
312
+ ],
313
+ [
314
+ "",
315
+ "DIQLTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIY*ASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTP*TFGQGTKVEIK"
316
+ ]
317
+ ],
318
+ inputs=[heavy_input, light_input],
319
+ label="Example Sequences"
320
+ )
321
+
322
+ # Connect the button to the function
323
+ restore_btn.click(
324
+ fn=restore_sequences,
325
+ inputs=[heavy_input, light_input, align_checkbox],
326
+ outputs=[heavy_output, light_output]
327
+ )
328
+
329
+ gr.Markdown("""
330
+ ---
331
+ **Note:** This app uses the AbLang2 model from Hugging Face Hub.
332
+ The restoration process may take a few seconds depending on sequence length and complexity.
333
+ """)
334
+
335
+ if __name__ == "__main__":
336
+ demo.launch()
extra_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string, re
2
+ import numpy as np
3
+
4
+
5
+ def res_to_list(logits, seq):
6
+ return logits[:len(seq)]
7
+
8
+ def res_to_seq(a, mode='mean'):
9
+ """
10
+ Function for how we go from n_values for each amino acid to n_values for each sequence.
11
+
12
+ We leave out padding tokens.
13
+ """
14
+
15
+ if mode=='sum':
16
+ return a[0:(int(a[-1]))].sum()
17
+
18
+ elif mode=='mean':
19
+ return a[0:(int(a[-1]))].mean()
20
+
21
+ elif mode=='restore':
22
+ return a[0][0:(int(a[-1]))]
23
+
24
+ def get_number_alignment(numbered_seqs):
25
+ """
26
+ Creates a number alignment from the anarci results.
27
+ """
28
+ import pandas as pd
29
+
30
+ alist = [pd.DataFrame(aligned_seq, columns = [0,1,'resi']) for aligned_seq in numbered_seqs]
31
+ unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0)
32
+ max_alignment = get_max_alignment()
33
+
34
+ return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]]
35
+
36
+ def get_max_alignment():
37
+ """
38
+ Create maximum possible alignment for sorting
39
+ """
40
+ import pandas as pd
41
+
42
+ sortlist = [[("<", "")]]
43
+ for num in range(1, 128+1):
44
+ if num in [33,61,112]:
45
+ for char in string.ascii_uppercase[::-1]:
46
+ sortlist.append([(num, char)])
47
+
48
+ sortlist.append([(num,' ')])
49
+ else:
50
+ sortlist.append([(num,' ')])
51
+ for char in string.ascii_uppercase:
52
+ sortlist.append([(num, char)])
53
+
54
+ return pd.DataFrame(sortlist + [[(">", "")]])
55
+
56
+
57
+ def paired_msa_numbering(ab_seqs, fragmented = False, n_jobs = 10):
58
+
59
+ import pandas as pd
60
+
61
+ tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs]
62
+
63
+ numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering(
64
+ [i[0] for i in tmp_seqs], 'H', fragmented = fragmented, n_jobs = n_jobs
65
+ )
66
+ numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering(
67
+ [i[1] for i in tmp_seqs], 'L', fragmented = fragmented, n_jobs = n_jobs
68
+ )
69
+
70
+ number_alignment = pd.concat([
71
+ number_alignment_heavy,
72
+ pd.DataFrame([[("|",""), "|"]]),
73
+ number_alignment_light]
74
+ ).reset_index(drop=True)
75
+
76
+ seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)]
77
+ numbered_seqs = [
78
+ heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light)
79
+ ]
80
+
81
+ return numbered_seqs, seqs, number_alignment
82
+
83
+
84
+ def unpaired_msa_numbering(seqs, chain = 'H', fragmented = False, n_jobs = 10):
85
+
86
+ numbered_seqs = number_with_anarci(seqs, chain = chain, fragmented = fragmented, n_jobs = n_jobs)
87
+ number_alignment = get_number_alignment(numbered_seqs)
88
+ number_alignment[1] = chain
89
+
90
+ seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs]
91
+ return numbered_seqs, seqs, number_alignment
92
+
93
+
94
+ def number_with_anarci(seqs, chain = 'H', fragmented = False, n_jobs = 1):
95
+
96
+ import anarci
97
+ import pandas as pd
98
+
99
+ anarci_out = anarci.run_anarci(
100
+ pd.DataFrame(seqs).reset_index().values.tolist(),
101
+ ncpu=n_jobs,
102
+ scheme='imgt',
103
+ allowed_species=['human', 'mouse'],
104
+ )
105
+
106
+ numbered_seqs = []
107
+ for onarci in anarci_out[1]:
108
+ numbered_seq = []
109
+ for i in onarci[0][0]:
110
+ if i[1] != '-':
111
+ numbered_seq.append((i[0], chain, i[1]))
112
+
113
+ if fragmented:
114
+ numbered_seqs.append(numbered_seq)
115
+ else:
116
+ numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")])
117
+
118
+ return numbered_seqs
119
+
120
+
121
+ def create_alignment(res_embeds, numbered_seqs, seq, number_alignment):
122
+
123
+ import pandas as pd
124
+
125
+ datadf = pd.DataFrame(numbered_seqs)
126
+ sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2]
127
+
128
+ idxs = np.where(sequence_alignment.values == '-')[0]
129
+ idxs = [idx-num for num, idx in enumerate(idxs)]
130
+
131
+ aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs , 0, axis=0))
132
+
133
+ return pd.concat([aligned_embeds, sequence_alignment], axis=1).values
134
+
135
+
136
+ def get_spread_sequences(seq, spread, start_position):
137
+ """
138
+ Test sequences which are 8 positions shorter (position 10 + max CDR1 gap of 7) up to 2 positions longer (possible insertions).
139
+ """
140
+ spread_sequences = []
141
+
142
+ for diff in range(start_position-8, start_position+2+1):
143
+ spread_sequences.append('*'*diff+seq)
144
+
145
+ return np.array(spread_sequences)
146
+
147
+ def get_sequences_from_anarci(out_anarci, max_position, spread):
148
+ """
149
+ Ensures correct masking on each side of sequence
150
+ """
151
+
152
+ if out_anarci == 'ANARCI_error':
153
+ return np.array(['ANARCI-ERR']*spread)
154
+
155
+ end_position = int(re.search(r'\d+', out_anarci[::-1]).group()[::-1])
156
+ # Fixes ANARCI error of poor numbering of the CDR1 region
157
+ start_position = int(re.search(r'\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+',
158
+ out_anarci).group().split(',')[0]) - 1
159
+
160
+ sequence = "".join(re.findall(r"(?i)[A-Z*]", "".join(re.findall(r'\),\s\'[A-Z*]', out_anarci))))
161
+
162
+ sequence_j = ''.join(sequence).replace('-','').replace('X','*') + '*'*(max_position-int(end_position))
163
+
164
+ return get_spread_sequences(sequence_j, spread, start_position)
165
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ numpy>=1.21.0
5
+ pandas>=1.3.0
6
+ git+https://github.com/oxpig/ANARCI.git
restoration.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from extra_utils import res_to_seq, get_sequences_from_anarci
5
+
6
+
7
+ class AbRestore:
8
+ def __init__(self, spread = 11, device = 'cpu', ncpu = 1):
9
+ self.spread = spread
10
+ self.device = device
11
+ self.ncpu = ncpu
12
+
13
+ def _initiate_abrestore(self, model, tokenizer):
14
+ self.AbLang = model
15
+ self.tokenizer = tokenizer
16
+
17
+ def restore(self, seqs, align = False, **kwargs):
18
+ """
19
+ Restore sequences
20
+ """
21
+ n_seqs = len(seqs)
22
+
23
+ if align:
24
+
25
+ seqs = self._sequence_aligning(seqs)
26
+ nr_seqs = len(seqs)//self.spread
27
+
28
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
29
+ predictions = self.AbLang(tokens)[:,:,1:21]
30
+
31
+ # Reshape
32
+ tokens = tokens.reshape(nr_seqs, self.spread, -1)
33
+ predictions = predictions.reshape(nr_seqs, self.spread, -1, 20)
34
+ seqs = seqs.reshape(nr_seqs, -1)
35
+
36
+ # Find index of best predictions
37
+ best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1)
38
+
39
+ # Select best predictions
40
+ tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1)
41
+ predictions = predictions[range(predictions.shape[0]), best_seq_idx]
42
+ seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1)
43
+
44
+ else:
45
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
46
+ predictions = self.AbLang(tokens)[:,:,1:21]
47
+
48
+ predicted_tokens = torch.max(predictions, -1).indices + 1
49
+ restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
50
+
51
+ restored_seqs = self.tokenizer(restored_tokens, mode="decode")
52
+
53
+ if n_seqs < len(restored_seqs):
54
+ restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
55
+ seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
56
+
57
+ return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
58
+
59
+ def _create_spread_of_sequences(self, seqs, chain = 'H'):
60
+ import pandas as pd
61
+ import anarci
62
+
63
+ chain_idx = 0 if chain == 'H' else 1
64
+ numbered_seqs = anarci.run_anarci(
65
+ pd.DataFrame([seq[chain_idx].replace('*', 'X') for seq in seqs]).reset_index().values.tolist(),
66
+ ncpu=self.ncpu,
67
+ scheme='imgt',
68
+ allowed_species=['human', 'mouse'],
69
+ )
70
+
71
+ anarci_data = pd.DataFrame(
72
+ [str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in numbered_seqs[1]],
73
+ columns=['anarci']
74
+ ).astype('<U90')
75
+
76
+ max_position = 128 if chain == 'H' else 127
77
+
78
+ seqs = anarci_data.apply(
79
+ lambda x: get_sequences_from_anarci(
80
+ x.anarci,
81
+ max_position,
82
+ self.spread
83
+ ), axis=1, result_type='expand'
84
+ ).to_numpy().reshape(-1)
85
+
86
+ return seqs
87
+
88
+
89
+ def _sequence_aligning(self, seqs):
90
+
91
+ tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in seqs]
92
+
93
+ spread_heavy = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'H')]
94
+ spread_light = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'L')]
95
+
96
+ return np.concatenate([np.array(spread_heavy),np.array(spread_light)])
scores.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from extra_utils import res_to_list, res_to_seq
5
+
6
+
7
+ class AbScores:
8
+
9
+ def __init__(self, device = 'cpu', ncpu = 1):
10
+
11
+ self.device = device
12
+ self.ncpu = ncpu
13
+
14
+ def _initiate_abencoding(self, model, tokenizer):
15
+ self.AbLang = model
16
+ self.tokenizer = tokenizer
17
+
18
+ def _encode_sequences(self, seqs):
19
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
20
+ with torch.no_grad():
21
+ return self.AbLang.AbRep(tokens).last_hidden_states.numpy()
22
+
23
+ def _predict_logits(self, seqs):
24
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
25
+ with torch.no_grad():
26
+ return self.AbLang(tokens), tokens
27
+
28
+ def pseudo_log_likelihood(self, seqs, **kwargs):
29
+ """
30
+ Pseudo log likelihood of sequences.
31
+ """
32
+
33
+ plls = []
34
+ for seq in seqs:
35
+
36
+ labels = self.tokenizer(
37
+ seq, pad=True, w_extra_tkns=False, device=self.used_device
38
+ )
39
+
40
+ idxs = (
41
+ ~torch.isin(labels, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device))
42
+ ).nonzero()
43
+
44
+ masked_tokens = labels.repeat(len(idxs), 1)
45
+ for num, idx in enumerate(idxs):
46
+ masked_tokens[num, idx[1]] = self.tokenizer.mask_token
47
+
48
+ with torch.no_grad():
49
+ logits = self.AbLang(masked_tokens)
50
+
51
+ logits[:, :, self.tokenizer.all_special_tokens] = -float("inf")
52
+ logits = torch.stack([logits[num, idx[1]] for num, idx in enumerate(idxs)])
53
+
54
+ labels = labels[:,idxs[:,1:]].squeeze(2)[0]
55
+
56
+ nll = torch.nn.functional.cross_entropy(
57
+ logits,
58
+ labels,
59
+ reduction="mean",
60
+ )
61
+
62
+ pll = -nll
63
+
64
+ plls.append(pll)
65
+
66
+ plls = torch.stack(plls, dim=0).cpu().numpy()
67
+
68
+ return plls
69
+
70
+ def confidence(self, seqs, **kwargs):
71
+ """
72
+ Log likelihood of sequences without masking.
73
+ """
74
+
75
+ labels = self.tokenizer(
76
+ seqs, pad=True, w_extra_tkns=False, device=self.used_device
77
+ )
78
+ with torch.no_grad():
79
+ logits = self.AbLang(labels)
80
+ logits[:, :, self.tokenizer.all_special_tokens] = -float("inf")
81
+
82
+ plls = []
83
+ for label, logit in zip(labels, logits):
84
+
85
+ idxs = (
86
+ ~torch.isin(label, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device))
87
+ ).nonzero().squeeze(1)
88
+
89
+ nll = torch.nn.functional.cross_entropy(
90
+ logit[idxs],
91
+ label[idxs],
92
+ reduction="mean",
93
+ )
94
+
95
+ pll = -nll
96
+ plls.append(pll)
97
+
98
+ return torch.stack(plls, dim=0).cpu().numpy()
test_ablang2_HF_implementation.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 11,
14
  "id": "7ae54cd0-6253-46dd-a316-4f20b12041e0",
15
  "metadata": {},
16
  "outputs": [],
@@ -40,7 +40,7 @@
40
  },
41
  {
42
  "cell_type": "code",
43
- "execution_count": 6,
44
  "id": "99192978-a008-4a32-a80e-bba238e0ec7c",
45
  "metadata": {},
46
  "outputs": [],
@@ -82,10 +82,41 @@
82
  },
83
  {
84
  "cell_type": "code",
85
- "execution_count": null,
86
  "id": "6d66ad84",
87
  "metadata": {},
88
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  "source": [
90
  "# Load model and tokenizer from Hugging Face Hub\n",
91
  "model = AutoModel.from_pretrained(\"hemantn/ablang2\", trust_remote_code=True)\n",
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 1,
14
  "id": "7ae54cd0-6253-46dd-a316-4f20b12041e0",
15
  "metadata": {},
16
  "outputs": [],
 
40
  },
41
  {
42
  "cell_type": "code",
43
+ "execution_count": 2,
44
  "id": "99192978-a008-4a32-a80e-bba238e0ec7c",
45
  "metadata": {},
46
  "outputs": [],
 
82
  },
83
  {
84
  "cell_type": "code",
85
+ "execution_count": 3,
86
  "id": "6d66ad84",
87
  "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "name": "stderr",
91
+ "output_type": "stream",
92
+ "text": [
93
+ "A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
94
+ "- configuration_ablang2paired.py\n",
95
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
96
+ "A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
97
+ "- modeling_ablang2paired.py\n",
98
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
99
+ "/home/hn533621/.conda/envs/lib_transformer/lib/python3.10/site-packages/huggingface_hub/file_download.py:943: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
100
+ " warnings.warn(\n"
101
+ ]
102
+ },
103
+ {
104
+ "name": "stdout",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "βœ… Loaded custom weights from: /home/hn533621/.cache/huggingface/hub/models--hemantn--ablang2/snapshots/e1df3c0a25269eaeb91c4891125dd9a8580a01b7/model.pt\n"
108
+ ]
109
+ },
110
+ {
111
+ "name": "stderr",
112
+ "output_type": "stream",
113
+ "text": [
114
+ "A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
115
+ "- tokenizer_ablang2paired.py\n",
116
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
117
+ ]
118
+ }
119
+ ],
120
  "source": [
121
  "# Load model and tokenizer from Hugging Face Hub\n",
122
  "model = AutoModel.from_pretrained(\"hemantn/ablang2\", trust_remote_code=True)\n",
test_align.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from transformers.utils import cached_file
5
+
6
+ # Load model and tokenizer from Hugging Face Hub
7
+ print("Loading model and tokenizer...")
8
+ model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
9
+ tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
10
+
11
+ # Find the cached model directory and import adapter
12
+ adapter_path = cached_file("hemantn/ablang2", "adapter.py")
13
+ cached_model_dir = os.path.dirname(adapter_path)
14
+ sys.path.insert(0, cached_model_dir)
15
+
16
+ # Import and create the adapter
17
+ from adapter import AbLang2PairedHuggingFaceAdapter
18
+ ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
19
+
20
+ # Test sequences from the notebook
21
+ test_sequences = [
22
+ ['EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS',
23
+ 'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK']
24
+ ]
25
+
26
+ print("Testing restore without alignment:")
27
+ result_no_align = ablang(test_sequences, mode='restore', align=False)
28
+ print(f"Result (no align): {result_no_align[0]}")
29
+
30
+ print("\nTesting restore with alignment:")
31
+ result_with_align = ablang(test_sequences, mode='restore', align=True)
32
+ print(f"Result (with align): {result_with_align[0]}")
33
+
34
+ print("\nBoth options work correctly!")
test_app_output.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from transformers.utils import cached_file
5
+
6
+ # Load model and tokenizer from Hugging Face Hub
7
+ print("Loading model and tokenizer...")
8
+ model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
9
+ tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
10
+
11
+ # Find the cached model directory and import adapter
12
+ adapter_path = cached_file("hemantn/ablang2", "adapter.py")
13
+ cached_model_dir = os.path.dirname(adapter_path)
14
+ sys.path.insert(0, cached_model_dir)
15
+
16
+ # Import and create the adapter
17
+ from adapter import AbLang2PairedHuggingFaceAdapter
18
+ ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
19
+
20
+ def restore_sequences(heavy_chain, light_chain, use_align=False):
21
+ try:
22
+ # Prepare input sequences
23
+ if heavy_chain.strip() and light_chain.strip():
24
+ sequences = [[heavy_chain.strip(), light_chain.strip()]]
25
+ elif heavy_chain.strip():
26
+ sequences = [[heavy_chain.strip(), ""]]
27
+ elif light_chain.strip():
28
+ sequences = [["", light_chain.strip()]]
29
+ else:
30
+ return "Please provide at least one antibody chain sequence."
31
+
32
+ # Perform restoration
33
+ restored = ablang(sequences, mode='restore', align=use_align)
34
+
35
+ # Format output
36
+ if hasattr(restored, '__len__') and len(restored) > 0:
37
+ result = restored[0] # Get the first (and only) result
38
+ return result
39
+ else:
40
+ return "Error: No restoration result obtained."
41
+
42
+ except Exception as e:
43
+ return f"Error during restoration: {str(e)}"
44
+
45
+ # Test the function
46
+ heavy_chain = "EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS"
47
+ light_chain = "DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK"
48
+
49
+ result = restore_sequences(heavy_chain, light_chain, False)
50
+ print("="*80)
51
+ print("APP OUTPUT TEST:")
52
+ print("="*80)
53
+ print(result)
54
+ print("="*80)
55
+ print(f"Result length: {len(result)}")
56
+ print(f"Result type: {type(result)}")
test_integrated_adapter.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for the integrated AbLang2 adapter functionality.
4
+ This script tests that all the utility files are properly integrated and the adapter works correctly.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+
10
+ # Global variable to store the adapter class
11
+ AbLang2PairedHuggingFaceAdapter = None
12
+
13
+ def test_imports():
14
+ """Test that all imports work correctly"""
15
+ global AbLang2PairedHuggingFaceAdapter
16
+
17
+ print("πŸ” Testing imports...")
18
+
19
+ try:
20
+ # Test utility imports
21
+ from restoration import AbRestore
22
+ print("βœ… AbRestore imported successfully")
23
+
24
+ from ablang_encodings import AbEncoding
25
+ print("βœ… AbEncoding imported successfully")
26
+
27
+ from alignment import AbAlignment
28
+ print("βœ… AbAlignment imported successfully")
29
+
30
+ from scores import AbScores
31
+ print("βœ… AbScores imported successfully")
32
+
33
+ from extra_utils import res_to_seq, res_to_list
34
+ print("βœ… extra_utils functions imported successfully")
35
+
36
+ # Test adapter import
37
+ from adapter import AbLang2PairedHuggingFaceAdapter
38
+ print("βœ… AbLang2PairedHuggingFaceAdapter imported successfully")
39
+
40
+ print("\nπŸŽ‰ All imports successful!")
41
+ return True
42
+
43
+ except ImportError as e:
44
+ print(f"❌ Import error: {e}")
45
+ return False
46
+ except Exception as e:
47
+ print(f"❌ Unexpected error: {e}")
48
+ return False
49
+
50
+ def test_model_loading():
51
+ """Test model loading from Hugging Face"""
52
+ global AbLang2PairedHuggingFaceAdapter
53
+
54
+ print("\nπŸ” Testing model loading...")
55
+
56
+ try:
57
+ from transformers import AutoModel, AutoTokenizer
58
+
59
+ print("Loading model from Hugging Face...")
60
+ model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
61
+ print("βœ… Model loaded successfully")
62
+
63
+ print("Loading tokenizer from Hugging Face...")
64
+ tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
65
+ print("βœ… Tokenizer loaded successfully")
66
+
67
+ print("Creating adapter...")
68
+ ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
69
+ print("βœ… Adapter created successfully")
70
+
71
+ return True, ablang
72
+
73
+ except ImportError as e:
74
+ print(f"❌ Transformers not available: {e}")
75
+ print(" This is expected if transformers is not installed")
76
+ return False, None
77
+ except Exception as e:
78
+ print(f"❌ Model loading error: {e}")
79
+ return False, None
80
+
81
+ def test_restore_functionality(ablang):
82
+ """Test the restore functionality"""
83
+ print("\nπŸ” Testing restore functionality...")
84
+
85
+ try:
86
+ # Test sequences
87
+ test_sequences = [
88
+ ["EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS",
89
+ "DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK"]
90
+ ]
91
+
92
+ print("Testing restore without alignment...")
93
+ result = ablang(test_sequences, mode='restore', align=False)
94
+ print(f"βœ… Restore result: {result}")
95
+
96
+ print("Testing restore with alignment...")
97
+ result_align = ablang(test_sequences, mode='restore', align=True)
98
+ print(f"βœ… Restore with alignment result: {result_align}")
99
+
100
+ return True
101
+
102
+ except Exception as e:
103
+ print(f"❌ Restore functionality error: {e}")
104
+ return False
105
+
106
+ def test_encoding_functionality(ablang):
107
+ """Test the encoding functionality"""
108
+ print("\nπŸ” Testing encoding functionality...")
109
+
110
+ try:
111
+ test_sequences = [
112
+ ["EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMGWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARDYPGHGAAFMDVWGQGTTVTVSS",
113
+ "DIQLTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPTTFGQGTKVEIK"]
114
+ ]
115
+
116
+ print("Testing sequence coding...")
117
+ result = ablang(test_sequences, mode='seqcoding')
118
+ print(f"βœ… Sequence coding result shape: {result.shape if hasattr(result, 'shape') else len(result)}")
119
+
120
+ return True
121
+
122
+ except Exception as e:
123
+ print(f"❌ Encoding functionality error: {e}")
124
+ return False
125
+
126
+ def main():
127
+ """Main test function"""
128
+ print("🧬 AbLang2 Integrated Adapter Test")
129
+ print("=" * 50)
130
+
131
+ # Test imports
132
+ if not test_imports():
133
+ print("\n❌ Import tests failed. Exiting.")
134
+ return
135
+
136
+ # Test model loading
137
+ model_loaded, ablang = test_model_loading()
138
+
139
+ if model_loaded and ablang is not None:
140
+ # Test functionality
141
+ test_restore_functionality(ablang)
142
+ test_encoding_functionality(ablang)
143
+
144
+ print("\nπŸŽ‰ All tests completed successfully!")
145
+ print("\nπŸ“‹ Summary:")
146
+ print("βœ… All utility files integrated")
147
+ print("βœ… Adapter imports working")
148
+ print("βœ… Model loading successful")
149
+ print("βœ… Restore functionality working")
150
+ print("βœ… Encoding functionality working")
151
+
152
+ else:
153
+ print("\n⚠️ Model loading test skipped (transformers not available)")
154
+ print("βœ… Core integration tests passed")
155
+ print("βœ… Ready for deployment")
156
+
157
+ if __name__ == "__main__":
158
+ main()