HARE / surgery.py
SixOpen's picture
Upload folder using huggingface_hub
f8ab83c verified
import argparse
import json
from pathlib import Path
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig
from birwkv7 import BiRWKV7Layer, init_from_attention
def _find_encoder(model):
for attr in ['encoder', 'model']:
if hasattr(model, attr):
candidate = getattr(model, attr)
if hasattr(candidate, 'layers'):
return candidate
if hasattr(model, 'layers'):
return model
raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
def find_attention_layers(model):
encoder = _find_encoder(model)
layers = []
for i, layer in enumerate(encoder.layers):
attn = None
attn_path = None
for name in ['attn', 'attention', 'self_attn', 'self_attention']:
if hasattr(layer, name):
attn = getattr(layer, name)
attn_path = f"layers.{i}.{name}"
break
if attn is None:
continue
is_global = False
if hasattr(attn, 'local_attention'):
is_global = not attn.local_attention
elif hasattr(attn, 'is_global_attention'):
is_global = attn.is_global_attention
elif hasattr(attn, 'use_sliding_window'):
is_global = not attn.use_sliding_window
elif hasattr(attn, 'sliding_window'):
is_global = attn.sliding_window is None
else:
is_global = (i % 3 == 2)
layers.append((i, attn_path, attn, is_global))
return layers
def perform_surgery(model, variant, hidden_size, num_heads, replaced_layers=None):
layers = find_attention_layers(model)
global_indices = [idx for idx, _, _, g in layers if g]
local_indices = [idx for idx, _, _, g in layers if not g]
print(f"\nFound {len(layers)} attention layers:")
print(f" Global: {global_indices}")
print(f" Local: {local_indices}")
if replaced_layers is not None:
replace_indices = {int(k) for k in replaced_layers.keys()}
elif variant == 'conservative':
replace_indices = set(local_indices)
elif variant == 'aggressive':
keep = set()
if global_indices:
keep.add(global_indices[0])
keep.add(global_indices[-1])
replace_indices = {idx for idx, _, _, _ in layers if idx not in keep}
elif variant == 'pure':
replace_indices = {idx for idx, _, _, _ in layers}
else:
raise ValueError(f"Unknown variant: {variant}")
print(f"\nVariant '{variant}': replacing {len(replace_indices)} of {len(layers)} layers")
encoder = _find_encoder(model)
report = {}
for layer_idx, attn_path, attn_module, is_global in layers:
if layer_idx not in replace_indices:
print(f" Layer {layer_idx}: KEEP ({'global' if is_global else 'local'})")
continue
birwkv = BiRWKV7Layer(hidden_size, num_heads)
transferred = init_from_attention(birwkv, attn_module)
device = next(attn_module.parameters()).device
dtype = next(attn_module.parameters()).dtype
birwkv = birwkv.to(device=device, dtype=dtype)
attn_name = attn_path.split('.')[-1]
setattr(encoder.layers[layer_idx], attn_name, birwkv)
report[layer_idx] = {'was_global': is_global, 'transferred': transferred}
print(f" Layer {layer_idx}: REPLACED ({'global' if is_global else 'local'}) "
f"-> BiRWKV-7 [{', '.join(transferred)}]")
return report
def mean_pool(hidden_states, attention_mask):
mask = attention_mask.unsqueeze(-1).float()
return (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
class HareWrapper(torch.nn.Module):
def __init__(self, model, tokenizer):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.config = model.config
def encode(self, texts, batch_size=32, max_length=512, show_progress=False):
all_embs = []
iterator = range(0, len(texts), batch_size)
if show_progress:
from tqdm import tqdm
iterator = tqdm(iterator, desc="Encoding")
for i in iterator:
batch = texts[i:i+batch_size]
enc = self.tokenizer(batch, padding=True, truncation=True,
max_length=max_length, return_tensors='pt')
enc = {k: v.to(next(self.model.parameters()).device) for k, v in enc.items()}
with torch.no_grad():
hidden = self.model(**enc).last_hidden_state
emb = mean_pool(hidden, enc['attention_mask'])
all_embs.append(F.normalize(emb, p=2, dim=-1).cpu())
return torch.cat(all_embs, dim=0)
def forward(self, **kwargs):
return self.model(**kwargs)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default='answerdotai/ModernBERT-base')
parser.add_argument('--variant', choices=['conservative', 'aggressive', 'pure'],
default='conservative')
parser.add_argument('--output', type=str, default=None)
parser.add_argument('--inspect_only', action='store_true')
args = parser.parse_args()
print(f"Loading {args.base_model}...")
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
model = AutoModel.from_pretrained(args.base_model, trust_remote_code=True)
config = model.config
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
print(f" hidden_size={hidden_size}, num_heads={num_heads}, head_size={hidden_size // num_heads}")
if args.inspect_only:
layers = find_attention_layers(model)
print(f"\n{len(layers)} attention layers:")
for idx, path, attn, is_g in layers:
n = sum(p.numel() for p in attn.parameters())
print(f" Layer {idx} ({'GLOBAL' if is_g else 'local'}): {type(attn).__name__} ({n:,}) @ {path}")
return
if not args.output:
parser.error("--output required for surgery (omit for --inspect_only)")
report = perform_surgery(model, args.variant, hidden_size, num_heads)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nPost-surgery: {total_params:,} params")
print("Sanity check :)")
inputs = tokenizer("Hello world", return_tensors='pt')
inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()}
with torch.no_grad():
out = model(**inputs)
print(f" Output: {out.last_hidden_state.shape}, norm={out.last_hidden_state.norm().item():.4f}")
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), output_dir / 'model.pt')
tokenizer.save_pretrained(output_dir)
config.save_pretrained(output_dir)
meta = {
'base_model': args.base_model,
'variant': args.variant,
'hidden_size': hidden_size,
'num_heads': num_heads,
'replaced_layers': {str(k): v for k, v in report.items()},
'total_params': total_params,
}
with open(output_dir / 'surgery_meta.json', 'w') as f:
json.dump(meta, f, indent=2)
print(f"\nSaved to {output_dir}/ ({total_params:,} params)")
if __name__ == '__main__':
main()