File size: 4,572 Bytes
0f92ffc 7d6b779 0f92ffc 7d6b779 0f92ffc 7d6b779 0f92ffc 7d6b779 0f92ffc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import json
import logging
from pathlib import Path
import torch
from transformers import AutoModel, PreTrainedModel
from transformers import ModernBertConfig
for _logger_name in ["transformers.modeling_utils", "transformers.configuration_utils"]:
logging.getLogger(_logger_name).setLevel(logging.ERROR)
from .configuration_hare import HareConfig
from .birwkv7 import BiRWKV7Layer, init_from_attention
def _find_encoder(model):
for attr in ['encoder', 'model']:
if hasattr(model, attr):
candidate = getattr(model, attr)
if hasattr(candidate, 'layers'):
return candidate
if hasattr(model, 'layers'):
return model
raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
def _perform_surgery(model, replaced_layers, hidden_size, num_heads):
encoder = _find_encoder(model)
for layer_idx_str, info in replaced_layers.items():
layer_idx = int(layer_idx_str)
layer = encoder.layers[layer_idx]
attn = None
attn_name = None
for name in ['attn', 'attention', 'self_attn', 'self_attention']:
if hasattr(layer, name):
attn = getattr(layer, name)
attn_name = name
break
if attn is None:
continue
birwkv = BiRWKV7Layer(hidden_size, num_heads)
device = next(attn.parameters()).device
dtype = next(attn.parameters()).dtype
birwkv = birwkv.to(device=device, dtype=dtype)
setattr(layer, attn_name, birwkv)
class HareModel(PreTrainedModel):
config_class = HareConfig
def __init__(self, config):
super().__init__(config)
base_config = ModernBertConfig(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=config.num_hidden_layers,
intermediate_size=config.intermediate_size,
vocab_size=config.vocab_size,
max_position_embeddings=config.max_position_embeddings,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
cls_token_id=getattr(config, 'cls_token_id', config.bos_token_id),
sep_token_id=getattr(config, 'sep_token_id', config.eos_token_id),
global_attn_every_n_layers=getattr(config, 'global_attn_every_n_layers', 3),
local_attention=getattr(config, 'local_attention', 128),
)
self.inner_model = AutoModel.from_config(base_config)
if config.replaced_layers:
_perform_surgery(
self.inner_model,
config.replaced_layers,
config.hidden_size,
config.num_attention_heads,
)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.inner_model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
return outputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model_dir = Path(pretrained_model_name_or_path)
surgery_meta_path = model_dir / "surgery_meta.json"
if not surgery_meta_path.exists():
from huggingface_hub import hf_hub_download
try:
surgery_meta_path = Path(hf_hub_download(
pretrained_model_name_or_path, "surgery_meta.json"))
model_dir = surgery_meta_path.parent
except Exception:
return super().from_pretrained(
pretrained_model_name_or_path, *args, **kwargs)
with open(surgery_meta_path) as f:
meta = json.load(f)
config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
config.replaced_layers = meta.get("replaced_layers")
config.surgery_variant = meta.get("variant", "conservative")
model = cls(config)
weights_path = model_dir / "model.pt"
if not weights_path.exists():
from huggingface_hub import hf_hub_download
try:
weights_path = Path(hf_hub_download(
pretrained_model_name_or_path, "model.pt"))
except Exception:
pass
if weights_path.exists():
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
model.inner_model.load_state_dict(state_dict)
return model.float().eval()
|