kashif HF Staff commited on
Commit
cbacdb4
·
verified ·
1 Parent(s): bc2257f

revert: remove auto_map, restore LlamaForCausalLM (no trust_remote_code needed)

Browse files
Files changed (1) hide show
  1. modeling_carbon.py +0 -209
modeling_carbon.py DELETED
@@ -1,209 +0,0 @@
1
- """
2
- Carbon with bp-level generation and scoring.
3
-
4
- generate_bp() plugs into the standard HF generate() pipeline via a
5
- LogitsProcessor — no internal methods are overridden, so it is compatible
6
- with any transformers version.
7
- """
8
- import torch
9
- import torch.nn.functional as F
10
- from transformers import LlamaForCausalLM, LogitsProcessor, LogitsProcessorList
11
- from typing import Union
12
-
13
- BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
14
- IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
15
-
16
-
17
- class _BPLogitsProcessor(LogitsProcessor):
18
- """Forces token selection to use per-base marginal probabilities.
19
-
20
- Runs LAST in the logits-processor chain so that temperature / top-k /
21
- top-p etc. influence the marginal distributions before base selection.
22
- """
23
-
24
- def __init__(self, kmer_ids, bp_base_index, flat_idx_to_token_id, bp_powers, k, do_sample):
25
- self.kmer_ids = kmer_ids
26
- self.bp_base_index = bp_base_index
27
- self.flat_idx_to_token_id = flat_idx_to_token_id
28
- self.bp_powers = bp_powers
29
- self.k = k
30
- self.do_sample = do_sample
31
-
32
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
33
- B = scores.shape[0]
34
- kmer_probs = F.softmax(scores[:, self.kmer_ids].float(), dim=-1) # [B, num_kmers]
35
-
36
- # Marginalise to per-base probabilities [B, k, 4]
37
- bp_probs = torch.zeros(B, self.k, 4, device=scores.device, dtype=kmer_probs.dtype)
38
- for pos in range(self.k):
39
- idx = self.bp_base_index[pos] # [num_kmers] in {0,1,2,3}
40
- for nt in range(4):
41
- bp_probs[:, pos, nt] = kmer_probs[:, idx == nt].sum(dim=-1)
42
-
43
- if self.do_sample:
44
- base_indices = torch.multinomial(bp_probs.view(-1, 4), 1).view(B, self.k)
45
- else:
46
- base_indices = bp_probs.argmax(dim=-1) # [B, k]
47
-
48
- flat_idx = (base_indices * self.bp_powers).sum(dim=-1) # [B]
49
- selected = self.flat_idx_to_token_id[flat_idx] # [B]
50
-
51
- # One-hot: both argmax and multinomial land on the bp-selected token
52
- new_scores = torch.full_like(scores, float("-inf"))
53
- new_scores.scatter_(1, selected.unsqueeze(1), 0.0)
54
- return new_scores
55
-
56
-
57
- class CarbonForCausalLM(LlamaForCausalLM):
58
- """LlamaForCausalLM with bp-level generation and sequence scoring."""
59
-
60
- def setup_tokenizer(self, tokenizer):
61
- """Cache tokenizer and precompute lookup tables for bp-level operations."""
62
- self.tokenizer = tokenizer
63
- k = tokenizer.k
64
- self.k = k
65
-
66
- device = next(self.parameters()).device
67
-
68
- # Build ordered kmer list from the tokenizer's DNA vocab
69
- kmer_items = sorted(
70
- [
71
- (kmer, tid)
72
- for kmer, tid in tokenizer.dna_token_to_id.items()
73
- if len(kmer) == k and all(b in "ATCG" for b in kmer)
74
- ],
75
- key=lambda x: x[1],
76
- )
77
- kmers = [item[0] for item in kmer_items]
78
- kmer_ids = [item[1] for item in kmer_items]
79
- num_kmers = len(kmer_ids)
80
-
81
- self._kmer_ids = torch.tensor(kmer_ids, dtype=torch.long, device=device)
82
-
83
- # bp_base_index[pos, j] = base index (0-3) of kmer j at position pos
84
- bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
85
- for j, kmer in enumerate(kmers):
86
- for pos, base in enumerate(kmer):
87
- bp_base_index[pos, j] = BASE_TO_IDX[base]
88
- self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
89
-
90
- self._bp_powers = torch.tensor(
91
- [4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
92
- )
93
-
94
- # flat kmer index -> token id (flat index = sum base_idx[i] * 4^(k-1-i))
95
- flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
96
- for j, (kmer, tid) in enumerate(kmer_items):
97
- flat_idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer))
98
- flat_to_tid[flat_idx] = tid
99
- self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
100
-
101
- def compute_bp_probs(self, logits):
102
- """Compute per-base marginal probabilities from token logits.
103
-
104
- Args:
105
- logits: [B, V] or [B, L, V]
106
- Returns:
107
- bp_probs: [B, k, 4] or [B, L, k, 4]
108
- """
109
- squeeze = logits.dim() == 2
110
- if squeeze:
111
- logits = logits.unsqueeze(1)
112
-
113
- kmer_logits = logits[:, :, self._kmer_ids]
114
- kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
115
- B, L, _ = kmer_probs.shape
116
- bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
117
- for pos in range(self.k):
118
- idx = self._bp_base_index[pos]
119
- for nt in range(4):
120
- bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
121
-
122
- return bp_probs.squeeze(1) if squeeze else bp_probs
123
-
124
- def generate_bp(self, inputs=None, generation_config=None, **kwargs):
125
- """Like generate(), but each token is selected base-by-base from marginal distributions.
126
-
127
- Temperature, top_k, top_p, repetition_penalty etc. all apply as usual —
128
- they run before the bp processor and shift the marginal distributions.
129
- Output shape and type are identical to generate().
130
- """
131
- assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer(tokenizer) first"
132
-
133
- gc = generation_config or self.generation_config
134
- do_sample = kwargs.get("do_sample", getattr(gc, "do_sample", False))
135
-
136
- bp_proc = _BPLogitsProcessor(
137
- kmer_ids=self._kmer_ids,
138
- bp_base_index=self._bp_base_index,
139
- flat_idx_to_token_id=self._flat_idx_to_token_id,
140
- bp_powers=self._bp_powers,
141
- k=self.k,
142
- do_sample=do_sample,
143
- )
144
- existing = list(kwargs.pop("logits_processor", None) or [])
145
- kwargs["logits_processor"] = LogitsProcessorList(existing + [bp_proc])
146
-
147
- return super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
148
-
149
- @torch.no_grad()
150
- def score_sequence(self, sequences: Union[str, list]):
151
- """Score DNA sequence(s) at base resolution.
152
-
153
- Returns per-base probability distributions and the probability of the
154
- actual base at each position, given all preceding context.
155
-
156
- Args:
157
- sequences: single DNA string or list of DNA strings (ACGT only)
158
-
159
- Returns:
160
- (bp_probs, actual_probs) for a single sequence, or
161
- (list of bp_probs, list of actual_probs) for a batch.
162
- bp_probs[i]: [seq_len_i, 4] — P(base | context) at each position
163
- actual_probs[i]: [seq_len_i] — P(actual base | context)
164
- """
165
- assert hasattr(self, "tokenizer"), "Call setup_tokenizer(tokenizer) first"
166
-
167
- is_single = isinstance(sequences, str)
168
- if is_single:
169
- sequences = [sequences]
170
-
171
- original_lens = [len(s) for s in sequences]
172
-
173
- # Right-pad to multiple of k with 'A' (matches tokenizer convention)
174
- padded = []
175
- for s in sequences:
176
- r = len(s) % self.k
177
- padded.append(s + "A" * (self.k - r) if r else s)
178
-
179
- # Prepend <dna> tag manually (training format)
180
- tagged = ["<dna>" + s for s in padded]
181
-
182
- inputs = self.tokenizer(
183
- tagged, return_tensors="pt", padding=True, add_special_tokens=False
184
- )
185
- input_ids = inputs["input_ids"].to(self.device)
186
- attention_mask = inputs["attention_mask"].to(self.device)
187
-
188
- logits = self(input_ids, attention_mask=attention_mask, return_dict=True).logits
189
- bp_probs_all = self.compute_bp_probs(logits) # [B, L, k, 4]
190
-
191
- bp_results, actual_results = [], []
192
- for i, (seq, orig_len, pad_seq) in enumerate(zip(sequences, original_lens, padded)):
193
- num_tokens = len(pad_seq) // self.k
194
- # logits[t] predicts token t+1; logits[0] (from <dna>) predicts token 1
195
- seq_bp = bp_probs_all[i, :num_tokens] # [num_tokens, k, 4]
196
- seq_bp = seq_bp.reshape(-1, 4)[:orig_len] # [orig_len, 4]
197
- actual = self._extract_actual_probs(seq_bp, seq)
198
- bp_results.append(seq_bp)
199
- actual_results.append(actual)
200
-
201
- if is_single:
202
- return bp_results[0], actual_results[0]
203
- return bp_results, actual_results
204
-
205
- def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str) -> torch.Tensor:
206
- actual = torch.zeros(len(sequence), device=bp_probs.device, dtype=bp_probs.dtype)
207
- for i, base in enumerate(sequence):
208
- actual[i] = bp_probs[i].max() if base == "N" else bp_probs[i, BASE_TO_IDX[base]]
209
- return actual