hemantn commited on
Commit
0e2f128
·
1 Parent(s): 3b6231c

deoloyment file added

Browse files
Files changed (5) hide show
  1. LICENSE +21 -0
  2. README_Spaces.md +55 -0
  3. adapter.py +306 -0
  4. app.py +330 -0
  5. requirements.txt +6 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 hemantn
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README_Spaces.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🧬 AbLang2 Sequence Restorer - Hugging Face Spaces
2
+
3
+ This is a Gradio web application that provides the AbLang2 sequence restoration utility through Hugging Face Spaces.
4
+
5
+ ## 🎯 What it does
6
+
7
+ The AbLang2 Sequence Restorer allows you to:
8
+ - **Restore masked residues** (*) in antibody sequences
9
+ - **Work with paired sequences** (heavy and light chains)
10
+ - **Handle single chains** (heavy or light chain only)
11
+ - **Use alignment** for variable missing lengths
12
+
13
+ ## 🚀 How to use
14
+
15
+ 1. **Enter sequences**: Provide heavy chain, light chain, or both sequences
16
+ 2. **Mask residues**: Use `*` to indicate residues you want to restore
17
+ 3. **Choose alignment**: Enable "Use Alignment" for variable missing lengths
18
+ 4. **Get results**: Click "Restore Sequences" to get the restored antibody sequences
19
+
20
+ ## 📝 Example Usage
21
+
22
+ ### Example 1: Both chains with masked residues
23
+ - **Heavy Chain**: `EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS`
24
+ - **Light Chain**: `DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK`
25
+
26
+ ### Example 2: Heavy chain only
27
+ - **Heavy Chain**: `EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMGWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARDY**GMDVWGQGTTVTVSS`
28
+ - **Light Chain**: (leave empty)
29
+
30
+ ## 🔧 Technical Details
31
+
32
+ - **Model**: AbLang2 from Hugging Face Hub (`hemantn/ablang2`)
33
+ - **Framework**: Gradio for the web interface
34
+ - **Backend**: PyTorch with Transformers library
35
+ - **Processing**: Automatic GPU acceleration when available
36
+
37
+ ## 📚 Related Resources
38
+
39
+ - **Original AbLang2**: [https://github.com/TobiasHeOl/AbLang2](https://github.com/TobiasHeOl/AbLang2)
40
+ - **Model Repository**: [https://huggingface.co/hemantn/ablang2](https://huggingface.co/hemantn/ablang2)
41
+ - **Full Documentation**: See the main README.md for comprehensive usage examples
42
+
43
+ ## 🤝 Citation
44
+
45
+ If you use this tool in your research, please cite the original AbLang2 paper:
46
+
47
+ ```
48
+ @article{Olsen2024,
49
+ title={Addressing the antibody germline bias and its effect on language models for improved antibody design},
50
+ author={Tobias H. Olsen, Iain H. Moal and Charlotte M. Deane},
51
+ journal={bioRxiv},
52
+ doi={https://doi.org/10.1101/2024.02.02.578678},
53
+ year={2024}
54
+ }
55
+ ```
adapter.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ablang2.pretrained_utils.restoration import AbRestore
2
+ from ablang2.pretrained_utils.encodings import AbEncoding
3
+ from ablang2.pretrained_utils.alignment import AbAlignment
4
+ from ablang2.pretrained_utils.scores import AbScores
5
+ import torch
6
+ import numpy as np
7
+ from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list
8
+
9
+ class HuggingFaceTokenizerAdapter:
10
+ def __init__(self, tokenizer, device):
11
+ self.tokenizer = tokenizer
12
+ self.device = device
13
+ self.pad_token_id = tokenizer.pad_token_id
14
+ self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
15
+ self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab
16
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
17
+ self.all_special_tokens = tokenizer.all_special_tokens
18
+
19
+ def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None):
20
+ tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
21
+ input_ids = tokens['input_ids'].to(self.device if device is None else device)
22
+ if mode == 'decode':
23
+ # seqs is a tensor of token ids
24
+ if isinstance(seqs, torch.Tensor):
25
+ seqs = seqs.cpu().numpy()
26
+ decoded = []
27
+ for i, seq in enumerate(seqs):
28
+ chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != '']
29
+ # Use res_to_seq for formatting, pass (sequence, length) tuple as in original code
30
+ # The length is not always available, so use len(chars) as fallback
31
+ formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore')
32
+ decoded.append(formatted)
33
+ return decoded
34
+ return input_ids
35
+
36
+ class HFAbRestore(AbRestore):
37
+ def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1):
38
+ super().__init__(spread=spread, device=device, ncpu=ncpu)
39
+ self.used_device = device
40
+ self._hf_model = hf_model
41
+ self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device)
42
+
43
+ @property
44
+ def AbLang(self):
45
+ def model_call(x):
46
+ output = self._hf_model(x)
47
+ if hasattr(output, 'last_hidden_state'):
48
+ return output.last_hidden_state
49
+ return output
50
+ return model_call
51
+
52
+ def add_angle_brackets(seq):
53
+ # Assumes input is 'VH|VL' or 'VH|' or '|VL'
54
+ if '|' in seq:
55
+ vh, vl = seq.split('|', 1)
56
+ else:
57
+ vh, vl = seq, ''
58
+ return f"<{vh}>|<{vl}>"
59
+
60
+ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores):
61
+ """
62
+ Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer.
63
+ Automatically uses CUDA if available, otherwise CPU.
64
+ """
65
+ def __init__(self, model, tokenizer, device=None, ncpu=1):
66
+ super().__init__()
67
+ if device is None:
68
+ self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
+ else:
70
+ self.used_device = torch.device(device)
71
+ self.AbLang = model # HuggingFace model instance
72
+ self.tokenizer = tokenizer
73
+ self.AbLang.to(self.used_device)
74
+ self.AbLang.eval()
75
+ # Always get AbRep from the underlying model
76
+ if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'):
77
+ self.AbRep = self.AbLang.model.AbRep
78
+ else:
79
+ raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.")
80
+ self.ncpu = ncpu
81
+ self.spread = 11 # For compatibility with original utilities
82
+ # The following is no longer needed since all_special_tokens now returns IDs directly
83
+ # self.tokenizer.all_special_token_ids = [
84
+ # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens
85
+ # ]
86
+ # self.tokenizer._all_special_tokens_str = self.tokenizer.all_special_tokens
87
+ # self.tokenizer.all_special_tokens = [
88
+ # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer._all_special_tokens_str
89
+ # ]
90
+
91
+ def freeze(self):
92
+ self.AbLang.eval()
93
+
94
+ def unfreeze(self):
95
+ self.AbLang.train()
96
+
97
+ def _encode_sequences(self, seqs):
98
+ # Use HuggingFace-style padding and return PyTorch tensors
99
+ tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
100
+ tokens = extract_input_ids(tokens, self.used_device)
101
+ return self.AbRep(tokens).last_hidden_states.detach()
102
+
103
+ def _predict_logits(self, seqs):
104
+ tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
105
+ tokens = extract_input_ids(tokens, self.used_device)
106
+ output = self.AbLang(tokens)
107
+ if hasattr(output, 'last_hidden_state'):
108
+ return output.last_hidden_state.detach()
109
+ return output.detach()
110
+
111
+ def _preprocess_labels(self, labels):
112
+ labels = extract_input_ids(labels, self.used_device)
113
+ return labels
114
+
115
+ def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50):
116
+ """
117
+ Use different modes for different usecases, mimicking the original pretrained class.
118
+ """
119
+ from ablang2.pretrained import format_seq_input
120
+
121
+ valid_modes = [
122
+ 'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
123
+ 'pseudo_log_likelihood', 'confidence'
124
+ ]
125
+ if mode not in valid_modes:
126
+ raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.")
127
+
128
+ seqs, chain = format_seq_input(seqs, fragmented=fragmented)
129
+
130
+ if align:
131
+ numbered_seqs, seqs, number_alignment = self.number_sequences(
132
+ seqs, chain=chain, fragmented=fragmented
133
+ )
134
+ else:
135
+ numbered_seqs = None
136
+ number_alignment = None
137
+
138
+ subset_list = []
139
+ for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]:
140
+ subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking))
141
+
142
+ return self.reformat_subsets(
143
+ subset_list,
144
+ mode=mode,
145
+ align=align,
146
+ numbered_seqs=numbered_seqs,
147
+ seqs=seqs,
148
+ number_alignment=number_alignment,
149
+ )
150
+
151
+ def pseudo_log_likelihood(self, seqs, **kwargs):
152
+ """
153
+ Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior.
154
+ """
155
+ # Format input: join VH and VL with '|'
156
+ formatted_seqs = []
157
+ for s in seqs:
158
+ if isinstance(s, (list, tuple)):
159
+ formatted_seqs.append('|'.join(s))
160
+ else:
161
+ formatted_seqs.append(s)
162
+
163
+ # Tokenize all sequences in batch
164
+ labels = self.tokenizer(
165
+ formatted_seqs, padding=True, return_tensors='pt'
166
+ )
167
+ labels = extract_input_ids(labels, self.used_device)
168
+
169
+ # Convert special tokens to IDs
170
+ if isinstance(self.tokenizer.all_special_tokens[0], int):
171
+ special_token_ids = set(self.tokenizer.all_special_tokens)
172
+ else:
173
+ special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
174
+ pad_token_id = self.tokenizer.pad_token_id
175
+
176
+ mask_token_id = getattr(self.tokenizer, 'mask_token_id', None)
177
+ if mask_token_id is None:
178
+ mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
179
+
180
+ plls = []
181
+ with torch.no_grad():
182
+ for i, seq_label in enumerate(labels):
183
+ seq_pll = []
184
+ for j, token_id in enumerate(seq_label):
185
+ if token_id.item() in special_token_ids or token_id.item() == pad_token_id:
186
+ continue
187
+ masked = seq_label.clone()
188
+ masked[j] = mask_token_id
189
+ logits = self.AbLang(masked.unsqueeze(0))
190
+ if hasattr(logits, 'last_hidden_state'):
191
+ logits = logits.last_hidden_state
192
+ logits = logits[0, j]
193
+ nll = torch.nn.functional.cross_entropy(
194
+ logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none"
195
+ )
196
+ seq_pll.append(-nll.item())
197
+ if seq_pll:
198
+ plls.append(np.mean(seq_pll))
199
+ else:
200
+ plls.append(float('nan'))
201
+ return np.array(plls)
202
+
203
+ def confidence(self, seqs, **kwargs):
204
+ """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
205
+ # Format input: join VH and VL with '|'
206
+ formatted_seqs = []
207
+ for s in seqs:
208
+ if isinstance(s, (list, tuple)):
209
+ formatted_seqs.append('|'.join(s))
210
+ else:
211
+ formatted_seqs.append(s)
212
+
213
+ plls = []
214
+ for seq in formatted_seqs:
215
+ tokens = self.tokenizer([seq], padding=True, return_tensors='pt')
216
+ input_ids = extract_input_ids(tokens, self.used_device)
217
+
218
+ with torch.no_grad():
219
+ output = self.AbLang(input_ids)
220
+ if hasattr(output, 'last_hidden_state'):
221
+ logits = output.last_hidden_state
222
+ else:
223
+ logits = output
224
+
225
+ # Get the sequence (remove batch dimension)
226
+ logits = logits[0] # [seq_len, vocab_size]
227
+ input_ids = input_ids[0] # [seq_len]
228
+
229
+ # Exclude all special tokens (pad, mask, etc.)
230
+ if isinstance(self.tokenizer.all_special_tokens[0], int):
231
+ special_token_ids = set(self.tokenizer.all_special_tokens)
232
+ else:
233
+ special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
234
+ valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device))
235
+
236
+ if valid_mask.sum() > 0:
237
+ valid_logits = logits[valid_mask]
238
+ valid_labels = input_ids[valid_mask]
239
+
240
+ # Calculate cross-entropy loss
241
+ nll = torch.nn.functional.cross_entropy(
242
+ valid_logits,
243
+ valid_labels,
244
+ reduction="mean"
245
+ )
246
+ pll = -nll.item()
247
+ else:
248
+ pll = 0.0
249
+
250
+ plls.append(pll)
251
+
252
+ return np.array(plls, dtype=np.float32)
253
+
254
+ def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
255
+ """
256
+ Probability of mutations - applies softmax to logits to get probabilities
257
+ """
258
+ # Format input: join VH and VL with '|'
259
+ formatted_seqs = []
260
+ for s in seqs:
261
+ if isinstance(s, (list, tuple)):
262
+ formatted_seqs.append('|'.join(s))
263
+ else:
264
+ formatted_seqs.append(s)
265
+
266
+ # Get logits
267
+ if stepwise_masking:
268
+ # For stepwise masking, we need to implement it similar to likelihood
269
+ # This is a simplified version - you might want to implement full stepwise masking
270
+ logits = self._predict_logits(formatted_seqs)
271
+ else:
272
+ logits = self._predict_logits(formatted_seqs)
273
+
274
+ # Apply softmax to get probabilities
275
+ probs = logits.softmax(-1).cpu().numpy()
276
+
277
+ if align:
278
+ return probs
279
+ else:
280
+ # Return residue-level probabilities (excluding special tokens)
281
+ return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)]
282
+
283
+ def restore(self, seqs, align=False, **kwargs):
284
+ hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
285
+ restored = hf_abrestore.restore(seqs, align=align)
286
+ # Apply angle brackets formatting
287
+ if isinstance(restored, np.ndarray):
288
+ restored = np.array([add_angle_brackets(seq) for seq in restored])
289
+ else:
290
+ restored = [add_angle_brackets(seq) for seq in restored]
291
+ return restored
292
+
293
+ def extract_input_ids(tokens, device):
294
+ if hasattr(tokens, 'input_ids'):
295
+ return tokens.input_ids.to(device)
296
+ elif isinstance(tokens, dict):
297
+ if 'input_ids' in tokens:
298
+ return tokens['input_ids'].to(device)
299
+ else:
300
+ for v in tokens.values():
301
+ if hasattr(v, 'ndim') or torch.is_tensor(v):
302
+ return v.to(device)
303
+ elif torch.is_tensor(tokens):
304
+ return tokens.to(device)
305
+ else:
306
+ raise ValueError("Could not extract input_ids from tokenizer output")
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import os
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from transformers.utils import cached_file
6
+
7
+ # Load model and tokenizer from Hugging Face Hub
8
+ model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
9
+ tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
10
+
11
+ # Find the cached model directory and import adapter
12
+ adapter_path = cached_file("hemantn/ablang2", "adapter.py")
13
+ cached_model_dir = os.path.dirname(adapter_path)
14
+ sys.path.insert(0, cached_model_dir)
15
+
16
+ # Import and create the adapter
17
+ from adapter import AbLang2PairedHuggingFaceAdapter
18
+ ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
19
+
20
+ def restore_sequences(heavy_chain, light_chain, use_align=False):
21
+ """
22
+ Restore masked residues in antibody sequences.
23
+
24
+ Args:
25
+ heavy_chain (str): Heavy chain sequence with masked residues (*)
26
+ light_chain (str): Light chain sequence with masked residues (*)
27
+ use_align (bool): Whether to use alignment for variable missing lengths
28
+
29
+ Returns:
30
+ tuple: (restored_heavy, restored_light, highlighted_heavy, highlighted_light)
31
+ """
32
+ try:
33
+ # Prepare input sequences
34
+ if heavy_chain.strip() and light_chain.strip():
35
+ # Both chains provided
36
+ sequences = [[heavy_chain.strip(), light_chain.strip()]]
37
+ elif heavy_chain.strip():
38
+ # Only heavy chain provided
39
+ sequences = [[heavy_chain.strip(), ""]]
40
+ elif light_chain.strip():
41
+ # Only light chain provided
42
+ sequences = [["", light_chain.strip()]]
43
+ else:
44
+ return "Please provide at least one antibody chain sequence.", "", "", ""
45
+
46
+ # Perform restoration
47
+ restored = ablang(sequences, mode='restore', align=use_align)
48
+
49
+ # Format output
50
+ if hasattr(restored, '__len__') and len(restored) > 0:
51
+ result = restored[0] # Get the first (and only) result
52
+
53
+ # Parse the result to separate heavy and light chains
54
+ if '>|<' in result:
55
+ # Both chains present
56
+ heavy_part = result.split('>|<')[0].replace('<', '').replace('>', '')
57
+ light_part = result.split('>|<')[1].replace('<', '').replace('>', '')
58
+ elif result.startswith('<') and result.endswith('>'):
59
+ # Only one chain present
60
+ if heavy_chain.strip():
61
+ heavy_part = result.replace('<', '').replace('>', '')
62
+ light_part = ""
63
+ else:
64
+ heavy_part = ""
65
+ light_part = result.replace('<', '').replace('>', '')
66
+ else:
67
+ return "Error: Unexpected result format.", "", "", ""
68
+
69
+ # Create highlighted versions
70
+ highlighted_heavy = highlight_restored_residues(heavy_chain.strip(), heavy_part)
71
+ highlighted_light = highlight_restored_residues(light_chain.strip(), light_part)
72
+
73
+ # Create HTML outputs with proper styling - no scroll, wrap text
74
+ heavy_html = f'<div class="restored-sequence-box" style="padding: 10px; background-color: #f8f9fa; border: 1px solid #dee2e6; border-radius: 4px;">{highlighted_heavy}</div>'
75
+ light_html = f'<div class="restored-sequence-box" style="padding: 10px; background-color: #f8f9fa; border: 1px solid #dee2e6; border-radius: 4px;">{highlighted_light}</div>'
76
+
77
+ return heavy_html, light_html
78
+ else:
79
+ return "Error: No restoration result obtained.", "", ""
80
+
81
+ except Exception as e:
82
+ return f"Error during restoration: {str(e)}", "", ""
83
+
84
+ def highlight_restored_residues(original_seq, restored_seq):
85
+ """
86
+ Highlight restored residues in green.
87
+ """
88
+ if not original_seq or not restored_seq:
89
+ return restored_seq
90
+
91
+ highlighted = ""
92
+ for i, (orig_char, rest_char) in enumerate(zip(original_seq, restored_seq)):
93
+ if orig_char == '*' and rest_char != '*':
94
+ # This residue was restored
95
+ highlighted += f'<span class="restored-highlight">{rest_char}</span>'
96
+ else:
97
+ highlighted += rest_char
98
+
99
+ # Add any remaining characters from restored sequence
100
+ if len(restored_seq) > len(original_seq):
101
+ highlighted += restored_seq[len(original_seq):]
102
+
103
+ return highlighted
104
+
105
+ # Create Gradio interface
106
+ with gr.Blocks(title="AbLang2 Sequence Restorer", theme=gr.themes.Soft(), css="""
107
+ * {
108
+ font-family: 'Courier New', monospace !important;
109
+ }
110
+ .sequence-input, .sequence-output {
111
+ font-family: 'Courier New', monospace !important;
112
+ font-size: 14px !important;
113
+ letter-spacing: 0.5px !important;
114
+ }
115
+ .restored-highlight {
116
+ background-color: #90EE90 !important;
117
+ color: #000 !important;
118
+ font-weight: bold !important;
119
+ }
120
+ .examples {
121
+ font-family: 'Courier New', monospace !important;
122
+ font-size: 14px !important;
123
+ letter-spacing: 0.5px !important;
124
+ }
125
+ .restored-sequence-box {
126
+ font-family: 'Courier New', monospace !important;
127
+ font-size: 14px !important;
128
+ letter-spacing: 0.5px !important;
129
+ white-space: pre-wrap !important;
130
+ word-wrap: break-word !important;
131
+ overflow-wrap: break-word !important;
132
+ }
133
+ .restored-heading {
134
+ color: #2E8B57 !important;
135
+ font-weight: bold !important;
136
+ font-size: 18px !important;
137
+ }
138
+ .example-text {
139
+ font-family: 'Courier New', monospace !important;
140
+ font-size: 12px !important;
141
+ white-space: pre-wrap !important;
142
+ word-wrap: break-word !important;
143
+ }
144
+ .examples-table {
145
+ font-family: 'Courier New', monospace !important;
146
+ font-size: 12px !important;
147
+ white-space: pre-wrap !important;
148
+ word-wrap: break-word !important;
149
+ max-width: none !important;
150
+ overflow: visible !important;
151
+ }
152
+ .examples-table td {
153
+ font-family: 'Courier New', monospace !important;
154
+ font-size: 12px !important;
155
+ white-space: pre-wrap !important;
156
+ word-wrap: break-word !important;
157
+ max-width: none !important;
158
+ overflow: visible !important;
159
+ text-overflow: unset !important;
160
+ }
161
+ .sequence-output label {
162
+ font-weight: bold !important;
163
+ color: #495057 !important;
164
+ font-size: 14px !important;
165
+ margin-bottom: 5px !important;
166
+ }
167
+ /* Force full display of examples */
168
+ .examples-container {
169
+ font-family: 'Courier New', monospace !important;
170
+ font-size: 12px !important;
171
+ }
172
+ .examples-container table {
173
+ width: 100% !important;
174
+ table-layout: auto !important;
175
+ }
176
+ .examples-container td {
177
+ white-space: pre-wrap !important;
178
+ word-wrap: break-word !important;
179
+ overflow-wrap: break-word !important;
180
+ max-width: none !important;
181
+ text-overflow: unset !important;
182
+ padding: 8px !important;
183
+ vertical-align: top !important;
184
+ }
185
+ .examples-container th {
186
+ white-space: nowrap !important;
187
+ padding: 8px !important;
188
+ }
189
+ /* Override any Gradio default truncation */
190
+ .examples table td {
191
+ white-space: pre-wrap !important;
192
+ word-wrap: break-word !important;
193
+ overflow-wrap: break-word !important;
194
+ max-width: none !important;
195
+ text-overflow: unset !important;
196
+ overflow: visible !important;
197
+ font-family: 'Courier New', monospace !important;
198
+ font-size: 12px !important;
199
+ }
200
+ .examples table {
201
+ table-layout: auto !important;
202
+ width: 100% !important;
203
+ }
204
+ /* Target the specific examples component */
205
+ div[data-testid="examples"] table td {
206
+ white-space: pre-wrap !important;
207
+ word-wrap: break-word !important;
208
+ overflow-wrap: break-word !important;
209
+ max-width: none !important;
210
+ text-overflow: unset !important;
211
+ overflow: visible !important;
212
+ font-family: 'Courier New', monospace !important;
213
+ font-size: 12px !important;
214
+ }
215
+ /* Force examples to show full content */
216
+ .examples table, .examples table td, .examples table th {
217
+ white-space: pre-wrap !important;
218
+ word-wrap: break-word !important;
219
+ overflow-wrap: break-word !important;
220
+ max-width: none !important;
221
+ text-overflow: unset !important;
222
+ overflow: visible !important;
223
+ font-family: 'Courier New', monospace !important;
224
+ font-size: 12px !important;
225
+ table-layout: auto !important;
226
+ width: auto !important;
227
+ min-width: 100% !important;
228
+ }
229
+ /* Override any inline styles */
230
+ .examples * {
231
+ white-space: pre-wrap !important;
232
+ word-wrap: break-word !important;
233
+ overflow-wrap: break-word !important;
234
+ max-width: none !important;
235
+ text-overflow: unset !important;
236
+ overflow: visible !important;
237
+ }
238
+ /* Style output labels to match input labels exactly */
239
+ .output-label {
240
+ font-weight: 600 !important;
241
+ color: var(--label-text-color) !important;
242
+ font-size: 14px !important;
243
+ margin-bottom: 8px !important;
244
+ margin-top: 16px !important;
245
+ line-height: 1.4 !important;
246
+ display: block !important;
247
+ }
248
+ """) as demo:
249
+ gr.Markdown("""
250
+ # 🧬 AbLang2 Sequence Restorer
251
+
252
+ This app uses the AbLang2 model to restore masked residues (*) in antibody sequences.
253
+ You can provide either one or both heavy and light chain sequences.
254
+
255
+ **Instructions:**
256
+ - Use `*` to mask residues you want to restore
257
+ - Provide heavy chain, light chain, or both
258
+ - Enable "Use Alignment" for variable missing lengths
259
+ """)
260
+
261
+ with gr.Row():
262
+ with gr.Column():
263
+ heavy_input = gr.Textbox(
264
+ label="Heavy Chain Sequence",
265
+ placeholder="Enter heavy chain sequence with masked residues (*)...",
266
+ lines=3,
267
+ max_lines=5,
268
+ elem_classes=["sequence-input"]
269
+ )
270
+
271
+ light_input = gr.Textbox(
272
+ label="Light Chain Sequence",
273
+ placeholder="Enter light chain sequence with masked residues (*)...",
274
+ lines=3,
275
+ max_lines=5,
276
+ elem_classes=["sequence-input"]
277
+ )
278
+
279
+ align_checkbox = gr.Checkbox(
280
+ label="Use Alignment (for variable missing lengths)",
281
+ value=False
282
+ )
283
+
284
+ restore_btn = gr.Button("🔄 Restore Sequences", variant="primary")
285
+
286
+ with gr.Column():
287
+ gr.Markdown("### 🧬 Restored Sequences", elem_classes=["restored-heading"])
288
+ gr.Markdown("*Green highlighting shows restored residues*")
289
+
290
+ gr.Markdown("**Heavy Chain Sequence**", elem_classes=["output-label"])
291
+ heavy_output = gr.HTML(label="")
292
+
293
+ gr.Markdown("**Light Chain Sequence**", elem_classes=["output-label"])
294
+ light_output = gr.HTML(label="")
295
+
296
+ # Example sequences
297
+ gr.Examples(
298
+ examples=[
299
+ [
300
+ "EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS",
301
+ "DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK"
302
+ ],
303
+ [
304
+ "EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMGWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARDY**GMDVWGQGTTVTVSS",
305
+ ""
306
+ ],
307
+ [
308
+ "",
309
+ "DIQLTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIY*ASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTP*TFGQGTKVEIK"
310
+ ]
311
+ ],
312
+ inputs=[heavy_input, light_input],
313
+ label="Example Sequences"
314
+ )
315
+
316
+ # Connect the button to the function
317
+ restore_btn.click(
318
+ fn=restore_sequences,
319
+ inputs=[heavy_input, light_input, align_checkbox],
320
+ outputs=[heavy_output, light_output]
321
+ )
322
+
323
+ gr.Markdown("""
324
+ ---
325
+ **Note:** This app uses the AbLang2 model from Hugging Face Hub.
326
+ The restoration process may take a few seconds depending on sequence length and complexity.
327
+ """)
328
+
329
+ if __name__ == "__main__":
330
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ numpy>=1.21.0
5
+ pandas>=1.3.0
6
+ anarci>=1.3