SixOpen commited on
Commit
f8ab83c
·
verified ·
1 Parent(s): d783988

Upload folder using huggingface_hub

Browse files
Files changed (11) hide show
  1. README.md +267 -0
  2. birwkv7.py +190 -0
  3. config.json +81 -0
  4. configuration_hare.py +45 -0
  5. model.pt +3 -0
  6. modeling_hare.py +98 -0
  7. streaming.py +202 -0
  8. surgery.py +205 -0
  9. surgery_meta.json +135 -0
  10. tokenizer.json +0 -0
  11. tokenizer_config.json +16 -0
README.md ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ tags:
5
+ - embeddings
6
+ - text-retrieval
7
+ - long-context
8
+ - rwkv
9
+ - modernbert
10
+ - streaming
11
+ - semantic-search
12
+ - retrieval
13
+ pipeline_tag: feature-extraction
14
+ library_name: transformers
15
+ base_model: Alibaba-NLP/gte-modernbert-base
16
+ ---
17
+
18
+ # HARE: Hybrid Attention-Recurrence Embeddings
19
+
20
+
21
+ TL;DR: Stateful embedding model that replaces sliding-window attention with RWKV recurrence, allowing for incremental encoding and streaming semantic search.
22
+
23
+ | | |
24
+ |---|---|
25
+ | **Parameters** | 173.9M |
26
+ | **Embedding dim** | 768 |
27
+ | **Base model** | [Alibaba-NLP/gte-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-modernbert-base) |
28
+ | **Architecture** | ModernBERT-base with 14/22 local attention layers replaced by bidirectional RWKV recurrence |
29
+ | **Language** | English |
30
+
31
+ Conventional embedding models are stateless: adding new content requires re-encoding from scratch because token representations depend on the entire sequence.
32
+ HARE replaces 14 local sliding-window attention layers in ModernBERT-base with bidirectional RWKV linear recurrence while retaining 8 global attention layers.
33
+ Each recurrent layer maintains a fixed-size state matrix that summarizes all prior tokens with O(1) per-token cost, making the encoder stateful thus it can save and resume from any position.
34
+
35
+ Essentially, the biggest advantage is being able to perform semantic search on large files way before they're 100% available - and across multiple streams simultaneously (for example parallel distributed files, concurrent transcripts, documents arriving from different sources on the same topic)
36
+
37
+ ## Results
38
+
39
+ ### LongEmbed (Needle/Passkey: nDCG@1; others: nDCG@10)
40
+
41
+ Chunk-level: 256-token chunks, mean-pooled, max-over-chunks scoring. Token-level: full-document encoding, per-token late interaction scoring.
42
+
43
+ | Task | Chunk-level | Token-level | GTE-ModernBERT-base |
44
+ |------|-------------|-------------|---------------------|
45
+ | Needle | 84.0 | **87.5** | 49.8 |
46
+ | Passkey | **96.3** | 52.5 | 47.0 |
47
+ | NarrativeQA | **54.2** | 53.6 | 46.6 |
48
+ | QMSum | 44.2 | **50.7** | 61.1 |
49
+ | WikimQA | 73.6 | **87.6** | 86.8 |
50
+ | SummScreenFD | 72.2 | **88.5** | 88.2 |
51
+ | **Average** | **70.7** | 70.1 | 63.2 |
52
+ | **Best-per-task** | | **77.5** | |
53
+
54
+ ### LoCo (12 long-context retrieval tasks, nDCG@10)
55
+
56
+ | Task | Chunk-level | Token-level | GTE-ModernBERT-base |
57
+ |------|-------------|-------------|---------------------|
58
+ | summ_screen_fd | 71.9 | **88.4** | 93.8 |
59
+ | gov_report | 86.2 | **97.2** | 97.5 |
60
+ | qmsum | **69.6** | 69.4 | 63.1 |
61
+ | qasper_title | 74.9 | **92.2** | 88.9 |
62
+ | qasper_abstract | 88.4 | **96.4** | 98.1 |
63
+ | multifieldqa | **93.4** | 92.9 | 93.4 |
64
+ | 2wikimqa | 90.0 | **91.1** | 86.6 |
65
+ | passage_retrieval | 95.1 | **95.5** | 52.7 |
66
+ | legal_case_reports | 11.4 | **24.3** | 44.8 |
67
+ | courtlistener_HTML | 43.6 | **51.4** | 23.5 |
68
+ | courtlistener_Plain_Text | 38.1 | **50.8** | 24.8 |
69
+ | stackoverflow | **43.3** | 36.7 | 90.9 |
70
+ | **Average** | 67.2 | **73.9** | 71.5 |
71
+
72
+ Token-level HARE (73.9) surpasses both GTE-ModernBERT-base (71.5) and bge-m3 (71.7) on LoCo.
73
+
74
+
75
+ ## Usage
76
+
77
+ ```python
78
+ import torch
79
+ import torch.nn.functional as F
80
+ from transformers import AutoModel, AutoTokenizer
81
+
82
+ model = AutoModel.from_pretrained("SixOpen/HARE", trust_remote_code=True)
83
+ tokenizer = AutoTokenizer.from_pretrained("SixOpen/HARE")
84
+ model = model.cuda().eval()
85
+
86
+ texts = ["Apple released a new iPhone model today", "The latest iPhone was announced by Apple"]
87
+ enc = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt')
88
+ enc = {k: v.to('cuda') for k, v in enc.items()}
89
+ with torch.no_grad():
90
+ hidden = model(**enc).last_hidden_state
91
+ mask = enc['attention_mask'].unsqueeze(-1).float()
92
+ embs = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
93
+ embs = F.normalize(embs, p=2, dim=-1)
94
+
95
+ similarity = (embs[0] @ embs[1]).item()
96
+ ```
97
+
98
+ ### Multi-vector retrieval (long documents)
99
+
100
+ For documents longer than 512 tokens, split into 256-token chunks with 64-token overlap and score with MaxSim.
101
+ HARE can also carry recurrent state across chunks, conditioning each chunk on all prior context without re-encoding. See the streaming demos for stateful usage.
102
+
103
+ ```python
104
+ import torch
105
+ import torch.nn.functional as F
106
+ from transformers import AutoModel, AutoTokenizer
107
+
108
+ model = AutoModel.from_pretrained("SixOpen/HARE", trust_remote_code=True)
109
+ tokenizer = AutoTokenizer.from_pretrained("SixOpen/HARE")
110
+ model = model.cuda().eval()
111
+
112
+ query = "your query"
113
+ document = open("document.txt").read() # any text format
114
+
115
+ # encode query
116
+ q_enc = tokenizer(query, return_tensors='pt', truncation=True, max_length=512)
117
+ q_enc = {k: v.cuda() for k, v in q_enc.items()}
118
+ with torch.no_grad():
119
+ q_hidden = model(**q_enc).last_hidden_state
120
+ q_mask = q_enc['attention_mask'].unsqueeze(-1).float()
121
+ query_emb = F.normalize((q_hidden * q_mask).sum(1) / q_mask.sum(1).clamp(min=1e-9), dim=-1)
122
+
123
+ # chunk document (256 tokens, 64-token overlap)
124
+ doc_ids = tokenizer(document, return_tensors='pt', truncation=False)['input_ids'][0]
125
+ chunk_size, stride = 256, 192
126
+ chunk_embs = []
127
+ for start in range(0, len(doc_ids), stride):
128
+ ids = doc_ids[start:start + chunk_size].unsqueeze(0).cuda()
129
+ with torch.no_grad():
130
+ h = model(input_ids=ids, attention_mask=torch.ones_like(ids)).last_hidden_state
131
+ emb = F.normalize(h.mean(1), dim=-1)
132
+ chunk_embs.append(emb)
133
+
134
+ chunk_embs = torch.cat(chunk_embs, dim=0)
135
+ scores = (query_emb @ chunk_embs.T).squeeze(0)
136
+ best_chunk = scores.argmax().item()
137
+ print(f"Best chunk: {best_chunk}, score: {scores[best_chunk]:.4f}")
138
+ ```
139
+
140
+ ### Stateful streaming (incremental encoding)
141
+
142
+ As mentioned prior unlike standard encoders, HARE can save and resume from any position. New text is encoded with full prior context without re-encoding anything before it.
143
+
144
+ ```python
145
+ from streaming import SpanEncoder
146
+
147
+ enc = SpanEncoder(model, tokenizer, "cuda", chunk_size=256)
148
+
149
+ # Mock lecture transcript arriving in 3 streaming pieces
150
+ pieces = [
151
+ "Today we will cover the fundamentals of quantum computing. Classical computers "
152
+ "use bits that are either 0 or 1. Quantum computers use qubits which can exist "
153
+ "in superposition, meaning they can be both 0 and 1 simultaneously. ",
154
+ "The key advantage comes from entanglement. When two qubits are entangled, "
155
+ "measuring one instantly determines the state of the other regardless of distance. "
156
+ "This allows quantum computers to process certain problems exponentially faster. ",
157
+ "The most important quantum algorithm is Shor's algorithm which can factor large "
158
+ "numbers in polynomial time. This has major implications for cryptography since "
159
+ "RSA encryption relies on the difficulty of factoring large primes. ",
160
+ ]
161
+
162
+ # Encode incrementally, only the new piece is processed each time
163
+ enc.encode_span(pieces[0], key="p0") # encode first piece
164
+ enc.extend_right(pieces[1], "p0", "p1") # extend with state carry
165
+ enc.extend_right(pieces[2], "p1", "p2") # extend again
166
+
167
+ # Search the incrementally built index
168
+ q_emb = enc.encode_query("why is Shor's algorithm important for cryptography")
169
+ chunk_embs = torch.cat(enc.span_data["p2"]["chunk_embs"], dim=0)
170
+ scores = (q_emb @ chunk_embs.T).squeeze(0)
171
+ best = scores.argmax().item()
172
+ print(f"Best chunk: {best}, score: {scores[best]:.4f}")
173
+ # → Best chunk: 2, score: 0.7814
174
+ ```
175
+
176
+ ### Token-level late interaction (offline, full-document)
177
+
178
+ For best quality on long documents, encode the full document in one pass and score at the token level, where query_tokens and doc_tokens are l2-normalized token embeddings:
179
+
180
+ ```python
181
+ score = sum(max(q_tok @ d_tok for d_tok in doc_tokens) for q_tok in query_tokens)
182
+ ```
183
+
184
+ ## Architecture
185
+
186
+ HARE starts from ModernBERT-base (22 layers, 768-dim, 12 heads) and performs architectural surgery:
187
+
188
+ - Layers 1, 2, 4, 5, 7, 8, 10, 11, 13, 14, 16, 17, 19, 20 (14 local sliding-window attention layers) are replaced with BiRWKV-7 bidirectional recurrence
189
+ - Layers 0, 3, 6, 9, 12, 15, 18, 21 (8 global attention layers) are retained unchanged
190
+ - Weight mapping: Q->R, K->K, V->V, O->O (attention projections initialize recurrence projections)
191
+ - Recurrence-specific parameters (decay, gate, mixing coefficients) are randomly initialized and learned during training
192
+
193
+ Each BiRWKV-7 layer runs a forward (left-to-right) and backward (right-to-left) scan, averaged. The forward scan's state matrix (64x64 per head, 12 heads per layer) can be saved and resumed for incremental encoding.
194
+
195
+ ## Training
196
+
197
+ Three-stage pipeline:
198
+
199
+ ### Stage 1: Contrastive distillation
200
+
201
+ | | |
202
+ |---|---|
203
+ | Teacher | GTE-ModernBERT-base |
204
+ | Data | NLI (AllNLI) + MS-MARCO |
205
+ | Loss | (1 - alpha) * MRL-InfoNCE + alpha * cosine distillation |
206
+ | MRL dims | 64, 128, 256, 768 |
207
+ | Alpha | 0.5 |
208
+ | Epochs | 3 |
209
+ | Batch size | 32 |
210
+ | Learning rate | 2e-5 (cosine decay) |
211
+ | Max length | 512 |
212
+ | Optimizer | AdamW (weight_decay=0.01) |
213
+
214
+ ### Stage 2: Long-context self-distillation
215
+
216
+ | | |
217
+ |---|---|
218
+ | Teacher | GTE-ModernBERT-base |
219
+ | Data | NLI + MS-MARCO (10K each, 20K total) |
220
+ | Loss | (1 - alpha) * MRL-InfoNCE + alpha * cosine distillation |
221
+ | Alpha | 0.3 |
222
+ | Epochs | 1 |
223
+ | Batch size | 8 |
224
+ | Learning rate | 5e-6 (cosine decay) |
225
+ | Max length | 2048 |
226
+
227
+ ### Stage 3: Synthetic IR training
228
+
229
+ | | |
230
+ |---|---|
231
+ | Data | 40% NLI + 40% MS-MARCO + 20% synthetic information-location pairs |
232
+ | Loss | MRL-InfoNCE |
233
+ | Epochs | 2 |
234
+ | Batch size | 32 |
235
+ | Learning rate | 5e-6 (cosine decay) |
236
+ | Max length | 512 |
237
+ | Merge | 30% Stage 2 weights + 70% Stage 3 weights |
238
+
239
+ ## Files
240
+
241
+ | File | Description |
242
+ |------|-------------|
243
+ | `model.pt` | Model weights (664MB) |
244
+ | `config.json` | ModernBERT model config |
245
+ | `surgery_meta.json` | Layer replacement mapping (which layers were replaced, weight transfer record) |
246
+ | `tokenizer.json` | Tokenizer |
247
+ | `tokenizer_config.json` | Tokenizer config |
248
+ | `surgery.py` | Standalone surgery CLI tool (inspect layers, perform surgery from scratch) |
249
+ | `birwkv7.py` | BiRWKV-7 recurrence layer (required for loading) |
250
+ | `streaming.py` | SpanEncoder for stateful incremental encoding |
251
+
252
+ ## Intended uses
253
+
254
+ - Semantic search and retrieval over short or long documents
255
+ - Incremental indexing where text arrives sequentially and must be searchable before completion: live transcription, real-time meeting/dispatch/etc indexing, distributed (ie torrent) content search, incremental document editing
256
+ - Multi-vector retrieval with chunk-level or token-level scoring
257
+
258
+
259
+ ## Citation
260
+
261
+ ```bibtex
262
+ @article{osman2026hare,
263
+ title={Stateful Embeddings via Hybrid Attention-Recurrence},
264
+ author={Osman A. Ender},
265
+ year={2026}
266
+ }
267
+ ```
birwkv7.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ _FLA_AVAILABLE = False
6
+ try:
7
+ import torch.distributed.tensor as _tdt
8
+ if not hasattr(_tdt, 'Replicate'):
9
+ try:
10
+ from torch.distributed._tensor import Replicate as _R, Shard as _S
11
+ _tdt.Replicate = _R; _tdt.Shard = _S
12
+ except ImportError:
13
+ pass
14
+ if not hasattr(_tdt, 'Placement'):
15
+ try:
16
+ from torch.distributed._tensor.placement_types import Placement as _P
17
+ _tdt.Placement = _P
18
+ except ImportError:
19
+ pass
20
+ if not hasattr(_tdt, 'distribute_module'):
21
+ _tdt.distribute_module = lambda *a, **kw: None
22
+ from fla.ops.rwkv7 import chunk_rwkv7 as _fla_chunk_rwkv7
23
+ if torch.cuda.is_available():
24
+ _test_r = torch.randn(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16, requires_grad=True)
25
+ _test_w = -torch.ones(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16)
26
+ _test_o, _ = _fla_chunk_rwkv7(_test_r, _test_w, _test_r, _test_r, _test_r, _test_r,
27
+ head_first=False)
28
+ _test_o.sum().backward()
29
+ if not _test_r.grad.isnan().any():
30
+ _FLA_AVAILABLE = True
31
+ del _test_r, _test_w, _test_o
32
+ torch.cuda.empty_cache()
33
+ else:
34
+ _FLA_AVAILABLE = True
35
+ except Exception:
36
+ pass
37
+
38
+
39
+ class BiRWKV7Layer(nn.Module):
40
+
41
+ def __init__(self, hidden_size, num_heads):
42
+ super().__init__()
43
+ assert hidden_size % num_heads == 0
44
+ self.hidden_size = hidden_size
45
+ self.num_heads = num_heads
46
+ self.head_size = hidden_size // num_heads
47
+
48
+ self.mu_r = nn.Parameter(torch.zeros(hidden_size))
49
+ self.mu_w = nn.Parameter(torch.zeros(hidden_size))
50
+ self.mu_k = nn.Parameter(torch.zeros(hidden_size))
51
+ self.mu_v = nn.Parameter(torch.zeros(hidden_size))
52
+ self.mu_a = nn.Parameter(torch.zeros(hidden_size))
53
+ self.mu_g = nn.Parameter(torch.zeros(hidden_size))
54
+
55
+ self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
56
+ self.W_k = nn.Linear(hidden_size, hidden_size, bias=False)
57
+ self.W_v = nn.Linear(hidden_size, hidden_size, bias=False)
58
+ self.W_w = nn.Linear(hidden_size, hidden_size, bias=False)
59
+ self.W_a = nn.Linear(hidden_size, hidden_size, bias=False)
60
+ self.W_g = nn.Linear(hidden_size, hidden_size, bias=False)
61
+
62
+ self.sab_gate = nn.Parameter(torch.tensor(-5.0))
63
+
64
+ self.group_norm = nn.GroupNorm(num_heads, hidden_size)
65
+ self.W_o = nn.Linear(hidden_size, hidden_size, bias=False)
66
+
67
+ nn.init.normal_(self.W_w.weight, std=0.01)
68
+ nn.init.normal_(self.W_a.weight, std=0.01)
69
+ nn.init.normal_(self.W_g.weight, std=0.02)
70
+
71
+ def _token_shift(self, x):
72
+ x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
73
+
74
+ def mix(mu):
75
+ return x + (x_prev - x) * torch.sigmoid(mu)
76
+
77
+ return {
78
+ 'r': mix(self.mu_r), 'w': mix(self.mu_w),
79
+ 'k': mix(self.mu_k), 'v': mix(self.mu_v),
80
+ 'a': mix(self.mu_a), 'g': mix(self.mu_g),
81
+ }
82
+
83
+ def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale):
84
+ B, T, H, D = r.shape
85
+ orig_dtype = r.dtype
86
+ r, w, k, v, a = [x.bfloat16() for x in (r, w, k, v, a)]
87
+ k_scaled = k * (D ** -0.5)
88
+ w_log = -0.6065306597633104 * torch.sigmoid(w)
89
+ a_sig = torch.sigmoid(a)
90
+ a_fla = -k_scaled
91
+ b_fla = sab_scale * k_scaled * a_sig
92
+ o, _ = _fla_chunk_rwkv7(r, w_log, k_scaled, v, a_fla, b_fla, scale=1.0)
93
+ return o.to(orig_dtype)
94
+
95
+ def _wkv7_scan_python(self, r, w, k, v, a, sab_scale):
96
+ B, T, H, D = r.shape
97
+ orig_dtype = r.dtype
98
+
99
+ r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
100
+ k = k * (D ** -0.5)
101
+ decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
102
+ a = torch.sigmoid(a)
103
+
104
+ state = torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32)
105
+ outputs = []
106
+
107
+ for t in range(T):
108
+ if t > 0 and t % 16 == 0:
109
+ state = state.detach()
110
+
111
+ kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t]
112
+
113
+ sa = torch.einsum('bhij,bhj->bhi', state, -kt)
114
+ sab = torch.einsum('bhi,bhj->bhij', sa, kt * at)
115
+ state = state * dt.unsqueeze(-2) + sab_scale * sab + torch.einsum('bhi,bhj->bhij', vt, kt)
116
+ state = state.clamp(-10.0, 10.0)
117
+
118
+ outputs.append(torch.einsum('bhij,bhj->bhi', state, rt))
119
+
120
+ return torch.stack(outputs, dim=1).to(orig_dtype)
121
+
122
+ def _wkv7_scan(self, r, w, k, v, a, sab_scale):
123
+ if _FLA_AVAILABLE and r.is_cuda:
124
+ return self._wkv7_scan_fla(r, w, k, v, a, sab_scale)
125
+ return self._wkv7_scan_python(r, w, k, v, a, sab_scale)
126
+
127
+ def forward(self, x, attention_mask=None, **kwargs):
128
+ B, T, C = x.shape
129
+ H, D = self.num_heads, self.head_size
130
+
131
+ mixed = self._token_shift(x)
132
+ r = self.W_r(mixed['r']).view(B, T, H, D)
133
+ w = self.W_w(mixed['w']).view(B, T, H, D)
134
+ k = self.W_k(mixed['k']).view(B, T, H, D)
135
+ v = self.W_v(mixed['v']).view(B, T, H, D)
136
+ a = self.W_a(mixed['a']).view(B, T, H, D)
137
+ g = torch.sigmoid(self.W_g(mixed['g']))
138
+
139
+ sab_scale = torch.sigmoid(self.sab_gate)
140
+
141
+ out_fwd = self._wkv7_scan(r, w, k, v, a, sab_scale)
142
+ out_bwd = self._wkv7_scan(
143
+ r.flip(1), w.flip(1), k.flip(1), v.flip(1), a.flip(1), sab_scale
144
+ ).flip(1)
145
+
146
+ out = (out_fwd + out_bwd).reshape(B, T, C) * 0.5
147
+ out = self.group_norm(out.transpose(1, 2)).transpose(1, 2)
148
+ out = self.W_o(out * g)
149
+
150
+ return out, None
151
+
152
+
153
+ def init_from_attention(birwkv, attn_module):
154
+ q_proj = k_proj = v_proj = o_proj = None
155
+
156
+ if hasattr(attn_module, 'Wqkv'):
157
+ fused = attn_module.Wqkv.weight.data
158
+ C = fused.shape[1]
159
+ q_proj, k_proj, v_proj = fused[:C], fused[C:2*C], fused[2*C:]
160
+ else:
161
+ for name in ['q_proj', 'query', 'W_q', 'wq']:
162
+ if hasattr(attn_module, name):
163
+ q_proj = getattr(attn_module, name).weight.data
164
+ break
165
+ for name in ['k_proj', 'key', 'W_k', 'wk']:
166
+ if hasattr(attn_module, name):
167
+ k_proj = getattr(attn_module, name).weight.data
168
+ break
169
+ for name in ['v_proj', 'value', 'W_v', 'wv']:
170
+ if hasattr(attn_module, name):
171
+ v_proj = getattr(attn_module, name).weight.data
172
+ break
173
+
174
+ for name in ['Wo', 'out_proj', 'o_proj', 'dense', 'W_o', 'wo']:
175
+ if hasattr(attn_module, name):
176
+ o_proj = getattr(attn_module, name).weight.data
177
+ break
178
+
179
+ transferred = []
180
+ for src, dst, label in [
181
+ (q_proj, birwkv.W_r, 'Q->R'),
182
+ (k_proj, birwkv.W_k, 'K->K'),
183
+ (v_proj, birwkv.W_v, 'V->V'),
184
+ (o_proj, birwkv.W_o, 'O->O'),
185
+ ]:
186
+ if src is not None:
187
+ dst.weight.data.copy_(src)
188
+ transferred.append(label)
189
+
190
+ return transferred
config.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HareModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_hare.HareConfig",
7
+ "AutoModel": "modeling_hare.HareModel"
8
+ },
9
+ "attention_bias": false,
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 50281,
12
+ "classifier_activation": "gelu",
13
+ "classifier_bias": false,
14
+ "classifier_dropout": 0.0,
15
+ "classifier_pooling": "mean",
16
+ "cls_token_id": 50281,
17
+ "decoder_bias": true,
18
+ "deterministic_flash_attn": false,
19
+ "dtype": "float16",
20
+ "embedding_dropout": 0.0,
21
+ "eos_token_id": 50282,
22
+ "global_attn_every_n_layers": 3,
23
+ "gradient_checkpointing": false,
24
+ "hidden_activation": "gelu",
25
+ "hidden_size": 768,
26
+ "initializer_cutoff_factor": 2.0,
27
+ "initializer_range": 0.02,
28
+ "intermediate_size": 1152,
29
+ "layer_norm_eps": 1e-05,
30
+ "layer_types": [
31
+ "full_attention",
32
+ "sliding_attention",
33
+ "sliding_attention",
34
+ "full_attention",
35
+ "sliding_attention",
36
+ "sliding_attention",
37
+ "full_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "full_attention",
44
+ "sliding_attention",
45
+ "sliding_attention",
46
+ "full_attention",
47
+ "sliding_attention",
48
+ "sliding_attention",
49
+ "full_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention"
53
+ ],
54
+ "local_attention": 128,
55
+ "max_position_embeddings": 8192,
56
+ "mlp_bias": false,
57
+ "mlp_dropout": 0.0,
58
+ "model_type": "hare",
59
+ "norm_bias": false,
60
+ "norm_eps": 1e-05,
61
+ "num_attention_heads": 12,
62
+ "num_hidden_layers": 22,
63
+ "pad_token_id": 50283,
64
+ "position_embedding_type": "absolute",
65
+ "rope_parameters": {
66
+ "full_attention": {
67
+ "rope_theta": 160000.0,
68
+ "rope_type": "default"
69
+ },
70
+ "sliding_attention": {
71
+ "rope_theta": 10000.0,
72
+ "rope_type": "default"
73
+ }
74
+ },
75
+ "sep_token_id": 50282,
76
+ "sparse_pred_ignore_index": -100,
77
+ "sparse_prediction": false,
78
+ "tie_word_embeddings": true,
79
+ "transformers_version": "5.2.0",
80
+ "vocab_size": 50368
81
+ }
configuration_hare.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class HareConfig(PretrainedConfig):
5
+ model_type = "hare"
6
+
7
+ def __init__(
8
+ self,
9
+ hidden_size=768,
10
+ num_attention_heads=12,
11
+ num_hidden_layers=22,
12
+ intermediate_size=1152,
13
+ hidden_activation="gelu",
14
+ max_position_embeddings=8192,
15
+ vocab_size=50368,
16
+ pad_token_id=50283,
17
+ bos_token_id=50281,
18
+ eos_token_id=50282,
19
+ cls_token_id=50281,
20
+ sep_token_id=50282,
21
+ global_attn_every_n_layers=3,
22
+ local_attention=128,
23
+ replaced_layers=None,
24
+ surgery_variant="conservative",
25
+ **kwargs,
26
+ ):
27
+ super().__init__(
28
+ pad_token_id=pad_token_id,
29
+ bos_token_id=bos_token_id,
30
+ eos_token_id=eos_token_id,
31
+ **kwargs,
32
+ )
33
+ self.hidden_size = hidden_size
34
+ self.num_attention_heads = num_attention_heads
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.intermediate_size = intermediate_size
37
+ self.hidden_activation = hidden_activation
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.vocab_size = vocab_size
40
+ self.cls_token_id = cls_token_id
41
+ self.sep_token_id = sep_token_id
42
+ self.global_attn_every_n_layers = global_attn_every_n_layers
43
+ self.local_attention = local_attention
44
+ self.replaced_layers = replaced_layers
45
+ self.surgery_variant = surgery_variant
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42a1d92de872ce85ff2bb1e189f8ac41fd3062e006827b15310484641e2b9157
3
+ size 695588290
modeling_hare.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from transformers import AutoModel, AutoConfig, PreTrainedModel
6
+ from transformers.modeling_outputs import BaseModelOutput
7
+
8
+ from .configuration_hare import HareConfig
9
+ from .birwkv7 import BiRWKV7Layer, init_from_attention
10
+
11
+
12
+ def _find_encoder(model):
13
+ for attr in ['encoder', 'model']:
14
+ if hasattr(model, attr):
15
+ candidate = getattr(model, attr)
16
+ if hasattr(candidate, 'layers'):
17
+ return candidate
18
+ if hasattr(model, 'layers'):
19
+ return model
20
+ raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
21
+
22
+
23
+ def _perform_surgery(model, replaced_layers, hidden_size, num_heads):
24
+ encoder = _find_encoder(model)
25
+ for layer_idx_str, info in replaced_layers.items():
26
+ layer_idx = int(layer_idx_str)
27
+ layer = encoder.layers[layer_idx]
28
+ attn = None
29
+ attn_name = None
30
+ for name in ['attn', 'attention', 'self_attn', 'self_attention']:
31
+ if hasattr(layer, name):
32
+ attn = getattr(layer, name)
33
+ attn_name = name
34
+ break
35
+ if attn is None:
36
+ continue
37
+ birwkv = BiRWKV7Layer(hidden_size, num_heads)
38
+ device = next(attn.parameters()).device
39
+ dtype = next(attn.parameters()).dtype
40
+ birwkv = birwkv.to(device=device, dtype=dtype)
41
+ setattr(layer, attn_name, birwkv)
42
+
43
+
44
+ class HareModel(PreTrainedModel):
45
+ config_class = HareConfig
46
+
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+ base_config = AutoConfig.from_pretrained(
50
+ "answerdotai/ModernBERT-base",
51
+ hidden_size=config.hidden_size,
52
+ num_attention_heads=config.num_attention_heads,
53
+ num_hidden_layers=config.num_hidden_layers,
54
+ intermediate_size=config.intermediate_size,
55
+ vocab_size=config.vocab_size,
56
+ max_position_embeddings=config.max_position_embeddings,
57
+ )
58
+ self.inner_model = AutoModel.from_config(base_config)
59
+
60
+ if config.replaced_layers:
61
+ _perform_surgery(
62
+ self.inner_model,
63
+ config.replaced_layers,
64
+ config.hidden_size,
65
+ config.num_attention_heads,
66
+ )
67
+
68
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
69
+ outputs = self.inner_model(
70
+ input_ids=input_ids,
71
+ attention_mask=attention_mask,
72
+ **kwargs,
73
+ )
74
+ return outputs
75
+
76
+ @classmethod
77
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
78
+ model_dir = Path(pretrained_model_name_or_path)
79
+ surgery_meta_path = model_dir / "surgery_meta.json"
80
+
81
+ if surgery_meta_path.exists():
82
+ with open(surgery_meta_path) as f:
83
+ meta = json.load(f)
84
+
85
+ config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
86
+ config.replaced_layers = meta.get("replaced_layers")
87
+ config.surgery_variant = meta.get("variant", "conservative")
88
+
89
+ model = cls(config)
90
+
91
+ weights_path = model_dir / "model.pt"
92
+ if weights_path.exists():
93
+ state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
94
+ model.inner_model.load_state_dict(state_dict)
95
+
96
+ return model.float().eval()
97
+
98
+ return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
streaming.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from birwkv7 import BiRWKV7Layer
5
+
6
+
7
+ def wkv7_forward_scan(r, w, k, v, a, sab_scale, init_state=None):
8
+ B, T, H, D = r.shape
9
+ r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
10
+ k = k * (D ** -0.5)
11
+ decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
12
+ a = torch.sigmoid(a)
13
+ sab_s = float(sab_scale)
14
+ state = init_state.float().clone() if init_state is not None else \
15
+ torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32)
16
+ outputs = []
17
+ for t in range(T):
18
+ kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t]
19
+ sa = torch.einsum('bhij,bhj->bhi', state, -kt)
20
+ sab = torch.einsum('bhi,bhj->bhij', sa, kt * at)
21
+ state = state * dt.unsqueeze(-2) + sab_s * sab + \
22
+ torch.einsum('bhi,bhj->bhij', vt, kt)
23
+ state = state.clamp(-10.0, 10.0)
24
+ outputs.append(torch.einsum('bhij,bhj->bhi', state, rt))
25
+ return torch.stack(outputs, dim=1), state.detach()
26
+
27
+
28
+ class SpanEncoder:
29
+
30
+ def __init__(self, model, tokenizer, device, chunk_size=512):
31
+ self.model = model
32
+ self.tokenizer = tokenizer
33
+ self.device = device
34
+ self.chunk_size = chunk_size
35
+
36
+ self.birwkv_layers = []
37
+ self.birwkv_ids = {}
38
+ for m in model.modules():
39
+ if isinstance(m, BiRWKV7Layer):
40
+ self.birwkv_ids[id(m)] = len(self.birwkv_layers)
41
+ self.birwkv_layers.append(m)
42
+
43
+ self._originals = {}
44
+ self._hooked = False
45
+ self._active_states = [None] * len(self.birwkv_layers)
46
+ self.span_data = {}
47
+
48
+ def _hook(self):
49
+ if self._hooked:
50
+ return
51
+ for layer in self.birwkv_layers:
52
+ self._originals[id(layer)] = layer.forward
53
+ layer.forward = self._make_fwd(layer)
54
+ self._hooked = True
55
+
56
+ def _unhook(self):
57
+ if not self._hooked:
58
+ return
59
+ for layer in self.birwkv_layers:
60
+ layer.forward = self._originals[id(layer)]
61
+ self._originals.clear()
62
+ self._hooked = False
63
+
64
+ def _make_fwd(self, layer):
65
+ enc = self
66
+ idx = self.birwkv_ids[id(layer)]
67
+
68
+ def fwd(x, attention_mask=None, **kwargs):
69
+ B, T, C_ = x.shape
70
+ H, D = layer.num_heads, layer.head_size
71
+ prev = enc._active_states[idx]
72
+ if prev is not None:
73
+ x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1)
74
+ else:
75
+ x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
76
+
77
+ def mix(mu):
78
+ return x + (x_prev - x) * torch.sigmoid(mu)
79
+
80
+ r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D)
81
+ w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D)
82
+ k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D)
83
+ v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D)
84
+ a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D)
85
+ g = torch.sigmoid(layer.W_g(mix(layer.mu_g)))
86
+ sab_scale = torch.sigmoid(layer.sab_gate)
87
+ init_st = prev['wkv_state'] if prev else None
88
+
89
+ try:
90
+ from birwkv7_triton import wkv7_scan_triton
91
+ r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float()
92
+ a_f = torch.sigmoid(a.float())
93
+ decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float()))
94
+ out_fwd, wkv_state = wkv7_scan_triton(
95
+ r_f, decay, k_f, v_f, a_f, sab_scale,
96
+ return_state=True, init_state=init_st)
97
+ out_bwd = wkv7_scan_triton(
98
+ r_f.flip(1), decay.flip(1), k_f.flip(1),
99
+ v_f.flip(1), a_f.flip(1), sab_scale,
100
+ return_state=False).flip(1)
101
+ except (ImportError, Exception):
102
+ out_fwd, wkv_state = wkv7_forward_scan(
103
+ r, w, k, v, a, sab_scale, init_st)
104
+ out_bwd = wkv7_forward_scan(
105
+ r.flip(1), w.flip(1), k.flip(1),
106
+ v.flip(1), a.flip(1), sab_scale, None)[0].flip(1)
107
+ enc._active_states[idx] = {
108
+ 'wkv_state': wkv_state,
109
+ 'last_x': x[:, -1:].detach().clone(),
110
+ }
111
+ out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_)
112
+ out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2)
113
+ out = layer.W_o(out * g)
114
+ return out, None
115
+ return fwd
116
+
117
+ @torch.no_grad()
118
+ def _forward_encode_raw(self, text, init_states=None, max_length=8192):
119
+ self._hook()
120
+ if init_states is not None:
121
+ self._active_states = [
122
+ {k: v.clone() for k, v in s.items()} if s else None
123
+ for s in init_states
124
+ ]
125
+ else:
126
+ self._active_states = [None] * len(self.birwkv_layers)
127
+
128
+ enc = self.tokenizer(text, return_tensors='pt', truncation=True,
129
+ max_length=max_length)
130
+ ids = enc['input_ids'].to(self.device)
131
+ mask = enc['attention_mask'].to(self.device)
132
+
133
+ h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
134
+ content = h[0, 1:-1, :].cpu()
135
+ n_content = content.shape[0]
136
+
137
+ final_states = [
138
+ {k: v.clone() for k, v in s.items()} if s else None
139
+ for s in self._active_states
140
+ ]
141
+ self._unhook()
142
+ return content, n_content, final_states
143
+
144
+ def _chunk_hidden(self, content, return_residual=False):
145
+ T = content.shape[0]
146
+ chunks = []
147
+ last_end = 0
148
+ for start in range(0, T, self.chunk_size):
149
+ end = min(start + self.chunk_size, T)
150
+ if end - start < 32:
151
+ break
152
+ emb = F.normalize(content[start:end].mean(0, keepdim=True),
153
+ p=2, dim=-1)
154
+ chunks.append(emb)
155
+ last_end = end
156
+ if not chunks and T > 0:
157
+ chunks.append(F.normalize(content.mean(0, keepdim=True),
158
+ p=2, dim=-1))
159
+ last_end = T
160
+ if return_residual:
161
+ residual = content[last_end:] if last_end < T else None
162
+ return chunks, residual
163
+ return chunks
164
+
165
+ @torch.no_grad()
166
+ def encode_query(self, query):
167
+ assert not self._hooked
168
+ enc = self.tokenizer(query, return_tensors='pt', truncation=True,
169
+ max_length=512)
170
+ ids = enc['input_ids'].to(self.device)
171
+ mask = enc['attention_mask'].to(self.device)
172
+ h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
173
+ m = mask.unsqueeze(-1).float()
174
+ emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9)
175
+ return F.normalize(emb, p=2, dim=-1).cpu()
176
+
177
+ def encode_span(self, text, key):
178
+ content, n_tok, states = self._forward_encode_raw(text)
179
+ chunks, residual = self._chunk_hidden(content, return_residual=True)
180
+ self.span_data[key] = {
181
+ 'layer_states': states,
182
+ 'chunk_embs': chunks,
183
+ 'n_tokens': n_tok,
184
+ 'residual_hidden': residual,
185
+ }
186
+ return n_tok
187
+
188
+ def extend_right(self, piece_text, old_key, new_key):
189
+ old = self.span_data.pop(old_key)
190
+ content, n_new, states = self._forward_encode_raw(
191
+ piece_text, init_states=old['layer_states'])
192
+ if old.get('residual_hidden') is not None:
193
+ content = torch.cat([old['residual_hidden'], content], dim=0)
194
+ new_chunks, residual = self._chunk_hidden(
195
+ content, return_residual=True)
196
+ self.span_data[new_key] = {
197
+ 'layer_states': states,
198
+ 'chunk_embs': old['chunk_embs'] + new_chunks,
199
+ 'n_tokens': old['n_tokens'] + n_new,
200
+ 'residual_hidden': residual,
201
+ }
202
+ return n_new
surgery.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
8
+
9
+ from birwkv7 import BiRWKV7Layer, init_from_attention
10
+
11
+
12
+ def _find_encoder(model):
13
+ for attr in ['encoder', 'model']:
14
+ if hasattr(model, attr):
15
+ candidate = getattr(model, attr)
16
+ if hasattr(candidate, 'layers'):
17
+ return candidate
18
+ if hasattr(model, 'layers'):
19
+ return model
20
+ raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
21
+
22
+
23
+ def find_attention_layers(model):
24
+ encoder = _find_encoder(model)
25
+ layers = []
26
+
27
+ for i, layer in enumerate(encoder.layers):
28
+ attn = None
29
+ attn_path = None
30
+ for name in ['attn', 'attention', 'self_attn', 'self_attention']:
31
+ if hasattr(layer, name):
32
+ attn = getattr(layer, name)
33
+ attn_path = f"layers.{i}.{name}"
34
+ break
35
+
36
+ if attn is None:
37
+ continue
38
+
39
+ is_global = False
40
+ if hasattr(attn, 'local_attention'):
41
+ is_global = not attn.local_attention
42
+ elif hasattr(attn, 'is_global_attention'):
43
+ is_global = attn.is_global_attention
44
+ elif hasattr(attn, 'use_sliding_window'):
45
+ is_global = not attn.use_sliding_window
46
+ elif hasattr(attn, 'sliding_window'):
47
+ is_global = attn.sliding_window is None
48
+ else:
49
+ is_global = (i % 3 == 2)
50
+
51
+ layers.append((i, attn_path, attn, is_global))
52
+
53
+ return layers
54
+
55
+
56
+ def perform_surgery(model, variant, hidden_size, num_heads, replaced_layers=None):
57
+ layers = find_attention_layers(model)
58
+ global_indices = [idx for idx, _, _, g in layers if g]
59
+ local_indices = [idx for idx, _, _, g in layers if not g]
60
+
61
+ print(f"\nFound {len(layers)} attention layers:")
62
+ print(f" Global: {global_indices}")
63
+ print(f" Local: {local_indices}")
64
+
65
+ if replaced_layers is not None:
66
+ replace_indices = {int(k) for k in replaced_layers.keys()}
67
+ elif variant == 'conservative':
68
+ replace_indices = set(local_indices)
69
+ elif variant == 'aggressive':
70
+ keep = set()
71
+ if global_indices:
72
+ keep.add(global_indices[0])
73
+ keep.add(global_indices[-1])
74
+ replace_indices = {idx for idx, _, _, _ in layers if idx not in keep}
75
+ elif variant == 'pure':
76
+ replace_indices = {idx for idx, _, _, _ in layers}
77
+ else:
78
+ raise ValueError(f"Unknown variant: {variant}")
79
+
80
+ print(f"\nVariant '{variant}': replacing {len(replace_indices)} of {len(layers)} layers")
81
+
82
+ encoder = _find_encoder(model)
83
+ report = {}
84
+
85
+ for layer_idx, attn_path, attn_module, is_global in layers:
86
+ if layer_idx not in replace_indices:
87
+ print(f" Layer {layer_idx}: KEEP ({'global' if is_global else 'local'})")
88
+ continue
89
+
90
+ birwkv = BiRWKV7Layer(hidden_size, num_heads)
91
+ transferred = init_from_attention(birwkv, attn_module)
92
+
93
+ device = next(attn_module.parameters()).device
94
+ dtype = next(attn_module.parameters()).dtype
95
+ birwkv = birwkv.to(device=device, dtype=dtype)
96
+
97
+ attn_name = attn_path.split('.')[-1]
98
+ setattr(encoder.layers[layer_idx], attn_name, birwkv)
99
+
100
+ report[layer_idx] = {'was_global': is_global, 'transferred': transferred}
101
+ print(f" Layer {layer_idx}: REPLACED ({'global' if is_global else 'local'}) "
102
+ f"-> BiRWKV-7 [{', '.join(transferred)}]")
103
+
104
+ return report
105
+
106
+
107
+ def mean_pool(hidden_states, attention_mask):
108
+ mask = attention_mask.unsqueeze(-1).float()
109
+ return (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
110
+
111
+
112
+ class HareWrapper(torch.nn.Module):
113
+
114
+ def __init__(self, model, tokenizer):
115
+ super().__init__()
116
+ self.model = model
117
+ self.tokenizer = tokenizer
118
+ self.config = model.config
119
+
120
+ def encode(self, texts, batch_size=32, max_length=512, show_progress=False):
121
+ all_embs = []
122
+ iterator = range(0, len(texts), batch_size)
123
+ if show_progress:
124
+ from tqdm import tqdm
125
+ iterator = tqdm(iterator, desc="Encoding")
126
+
127
+ for i in iterator:
128
+ batch = texts[i:i+batch_size]
129
+ enc = self.tokenizer(batch, padding=True, truncation=True,
130
+ max_length=max_length, return_tensors='pt')
131
+ enc = {k: v.to(next(self.model.parameters()).device) for k, v in enc.items()}
132
+
133
+ with torch.no_grad():
134
+ hidden = self.model(**enc).last_hidden_state
135
+ emb = mean_pool(hidden, enc['attention_mask'])
136
+ all_embs.append(F.normalize(emb, p=2, dim=-1).cpu())
137
+
138
+ return torch.cat(all_embs, dim=0)
139
+
140
+ def forward(self, **kwargs):
141
+ return self.model(**kwargs)
142
+
143
+
144
+ def main():
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument('--base_model', default='answerdotai/ModernBERT-base')
147
+ parser.add_argument('--variant', choices=['conservative', 'aggressive', 'pure'],
148
+ default='conservative')
149
+ parser.add_argument('--output', type=str, default=None)
150
+ parser.add_argument('--inspect_only', action='store_true')
151
+ args = parser.parse_args()
152
+
153
+ print(f"Loading {args.base_model}...")
154
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model)
155
+ model = AutoModel.from_pretrained(args.base_model, trust_remote_code=True)
156
+ config = model.config
157
+ hidden_size = config.hidden_size
158
+ num_heads = config.num_attention_heads
159
+ print(f" hidden_size={hidden_size}, num_heads={num_heads}, head_size={hidden_size // num_heads}")
160
+
161
+ if args.inspect_only:
162
+ layers = find_attention_layers(model)
163
+ print(f"\n{len(layers)} attention layers:")
164
+ for idx, path, attn, is_g in layers:
165
+ n = sum(p.numel() for p in attn.parameters())
166
+ print(f" Layer {idx} ({'GLOBAL' if is_g else 'local'}): {type(attn).__name__} ({n:,}) @ {path}")
167
+ return
168
+
169
+ if not args.output:
170
+ parser.error("--output required for surgery (omit for --inspect_only)")
171
+
172
+ report = perform_surgery(model, args.variant, hidden_size, num_heads)
173
+
174
+ total_params = sum(p.numel() for p in model.parameters())
175
+ print(f"\nPost-surgery: {total_params:,} params")
176
+
177
+ print("Sanity check :)")
178
+ inputs = tokenizer("Hello world", return_tensors='pt')
179
+ inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()}
180
+ with torch.no_grad():
181
+ out = model(**inputs)
182
+ print(f" Output: {out.last_hidden_state.shape}, norm={out.last_hidden_state.norm().item():.4f}")
183
+
184
+ output_dir = Path(args.output)
185
+ output_dir.mkdir(parents=True, exist_ok=True)
186
+ torch.save(model.state_dict(), output_dir / 'model.pt')
187
+ tokenizer.save_pretrained(output_dir)
188
+ config.save_pretrained(output_dir)
189
+
190
+ meta = {
191
+ 'base_model': args.base_model,
192
+ 'variant': args.variant,
193
+ 'hidden_size': hidden_size,
194
+ 'num_heads': num_heads,
195
+ 'replaced_layers': {str(k): v for k, v in report.items()},
196
+ 'total_params': total_params,
197
+ }
198
+ with open(output_dir / 'surgery_meta.json', 'w') as f:
199
+ json.dump(meta, f, indent=2)
200
+
201
+ print(f"\nSaved to {output_dir}/ ({total_params:,} params)")
202
+
203
+
204
+ if __name__ == '__main__':
205
+ main()
surgery_meta.json ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "Alibaba-NLP/gte-modernbert-base",
3
+ "variant": "conservative",
4
+ "hidden_size": 768,
5
+ "num_heads": 12,
6
+ "replaced_layers": {
7
+ "1": {
8
+ "was_global": false,
9
+ "transferred": [
10
+ "Q->R",
11
+ "K->K",
12
+ "V->V",
13
+ "O->O"
14
+ ]
15
+ },
16
+ "2": {
17
+ "was_global": false,
18
+ "transferred": [
19
+ "Q->R",
20
+ "K->K",
21
+ "V->V",
22
+ "O->O"
23
+ ]
24
+ },
25
+ "4": {
26
+ "was_global": false,
27
+ "transferred": [
28
+ "Q->R",
29
+ "K->K",
30
+ "V->V",
31
+ "O->O"
32
+ ]
33
+ },
34
+ "5": {
35
+ "was_global": false,
36
+ "transferred": [
37
+ "Q->R",
38
+ "K->K",
39
+ "V->V",
40
+ "O->O"
41
+ ]
42
+ },
43
+ "7": {
44
+ "was_global": false,
45
+ "transferred": [
46
+ "Q->R",
47
+ "K->K",
48
+ "V->V",
49
+ "O->O"
50
+ ]
51
+ },
52
+ "8": {
53
+ "was_global": false,
54
+ "transferred": [
55
+ "Q->R",
56
+ "K->K",
57
+ "V->V",
58
+ "O->O"
59
+ ]
60
+ },
61
+ "10": {
62
+ "was_global": false,
63
+ "transferred": [
64
+ "Q->R",
65
+ "K->K",
66
+ "V->V",
67
+ "O->O"
68
+ ]
69
+ },
70
+ "11": {
71
+ "was_global": false,
72
+ "transferred": [
73
+ "Q->R",
74
+ "K->K",
75
+ "V->V",
76
+ "O->O"
77
+ ]
78
+ },
79
+ "13": {
80
+ "was_global": false,
81
+ "transferred": [
82
+ "Q->R",
83
+ "K->K",
84
+ "V->V",
85
+ "O->O"
86
+ ]
87
+ },
88
+ "14": {
89
+ "was_global": false,
90
+ "transferred": [
91
+ "Q->R",
92
+ "K->K",
93
+ "V->V",
94
+ "O->O"
95
+ ]
96
+ },
97
+ "16": {
98
+ "was_global": false,
99
+ "transferred": [
100
+ "Q->R",
101
+ "K->K",
102
+ "V->V",
103
+ "O->O"
104
+ ]
105
+ },
106
+ "17": {
107
+ "was_global": false,
108
+ "transferred": [
109
+ "Q->R",
110
+ "K->K",
111
+ "V->V",
112
+ "O->O"
113
+ ]
114
+ },
115
+ "19": {
116
+ "was_global": false,
117
+ "transferred": [
118
+ "Q->R",
119
+ "K->K",
120
+ "V->V",
121
+ "O->O"
122
+ ]
123
+ },
124
+ "20": {
125
+ "was_global": false,
126
+ "transferred": [
127
+ "Q->R",
128
+ "K->K",
129
+ "V->V",
130
+ "O->O"
131
+ ]
132
+ }
133
+ },
134
+ "total_params": 173872910
135
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "clean_up_tokenization_spaces": true,
4
+ "cls_token": "[CLS]",
5
+ "is_local": true,
6
+ "mask_token": "[MASK]",
7
+ "model_input_names": [
8
+ "input_ids",
9
+ "attention_mask"
10
+ ],
11
+ "model_max_length": 1000000000000000019884624838656,
12
+ "pad_token": "[PAD]",
13
+ "sep_token": "[SEP]",
14
+ "tokenizer_class": "TokenizersBackend",
15
+ "unk_token": "[UNK]"
16
+ }