Upload folder using huggingface_hub
Browse files- README.md +267 -0
- birwkv7.py +190 -0
- config.json +81 -0
- configuration_hare.py +45 -0
- model.pt +3 -0
- modeling_hare.py +98 -0
- streaming.py +202 -0
- surgery.py +205 -0
- surgery_meta.json +135 -0
- tokenizer.json +0 -0
- tokenizer_config.json +16 -0
README.md
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
tags:
|
| 5 |
+
- embeddings
|
| 6 |
+
- text-retrieval
|
| 7 |
+
- long-context
|
| 8 |
+
- rwkv
|
| 9 |
+
- modernbert
|
| 10 |
+
- streaming
|
| 11 |
+
- semantic-search
|
| 12 |
+
- retrieval
|
| 13 |
+
pipeline_tag: feature-extraction
|
| 14 |
+
library_name: transformers
|
| 15 |
+
base_model: Alibaba-NLP/gte-modernbert-base
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# HARE: Hybrid Attention-Recurrence Embeddings
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
TL;DR: Stateful embedding model that replaces sliding-window attention with RWKV recurrence, allowing for incremental encoding and streaming semantic search.
|
| 22 |
+
|
| 23 |
+
| | |
|
| 24 |
+
|---|---|
|
| 25 |
+
| **Parameters** | 173.9M |
|
| 26 |
+
| **Embedding dim** | 768 |
|
| 27 |
+
| **Base model** | [Alibaba-NLP/gte-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-modernbert-base) |
|
| 28 |
+
| **Architecture** | ModernBERT-base with 14/22 local attention layers replaced by bidirectional RWKV recurrence |
|
| 29 |
+
| **Language** | English |
|
| 30 |
+
|
| 31 |
+
Conventional embedding models are stateless: adding new content requires re-encoding from scratch because token representations depend on the entire sequence.
|
| 32 |
+
HARE replaces 14 local sliding-window attention layers in ModernBERT-base with bidirectional RWKV linear recurrence while retaining 8 global attention layers.
|
| 33 |
+
Each recurrent layer maintains a fixed-size state matrix that summarizes all prior tokens with O(1) per-token cost, making the encoder stateful thus it can save and resume from any position.
|
| 34 |
+
|
| 35 |
+
Essentially, the biggest advantage is being able to perform semantic search on large files way before they're 100% available - and across multiple streams simultaneously (for example parallel distributed files, concurrent transcripts, documents arriving from different sources on the same topic)
|
| 36 |
+
|
| 37 |
+
## Results
|
| 38 |
+
|
| 39 |
+
### LongEmbed (Needle/Passkey: nDCG@1; others: nDCG@10)
|
| 40 |
+
|
| 41 |
+
Chunk-level: 256-token chunks, mean-pooled, max-over-chunks scoring. Token-level: full-document encoding, per-token late interaction scoring.
|
| 42 |
+
|
| 43 |
+
| Task | Chunk-level | Token-level | GTE-ModernBERT-base |
|
| 44 |
+
|------|-------------|-------------|---------------------|
|
| 45 |
+
| Needle | 84.0 | **87.5** | 49.8 |
|
| 46 |
+
| Passkey | **96.3** | 52.5 | 47.0 |
|
| 47 |
+
| NarrativeQA | **54.2** | 53.6 | 46.6 |
|
| 48 |
+
| QMSum | 44.2 | **50.7** | 61.1 |
|
| 49 |
+
| WikimQA | 73.6 | **87.6** | 86.8 |
|
| 50 |
+
| SummScreenFD | 72.2 | **88.5** | 88.2 |
|
| 51 |
+
| **Average** | **70.7** | 70.1 | 63.2 |
|
| 52 |
+
| **Best-per-task** | | **77.5** | |
|
| 53 |
+
|
| 54 |
+
### LoCo (12 long-context retrieval tasks, nDCG@10)
|
| 55 |
+
|
| 56 |
+
| Task | Chunk-level | Token-level | GTE-ModernBERT-base |
|
| 57 |
+
|------|-------------|-------------|---------------------|
|
| 58 |
+
| summ_screen_fd | 71.9 | **88.4** | 93.8 |
|
| 59 |
+
| gov_report | 86.2 | **97.2** | 97.5 |
|
| 60 |
+
| qmsum | **69.6** | 69.4 | 63.1 |
|
| 61 |
+
| qasper_title | 74.9 | **92.2** | 88.9 |
|
| 62 |
+
| qasper_abstract | 88.4 | **96.4** | 98.1 |
|
| 63 |
+
| multifieldqa | **93.4** | 92.9 | 93.4 |
|
| 64 |
+
| 2wikimqa | 90.0 | **91.1** | 86.6 |
|
| 65 |
+
| passage_retrieval | 95.1 | **95.5** | 52.7 |
|
| 66 |
+
| legal_case_reports | 11.4 | **24.3** | 44.8 |
|
| 67 |
+
| courtlistener_HTML | 43.6 | **51.4** | 23.5 |
|
| 68 |
+
| courtlistener_Plain_Text | 38.1 | **50.8** | 24.8 |
|
| 69 |
+
| stackoverflow | **43.3** | 36.7 | 90.9 |
|
| 70 |
+
| **Average** | 67.2 | **73.9** | 71.5 |
|
| 71 |
+
|
| 72 |
+
Token-level HARE (73.9) surpasses both GTE-ModernBERT-base (71.5) and bge-m3 (71.7) on LoCo.
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
## Usage
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
import torch
|
| 79 |
+
import torch.nn.functional as F
|
| 80 |
+
from transformers import AutoModel, AutoTokenizer
|
| 81 |
+
|
| 82 |
+
model = AutoModel.from_pretrained("SixOpen/HARE", trust_remote_code=True)
|
| 83 |
+
tokenizer = AutoTokenizer.from_pretrained("SixOpen/HARE")
|
| 84 |
+
model = model.cuda().eval()
|
| 85 |
+
|
| 86 |
+
texts = ["Apple released a new iPhone model today", "The latest iPhone was announced by Apple"]
|
| 87 |
+
enc = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
| 88 |
+
enc = {k: v.to('cuda') for k, v in enc.items()}
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
hidden = model(**enc).last_hidden_state
|
| 91 |
+
mask = enc['attention_mask'].unsqueeze(-1).float()
|
| 92 |
+
embs = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
|
| 93 |
+
embs = F.normalize(embs, p=2, dim=-1)
|
| 94 |
+
|
| 95 |
+
similarity = (embs[0] @ embs[1]).item()
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### Multi-vector retrieval (long documents)
|
| 99 |
+
|
| 100 |
+
For documents longer than 512 tokens, split into 256-token chunks with 64-token overlap and score with MaxSim.
|
| 101 |
+
HARE can also carry recurrent state across chunks, conditioning each chunk on all prior context without re-encoding. See the streaming demos for stateful usage.
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
import torch
|
| 105 |
+
import torch.nn.functional as F
|
| 106 |
+
from transformers import AutoModel, AutoTokenizer
|
| 107 |
+
|
| 108 |
+
model = AutoModel.from_pretrained("SixOpen/HARE", trust_remote_code=True)
|
| 109 |
+
tokenizer = AutoTokenizer.from_pretrained("SixOpen/HARE")
|
| 110 |
+
model = model.cuda().eval()
|
| 111 |
+
|
| 112 |
+
query = "your query"
|
| 113 |
+
document = open("document.txt").read() # any text format
|
| 114 |
+
|
| 115 |
+
# encode query
|
| 116 |
+
q_enc = tokenizer(query, return_tensors='pt', truncation=True, max_length=512)
|
| 117 |
+
q_enc = {k: v.cuda() for k, v in q_enc.items()}
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
q_hidden = model(**q_enc).last_hidden_state
|
| 120 |
+
q_mask = q_enc['attention_mask'].unsqueeze(-1).float()
|
| 121 |
+
query_emb = F.normalize((q_hidden * q_mask).sum(1) / q_mask.sum(1).clamp(min=1e-9), dim=-1)
|
| 122 |
+
|
| 123 |
+
# chunk document (256 tokens, 64-token overlap)
|
| 124 |
+
doc_ids = tokenizer(document, return_tensors='pt', truncation=False)['input_ids'][0]
|
| 125 |
+
chunk_size, stride = 256, 192
|
| 126 |
+
chunk_embs = []
|
| 127 |
+
for start in range(0, len(doc_ids), stride):
|
| 128 |
+
ids = doc_ids[start:start + chunk_size].unsqueeze(0).cuda()
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
h = model(input_ids=ids, attention_mask=torch.ones_like(ids)).last_hidden_state
|
| 131 |
+
emb = F.normalize(h.mean(1), dim=-1)
|
| 132 |
+
chunk_embs.append(emb)
|
| 133 |
+
|
| 134 |
+
chunk_embs = torch.cat(chunk_embs, dim=0)
|
| 135 |
+
scores = (query_emb @ chunk_embs.T).squeeze(0)
|
| 136 |
+
best_chunk = scores.argmax().item()
|
| 137 |
+
print(f"Best chunk: {best_chunk}, score: {scores[best_chunk]:.4f}")
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### Stateful streaming (incremental encoding)
|
| 141 |
+
|
| 142 |
+
As mentioned prior unlike standard encoders, HARE can save and resume from any position. New text is encoded with full prior context without re-encoding anything before it.
|
| 143 |
+
|
| 144 |
+
```python
|
| 145 |
+
from streaming import SpanEncoder
|
| 146 |
+
|
| 147 |
+
enc = SpanEncoder(model, tokenizer, "cuda", chunk_size=256)
|
| 148 |
+
|
| 149 |
+
# Mock lecture transcript arriving in 3 streaming pieces
|
| 150 |
+
pieces = [
|
| 151 |
+
"Today we will cover the fundamentals of quantum computing. Classical computers "
|
| 152 |
+
"use bits that are either 0 or 1. Quantum computers use qubits which can exist "
|
| 153 |
+
"in superposition, meaning they can be both 0 and 1 simultaneously. ",
|
| 154 |
+
"The key advantage comes from entanglement. When two qubits are entangled, "
|
| 155 |
+
"measuring one instantly determines the state of the other regardless of distance. "
|
| 156 |
+
"This allows quantum computers to process certain problems exponentially faster. ",
|
| 157 |
+
"The most important quantum algorithm is Shor's algorithm which can factor large "
|
| 158 |
+
"numbers in polynomial time. This has major implications for cryptography since "
|
| 159 |
+
"RSA encryption relies on the difficulty of factoring large primes. ",
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
# Encode incrementally, only the new piece is processed each time
|
| 163 |
+
enc.encode_span(pieces[0], key="p0") # encode first piece
|
| 164 |
+
enc.extend_right(pieces[1], "p0", "p1") # extend with state carry
|
| 165 |
+
enc.extend_right(pieces[2], "p1", "p2") # extend again
|
| 166 |
+
|
| 167 |
+
# Search the incrementally built index
|
| 168 |
+
q_emb = enc.encode_query("why is Shor's algorithm important for cryptography")
|
| 169 |
+
chunk_embs = torch.cat(enc.span_data["p2"]["chunk_embs"], dim=0)
|
| 170 |
+
scores = (q_emb @ chunk_embs.T).squeeze(0)
|
| 171 |
+
best = scores.argmax().item()
|
| 172 |
+
print(f"Best chunk: {best}, score: {scores[best]:.4f}")
|
| 173 |
+
# → Best chunk: 2, score: 0.7814
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
### Token-level late interaction (offline, full-document)
|
| 177 |
+
|
| 178 |
+
For best quality on long documents, encode the full document in one pass and score at the token level, where query_tokens and doc_tokens are l2-normalized token embeddings:
|
| 179 |
+
|
| 180 |
+
```python
|
| 181 |
+
score = sum(max(q_tok @ d_tok for d_tok in doc_tokens) for q_tok in query_tokens)
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
## Architecture
|
| 185 |
+
|
| 186 |
+
HARE starts from ModernBERT-base (22 layers, 768-dim, 12 heads) and performs architectural surgery:
|
| 187 |
+
|
| 188 |
+
- Layers 1, 2, 4, 5, 7, 8, 10, 11, 13, 14, 16, 17, 19, 20 (14 local sliding-window attention layers) are replaced with BiRWKV-7 bidirectional recurrence
|
| 189 |
+
- Layers 0, 3, 6, 9, 12, 15, 18, 21 (8 global attention layers) are retained unchanged
|
| 190 |
+
- Weight mapping: Q->R, K->K, V->V, O->O (attention projections initialize recurrence projections)
|
| 191 |
+
- Recurrence-specific parameters (decay, gate, mixing coefficients) are randomly initialized and learned during training
|
| 192 |
+
|
| 193 |
+
Each BiRWKV-7 layer runs a forward (left-to-right) and backward (right-to-left) scan, averaged. The forward scan's state matrix (64x64 per head, 12 heads per layer) can be saved and resumed for incremental encoding.
|
| 194 |
+
|
| 195 |
+
## Training
|
| 196 |
+
|
| 197 |
+
Three-stage pipeline:
|
| 198 |
+
|
| 199 |
+
### Stage 1: Contrastive distillation
|
| 200 |
+
|
| 201 |
+
| | |
|
| 202 |
+
|---|---|
|
| 203 |
+
| Teacher | GTE-ModernBERT-base |
|
| 204 |
+
| Data | NLI (AllNLI) + MS-MARCO |
|
| 205 |
+
| Loss | (1 - alpha) * MRL-InfoNCE + alpha * cosine distillation |
|
| 206 |
+
| MRL dims | 64, 128, 256, 768 |
|
| 207 |
+
| Alpha | 0.5 |
|
| 208 |
+
| Epochs | 3 |
|
| 209 |
+
| Batch size | 32 |
|
| 210 |
+
| Learning rate | 2e-5 (cosine decay) |
|
| 211 |
+
| Max length | 512 |
|
| 212 |
+
| Optimizer | AdamW (weight_decay=0.01) |
|
| 213 |
+
|
| 214 |
+
### Stage 2: Long-context self-distillation
|
| 215 |
+
|
| 216 |
+
| | |
|
| 217 |
+
|---|---|
|
| 218 |
+
| Teacher | GTE-ModernBERT-base |
|
| 219 |
+
| Data | NLI + MS-MARCO (10K each, 20K total) |
|
| 220 |
+
| Loss | (1 - alpha) * MRL-InfoNCE + alpha * cosine distillation |
|
| 221 |
+
| Alpha | 0.3 |
|
| 222 |
+
| Epochs | 1 |
|
| 223 |
+
| Batch size | 8 |
|
| 224 |
+
| Learning rate | 5e-6 (cosine decay) |
|
| 225 |
+
| Max length | 2048 |
|
| 226 |
+
|
| 227 |
+
### Stage 3: Synthetic IR training
|
| 228 |
+
|
| 229 |
+
| | |
|
| 230 |
+
|---|---|
|
| 231 |
+
| Data | 40% NLI + 40% MS-MARCO + 20% synthetic information-location pairs |
|
| 232 |
+
| Loss | MRL-InfoNCE |
|
| 233 |
+
| Epochs | 2 |
|
| 234 |
+
| Batch size | 32 |
|
| 235 |
+
| Learning rate | 5e-6 (cosine decay) |
|
| 236 |
+
| Max length | 512 |
|
| 237 |
+
| Merge | 30% Stage 2 weights + 70% Stage 3 weights |
|
| 238 |
+
|
| 239 |
+
## Files
|
| 240 |
+
|
| 241 |
+
| File | Description |
|
| 242 |
+
|------|-------------|
|
| 243 |
+
| `model.pt` | Model weights (664MB) |
|
| 244 |
+
| `config.json` | ModernBERT model config |
|
| 245 |
+
| `surgery_meta.json` | Layer replacement mapping (which layers were replaced, weight transfer record) |
|
| 246 |
+
| `tokenizer.json` | Tokenizer |
|
| 247 |
+
| `tokenizer_config.json` | Tokenizer config |
|
| 248 |
+
| `surgery.py` | Standalone surgery CLI tool (inspect layers, perform surgery from scratch) |
|
| 249 |
+
| `birwkv7.py` | BiRWKV-7 recurrence layer (required for loading) |
|
| 250 |
+
| `streaming.py` | SpanEncoder for stateful incremental encoding |
|
| 251 |
+
|
| 252 |
+
## Intended uses
|
| 253 |
+
|
| 254 |
+
- Semantic search and retrieval over short or long documents
|
| 255 |
+
- Incremental indexing where text arrives sequentially and must be searchable before completion: live transcription, real-time meeting/dispatch/etc indexing, distributed (ie torrent) content search, incremental document editing
|
| 256 |
+
- Multi-vector retrieval with chunk-level or token-level scoring
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
## Citation
|
| 260 |
+
|
| 261 |
+
```bibtex
|
| 262 |
+
@article{osman2026hare,
|
| 263 |
+
title={Stateful Embeddings via Hybrid Attention-Recurrence},
|
| 264 |
+
author={Osman A. Ender},
|
| 265 |
+
year={2026}
|
| 266 |
+
}
|
| 267 |
+
```
|
birwkv7.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
_FLA_AVAILABLE = False
|
| 6 |
+
try:
|
| 7 |
+
import torch.distributed.tensor as _tdt
|
| 8 |
+
if not hasattr(_tdt, 'Replicate'):
|
| 9 |
+
try:
|
| 10 |
+
from torch.distributed._tensor import Replicate as _R, Shard as _S
|
| 11 |
+
_tdt.Replicate = _R; _tdt.Shard = _S
|
| 12 |
+
except ImportError:
|
| 13 |
+
pass
|
| 14 |
+
if not hasattr(_tdt, 'Placement'):
|
| 15 |
+
try:
|
| 16 |
+
from torch.distributed._tensor.placement_types import Placement as _P
|
| 17 |
+
_tdt.Placement = _P
|
| 18 |
+
except ImportError:
|
| 19 |
+
pass
|
| 20 |
+
if not hasattr(_tdt, 'distribute_module'):
|
| 21 |
+
_tdt.distribute_module = lambda *a, **kw: None
|
| 22 |
+
from fla.ops.rwkv7 import chunk_rwkv7 as _fla_chunk_rwkv7
|
| 23 |
+
if torch.cuda.is_available():
|
| 24 |
+
_test_r = torch.randn(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16, requires_grad=True)
|
| 25 |
+
_test_w = -torch.ones(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16)
|
| 26 |
+
_test_o, _ = _fla_chunk_rwkv7(_test_r, _test_w, _test_r, _test_r, _test_r, _test_r,
|
| 27 |
+
head_first=False)
|
| 28 |
+
_test_o.sum().backward()
|
| 29 |
+
if not _test_r.grad.isnan().any():
|
| 30 |
+
_FLA_AVAILABLE = True
|
| 31 |
+
del _test_r, _test_w, _test_o
|
| 32 |
+
torch.cuda.empty_cache()
|
| 33 |
+
else:
|
| 34 |
+
_FLA_AVAILABLE = True
|
| 35 |
+
except Exception:
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BiRWKV7Layer(nn.Module):
|
| 40 |
+
|
| 41 |
+
def __init__(self, hidden_size, num_heads):
|
| 42 |
+
super().__init__()
|
| 43 |
+
assert hidden_size % num_heads == 0
|
| 44 |
+
self.hidden_size = hidden_size
|
| 45 |
+
self.num_heads = num_heads
|
| 46 |
+
self.head_size = hidden_size // num_heads
|
| 47 |
+
|
| 48 |
+
self.mu_r = nn.Parameter(torch.zeros(hidden_size))
|
| 49 |
+
self.mu_w = nn.Parameter(torch.zeros(hidden_size))
|
| 50 |
+
self.mu_k = nn.Parameter(torch.zeros(hidden_size))
|
| 51 |
+
self.mu_v = nn.Parameter(torch.zeros(hidden_size))
|
| 52 |
+
self.mu_a = nn.Parameter(torch.zeros(hidden_size))
|
| 53 |
+
self.mu_g = nn.Parameter(torch.zeros(hidden_size))
|
| 54 |
+
|
| 55 |
+
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 56 |
+
self.W_k = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 57 |
+
self.W_v = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 58 |
+
self.W_w = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 59 |
+
self.W_a = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 60 |
+
self.W_g = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 61 |
+
|
| 62 |
+
self.sab_gate = nn.Parameter(torch.tensor(-5.0))
|
| 63 |
+
|
| 64 |
+
self.group_norm = nn.GroupNorm(num_heads, hidden_size)
|
| 65 |
+
self.W_o = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 66 |
+
|
| 67 |
+
nn.init.normal_(self.W_w.weight, std=0.01)
|
| 68 |
+
nn.init.normal_(self.W_a.weight, std=0.01)
|
| 69 |
+
nn.init.normal_(self.W_g.weight, std=0.02)
|
| 70 |
+
|
| 71 |
+
def _token_shift(self, x):
|
| 72 |
+
x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
|
| 73 |
+
|
| 74 |
+
def mix(mu):
|
| 75 |
+
return x + (x_prev - x) * torch.sigmoid(mu)
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
'r': mix(self.mu_r), 'w': mix(self.mu_w),
|
| 79 |
+
'k': mix(self.mu_k), 'v': mix(self.mu_v),
|
| 80 |
+
'a': mix(self.mu_a), 'g': mix(self.mu_g),
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale):
|
| 84 |
+
B, T, H, D = r.shape
|
| 85 |
+
orig_dtype = r.dtype
|
| 86 |
+
r, w, k, v, a = [x.bfloat16() for x in (r, w, k, v, a)]
|
| 87 |
+
k_scaled = k * (D ** -0.5)
|
| 88 |
+
w_log = -0.6065306597633104 * torch.sigmoid(w)
|
| 89 |
+
a_sig = torch.sigmoid(a)
|
| 90 |
+
a_fla = -k_scaled
|
| 91 |
+
b_fla = sab_scale * k_scaled * a_sig
|
| 92 |
+
o, _ = _fla_chunk_rwkv7(r, w_log, k_scaled, v, a_fla, b_fla, scale=1.0)
|
| 93 |
+
return o.to(orig_dtype)
|
| 94 |
+
|
| 95 |
+
def _wkv7_scan_python(self, r, w, k, v, a, sab_scale):
|
| 96 |
+
B, T, H, D = r.shape
|
| 97 |
+
orig_dtype = r.dtype
|
| 98 |
+
|
| 99 |
+
r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
|
| 100 |
+
k = k * (D ** -0.5)
|
| 101 |
+
decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
|
| 102 |
+
a = torch.sigmoid(a)
|
| 103 |
+
|
| 104 |
+
state = torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32)
|
| 105 |
+
outputs = []
|
| 106 |
+
|
| 107 |
+
for t in range(T):
|
| 108 |
+
if t > 0 and t % 16 == 0:
|
| 109 |
+
state = state.detach()
|
| 110 |
+
|
| 111 |
+
kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t]
|
| 112 |
+
|
| 113 |
+
sa = torch.einsum('bhij,bhj->bhi', state, -kt)
|
| 114 |
+
sab = torch.einsum('bhi,bhj->bhij', sa, kt * at)
|
| 115 |
+
state = state * dt.unsqueeze(-2) + sab_scale * sab + torch.einsum('bhi,bhj->bhij', vt, kt)
|
| 116 |
+
state = state.clamp(-10.0, 10.0)
|
| 117 |
+
|
| 118 |
+
outputs.append(torch.einsum('bhij,bhj->bhi', state, rt))
|
| 119 |
+
|
| 120 |
+
return torch.stack(outputs, dim=1).to(orig_dtype)
|
| 121 |
+
|
| 122 |
+
def _wkv7_scan(self, r, w, k, v, a, sab_scale):
|
| 123 |
+
if _FLA_AVAILABLE and r.is_cuda:
|
| 124 |
+
return self._wkv7_scan_fla(r, w, k, v, a, sab_scale)
|
| 125 |
+
return self._wkv7_scan_python(r, w, k, v, a, sab_scale)
|
| 126 |
+
|
| 127 |
+
def forward(self, x, attention_mask=None, **kwargs):
|
| 128 |
+
B, T, C = x.shape
|
| 129 |
+
H, D = self.num_heads, self.head_size
|
| 130 |
+
|
| 131 |
+
mixed = self._token_shift(x)
|
| 132 |
+
r = self.W_r(mixed['r']).view(B, T, H, D)
|
| 133 |
+
w = self.W_w(mixed['w']).view(B, T, H, D)
|
| 134 |
+
k = self.W_k(mixed['k']).view(B, T, H, D)
|
| 135 |
+
v = self.W_v(mixed['v']).view(B, T, H, D)
|
| 136 |
+
a = self.W_a(mixed['a']).view(B, T, H, D)
|
| 137 |
+
g = torch.sigmoid(self.W_g(mixed['g']))
|
| 138 |
+
|
| 139 |
+
sab_scale = torch.sigmoid(self.sab_gate)
|
| 140 |
+
|
| 141 |
+
out_fwd = self._wkv7_scan(r, w, k, v, a, sab_scale)
|
| 142 |
+
out_bwd = self._wkv7_scan(
|
| 143 |
+
r.flip(1), w.flip(1), k.flip(1), v.flip(1), a.flip(1), sab_scale
|
| 144 |
+
).flip(1)
|
| 145 |
+
|
| 146 |
+
out = (out_fwd + out_bwd).reshape(B, T, C) * 0.5
|
| 147 |
+
out = self.group_norm(out.transpose(1, 2)).transpose(1, 2)
|
| 148 |
+
out = self.W_o(out * g)
|
| 149 |
+
|
| 150 |
+
return out, None
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def init_from_attention(birwkv, attn_module):
|
| 154 |
+
q_proj = k_proj = v_proj = o_proj = None
|
| 155 |
+
|
| 156 |
+
if hasattr(attn_module, 'Wqkv'):
|
| 157 |
+
fused = attn_module.Wqkv.weight.data
|
| 158 |
+
C = fused.shape[1]
|
| 159 |
+
q_proj, k_proj, v_proj = fused[:C], fused[C:2*C], fused[2*C:]
|
| 160 |
+
else:
|
| 161 |
+
for name in ['q_proj', 'query', 'W_q', 'wq']:
|
| 162 |
+
if hasattr(attn_module, name):
|
| 163 |
+
q_proj = getattr(attn_module, name).weight.data
|
| 164 |
+
break
|
| 165 |
+
for name in ['k_proj', 'key', 'W_k', 'wk']:
|
| 166 |
+
if hasattr(attn_module, name):
|
| 167 |
+
k_proj = getattr(attn_module, name).weight.data
|
| 168 |
+
break
|
| 169 |
+
for name in ['v_proj', 'value', 'W_v', 'wv']:
|
| 170 |
+
if hasattr(attn_module, name):
|
| 171 |
+
v_proj = getattr(attn_module, name).weight.data
|
| 172 |
+
break
|
| 173 |
+
|
| 174 |
+
for name in ['Wo', 'out_proj', 'o_proj', 'dense', 'W_o', 'wo']:
|
| 175 |
+
if hasattr(attn_module, name):
|
| 176 |
+
o_proj = getattr(attn_module, name).weight.data
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
transferred = []
|
| 180 |
+
for src, dst, label in [
|
| 181 |
+
(q_proj, birwkv.W_r, 'Q->R'),
|
| 182 |
+
(k_proj, birwkv.W_k, 'K->K'),
|
| 183 |
+
(v_proj, birwkv.W_v, 'V->V'),
|
| 184 |
+
(o_proj, birwkv.W_o, 'O->O'),
|
| 185 |
+
]:
|
| 186 |
+
if src is not None:
|
| 187 |
+
dst.weight.data.copy_(src)
|
| 188 |
+
transferred.append(label)
|
| 189 |
+
|
| 190 |
+
return transferred
|
config.json
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"HareModel"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_hare.HareConfig",
|
| 7 |
+
"AutoModel": "modeling_hare.HareModel"
|
| 8 |
+
},
|
| 9 |
+
"attention_bias": false,
|
| 10 |
+
"attention_dropout": 0.0,
|
| 11 |
+
"bos_token_id": 50281,
|
| 12 |
+
"classifier_activation": "gelu",
|
| 13 |
+
"classifier_bias": false,
|
| 14 |
+
"classifier_dropout": 0.0,
|
| 15 |
+
"classifier_pooling": "mean",
|
| 16 |
+
"cls_token_id": 50281,
|
| 17 |
+
"decoder_bias": true,
|
| 18 |
+
"deterministic_flash_attn": false,
|
| 19 |
+
"dtype": "float16",
|
| 20 |
+
"embedding_dropout": 0.0,
|
| 21 |
+
"eos_token_id": 50282,
|
| 22 |
+
"global_attn_every_n_layers": 3,
|
| 23 |
+
"gradient_checkpointing": false,
|
| 24 |
+
"hidden_activation": "gelu",
|
| 25 |
+
"hidden_size": 768,
|
| 26 |
+
"initializer_cutoff_factor": 2.0,
|
| 27 |
+
"initializer_range": 0.02,
|
| 28 |
+
"intermediate_size": 1152,
|
| 29 |
+
"layer_norm_eps": 1e-05,
|
| 30 |
+
"layer_types": [
|
| 31 |
+
"full_attention",
|
| 32 |
+
"sliding_attention",
|
| 33 |
+
"sliding_attention",
|
| 34 |
+
"full_attention",
|
| 35 |
+
"sliding_attention",
|
| 36 |
+
"sliding_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"sliding_attention",
|
| 39 |
+
"sliding_attention",
|
| 40 |
+
"full_attention",
|
| 41 |
+
"sliding_attention",
|
| 42 |
+
"sliding_attention",
|
| 43 |
+
"full_attention",
|
| 44 |
+
"sliding_attention",
|
| 45 |
+
"sliding_attention",
|
| 46 |
+
"full_attention",
|
| 47 |
+
"sliding_attention",
|
| 48 |
+
"sliding_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"sliding_attention",
|
| 51 |
+
"sliding_attention",
|
| 52 |
+
"full_attention"
|
| 53 |
+
],
|
| 54 |
+
"local_attention": 128,
|
| 55 |
+
"max_position_embeddings": 8192,
|
| 56 |
+
"mlp_bias": false,
|
| 57 |
+
"mlp_dropout": 0.0,
|
| 58 |
+
"model_type": "hare",
|
| 59 |
+
"norm_bias": false,
|
| 60 |
+
"norm_eps": 1e-05,
|
| 61 |
+
"num_attention_heads": 12,
|
| 62 |
+
"num_hidden_layers": 22,
|
| 63 |
+
"pad_token_id": 50283,
|
| 64 |
+
"position_embedding_type": "absolute",
|
| 65 |
+
"rope_parameters": {
|
| 66 |
+
"full_attention": {
|
| 67 |
+
"rope_theta": 160000.0,
|
| 68 |
+
"rope_type": "default"
|
| 69 |
+
},
|
| 70 |
+
"sliding_attention": {
|
| 71 |
+
"rope_theta": 10000.0,
|
| 72 |
+
"rope_type": "default"
|
| 73 |
+
}
|
| 74 |
+
},
|
| 75 |
+
"sep_token_id": 50282,
|
| 76 |
+
"sparse_pred_ignore_index": -100,
|
| 77 |
+
"sparse_prediction": false,
|
| 78 |
+
"tie_word_embeddings": true,
|
| 79 |
+
"transformers_version": "5.2.0",
|
| 80 |
+
"vocab_size": 50368
|
| 81 |
+
}
|
configuration_hare.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class HareConfig(PretrainedConfig):
|
| 5 |
+
model_type = "hare"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
hidden_size=768,
|
| 10 |
+
num_attention_heads=12,
|
| 11 |
+
num_hidden_layers=22,
|
| 12 |
+
intermediate_size=1152,
|
| 13 |
+
hidden_activation="gelu",
|
| 14 |
+
max_position_embeddings=8192,
|
| 15 |
+
vocab_size=50368,
|
| 16 |
+
pad_token_id=50283,
|
| 17 |
+
bos_token_id=50281,
|
| 18 |
+
eos_token_id=50282,
|
| 19 |
+
cls_token_id=50281,
|
| 20 |
+
sep_token_id=50282,
|
| 21 |
+
global_attn_every_n_layers=3,
|
| 22 |
+
local_attention=128,
|
| 23 |
+
replaced_layers=None,
|
| 24 |
+
surgery_variant="conservative",
|
| 25 |
+
**kwargs,
|
| 26 |
+
):
|
| 27 |
+
super().__init__(
|
| 28 |
+
pad_token_id=pad_token_id,
|
| 29 |
+
bos_token_id=bos_token_id,
|
| 30 |
+
eos_token_id=eos_token_id,
|
| 31 |
+
**kwargs,
|
| 32 |
+
)
|
| 33 |
+
self.hidden_size = hidden_size
|
| 34 |
+
self.num_attention_heads = num_attention_heads
|
| 35 |
+
self.num_hidden_layers = num_hidden_layers
|
| 36 |
+
self.intermediate_size = intermediate_size
|
| 37 |
+
self.hidden_activation = hidden_activation
|
| 38 |
+
self.max_position_embeddings = max_position_embeddings
|
| 39 |
+
self.vocab_size = vocab_size
|
| 40 |
+
self.cls_token_id = cls_token_id
|
| 41 |
+
self.sep_token_id = sep_token_id
|
| 42 |
+
self.global_attn_every_n_layers = global_attn_every_n_layers
|
| 43 |
+
self.local_attention = local_attention
|
| 44 |
+
self.replaced_layers = replaced_layers
|
| 45 |
+
self.surgery_variant = surgery_variant
|
model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42a1d92de872ce85ff2bb1e189f8ac41fd3062e006827b15310484641e2b9157
|
| 3 |
+
size 695588290
|
modeling_hare.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoModel, AutoConfig, PreTrainedModel
|
| 6 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 7 |
+
|
| 8 |
+
from .configuration_hare import HareConfig
|
| 9 |
+
from .birwkv7 import BiRWKV7Layer, init_from_attention
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _find_encoder(model):
|
| 13 |
+
for attr in ['encoder', 'model']:
|
| 14 |
+
if hasattr(model, attr):
|
| 15 |
+
candidate = getattr(model, attr)
|
| 16 |
+
if hasattr(candidate, 'layers'):
|
| 17 |
+
return candidate
|
| 18 |
+
if hasattr(model, 'layers'):
|
| 19 |
+
return model
|
| 20 |
+
raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _perform_surgery(model, replaced_layers, hidden_size, num_heads):
|
| 24 |
+
encoder = _find_encoder(model)
|
| 25 |
+
for layer_idx_str, info in replaced_layers.items():
|
| 26 |
+
layer_idx = int(layer_idx_str)
|
| 27 |
+
layer = encoder.layers[layer_idx]
|
| 28 |
+
attn = None
|
| 29 |
+
attn_name = None
|
| 30 |
+
for name in ['attn', 'attention', 'self_attn', 'self_attention']:
|
| 31 |
+
if hasattr(layer, name):
|
| 32 |
+
attn = getattr(layer, name)
|
| 33 |
+
attn_name = name
|
| 34 |
+
break
|
| 35 |
+
if attn is None:
|
| 36 |
+
continue
|
| 37 |
+
birwkv = BiRWKV7Layer(hidden_size, num_heads)
|
| 38 |
+
device = next(attn.parameters()).device
|
| 39 |
+
dtype = next(attn.parameters()).dtype
|
| 40 |
+
birwkv = birwkv.to(device=device, dtype=dtype)
|
| 41 |
+
setattr(layer, attn_name, birwkv)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class HareModel(PreTrainedModel):
|
| 45 |
+
config_class = HareConfig
|
| 46 |
+
|
| 47 |
+
def __init__(self, config):
|
| 48 |
+
super().__init__(config)
|
| 49 |
+
base_config = AutoConfig.from_pretrained(
|
| 50 |
+
"answerdotai/ModernBERT-base",
|
| 51 |
+
hidden_size=config.hidden_size,
|
| 52 |
+
num_attention_heads=config.num_attention_heads,
|
| 53 |
+
num_hidden_layers=config.num_hidden_layers,
|
| 54 |
+
intermediate_size=config.intermediate_size,
|
| 55 |
+
vocab_size=config.vocab_size,
|
| 56 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 57 |
+
)
|
| 58 |
+
self.inner_model = AutoModel.from_config(base_config)
|
| 59 |
+
|
| 60 |
+
if config.replaced_layers:
|
| 61 |
+
_perform_surgery(
|
| 62 |
+
self.inner_model,
|
| 63 |
+
config.replaced_layers,
|
| 64 |
+
config.hidden_size,
|
| 65 |
+
config.num_attention_heads,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
| 69 |
+
outputs = self.inner_model(
|
| 70 |
+
input_ids=input_ids,
|
| 71 |
+
attention_mask=attention_mask,
|
| 72 |
+
**kwargs,
|
| 73 |
+
)
|
| 74 |
+
return outputs
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 78 |
+
model_dir = Path(pretrained_model_name_or_path)
|
| 79 |
+
surgery_meta_path = model_dir / "surgery_meta.json"
|
| 80 |
+
|
| 81 |
+
if surgery_meta_path.exists():
|
| 82 |
+
with open(surgery_meta_path) as f:
|
| 83 |
+
meta = json.load(f)
|
| 84 |
+
|
| 85 |
+
config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
|
| 86 |
+
config.replaced_layers = meta.get("replaced_layers")
|
| 87 |
+
config.surgery_variant = meta.get("variant", "conservative")
|
| 88 |
+
|
| 89 |
+
model = cls(config)
|
| 90 |
+
|
| 91 |
+
weights_path = model_dir / "model.pt"
|
| 92 |
+
if weights_path.exists():
|
| 93 |
+
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
|
| 94 |
+
model.inner_model.load_state_dict(state_dict)
|
| 95 |
+
|
| 96 |
+
return model.float().eval()
|
| 97 |
+
|
| 98 |
+
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
streaming.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from birwkv7 import BiRWKV7Layer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def wkv7_forward_scan(r, w, k, v, a, sab_scale, init_state=None):
|
| 8 |
+
B, T, H, D = r.shape
|
| 9 |
+
r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
|
| 10 |
+
k = k * (D ** -0.5)
|
| 11 |
+
decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
|
| 12 |
+
a = torch.sigmoid(a)
|
| 13 |
+
sab_s = float(sab_scale)
|
| 14 |
+
state = init_state.float().clone() if init_state is not None else \
|
| 15 |
+
torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32)
|
| 16 |
+
outputs = []
|
| 17 |
+
for t in range(T):
|
| 18 |
+
kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t]
|
| 19 |
+
sa = torch.einsum('bhij,bhj->bhi', state, -kt)
|
| 20 |
+
sab = torch.einsum('bhi,bhj->bhij', sa, kt * at)
|
| 21 |
+
state = state * dt.unsqueeze(-2) + sab_s * sab + \
|
| 22 |
+
torch.einsum('bhi,bhj->bhij', vt, kt)
|
| 23 |
+
state = state.clamp(-10.0, 10.0)
|
| 24 |
+
outputs.append(torch.einsum('bhij,bhj->bhi', state, rt))
|
| 25 |
+
return torch.stack(outputs, dim=1), state.detach()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SpanEncoder:
|
| 29 |
+
|
| 30 |
+
def __init__(self, model, tokenizer, device, chunk_size=512):
|
| 31 |
+
self.model = model
|
| 32 |
+
self.tokenizer = tokenizer
|
| 33 |
+
self.device = device
|
| 34 |
+
self.chunk_size = chunk_size
|
| 35 |
+
|
| 36 |
+
self.birwkv_layers = []
|
| 37 |
+
self.birwkv_ids = {}
|
| 38 |
+
for m in model.modules():
|
| 39 |
+
if isinstance(m, BiRWKV7Layer):
|
| 40 |
+
self.birwkv_ids[id(m)] = len(self.birwkv_layers)
|
| 41 |
+
self.birwkv_layers.append(m)
|
| 42 |
+
|
| 43 |
+
self._originals = {}
|
| 44 |
+
self._hooked = False
|
| 45 |
+
self._active_states = [None] * len(self.birwkv_layers)
|
| 46 |
+
self.span_data = {}
|
| 47 |
+
|
| 48 |
+
def _hook(self):
|
| 49 |
+
if self._hooked:
|
| 50 |
+
return
|
| 51 |
+
for layer in self.birwkv_layers:
|
| 52 |
+
self._originals[id(layer)] = layer.forward
|
| 53 |
+
layer.forward = self._make_fwd(layer)
|
| 54 |
+
self._hooked = True
|
| 55 |
+
|
| 56 |
+
def _unhook(self):
|
| 57 |
+
if not self._hooked:
|
| 58 |
+
return
|
| 59 |
+
for layer in self.birwkv_layers:
|
| 60 |
+
layer.forward = self._originals[id(layer)]
|
| 61 |
+
self._originals.clear()
|
| 62 |
+
self._hooked = False
|
| 63 |
+
|
| 64 |
+
def _make_fwd(self, layer):
|
| 65 |
+
enc = self
|
| 66 |
+
idx = self.birwkv_ids[id(layer)]
|
| 67 |
+
|
| 68 |
+
def fwd(x, attention_mask=None, **kwargs):
|
| 69 |
+
B, T, C_ = x.shape
|
| 70 |
+
H, D = layer.num_heads, layer.head_size
|
| 71 |
+
prev = enc._active_states[idx]
|
| 72 |
+
if prev is not None:
|
| 73 |
+
x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1)
|
| 74 |
+
else:
|
| 75 |
+
x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
|
| 76 |
+
|
| 77 |
+
def mix(mu):
|
| 78 |
+
return x + (x_prev - x) * torch.sigmoid(mu)
|
| 79 |
+
|
| 80 |
+
r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D)
|
| 81 |
+
w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D)
|
| 82 |
+
k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D)
|
| 83 |
+
v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D)
|
| 84 |
+
a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D)
|
| 85 |
+
g = torch.sigmoid(layer.W_g(mix(layer.mu_g)))
|
| 86 |
+
sab_scale = torch.sigmoid(layer.sab_gate)
|
| 87 |
+
init_st = prev['wkv_state'] if prev else None
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
from birwkv7_triton import wkv7_scan_triton
|
| 91 |
+
r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float()
|
| 92 |
+
a_f = torch.sigmoid(a.float())
|
| 93 |
+
decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float()))
|
| 94 |
+
out_fwd, wkv_state = wkv7_scan_triton(
|
| 95 |
+
r_f, decay, k_f, v_f, a_f, sab_scale,
|
| 96 |
+
return_state=True, init_state=init_st)
|
| 97 |
+
out_bwd = wkv7_scan_triton(
|
| 98 |
+
r_f.flip(1), decay.flip(1), k_f.flip(1),
|
| 99 |
+
v_f.flip(1), a_f.flip(1), sab_scale,
|
| 100 |
+
return_state=False).flip(1)
|
| 101 |
+
except (ImportError, Exception):
|
| 102 |
+
out_fwd, wkv_state = wkv7_forward_scan(
|
| 103 |
+
r, w, k, v, a, sab_scale, init_st)
|
| 104 |
+
out_bwd = wkv7_forward_scan(
|
| 105 |
+
r.flip(1), w.flip(1), k.flip(1),
|
| 106 |
+
v.flip(1), a.flip(1), sab_scale, None)[0].flip(1)
|
| 107 |
+
enc._active_states[idx] = {
|
| 108 |
+
'wkv_state': wkv_state,
|
| 109 |
+
'last_x': x[:, -1:].detach().clone(),
|
| 110 |
+
}
|
| 111 |
+
out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_)
|
| 112 |
+
out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2)
|
| 113 |
+
out = layer.W_o(out * g)
|
| 114 |
+
return out, None
|
| 115 |
+
return fwd
|
| 116 |
+
|
| 117 |
+
@torch.no_grad()
|
| 118 |
+
def _forward_encode_raw(self, text, init_states=None, max_length=8192):
|
| 119 |
+
self._hook()
|
| 120 |
+
if init_states is not None:
|
| 121 |
+
self._active_states = [
|
| 122 |
+
{k: v.clone() for k, v in s.items()} if s else None
|
| 123 |
+
for s in init_states
|
| 124 |
+
]
|
| 125 |
+
else:
|
| 126 |
+
self._active_states = [None] * len(self.birwkv_layers)
|
| 127 |
+
|
| 128 |
+
enc = self.tokenizer(text, return_tensors='pt', truncation=True,
|
| 129 |
+
max_length=max_length)
|
| 130 |
+
ids = enc['input_ids'].to(self.device)
|
| 131 |
+
mask = enc['attention_mask'].to(self.device)
|
| 132 |
+
|
| 133 |
+
h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 134 |
+
content = h[0, 1:-1, :].cpu()
|
| 135 |
+
n_content = content.shape[0]
|
| 136 |
+
|
| 137 |
+
final_states = [
|
| 138 |
+
{k: v.clone() for k, v in s.items()} if s else None
|
| 139 |
+
for s in self._active_states
|
| 140 |
+
]
|
| 141 |
+
self._unhook()
|
| 142 |
+
return content, n_content, final_states
|
| 143 |
+
|
| 144 |
+
def _chunk_hidden(self, content, return_residual=False):
|
| 145 |
+
T = content.shape[0]
|
| 146 |
+
chunks = []
|
| 147 |
+
last_end = 0
|
| 148 |
+
for start in range(0, T, self.chunk_size):
|
| 149 |
+
end = min(start + self.chunk_size, T)
|
| 150 |
+
if end - start < 32:
|
| 151 |
+
break
|
| 152 |
+
emb = F.normalize(content[start:end].mean(0, keepdim=True),
|
| 153 |
+
p=2, dim=-1)
|
| 154 |
+
chunks.append(emb)
|
| 155 |
+
last_end = end
|
| 156 |
+
if not chunks and T > 0:
|
| 157 |
+
chunks.append(F.normalize(content.mean(0, keepdim=True),
|
| 158 |
+
p=2, dim=-1))
|
| 159 |
+
last_end = T
|
| 160 |
+
if return_residual:
|
| 161 |
+
residual = content[last_end:] if last_end < T else None
|
| 162 |
+
return chunks, residual
|
| 163 |
+
return chunks
|
| 164 |
+
|
| 165 |
+
@torch.no_grad()
|
| 166 |
+
def encode_query(self, query):
|
| 167 |
+
assert not self._hooked
|
| 168 |
+
enc = self.tokenizer(query, return_tensors='pt', truncation=True,
|
| 169 |
+
max_length=512)
|
| 170 |
+
ids = enc['input_ids'].to(self.device)
|
| 171 |
+
mask = enc['attention_mask'].to(self.device)
|
| 172 |
+
h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 173 |
+
m = mask.unsqueeze(-1).float()
|
| 174 |
+
emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9)
|
| 175 |
+
return F.normalize(emb, p=2, dim=-1).cpu()
|
| 176 |
+
|
| 177 |
+
def encode_span(self, text, key):
|
| 178 |
+
content, n_tok, states = self._forward_encode_raw(text)
|
| 179 |
+
chunks, residual = self._chunk_hidden(content, return_residual=True)
|
| 180 |
+
self.span_data[key] = {
|
| 181 |
+
'layer_states': states,
|
| 182 |
+
'chunk_embs': chunks,
|
| 183 |
+
'n_tokens': n_tok,
|
| 184 |
+
'residual_hidden': residual,
|
| 185 |
+
}
|
| 186 |
+
return n_tok
|
| 187 |
+
|
| 188 |
+
def extend_right(self, piece_text, old_key, new_key):
|
| 189 |
+
old = self.span_data.pop(old_key)
|
| 190 |
+
content, n_new, states = self._forward_encode_raw(
|
| 191 |
+
piece_text, init_states=old['layer_states'])
|
| 192 |
+
if old.get('residual_hidden') is not None:
|
| 193 |
+
content = torch.cat([old['residual_hidden'], content], dim=0)
|
| 194 |
+
new_chunks, residual = self._chunk_hidden(
|
| 195 |
+
content, return_residual=True)
|
| 196 |
+
self.span_data[new_key] = {
|
| 197 |
+
'layer_states': states,
|
| 198 |
+
'chunk_embs': old['chunk_embs'] + new_chunks,
|
| 199 |
+
'n_tokens': old['n_tokens'] + n_new,
|
| 200 |
+
'residual_hidden': residual,
|
| 201 |
+
}
|
| 202 |
+
return n_new
|
surgery.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
| 8 |
+
|
| 9 |
+
from birwkv7 import BiRWKV7Layer, init_from_attention
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _find_encoder(model):
|
| 13 |
+
for attr in ['encoder', 'model']:
|
| 14 |
+
if hasattr(model, attr):
|
| 15 |
+
candidate = getattr(model, attr)
|
| 16 |
+
if hasattr(candidate, 'layers'):
|
| 17 |
+
return candidate
|
| 18 |
+
if hasattr(model, 'layers'):
|
| 19 |
+
return model
|
| 20 |
+
raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def find_attention_layers(model):
|
| 24 |
+
encoder = _find_encoder(model)
|
| 25 |
+
layers = []
|
| 26 |
+
|
| 27 |
+
for i, layer in enumerate(encoder.layers):
|
| 28 |
+
attn = None
|
| 29 |
+
attn_path = None
|
| 30 |
+
for name in ['attn', 'attention', 'self_attn', 'self_attention']:
|
| 31 |
+
if hasattr(layer, name):
|
| 32 |
+
attn = getattr(layer, name)
|
| 33 |
+
attn_path = f"layers.{i}.{name}"
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
if attn is None:
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
is_global = False
|
| 40 |
+
if hasattr(attn, 'local_attention'):
|
| 41 |
+
is_global = not attn.local_attention
|
| 42 |
+
elif hasattr(attn, 'is_global_attention'):
|
| 43 |
+
is_global = attn.is_global_attention
|
| 44 |
+
elif hasattr(attn, 'use_sliding_window'):
|
| 45 |
+
is_global = not attn.use_sliding_window
|
| 46 |
+
elif hasattr(attn, 'sliding_window'):
|
| 47 |
+
is_global = attn.sliding_window is None
|
| 48 |
+
else:
|
| 49 |
+
is_global = (i % 3 == 2)
|
| 50 |
+
|
| 51 |
+
layers.append((i, attn_path, attn, is_global))
|
| 52 |
+
|
| 53 |
+
return layers
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def perform_surgery(model, variant, hidden_size, num_heads, replaced_layers=None):
|
| 57 |
+
layers = find_attention_layers(model)
|
| 58 |
+
global_indices = [idx for idx, _, _, g in layers if g]
|
| 59 |
+
local_indices = [idx for idx, _, _, g in layers if not g]
|
| 60 |
+
|
| 61 |
+
print(f"\nFound {len(layers)} attention layers:")
|
| 62 |
+
print(f" Global: {global_indices}")
|
| 63 |
+
print(f" Local: {local_indices}")
|
| 64 |
+
|
| 65 |
+
if replaced_layers is not None:
|
| 66 |
+
replace_indices = {int(k) for k in replaced_layers.keys()}
|
| 67 |
+
elif variant == 'conservative':
|
| 68 |
+
replace_indices = set(local_indices)
|
| 69 |
+
elif variant == 'aggressive':
|
| 70 |
+
keep = set()
|
| 71 |
+
if global_indices:
|
| 72 |
+
keep.add(global_indices[0])
|
| 73 |
+
keep.add(global_indices[-1])
|
| 74 |
+
replace_indices = {idx for idx, _, _, _ in layers if idx not in keep}
|
| 75 |
+
elif variant == 'pure':
|
| 76 |
+
replace_indices = {idx for idx, _, _, _ in layers}
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(f"Unknown variant: {variant}")
|
| 79 |
+
|
| 80 |
+
print(f"\nVariant '{variant}': replacing {len(replace_indices)} of {len(layers)} layers")
|
| 81 |
+
|
| 82 |
+
encoder = _find_encoder(model)
|
| 83 |
+
report = {}
|
| 84 |
+
|
| 85 |
+
for layer_idx, attn_path, attn_module, is_global in layers:
|
| 86 |
+
if layer_idx not in replace_indices:
|
| 87 |
+
print(f" Layer {layer_idx}: KEEP ({'global' if is_global else 'local'})")
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
birwkv = BiRWKV7Layer(hidden_size, num_heads)
|
| 91 |
+
transferred = init_from_attention(birwkv, attn_module)
|
| 92 |
+
|
| 93 |
+
device = next(attn_module.parameters()).device
|
| 94 |
+
dtype = next(attn_module.parameters()).dtype
|
| 95 |
+
birwkv = birwkv.to(device=device, dtype=dtype)
|
| 96 |
+
|
| 97 |
+
attn_name = attn_path.split('.')[-1]
|
| 98 |
+
setattr(encoder.layers[layer_idx], attn_name, birwkv)
|
| 99 |
+
|
| 100 |
+
report[layer_idx] = {'was_global': is_global, 'transferred': transferred}
|
| 101 |
+
print(f" Layer {layer_idx}: REPLACED ({'global' if is_global else 'local'}) "
|
| 102 |
+
f"-> BiRWKV-7 [{', '.join(transferred)}]")
|
| 103 |
+
|
| 104 |
+
return report
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def mean_pool(hidden_states, attention_mask):
|
| 108 |
+
mask = attention_mask.unsqueeze(-1).float()
|
| 109 |
+
return (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class HareWrapper(torch.nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self, model, tokenizer):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.model = model
|
| 117 |
+
self.tokenizer = tokenizer
|
| 118 |
+
self.config = model.config
|
| 119 |
+
|
| 120 |
+
def encode(self, texts, batch_size=32, max_length=512, show_progress=False):
|
| 121 |
+
all_embs = []
|
| 122 |
+
iterator = range(0, len(texts), batch_size)
|
| 123 |
+
if show_progress:
|
| 124 |
+
from tqdm import tqdm
|
| 125 |
+
iterator = tqdm(iterator, desc="Encoding")
|
| 126 |
+
|
| 127 |
+
for i in iterator:
|
| 128 |
+
batch = texts[i:i+batch_size]
|
| 129 |
+
enc = self.tokenizer(batch, padding=True, truncation=True,
|
| 130 |
+
max_length=max_length, return_tensors='pt')
|
| 131 |
+
enc = {k: v.to(next(self.model.parameters()).device) for k, v in enc.items()}
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
hidden = self.model(**enc).last_hidden_state
|
| 135 |
+
emb = mean_pool(hidden, enc['attention_mask'])
|
| 136 |
+
all_embs.append(F.normalize(emb, p=2, dim=-1).cpu())
|
| 137 |
+
|
| 138 |
+
return torch.cat(all_embs, dim=0)
|
| 139 |
+
|
| 140 |
+
def forward(self, **kwargs):
|
| 141 |
+
return self.model(**kwargs)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def main():
|
| 145 |
+
parser = argparse.ArgumentParser()
|
| 146 |
+
parser.add_argument('--base_model', default='answerdotai/ModernBERT-base')
|
| 147 |
+
parser.add_argument('--variant', choices=['conservative', 'aggressive', 'pure'],
|
| 148 |
+
default='conservative')
|
| 149 |
+
parser.add_argument('--output', type=str, default=None)
|
| 150 |
+
parser.add_argument('--inspect_only', action='store_true')
|
| 151 |
+
args = parser.parse_args()
|
| 152 |
+
|
| 153 |
+
print(f"Loading {args.base_model}...")
|
| 154 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
|
| 155 |
+
model = AutoModel.from_pretrained(args.base_model, trust_remote_code=True)
|
| 156 |
+
config = model.config
|
| 157 |
+
hidden_size = config.hidden_size
|
| 158 |
+
num_heads = config.num_attention_heads
|
| 159 |
+
print(f" hidden_size={hidden_size}, num_heads={num_heads}, head_size={hidden_size // num_heads}")
|
| 160 |
+
|
| 161 |
+
if args.inspect_only:
|
| 162 |
+
layers = find_attention_layers(model)
|
| 163 |
+
print(f"\n{len(layers)} attention layers:")
|
| 164 |
+
for idx, path, attn, is_g in layers:
|
| 165 |
+
n = sum(p.numel() for p in attn.parameters())
|
| 166 |
+
print(f" Layer {idx} ({'GLOBAL' if is_g else 'local'}): {type(attn).__name__} ({n:,}) @ {path}")
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
if not args.output:
|
| 170 |
+
parser.error("--output required for surgery (omit for --inspect_only)")
|
| 171 |
+
|
| 172 |
+
report = perform_surgery(model, args.variant, hidden_size, num_heads)
|
| 173 |
+
|
| 174 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 175 |
+
print(f"\nPost-surgery: {total_params:,} params")
|
| 176 |
+
|
| 177 |
+
print("Sanity check :)")
|
| 178 |
+
inputs = tokenizer("Hello world", return_tensors='pt')
|
| 179 |
+
inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()}
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
out = model(**inputs)
|
| 182 |
+
print(f" Output: {out.last_hidden_state.shape}, norm={out.last_hidden_state.norm().item():.4f}")
|
| 183 |
+
|
| 184 |
+
output_dir = Path(args.output)
|
| 185 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 186 |
+
torch.save(model.state_dict(), output_dir / 'model.pt')
|
| 187 |
+
tokenizer.save_pretrained(output_dir)
|
| 188 |
+
config.save_pretrained(output_dir)
|
| 189 |
+
|
| 190 |
+
meta = {
|
| 191 |
+
'base_model': args.base_model,
|
| 192 |
+
'variant': args.variant,
|
| 193 |
+
'hidden_size': hidden_size,
|
| 194 |
+
'num_heads': num_heads,
|
| 195 |
+
'replaced_layers': {str(k): v for k, v in report.items()},
|
| 196 |
+
'total_params': total_params,
|
| 197 |
+
}
|
| 198 |
+
with open(output_dir / 'surgery_meta.json', 'w') as f:
|
| 199 |
+
json.dump(meta, f, indent=2)
|
| 200 |
+
|
| 201 |
+
print(f"\nSaved to {output_dir}/ ({total_params:,} params)")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == '__main__':
|
| 205 |
+
main()
|
surgery_meta.json
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model": "Alibaba-NLP/gte-modernbert-base",
|
| 3 |
+
"variant": "conservative",
|
| 4 |
+
"hidden_size": 768,
|
| 5 |
+
"num_heads": 12,
|
| 6 |
+
"replaced_layers": {
|
| 7 |
+
"1": {
|
| 8 |
+
"was_global": false,
|
| 9 |
+
"transferred": [
|
| 10 |
+
"Q->R",
|
| 11 |
+
"K->K",
|
| 12 |
+
"V->V",
|
| 13 |
+
"O->O"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
"2": {
|
| 17 |
+
"was_global": false,
|
| 18 |
+
"transferred": [
|
| 19 |
+
"Q->R",
|
| 20 |
+
"K->K",
|
| 21 |
+
"V->V",
|
| 22 |
+
"O->O"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
"4": {
|
| 26 |
+
"was_global": false,
|
| 27 |
+
"transferred": [
|
| 28 |
+
"Q->R",
|
| 29 |
+
"K->K",
|
| 30 |
+
"V->V",
|
| 31 |
+
"O->O"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
"5": {
|
| 35 |
+
"was_global": false,
|
| 36 |
+
"transferred": [
|
| 37 |
+
"Q->R",
|
| 38 |
+
"K->K",
|
| 39 |
+
"V->V",
|
| 40 |
+
"O->O"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
"7": {
|
| 44 |
+
"was_global": false,
|
| 45 |
+
"transferred": [
|
| 46 |
+
"Q->R",
|
| 47 |
+
"K->K",
|
| 48 |
+
"V->V",
|
| 49 |
+
"O->O"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
"8": {
|
| 53 |
+
"was_global": false,
|
| 54 |
+
"transferred": [
|
| 55 |
+
"Q->R",
|
| 56 |
+
"K->K",
|
| 57 |
+
"V->V",
|
| 58 |
+
"O->O"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
"10": {
|
| 62 |
+
"was_global": false,
|
| 63 |
+
"transferred": [
|
| 64 |
+
"Q->R",
|
| 65 |
+
"K->K",
|
| 66 |
+
"V->V",
|
| 67 |
+
"O->O"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
"11": {
|
| 71 |
+
"was_global": false,
|
| 72 |
+
"transferred": [
|
| 73 |
+
"Q->R",
|
| 74 |
+
"K->K",
|
| 75 |
+
"V->V",
|
| 76 |
+
"O->O"
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
"13": {
|
| 80 |
+
"was_global": false,
|
| 81 |
+
"transferred": [
|
| 82 |
+
"Q->R",
|
| 83 |
+
"K->K",
|
| 84 |
+
"V->V",
|
| 85 |
+
"O->O"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
"14": {
|
| 89 |
+
"was_global": false,
|
| 90 |
+
"transferred": [
|
| 91 |
+
"Q->R",
|
| 92 |
+
"K->K",
|
| 93 |
+
"V->V",
|
| 94 |
+
"O->O"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
"16": {
|
| 98 |
+
"was_global": false,
|
| 99 |
+
"transferred": [
|
| 100 |
+
"Q->R",
|
| 101 |
+
"K->K",
|
| 102 |
+
"V->V",
|
| 103 |
+
"O->O"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
"17": {
|
| 107 |
+
"was_global": false,
|
| 108 |
+
"transferred": [
|
| 109 |
+
"Q->R",
|
| 110 |
+
"K->K",
|
| 111 |
+
"V->V",
|
| 112 |
+
"O->O"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
"19": {
|
| 116 |
+
"was_global": false,
|
| 117 |
+
"transferred": [
|
| 118 |
+
"Q->R",
|
| 119 |
+
"K->K",
|
| 120 |
+
"V->V",
|
| 121 |
+
"O->O"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
"20": {
|
| 125 |
+
"was_global": false,
|
| 126 |
+
"transferred": [
|
| 127 |
+
"Q->R",
|
| 128 |
+
"K->K",
|
| 129 |
+
"V->V",
|
| 130 |
+
"O->O"
|
| 131 |
+
]
|
| 132 |
+
}
|
| 133 |
+
},
|
| 134 |
+
"total_params": 173872910
|
| 135 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"clean_up_tokenization_spaces": true,
|
| 4 |
+
"cls_token": "[CLS]",
|
| 5 |
+
"is_local": true,
|
| 6 |
+
"mask_token": "[MASK]",
|
| 7 |
+
"model_input_names": [
|
| 8 |
+
"input_ids",
|
| 9 |
+
"attention_mask"
|
| 10 |
+
],
|
| 11 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 12 |
+
"pad_token": "[PAD]",
|
| 13 |
+
"sep_token": "[SEP]",
|
| 14 |
+
"tokenizer_class": "TokenizersBackend",
|
| 15 |
+
"unk_token": "[UNK]"
|
| 16 |
+
}
|