import argparse import json from pathlib import Path import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer, AutoConfig from birwkv7 import BiRWKV7Layer, init_from_attention def _find_encoder(model): for attr in ['encoder', 'model']: if hasattr(model, attr): candidate = getattr(model, attr) if hasattr(candidate, 'layers'): return candidate if hasattr(model, 'layers'): return model raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}") def find_attention_layers(model): encoder = _find_encoder(model) layers = [] for i, layer in enumerate(encoder.layers): attn = None attn_path = None for name in ['attn', 'attention', 'self_attn', 'self_attention']: if hasattr(layer, name): attn = getattr(layer, name) attn_path = f"layers.{i}.{name}" break if attn is None: continue is_global = False if hasattr(attn, 'local_attention'): is_global = not attn.local_attention elif hasattr(attn, 'is_global_attention'): is_global = attn.is_global_attention elif hasattr(attn, 'use_sliding_window'): is_global = not attn.use_sliding_window elif hasattr(attn, 'sliding_window'): is_global = attn.sliding_window is None else: is_global = (i % 3 == 2) layers.append((i, attn_path, attn, is_global)) return layers def perform_surgery(model, variant, hidden_size, num_heads, replaced_layers=None): layers = find_attention_layers(model) global_indices = [idx for idx, _, _, g in layers if g] local_indices = [idx for idx, _, _, g in layers if not g] print(f"\nFound {len(layers)} attention layers:") print(f" Global: {global_indices}") print(f" Local: {local_indices}") if replaced_layers is not None: replace_indices = {int(k) for k in replaced_layers.keys()} elif variant == 'conservative': replace_indices = set(local_indices) elif variant == 'aggressive': keep = set() if global_indices: keep.add(global_indices[0]) keep.add(global_indices[-1]) replace_indices = {idx for idx, _, _, _ in layers if idx not in keep} elif variant == 'pure': replace_indices = {idx for idx, _, _, _ in layers} else: raise ValueError(f"Unknown variant: {variant}") print(f"\nVariant '{variant}': replacing {len(replace_indices)} of {len(layers)} layers") encoder = _find_encoder(model) report = {} for layer_idx, attn_path, attn_module, is_global in layers: if layer_idx not in replace_indices: print(f" Layer {layer_idx}: KEEP ({'global' if is_global else 'local'})") continue birwkv = BiRWKV7Layer(hidden_size, num_heads) transferred = init_from_attention(birwkv, attn_module) device = next(attn_module.parameters()).device dtype = next(attn_module.parameters()).dtype birwkv = birwkv.to(device=device, dtype=dtype) attn_name = attn_path.split('.')[-1] setattr(encoder.layers[layer_idx], attn_name, birwkv) report[layer_idx] = {'was_global': is_global, 'transferred': transferred} print(f" Layer {layer_idx}: REPLACED ({'global' if is_global else 'local'}) " f"-> BiRWKV-7 [{', '.join(transferred)}]") return report def mean_pool(hidden_states, attention_mask): mask = attention_mask.unsqueeze(-1).float() return (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) class HareWrapper(torch.nn.Module): def __init__(self, model, tokenizer): super().__init__() self.model = model self.tokenizer = tokenizer self.config = model.config def encode(self, texts, batch_size=32, max_length=512, show_progress=False): all_embs = [] iterator = range(0, len(texts), batch_size) if show_progress: from tqdm import tqdm iterator = tqdm(iterator, desc="Encoding") for i in iterator: batch = texts[i:i+batch_size] enc = self.tokenizer(batch, padding=True, truncation=True, max_length=max_length, return_tensors='pt') enc = {k: v.to(next(self.model.parameters()).device) for k, v in enc.items()} with torch.no_grad(): hidden = self.model(**enc).last_hidden_state emb = mean_pool(hidden, enc['attention_mask']) all_embs.append(F.normalize(emb, p=2, dim=-1).cpu()) return torch.cat(all_embs, dim=0) def forward(self, **kwargs): return self.model(**kwargs) def main(): parser = argparse.ArgumentParser() parser.add_argument('--base_model', default='answerdotai/ModernBERT-base') parser.add_argument('--variant', choices=['conservative', 'aggressive', 'pure'], default='conservative') parser.add_argument('--output', type=str, default=None) parser.add_argument('--inspect_only', action='store_true') args = parser.parse_args() print(f"Loading {args.base_model}...") tokenizer = AutoTokenizer.from_pretrained(args.base_model) model = AutoModel.from_pretrained(args.base_model, trust_remote_code=True) config = model.config hidden_size = config.hidden_size num_heads = config.num_attention_heads print(f" hidden_size={hidden_size}, num_heads={num_heads}, head_size={hidden_size // num_heads}") if args.inspect_only: layers = find_attention_layers(model) print(f"\n{len(layers)} attention layers:") for idx, path, attn, is_g in layers: n = sum(p.numel() for p in attn.parameters()) print(f" Layer {idx} ({'GLOBAL' if is_g else 'local'}): {type(attn).__name__} ({n:,}) @ {path}") return if not args.output: parser.error("--output required for surgery (omit for --inspect_only)") report = perform_surgery(model, args.variant, hidden_size, num_heads) total_params = sum(p.numel() for p in model.parameters()) print(f"\nPost-surgery: {total_params:,} params") print("Sanity check :)") inputs = tokenizer("Hello world", return_tensors='pt') inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()} with torch.no_grad(): out = model(**inputs) print(f" Output: {out.last_hidden_state.shape}, norm={out.last_hidden_state.norm().item():.4f}") output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), output_dir / 'model.pt') tokenizer.save_pretrained(output_dir) config.save_pretrained(output_dir) meta = { 'base_model': args.base_model, 'variant': args.variant, 'hidden_size': hidden_size, 'num_heads': num_heads, 'replaced_layers': {str(k): v for k, v in report.items()}, 'total_params': total_params, } with open(output_dir / 'surgery_meta.json', 'w') as f: json.dump(meta, f, indent=2) print(f"\nSaved to {output_dir}/ ({total_params:,} params)") if __name__ == '__main__': main()