File size: 7,308 Bytes
f8ab83c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import argparse
import json
from pathlib import Path

import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig

from birwkv7 import BiRWKV7Layer, init_from_attention


def _find_encoder(model):
    for attr in ['encoder', 'model']:
        if hasattr(model, attr):
            candidate = getattr(model, attr)
            if hasattr(candidate, 'layers'):
                return candidate
    if hasattr(model, 'layers'):
        return model
    raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")


def find_attention_layers(model):
    encoder = _find_encoder(model)
    layers = []

    for i, layer in enumerate(encoder.layers):
        attn = None
        attn_path = None
        for name in ['attn', 'attention', 'self_attn', 'self_attention']:
            if hasattr(layer, name):
                attn = getattr(layer, name)
                attn_path = f"layers.{i}.{name}"
                break

        if attn is None:
            continue

        is_global = False
        if hasattr(attn, 'local_attention'):
            is_global = not attn.local_attention
        elif hasattr(attn, 'is_global_attention'):
            is_global = attn.is_global_attention
        elif hasattr(attn, 'use_sliding_window'):
            is_global = not attn.use_sliding_window
        elif hasattr(attn, 'sliding_window'):
            is_global = attn.sliding_window is None
        else:
            is_global = (i % 3 == 2)

        layers.append((i, attn_path, attn, is_global))

    return layers


def perform_surgery(model, variant, hidden_size, num_heads, replaced_layers=None):
    layers = find_attention_layers(model)
    global_indices = [idx for idx, _, _, g in layers if g]
    local_indices = [idx for idx, _, _, g in layers if not g]

    print(f"\nFound {len(layers)} attention layers:")
    print(f"  Global: {global_indices}")
    print(f"  Local:  {local_indices}")

    if replaced_layers is not None:
        replace_indices = {int(k) for k in replaced_layers.keys()}
    elif variant == 'conservative':
        replace_indices = set(local_indices)
    elif variant == 'aggressive':
        keep = set()
        if global_indices:
            keep.add(global_indices[0])
            keep.add(global_indices[-1])
        replace_indices = {idx for idx, _, _, _ in layers if idx not in keep}
    elif variant == 'pure':
        replace_indices = {idx for idx, _, _, _ in layers}
    else:
        raise ValueError(f"Unknown variant: {variant}")

    print(f"\nVariant '{variant}': replacing {len(replace_indices)} of {len(layers)} layers")

    encoder = _find_encoder(model)
    report = {}

    for layer_idx, attn_path, attn_module, is_global in layers:
        if layer_idx not in replace_indices:
            print(f"  Layer {layer_idx}: KEEP ({'global' if is_global else 'local'})")
            continue

        birwkv = BiRWKV7Layer(hidden_size, num_heads)
        transferred = init_from_attention(birwkv, attn_module)

        device = next(attn_module.parameters()).device
        dtype = next(attn_module.parameters()).dtype
        birwkv = birwkv.to(device=device, dtype=dtype)

        attn_name = attn_path.split('.')[-1]
        setattr(encoder.layers[layer_idx], attn_name, birwkv)

        report[layer_idx] = {'was_global': is_global, 'transferred': transferred}
        print(f"  Layer {layer_idx}: REPLACED ({'global' if is_global else 'local'}) "
              f"-> BiRWKV-7 [{', '.join(transferred)}]")

    return report


def mean_pool(hidden_states, attention_mask):
    mask = attention_mask.unsqueeze(-1).float()
    return (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9)


class HareWrapper(torch.nn.Module):

    def __init__(self, model, tokenizer):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.config = model.config

    def encode(self, texts, batch_size=32, max_length=512, show_progress=False):
        all_embs = []
        iterator = range(0, len(texts), batch_size)
        if show_progress:
            from tqdm import tqdm
            iterator = tqdm(iterator, desc="Encoding")

        for i in iterator:
            batch = texts[i:i+batch_size]
            enc = self.tokenizer(batch, padding=True, truncation=True,
                                 max_length=max_length, return_tensors='pt')
            enc = {k: v.to(next(self.model.parameters()).device) for k, v in enc.items()}

            with torch.no_grad():
                hidden = self.model(**enc).last_hidden_state
            emb = mean_pool(hidden, enc['attention_mask'])
            all_embs.append(F.normalize(emb, p=2, dim=-1).cpu())

        return torch.cat(all_embs, dim=0)

    def forward(self, **kwargs):
        return self.model(**kwargs)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_model', default='answerdotai/ModernBERT-base')
    parser.add_argument('--variant', choices=['conservative', 'aggressive', 'pure'],
                        default='conservative')
    parser.add_argument('--output', type=str, default=None)
    parser.add_argument('--inspect_only', action='store_true')
    args = parser.parse_args()

    print(f"Loading {args.base_model}...")
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = AutoModel.from_pretrained(args.base_model, trust_remote_code=True)
    config = model.config
    hidden_size = config.hidden_size
    num_heads = config.num_attention_heads
    print(f"  hidden_size={hidden_size}, num_heads={num_heads}, head_size={hidden_size // num_heads}")

    if args.inspect_only:
        layers = find_attention_layers(model)
        print(f"\n{len(layers)} attention layers:")
        for idx, path, attn, is_g in layers:
            n = sum(p.numel() for p in attn.parameters())
            print(f"  Layer {idx} ({'GLOBAL' if is_g else 'local'}): {type(attn).__name__} ({n:,}) @ {path}")
        return

    if not args.output:
        parser.error("--output required for surgery (omit for --inspect_only)")

    report = perform_surgery(model, args.variant, hidden_size, num_heads)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nPost-surgery: {total_params:,} params")

    print("Sanity check :)")
    inputs = tokenizer("Hello world", return_tensors='pt')
    inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()}
    with torch.no_grad():
        out = model(**inputs)
    print(f"  Output: {out.last_hidden_state.shape}, norm={out.last_hidden_state.norm().item():.4f}")

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), output_dir / 'model.pt')
    tokenizer.save_pretrained(output_dir)
    config.save_pretrained(output_dir)

    meta = {
        'base_model': args.base_model,
        'variant': args.variant,
        'hidden_size': hidden_size,
        'num_heads': num_heads,
        'replaced_layers': {str(k): v for k, v in report.items()},
        'total_params': total_params,
    }
    with open(output_dir / 'surgery_meta.json', 'w') as f:
        json.dump(meta, f, indent=2)

    print(f"\nSaved to {output_dir}/ ({total_params:,} params)")


if __name__ == '__main__':
    main()