import json import logging from pathlib import Path import torch from transformers import AutoModel, PreTrainedModel from transformers import ModernBertConfig for _logger_name in ["transformers.modeling_utils", "transformers.configuration_utils"]: logging.getLogger(_logger_name).setLevel(logging.ERROR) from .configuration_hare import HareConfig from .birwkv7 import BiRWKV7Layer, init_from_attention def _find_encoder(model): for attr in ['encoder', 'model']: if hasattr(model, attr): candidate = getattr(model, attr) if hasattr(candidate, 'layers'): return candidate if hasattr(model, 'layers'): return model raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}") def _perform_surgery(model, replaced_layers, hidden_size, num_heads): encoder = _find_encoder(model) for layer_idx_str, info in replaced_layers.items(): layer_idx = int(layer_idx_str) layer = encoder.layers[layer_idx] attn = None attn_name = None for name in ['attn', 'attention', 'self_attn', 'self_attention']: if hasattr(layer, name): attn = getattr(layer, name) attn_name = name break if attn is None: continue birwkv = BiRWKV7Layer(hidden_size, num_heads) device = next(attn.parameters()).device dtype = next(attn.parameters()).dtype birwkv = birwkv.to(device=device, dtype=dtype) setattr(layer, attn_name, birwkv) class HareModel(PreTrainedModel): config_class = HareConfig def __init__(self, config): super().__init__(config) base_config = ModernBertConfig( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_hidden_layers=config.num_hidden_layers, intermediate_size=config.intermediate_size, vocab_size=config.vocab_size, max_position_embeddings=config.max_position_embeddings, pad_token_id=config.pad_token_id, bos_token_id=config.bos_token_id, eos_token_id=config.eos_token_id, cls_token_id=getattr(config, 'cls_token_id', config.bos_token_id), sep_token_id=getattr(config, 'sep_token_id', config.eos_token_id), global_attn_every_n_layers=getattr(config, 'global_attn_every_n_layers', 3), local_attention=getattr(config, 'local_attention', 128), ) self.inner_model = AutoModel.from_config(base_config) if config.replaced_layers: _perform_surgery( self.inner_model, config.replaced_layers, config.hidden_size, config.num_attention_heads, ) def forward(self, input_ids=None, attention_mask=None, **kwargs): outputs = self.inner_model( input_ids=input_ids, attention_mask=attention_mask, **kwargs, ) return outputs @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): model_dir = Path(pretrained_model_name_or_path) surgery_meta_path = model_dir / "surgery_meta.json" if not surgery_meta_path.exists(): from huggingface_hub import hf_hub_download try: surgery_meta_path = Path(hf_hub_download( pretrained_model_name_or_path, "surgery_meta.json")) model_dir = surgery_meta_path.parent except Exception: return super().from_pretrained( pretrained_model_name_or_path, *args, **kwargs) with open(surgery_meta_path) as f: meta = json.load(f) config = cls.config_class.from_pretrained(pretrained_model_name_or_path) config.replaced_layers = meta.get("replaced_layers") config.surgery_variant = meta.get("variant", "conservative") model = cls(config) weights_path = model_dir / "model.pt" if not weights_path.exists(): from huggingface_hub import hf_hub_download try: weights_path = Path(hf_hub_download( pretrained_model_name_or_path, "model.pt")) except Exception: pass if weights_path.exists(): state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) model.inner_model.load_state_dict(state_dict) return model.float().eval()