| import argparse | |
| import json | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig | |
| from birwkv7 import BiRWKV7Layer, init_from_attention | |
| def _find_encoder(model): | |
| for attr in ['encoder', 'model']: | |
| if hasattr(model, attr): | |
| candidate = getattr(model, attr) | |
| if hasattr(candidate, 'layers'): | |
| return candidate | |
| if hasattr(model, 'layers'): | |
| return model | |
| raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}") | |
| def find_attention_layers(model): | |
| encoder = _find_encoder(model) | |
| layers = [] | |
| for i, layer in enumerate(encoder.layers): | |
| attn = None | |
| attn_path = None | |
| for name in ['attn', 'attention', 'self_attn', 'self_attention']: | |
| if hasattr(layer, name): | |
| attn = getattr(layer, name) | |
| attn_path = f"layers.{i}.{name}" | |
| break | |
| if attn is None: | |
| continue | |
| is_global = False | |
| if hasattr(attn, 'local_attention'): | |
| is_global = not attn.local_attention | |
| elif hasattr(attn, 'is_global_attention'): | |
| is_global = attn.is_global_attention | |
| elif hasattr(attn, 'use_sliding_window'): | |
| is_global = not attn.use_sliding_window | |
| elif hasattr(attn, 'sliding_window'): | |
| is_global = attn.sliding_window is None | |
| else: | |
| is_global = (i % 3 == 2) | |
| layers.append((i, attn_path, attn, is_global)) | |
| return layers | |
| def perform_surgery(model, variant, hidden_size, num_heads, replaced_layers=None): | |
| layers = find_attention_layers(model) | |
| global_indices = [idx for idx, _, _, g in layers if g] | |
| local_indices = [idx for idx, _, _, g in layers if not g] | |
| print(f"\nFound {len(layers)} attention layers:") | |
| print(f" Global: {global_indices}") | |
| print(f" Local: {local_indices}") | |
| if replaced_layers is not None: | |
| replace_indices = {int(k) for k in replaced_layers.keys()} | |
| elif variant == 'conservative': | |
| replace_indices = set(local_indices) | |
| elif variant == 'aggressive': | |
| keep = set() | |
| if global_indices: | |
| keep.add(global_indices[0]) | |
| keep.add(global_indices[-1]) | |
| replace_indices = {idx for idx, _, _, _ in layers if idx not in keep} | |
| elif variant == 'pure': | |
| replace_indices = {idx for idx, _, _, _ in layers} | |
| else: | |
| raise ValueError(f"Unknown variant: {variant}") | |
| print(f"\nVariant '{variant}': replacing {len(replace_indices)} of {len(layers)} layers") | |
| encoder = _find_encoder(model) | |
| report = {} | |
| for layer_idx, attn_path, attn_module, is_global in layers: | |
| if layer_idx not in replace_indices: | |
| print(f" Layer {layer_idx}: KEEP ({'global' if is_global else 'local'})") | |
| continue | |
| birwkv = BiRWKV7Layer(hidden_size, num_heads) | |
| transferred = init_from_attention(birwkv, attn_module) | |
| device = next(attn_module.parameters()).device | |
| dtype = next(attn_module.parameters()).dtype | |
| birwkv = birwkv.to(device=device, dtype=dtype) | |
| attn_name = attn_path.split('.')[-1] | |
| setattr(encoder.layers[layer_idx], attn_name, birwkv) | |
| report[layer_idx] = {'was_global': is_global, 'transferred': transferred} | |
| print(f" Layer {layer_idx}: REPLACED ({'global' if is_global else 'local'}) " | |
| f"-> BiRWKV-7 [{', '.join(transferred)}]") | |
| return report | |
| def mean_pool(hidden_states, attention_mask): | |
| mask = attention_mask.unsqueeze(-1).float() | |
| return (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) | |
| class HareWrapper(torch.nn.Module): | |
| def __init__(self, model, tokenizer): | |
| super().__init__() | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.config = model.config | |
| def encode(self, texts, batch_size=32, max_length=512, show_progress=False): | |
| all_embs = [] | |
| iterator = range(0, len(texts), batch_size) | |
| if show_progress: | |
| from tqdm import tqdm | |
| iterator = tqdm(iterator, desc="Encoding") | |
| for i in iterator: | |
| batch = texts[i:i+batch_size] | |
| enc = self.tokenizer(batch, padding=True, truncation=True, | |
| max_length=max_length, return_tensors='pt') | |
| enc = {k: v.to(next(self.model.parameters()).device) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| hidden = self.model(**enc).last_hidden_state | |
| emb = mean_pool(hidden, enc['attention_mask']) | |
| all_embs.append(F.normalize(emb, p=2, dim=-1).cpu()) | |
| return torch.cat(all_embs, dim=0) | |
| def forward(self, **kwargs): | |
| return self.model(**kwargs) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--base_model', default='answerdotai/ModernBERT-base') | |
| parser.add_argument('--variant', choices=['conservative', 'aggressive', 'pure'], | |
| default='conservative') | |
| parser.add_argument('--output', type=str, default=None) | |
| parser.add_argument('--inspect_only', action='store_true') | |
| args = parser.parse_args() | |
| print(f"Loading {args.base_model}...") | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model) | |
| model = AutoModel.from_pretrained(args.base_model, trust_remote_code=True) | |
| config = model.config | |
| hidden_size = config.hidden_size | |
| num_heads = config.num_attention_heads | |
| print(f" hidden_size={hidden_size}, num_heads={num_heads}, head_size={hidden_size // num_heads}") | |
| if args.inspect_only: | |
| layers = find_attention_layers(model) | |
| print(f"\n{len(layers)} attention layers:") | |
| for idx, path, attn, is_g in layers: | |
| n = sum(p.numel() for p in attn.parameters()) | |
| print(f" Layer {idx} ({'GLOBAL' if is_g else 'local'}): {type(attn).__name__} ({n:,}) @ {path}") | |
| return | |
| if not args.output: | |
| parser.error("--output required for surgery (omit for --inspect_only)") | |
| report = perform_surgery(model, args.variant, hidden_size, num_heads) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| print(f"\nPost-surgery: {total_params:,} params") | |
| print("Sanity check :)") | |
| inputs = tokenizer("Hello world", return_tensors='pt') | |
| inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = model(**inputs) | |
| print(f" Output: {out.last_hidden_state.shape}, norm={out.last_hidden_state.norm().item():.4f}") | |
| output_dir = Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| torch.save(model.state_dict(), output_dir / 'model.pt') | |
| tokenizer.save_pretrained(output_dir) | |
| config.save_pretrained(output_dir) | |
| meta = { | |
| 'base_model': args.base_model, | |
| 'variant': args.variant, | |
| 'hidden_size': hidden_size, | |
| 'num_heads': num_heads, | |
| 'replaced_layers': {str(k): v for k, v in report.items()}, | |
| 'total_params': total_params, | |
| } | |
| with open(output_dir / 'surgery_meta.json', 'w') as f: | |
| json.dump(meta, f, indent=2) | |
| print(f"\nSaved to {output_dir}/ ({total_params:,} params)") | |
| if __name__ == '__main__': | |
| main() | |