File size: 13,154 Bytes
0e2f128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
from ablang2.pretrained_utils.restoration import AbRestore
from ablang2.pretrained_utils.encodings import AbEncoding
from ablang2.pretrained_utils.alignment import AbAlignment
from ablang2.pretrained_utils.scores import AbScores
import torch
import numpy as np
from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list

class HuggingFaceTokenizerAdapter:
    def __init__(self, tokenizer, device):
        self.tokenizer = tokenizer
        self.device = device
        self.pad_token_id = tokenizer.pad_token_id
        self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
        self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.all_special_tokens = tokenizer.all_special_tokens

    def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None):
        tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
        input_ids = tokens['input_ids'].to(self.device if device is None else device)
        if mode == 'decode':
            # seqs is a tensor of token ids
            if isinstance(seqs, torch.Tensor):
                seqs = seqs.cpu().numpy()
            decoded = []
            for i, seq in enumerate(seqs):
                chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != '']
                # Use res_to_seq for formatting, pass (sequence, length) tuple as in original code
                # The length is not always available, so use len(chars) as fallback
                formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore')
                decoded.append(formatted)
            return decoded
        return input_ids

class HFAbRestore(AbRestore):
    def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1):
        super().__init__(spread=spread, device=device, ncpu=ncpu)
        self.used_device = device
        self._hf_model = hf_model
        self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device)

    @property
    def AbLang(self):
        def model_call(x):
            output = self._hf_model(x)
            if hasattr(output, 'last_hidden_state'):
                return output.last_hidden_state
            return output
        return model_call

def add_angle_brackets(seq):
    # Assumes input is 'VH|VL' or 'VH|' or '|VL'
    if '|' in seq:
        vh, vl = seq.split('|', 1)
    else:
        vh, vl = seq, ''
    return f"<{vh}>|<{vl}>"

class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores):
    """
    Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer.
    Automatically uses CUDA if available, otherwise CPU.
    """
    def __init__(self, model, tokenizer, device=None, ncpu=1):
        super().__init__()
        if device is None:
            self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.used_device = torch.device(device)
        self.AbLang = model  # HuggingFace model instance
        self.tokenizer = tokenizer
        self.AbLang.to(self.used_device)
        self.AbLang.eval()
        # Always get AbRep from the underlying model
        if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'):
            self.AbRep = self.AbLang.model.AbRep
        else:
            raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.")
        self.ncpu = ncpu
        self.spread = 11  # For compatibility with original utilities
        # The following is no longer needed since all_special_tokens now returns IDs directly
        # self.tokenizer.all_special_token_ids = [
        #     self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens
        # ]
        # self.tokenizer._all_special_tokens_str = self.tokenizer.all_special_tokens
        # self.tokenizer.all_special_tokens = [
        #     self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer._all_special_tokens_str
        # ]

    def freeze(self):
        self.AbLang.eval()

    def unfreeze(self):
        self.AbLang.train()

    def _encode_sequences(self, seqs):
        # Use HuggingFace-style padding and return PyTorch tensors
        tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
        tokens = extract_input_ids(tokens, self.used_device)
        return self.AbRep(tokens).last_hidden_states.detach()

    def _predict_logits(self, seqs):
        tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
        tokens = extract_input_ids(tokens, self.used_device)
        output = self.AbLang(tokens)
        if hasattr(output, 'last_hidden_state'):
            return output.last_hidden_state.detach()
        return output.detach()

    def _preprocess_labels(self, labels):
        labels = extract_input_ids(labels, self.used_device)
        return labels

    def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50):
        """
        Use different modes for different usecases, mimicking the original pretrained class.
        """
        from ablang2.pretrained import format_seq_input

        valid_modes = [
            'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
            'pseudo_log_likelihood', 'confidence'
        ]
        if mode not in valid_modes:
            raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.")

        seqs, chain = format_seq_input(seqs, fragmented=fragmented)

        if align:
            numbered_seqs, seqs, number_alignment = self.number_sequences(
                seqs, chain=chain, fragmented=fragmented
            )
        else:
            numbered_seqs = None
            number_alignment = None

        subset_list = []
        for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]:
            subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking))

        return self.reformat_subsets(
            subset_list,
            mode=mode,
            align=align,
            numbered_seqs=numbered_seqs,
            seqs=seqs,
            number_alignment=number_alignment,
        )

    def pseudo_log_likelihood(self, seqs, **kwargs):
        """
        Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior.
        """
        # Format input: join VH and VL with '|'
        formatted_seqs = []
        for s in seqs:
            if isinstance(s, (list, tuple)):
                formatted_seqs.append('|'.join(s))
            else:
                formatted_seqs.append(s)

        # Tokenize all sequences in batch
        labels = self.tokenizer(
            formatted_seqs, padding=True, return_tensors='pt'
        )
        labels = extract_input_ids(labels, self.used_device)

        # Convert special tokens to IDs
        if isinstance(self.tokenizer.all_special_tokens[0], int):
            special_token_ids = set(self.tokenizer.all_special_tokens)
        else:
            special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
        pad_token_id = self.tokenizer.pad_token_id

        mask_token_id = getattr(self.tokenizer, 'mask_token_id', None)
        if mask_token_id is None:
            mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        plls = []
        with torch.no_grad():
            for i, seq_label in enumerate(labels):
                seq_pll = []
                for j, token_id in enumerate(seq_label):
                    if token_id.item() in special_token_ids or token_id.item() == pad_token_id:
                        continue
                    masked = seq_label.clone()
                    masked[j] = mask_token_id
                    logits = self.AbLang(masked.unsqueeze(0))
                    if hasattr(logits, 'last_hidden_state'):
                        logits = logits.last_hidden_state
                    logits = logits[0, j]
                    nll = torch.nn.functional.cross_entropy(
                        logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none"
                    )
                    seq_pll.append(-nll.item())
                if seq_pll:
                    plls.append(np.mean(seq_pll))
                else:
                    plls.append(float('nan'))
        return np.array(plls)

    def confidence(self, seqs, **kwargs):
        """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
        # Format input: join VH and VL with '|'
        formatted_seqs = []
        for s in seqs:
            if isinstance(s, (list, tuple)):
                formatted_seqs.append('|'.join(s))
            else:
                formatted_seqs.append(s)
        
        plls = []
        for seq in formatted_seqs:
            tokens = self.tokenizer([seq], padding=True, return_tensors='pt')
            input_ids = extract_input_ids(tokens, self.used_device)
            
            with torch.no_grad():
                output = self.AbLang(input_ids)
                if hasattr(output, 'last_hidden_state'):
                    logits = output.last_hidden_state
                else:
                    logits = output
                
                # Get the sequence (remove batch dimension)
                logits = logits[0]  # [seq_len, vocab_size]
                input_ids = input_ids[0]  # [seq_len]
                
                # Exclude all special tokens (pad, mask, etc.)
                if isinstance(self.tokenizer.all_special_tokens[0], int):
                    special_token_ids = set(self.tokenizer.all_special_tokens)
                else:
                    special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
                valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device))
                
                if valid_mask.sum() > 0:
                    valid_logits = logits[valid_mask]
                    valid_labels = input_ids[valid_mask]
                    
                    # Calculate cross-entropy loss
                    nll = torch.nn.functional.cross_entropy(
                        valid_logits,
                        valid_labels,
                        reduction="mean"
                    )
                    pll = -nll.item()
                else:
                    pll = 0.0
                
                plls.append(pll)
        
        return np.array(plls, dtype=np.float32)

    def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
        """
        Probability of mutations - applies softmax to logits to get probabilities
        """
        # Format input: join VH and VL with '|'
        formatted_seqs = []
        for s in seqs:
            if isinstance(s, (list, tuple)):
                formatted_seqs.append('|'.join(s))
            else:
                formatted_seqs.append(s)

        # Get logits
        if stepwise_masking:
            # For stepwise masking, we need to implement it similar to likelihood
            # This is a simplified version - you might want to implement full stepwise masking
            logits = self._predict_logits(formatted_seqs)
        else:
            logits = self._predict_logits(formatted_seqs)
        
        # Apply softmax to get probabilities
        probs = logits.softmax(-1).cpu().numpy()
        
        if align:
            return probs
        else:
            # Return residue-level probabilities (excluding special tokens)
            return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)]

    def restore(self, seqs, align=False, **kwargs):
        hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
        restored = hf_abrestore.restore(seqs, align=align)
        # Apply angle brackets formatting
        if isinstance(restored, np.ndarray):
            restored = np.array([add_angle_brackets(seq) for seq in restored])
        else:
            restored = [add_angle_brackets(seq) for seq in restored]
        return restored

def extract_input_ids(tokens, device):
    if hasattr(tokens, 'input_ids'):
        return tokens.input_ids.to(device)
    elif isinstance(tokens, dict):
        if 'input_ids' in tokens:
            return tokens['input_ids'].to(device)
        else:
            for v in tokens.values():
                if hasattr(v, 'ndim') or torch.is_tensor(v):
                    return v.to(device)
    elif torch.is_tensor(tokens):
        return tokens.to(device)
    else:
        raise ValueError("Could not extract input_ids from tokenizer output")