HARE / modeling_hare.py
SixOpen's picture
Update modeling_hare.py
7d6b779 verified
import json
import logging
from pathlib import Path
import torch
from transformers import AutoModel, PreTrainedModel
from transformers import ModernBertConfig
for _logger_name in ["transformers.modeling_utils", "transformers.configuration_utils"]:
logging.getLogger(_logger_name).setLevel(logging.ERROR)
from .configuration_hare import HareConfig
from .birwkv7 import BiRWKV7Layer, init_from_attention
def _find_encoder(model):
for attr in ['encoder', 'model']:
if hasattr(model, attr):
candidate = getattr(model, attr)
if hasattr(candidate, 'layers'):
return candidate
if hasattr(model, 'layers'):
return model
raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
def _perform_surgery(model, replaced_layers, hidden_size, num_heads):
encoder = _find_encoder(model)
for layer_idx_str, info in replaced_layers.items():
layer_idx = int(layer_idx_str)
layer = encoder.layers[layer_idx]
attn = None
attn_name = None
for name in ['attn', 'attention', 'self_attn', 'self_attention']:
if hasattr(layer, name):
attn = getattr(layer, name)
attn_name = name
break
if attn is None:
continue
birwkv = BiRWKV7Layer(hidden_size, num_heads)
device = next(attn.parameters()).device
dtype = next(attn.parameters()).dtype
birwkv = birwkv.to(device=device, dtype=dtype)
setattr(layer, attn_name, birwkv)
class HareModel(PreTrainedModel):
config_class = HareConfig
def __init__(self, config):
super().__init__(config)
base_config = ModernBertConfig(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=config.num_hidden_layers,
intermediate_size=config.intermediate_size,
vocab_size=config.vocab_size,
max_position_embeddings=config.max_position_embeddings,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
cls_token_id=getattr(config, 'cls_token_id', config.bos_token_id),
sep_token_id=getattr(config, 'sep_token_id', config.eos_token_id),
global_attn_every_n_layers=getattr(config, 'global_attn_every_n_layers', 3),
local_attention=getattr(config, 'local_attention', 128),
)
self.inner_model = AutoModel.from_config(base_config)
if config.replaced_layers:
_perform_surgery(
self.inner_model,
config.replaced_layers,
config.hidden_size,
config.num_attention_heads,
)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.inner_model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
return outputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model_dir = Path(pretrained_model_name_or_path)
surgery_meta_path = model_dir / "surgery_meta.json"
if not surgery_meta_path.exists():
from huggingface_hub import hf_hub_download
try:
surgery_meta_path = Path(hf_hub_download(
pretrained_model_name_or_path, "surgery_meta.json"))
model_dir = surgery_meta_path.parent
except Exception:
return super().from_pretrained(
pretrained_model_name_or_path, *args, **kwargs)
with open(surgery_meta_path) as f:
meta = json.load(f)
config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
config.replaced_layers = meta.get("replaced_layers")
config.surgery_variant = meta.get("variant", "conservative")
model = cls(config)
weights_path = model_dir / "model.pt"
if not weights_path.exists():
from huggingface_hub import hf_hub_download
try:
weights_path = Path(hf_hub_download(
pretrained_model_name_or_path, "model.pt"))
except Exception:
pass
if weights_path.exists():
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
model.inner_model.load_state_dict(state_dict)
return model.float().eval()