GenerTeam commited on
Commit
7b5eb9e
·
verified ·
1 Parent(s): da5fa93

Update modeling_generator.py

Browse files
Files changed (1) hide show
  1. modeling_generator.py +160 -355
modeling_generator.py CHANGED
@@ -1,431 +1,236 @@
1
  """
2
- GENERator with bp_probs generation support.
3
 
4
- generate_bp() reuses the full HF generate() pipeline (parameter preparation,
5
- cache management, stopping criteria, logits processing, etc.) and only replaces
6
- the token selection step with bp-level independent base selection.
7
  """
8
- import os
9
- from typing import Optional, Union
10
-
11
  import torch
12
- import torch.nn as nn
13
  import torch.nn.functional as F
14
- from transformers import LlamaForCausalLM
 
15
 
16
  BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
17
  IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class GENERatorForCausalLM(LlamaForCausalLM):
21
  """LlamaForCausalLM with bp-level autoregressive generation.
22
 
23
  Inherits all standard functionality (forward, generate, etc.)
24
  and adds generate_bp() for base-pair independent generation.
 
 
25
  """
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def setup_tokenizer(self, tokenizer):
28
- """Cache tokenizer and precompute lookup tables for bp generation."""
29
  self.tokenizer = tokenizer
30
  k = tokenizer.k
31
  self.k = k
32
- num_special = len(tokenizer.special_tokens)
33
- num_kmers = 4 ** k
34
 
35
- self._num_special = num_special
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # _bp_base_index[pos, m] = kmer m 在位置 pos 的碱基编号 (0-3)
38
  bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
39
- for j in range(k):
40
- bp_base_index[j] = torch.arange(num_kmers) >> ((k - 1 - j) * 2) & 3
41
- device = next(self.parameters()).device
42
  self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
43
 
44
- # base_indices -> token_id 的查表张量
45
- # flat index = sum(base[i] * 4^(k-1-i))
46
- self._bp_powers = torch.tensor(
47
  [4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
48
  )
 
 
 
49
  flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
50
- for kmer, tid in tokenizer.vocab.items():
51
- if kmer in tokenizer.special_tokens:
52
- continue
53
- idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer[:k]))
54
- flat_to_tid[idx] = tid
55
  self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
56
 
57
  def compute_bp_probs(self, logits):
58
- """Compute per-base marginal probabilities from token logits (vectorized).
59
 
60
  Args:
61
- logits: [B, V] or [B, L, V] token logits
62
  Returns:
63
  bp_probs: [B, k, 4] or [B, L, k, 4]
64
  """
65
- squeeze = False
66
- if logits.dim() == 2:
67
- logits = logits.unsqueeze(1) # [B, 1, V]
68
- squeeze = True
69
 
70
- kmer_logits = logits[:, :, self._num_special:] # [B, L, num_kmers]
71
  kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
72
  B, L, _ = kmer_probs.shape
73
  bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
74
  for pos in range(self.k):
75
- idx = self._bp_base_index[pos] # [num_kmers] -> 0~3
76
  for nt in range(4):
77
  bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
78
 
79
- if squeeze:
80
- bp_probs = bp_probs.squeeze(1) # [B, k, 4]
81
- return bp_probs
82
-
83
- # -------------------------------------------------------------------------
84
- # generate_bp: sets a flag then delegates to the standard generate()
85
- # -------------------------------------------------------------------------
86
- @torch.no_grad()
87
- def generate_bp(self, inputs=None, generation_config=None, **kwargs):
88
- """Same interface as generate(), but with bp-level independent base selection.
89
 
90
- Token logits are marginalized to per-base probabilities [k, 4], and each
91
- base position is selected independently. All standard generate() parameters
92
- (temperature, top_k, top_p, do_sample, attention_mask, etc.) are fully
93
- supported — they are processed by the HF generate pipeline as usual.
94
 
95
- Returns:
96
- Same as generate() token ids tensor or GenerateOutput.
 
97
  """
98
- assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer() first"
99
- self._bp_generation = True
100
- try:
101
- return super().generate(
102
- inputs=inputs, generation_config=generation_config, **kwargs
103
- )
104
- finally:
105
- self._bp_generation = False
106
-
107
- # -------------------------------------------------------------------------
108
- # Override _sample: when _bp_generation is set, use bp-level token selection
109
- # -------------------------------------------------------------------------
110
- def _sample(
111
- self,
112
- input_ids,
113
- logits_processor,
114
- stopping_criteria,
115
- generation_config,
116
- synced_gpus,
117
- streamer,
118
- **model_kwargs,
119
- ):
120
- if not getattr(self, "_bp_generation", False):
121
- return super()._sample(
122
- input_ids,
123
- logits_processor,
124
- stopping_criteria,
125
- generation_config,
126
- synced_gpus,
127
- streamer,
128
- **model_kwargs,
129
- )
130
-
131
- # ==================================================================
132
- # BP generation mode — copied from transformers 4.56.0 _sample(),
133
- # with ONLY the token selection block replaced by bp marginalization.
134
- # ==================================================================
135
- from transformers.generation.utils import (
136
- GenerateDecoderOnlyOutput,
137
- )
138
-
139
- # init values
140
- pad_token_id = generation_config._pad_token_tensor
141
- output_attentions = generation_config.output_attentions
142
- output_hidden_states = generation_config.output_hidden_states
143
- output_scores = generation_config.output_scores
144
- output_logits = generation_config.output_logits
145
- return_dict_in_generate = generation_config.return_dict_in_generate
146
- has_eos_stopping_criteria = any(
147
- hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
148
  )
149
- do_sample = generation_config.do_sample
 
150
 
151
- # init attention / hidden states / scores tuples
152
- scores = () if (return_dict_in_generate and output_scores) else None
153
- raw_logits = () if (return_dict_in_generate and output_logits) else None
154
- decoder_attentions = (
155
- () if (return_dict_in_generate and output_attentions) else None
156
- )
157
- decoder_hidden_states = (
158
- () if (return_dict_in_generate and output_hidden_states) else None
159
- )
160
-
161
- # keep track of which sequences are already finished
162
- batch_size, cur_len = input_ids.shape[:2]
163
- this_peer_finished = False
164
- unfinished_sequences = torch.ones(
165
- batch_size, dtype=torch.long, device=input_ids.device
166
- )
167
- model_kwargs = self._get_initial_cache_position(
168
- cur_len, input_ids.device, model_kwargs
169
- )
170
-
171
- model_forward = self.__call__
172
- compile_forward = self._valid_auto_compile_criteria(
173
- model_kwargs, generation_config
174
- )
175
- if compile_forward:
176
- os.environ["TOKENIZERS_PARALLELISM"] = "0"
177
- if self.config._attn_implementation == "flash_attention_2":
178
- if (
179
- generation_config.compile_config is not None
180
- and generation_config.compile_config.fullgraph
181
- ):
182
- generation_config.compile_config.fullgraph = False
183
- model_forward = self.get_compiled_call(generation_config.compile_config)
184
-
185
- if generation_config.prefill_chunk_size is not None:
186
- model_kwargs = self._prefill_chunking(
187
- input_ids, generation_config, **model_kwargs
188
- )
189
- is_prefill = False
190
- else:
191
- is_prefill = True
192
-
193
- while self._has_unfinished_sequences(
194
- this_peer_finished, synced_gpus, device=input_ids.device
195
- ):
196
- # prepare model inputs
197
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
198
-
199
- # prepare variable output controls
200
- model_inputs.update(
201
- {"output_attentions": output_attentions} if output_attentions else {}
202
- )
203
- model_inputs.update(
204
- {"output_hidden_states": output_hidden_states}
205
- if output_hidden_states
206
- else {}
207
- )
208
-
209
- if is_prefill:
210
- outputs = self(**model_inputs, return_dict=True)
211
- is_prefill = False
212
- else:
213
- outputs = model_forward(**model_inputs, return_dict=True)
214
-
215
- # update model kwargs for next step (handles cache, attention_mask, etc.)
216
- model_kwargs = self._update_model_kwargs_for_generation(
217
- outputs,
218
- model_kwargs,
219
- is_encoder_decoder=self.config.is_encoder_decoder,
220
- )
221
- if synced_gpus and this_peer_finished:
222
- continue
223
-
224
- next_token_logits = outputs.logits[:, -1, :].to(
225
- copy=True, dtype=torch.float32, device=input_ids.device
226
- )
227
-
228
- # pre-process distribution (temperature, top_k, top_p, repetition_penalty, etc.)
229
- next_token_scores = logits_processor(input_ids, next_token_logits)
230
-
231
- # Store scores, attentions and hidden_states when required
232
- if return_dict_in_generate:
233
- if output_scores:
234
- scores += (next_token_scores,)
235
- if output_logits:
236
- raw_logits += (next_token_logits,)
237
- if output_attentions:
238
- decoder_attentions += ((outputs.attentions,),)
239
- if output_hidden_states:
240
- decoder_hidden_states += ((outputs.hidden_states,),)
241
-
242
- # =============================================================
243
- # BP-LEVEL TOKEN SELECTION (vectorized, the ONLY change)
244
- # =============================================================
245
- # [B, V] -> [B, k, 4] marginal bp probabilities
246
- bp_probs = self.compute_bp_probs(next_token_scores) # [B, k, 4]
247
-
248
- if do_sample:
249
- # [B*k, 4] -> multinomial -> [B, k]
250
- base_indices = torch.multinomial(
251
- bp_probs.view(-1, 4), 1
252
- ).view(batch_size, self.k)
253
- else:
254
- base_indices = bp_probs.argmax(dim=-1) # [B, k]
255
-
256
- # base_indices [B, k] -> flat kmer index -> token_id [B]
257
- flat_idx = (base_indices * self._bp_powers).sum(dim=-1) # [B]
258
- next_tokens = self._flat_idx_to_token_id[flat_idx] # [B]
259
- # =============================================================
260
-
261
- # finished sentences should have their next token be a padding token
262
- if has_eos_stopping_criteria:
263
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
264
- 1 - unfinished_sequences
265
- )
266
-
267
- # update generated ids, model inputs, and length for next step
268
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
269
- if streamer is not None:
270
- streamer.put(next_tokens.cpu())
271
-
272
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(
273
- input_ids, scores
274
- )
275
- this_peer_finished = unfinished_sequences.max() == 0
276
- cur_len += 1
277
-
278
- del outputs
279
-
280
- if streamer is not None:
281
- streamer.end()
282
-
283
- if return_dict_in_generate:
284
- return GenerateDecoderOnlyOutput(
285
- sequences=input_ids,
286
- scores=scores,
287
- logits=raw_logits,
288
- attentions=decoder_attentions,
289
- hidden_states=decoder_hidden_states,
290
- past_key_values=model_kwargs.get("past_key_values"),
291
- )
292
- else:
293
- return input_ids
294
 
295
  @torch.no_grad()
296
- def score_sequence(self, sequences: Union[str, list[str]]):
297
- """Score DNA sequence(s) and return per-base conditional probabilities.
298
 
299
- Each sequence is manually prepended with BOS token ("<dna>") and padded
300
- with 'A' if length is not a multiple of k. Returns probabilities for the
301
- original sequences only (excluding padding).
302
 
303
  Args:
304
- sequences: Single DNA sequence string or list of sequences
305
 
306
  Returns:
307
- Tuple of (bp_probs, actual_probs):
308
- - bp_probs: Full probability distribution
309
- * Single sequence: [seq_len, 4] tensor
310
- * Batch: list of [seq_len_i, 4] tensors
311
- - actual_probs: Probability of the actual base at each position
312
- * Single sequence: [seq_len] tensor
313
- * Batch: list of [seq_len_i] tensors
314
-
315
- bp_probs[i, j] = P(base at position i is nucleotide j | context)
316
- actual_probs[i] = P(actual base at position i | context)
317
- where j: 0=A, 1=T, 2=C, 3=G
318
-
319
- Example:
320
- # Single sequence
321
- bp_probs, actual_probs = model.score_sequence("ACGT")
322
-
323
- # Batch of sequences
324
- bp_probs_list, actual_probs_list = model.score_sequence([
325
- "ACGT" * 150,
326
- "ACGT" * 149 + "AC",
327
- ])
328
  """
329
- assert hasattr(self, "tokenizer"), "Call setup_tokenizer() first"
330
 
331
- # Handle single sequence case
332
  is_single = isinstance(sequences, str)
333
  if is_single:
334
  sequences = [sequences]
335
 
336
- # Store original info
337
- original_lens = [len(seq) for seq in sequences]
338
- original_sequences = sequences.copy()
339
 
340
- # Pad each sequence to multiple of k with 'A'
341
- padded_sequences = []
342
- for seq in sequences:
343
- if len(seq) % self.k != 0:
344
- padding_len = self.k - (len(seq) % self.k)
345
- seq = seq + 'A' * padding_len
346
- padded_sequences.append(seq)
347
 
348
- # Manually prepend BOS token "<dna>" to each sequence
349
- sequences_with_bos = ["<s>" + seq for seq in padded_sequences]
350
 
351
- # Tokenize batch (without add_special_tokens since we added manually)
352
  inputs = self.tokenizer(
353
- sequences_with_bos,
354
- return_tensors="pt",
355
- padding=True,
356
- add_special_tokens=False
357
  )
358
  input_ids = inputs["input_ids"].to(self.device)
359
  attention_mask = inputs["attention_mask"].to(self.device)
360
 
361
- # Forward pass to get logits for all positions
362
- outputs = self(input_ids, attention_mask=attention_mask, return_dict=True)
363
- logits = outputs.logits # [B, max_seq_len, vocab_size]
364
-
365
- # Compute bp probabilities for all token positions
366
- bp_probs = self.compute_bp_probs(logits) # [B, max_seq_len, k, 4]
367
-
368
- # Process each sequence in the batch
369
- bp_probs_results = []
370
- actual_probs_results = []
371
 
372
- for i, (original_seq, original_len, padded_seq) in enumerate(
373
- zip(original_sequences, original_lens, padded_sequences)
374
- ):
375
- # Calculate number of actual sequence tokens (excluding BOS)
376
- num_seq_tokens = len(padded_seq) // self.k
 
 
 
 
377
 
378
- # Extract bp_probs for this sequence
379
- # logits[0] predicts token after BOS (first sequence token)
380
- # logits[i] predicts token[i+1]
381
- # So logits[0:num_seq_tokens] predict the sequence tokens
382
- seq_bp_probs = bp_probs[i, :num_seq_tokens] # [num_seq_tokens, k, 4]
383
-
384
- # Reshape: [num_seq_tokens, k, 4] -> [num_seq_tokens * k, 4]
385
- seq_result = seq_bp_probs.reshape(-1, 4)
386
-
387
- # Trim to original sequence length (remove padding)
388
- seq_result = seq_result[:original_len]
389
-
390
- # Extract actual base probabilities
391
- actual_probs = self._extract_actual_probs(seq_result, original_seq)
392
-
393
- bp_probs_results.append(seq_result)
394
- actual_probs_results.append(actual_probs)
395
-
396
- # Return single tensors if input was single sequence
397
  if is_single:
398
- return bp_probs_results[0], actual_probs_results[0]
399
-
400
- return bp_probs_results, actual_probs_results
401
-
402
- def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str):
403
- """Extract probabilities of actual bases in the sequence.
404
-
405
- For each position i in the sequence, returns the probability that the model
406
- assigned to the actual base at that position.
407
-
408
- For 'N' bases (unknown), returns the maximum probability across all 4 bases.
409
-
410
- Args:
411
- bp_probs: [seq_len, 4] probability distribution from logits
412
- bp_probs[i] = P(position i | context before i)
413
- sequence: DNA sequence string (may contain 'N')
414
-
415
- Returns:
416
- actual_probs: [seq_len] probabilities of actual bases
417
- actual_probs[i] = bp_probs[i, sequence[i]] for A/T/C/G
418
- actual_probs[i] = max(bp_probs[i]) for 'N'
419
- """
420
- seq_len = len(sequence)
421
- actual_probs = torch.zeros(seq_len, device=bp_probs.device, dtype=bp_probs.dtype)
422
 
 
 
423
  for i, base in enumerate(sequence):
424
- if base == 'N':
425
- # For N, take the maximum probability across all 4 bases
426
- actual_probs[i] = bp_probs[i].max()
427
- else:
428
- base_idx = BASE_TO_IDX[base]
429
- actual_probs[i] = bp_probs[i, base_idx]
430
-
431
- return actual_probs
 
1
  """
2
+ GENERator with bp-level generation and scoring.
3
 
4
+ generate_bp() plugs into the standard HF generate() pipeline via a
5
+ LogitsProcessor no internal methods are overridden, so it is compatible
6
+ with any transformers version.
7
  """
 
 
 
8
  import torch
 
9
  import torch.nn.functional as F
10
+ from transformers import LlamaForCausalLM, LogitsProcessor, LogitsProcessorList
11
+ from typing import Union
12
 
13
  BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
14
  IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
15
 
16
 
17
+ class _BPLogitsProcessor(LogitsProcessor):
18
+ """Forces token selection to use per-base marginal probabilities.
19
+
20
+ Runs LAST in the logits-processor chain so that temperature / top-k /
21
+ top-p etc. influence the marginal distributions before base selection.
22
+ """
23
+
24
+ def __init__(self, kmer_ids, bp_base_index, flat_idx_to_token_id, bp_powers, k, do_sample):
25
+ self.kmer_ids = kmer_ids
26
+ self.bp_base_index = bp_base_index
27
+ self.flat_idx_to_token_id = flat_idx_to_token_id
28
+ self.bp_powers = bp_powers
29
+ self.k = k
30
+ self.do_sample = do_sample
31
+
32
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
33
+ B = scores.shape[0]
34
+ kmer_probs = F.softmax(scores[:, self.kmer_ids].float(), dim=-1) # [B, num_kmers]
35
+
36
+ # Marginalise to per-base probabilities [B, k, 4]
37
+ bp_probs = torch.zeros(B, self.k, 4, device=scores.device, dtype=kmer_probs.dtype)
38
+ for pos in range(self.k):
39
+ idx = self.bp_base_index[pos] # [num_kmers] in {0,1,2,3}
40
+ for nt in range(4):
41
+ bp_probs[:, pos, nt] = kmer_probs[:, idx == nt].sum(dim=-1)
42
+
43
+ if self.do_sample:
44
+ base_indices = torch.multinomial(bp_probs.view(-1, 4), 1).view(B, self.k)
45
+ else:
46
+ base_indices = bp_probs.argmax(dim=-1) # [B, k]
47
+
48
+ flat_idx = (base_indices * self.bp_powers).sum(dim=-1) # [B]
49
+ selected = self.flat_idx_to_token_id[flat_idx] # [B]
50
+
51
+ # One-hot: both argmax and multinomial land on the bp-selected token
52
+ new_scores = torch.full_like(scores, float("-inf"))
53
+ new_scores.scatter_(1, selected.unsqueeze(1), 0.0)
54
+ return new_scores
55
+
56
+
57
  class GENERatorForCausalLM(LlamaForCausalLM):
58
  """LlamaForCausalLM with bp-level autoregressive generation.
59
 
60
  Inherits all standard functionality (forward, generate, etc.)
61
  and adds generate_bp() for base-pair independent generation.
62
+
63
+ The tokenizer is automatically set up when loading the model with from_pretrained().
64
  """
65
 
66
+ @classmethod
67
+ def from_pretrained(cls, *args, **kwargs):
68
+ """Load model and automatically setup tokenizer if available."""
69
+ model = super().from_pretrained(*args, **kwargs)
70
+
71
+ model_path = args[0] if len(args) > 0 else kwargs.get('pretrained_model_name_or_path')
72
+
73
+ if model_path:
74
+ try:
75
+ from transformers import AutoTokenizer
76
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
77
+ model.setup_tokenizer(tokenizer)
78
+ print(f"Tokenizer automatically loaded and configured for bp-level scoring")
79
+ except Exception as e:
80
+ print(f"Could not auto-load tokenizer: {e}")
81
+ print(f" Call model.setup_tokenizer(tokenizer) manually if needed")
82
+
83
+ return model
84
+
85
  def setup_tokenizer(self, tokenizer):
86
+ """Cache tokenizer and precompute lookup tables for bp-level operations."""
87
  self.tokenizer = tokenizer
88
  k = tokenizer.k
89
  self.k = k
 
 
90
 
91
+ device = next(self.parameters()).device
92
+
93
+ # Build ordered kmer list from the tokenizer's DNA vocab
94
+ kmer_items = sorted(
95
+ [
96
+ (kmer, tid)
97
+ for kmer, tid in tokenizer.vocab.items()
98
+ if len(kmer) == k and all(b in "ATCG" for b in kmer)
99
+ ],
100
+ key=lambda x: x[1],
101
+ )
102
+ kmers = [item[0] for item in kmer_items]
103
+ kmer_ids = [item[1] for item in kmer_items]
104
+ num_kmers = len(kmer_ids)
105
+
106
+ kmer_ids_tensor = torch.tensor(kmer_ids, dtype=torch.long, device=device)
107
+ self.register_buffer("_kmer_ids", kmer_ids_tensor, persistent=False)
108
 
109
+ # bp_base_index[pos, j] = base index (0-3) of kmer j at position pos
110
  bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
111
+ for j, kmer in enumerate(kmers):
112
+ for pos, base in enumerate(kmer):
113
+ bp_base_index[pos, j] = BASE_TO_IDX[base]
114
  self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
115
 
116
+ bp_powers = torch.tensor(
 
 
117
  [4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
118
  )
119
+ self.register_buffer("_bp_powers", bp_powers, persistent=False)
120
+
121
+ # flat kmer index -> token id (flat index = sum base_idx[i] * 4^(k-1-i))
122
  flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
123
+ for j, (kmer, tid) in enumerate(kmer_items):
124
+ flat_idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer))
125
+ flat_to_tid[flat_idx] = tid
 
 
126
  self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
127
 
128
  def compute_bp_probs(self, logits):
129
+ """Compute per-base marginal probabilities from token logits.
130
 
131
  Args:
132
+ logits: [B, V] or [B, L, V]
133
  Returns:
134
  bp_probs: [B, k, 4] or [B, L, k, 4]
135
  """
136
+ squeeze = logits.dim() == 2
137
+ if squeeze:
138
+ logits = logits.unsqueeze(1)
 
139
 
140
+ kmer_logits = logits[:, :, self._kmer_ids]
141
  kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
142
  B, L, _ = kmer_probs.shape
143
  bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
144
  for pos in range(self.k):
145
+ idx = self._bp_base_index[pos]
146
  for nt in range(4):
147
  bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
148
 
149
+ return bp_probs.squeeze(1) if squeeze else bp_probs
 
 
 
 
 
 
 
 
 
150
 
151
+ def generate(self, inputs=None, generation_config=None, **kwargs):
152
+ """Like generate(), but each token is selected base-by-base from marginal distributions.
 
 
153
 
154
+ Temperature, top_k, top_p, repetition_penalty etc. all apply as usual —
155
+ they run before the bp processor and shift the marginal distributions.
156
+ Output shape and type are identical to generate().
157
  """
158
+ assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer(tokenizer) first"
159
+
160
+ gc = generation_config or self.generation_config
161
+ do_sample = kwargs.get("do_sample", getattr(gc, "do_sample", False))
162
+
163
+ bp_proc = _BPLogitsProcessor(
164
+ kmer_ids=self._kmer_ids,
165
+ bp_base_index=self._bp_base_index,
166
+ flat_idx_to_token_id=self._flat_idx_to_token_id,
167
+ bp_powers=self._bp_powers,
168
+ k=self.k,
169
+ do_sample=do_sample,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
+ existing = list(kwargs.pop("logits_processor", None) or [])
172
+ kwargs["logits_processor"] = LogitsProcessorList(existing + [bp_proc])
173
 
174
+ return super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  @torch.no_grad()
177
+ def score_sequence(self, sequences: Union[str, list]):
178
+ """Score DNA sequence(s) at base resolution.
179
 
180
+ Returns per-base probability distributions and the probability of the
181
+ actual base at each position, given all preceding context.
 
182
 
183
  Args:
184
+ sequences: single DNA string or list of DNA strings (ACGT only)
185
 
186
  Returns:
187
+ (bp_probs, actual_probs) for a single sequence, or
188
+ (list of bp_probs, list of actual_probs) for a batch.
189
+ bp_probs[i]: [seq_len_i, 4] — P(base | context) at each position
190
+ actual_probs[i]: [seq_len_i] — P(actual base | context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  """
192
+ assert hasattr(self, "tokenizer"), "Call setup_tokenizer(tokenizer) first"
193
 
 
194
  is_single = isinstance(sequences, str)
195
  if is_single:
196
  sequences = [sequences]
197
 
198
+ original_lens = [len(s) for s in sequences]
 
 
199
 
200
+ # Right-pad to multiple of k with 'A' (matches tokenizer convention)
201
+ padded = []
202
+ for s in sequences:
203
+ r = len(s) % self.k
204
+ padded.append(s + "A" * (self.k - r) if r else s)
 
 
205
 
206
+ # Prepend BOS manually (training format)
207
+ tagged = ["<s>" + s for s in padded]
208
 
 
209
  inputs = self.tokenizer(
210
+ tagged, return_tensors="pt", padding=True, add_special_tokens=False
 
 
 
211
  )
212
  input_ids = inputs["input_ids"].to(self.device)
213
  attention_mask = inputs["attention_mask"].to(self.device)
214
 
215
+ logits = self(input_ids, attention_mask=attention_mask, return_dict=True).logits
216
+ bp_probs_all = self.compute_bp_probs(logits) # [B, L, k, 4]
 
 
 
 
 
 
 
 
217
 
218
+ bp_results, actual_results = [], []
219
+ for i, (seq, orig_len, pad_seq) in enumerate(zip(sequences, original_lens, padded)):
220
+ num_tokens = len(pad_seq) // self.k
221
+ # logits[t] predicts token t+1; logits[0] (from <s>) predicts token 1
222
+ seq_bp = bp_probs_all[i, :num_tokens] # [num_tokens, k, 4]
223
+ seq_bp = seq_bp.reshape(-1, 4)[:orig_len] # [orig_len, 4]
224
+ actual = self._extract_actual_probs(seq_bp, seq)
225
+ bp_results.append(seq_bp)
226
+ actual_results.append(actual)
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  if is_single:
229
+ return bp_results[0], actual_results[0]
230
+ return bp_results, actual_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str) -> torch.Tensor:
233
+ actual = torch.zeros(len(sequence), device=bp_probs.device, dtype=bp_probs.dtype)
234
  for i, base in enumerate(sequence):
235
+ actual[i] = bp_probs[i].max() if base == "N" else bp_probs[i, BASE_TO_IDX[base]]
236
+ return actual