| | import json |
| | import logging |
| | from pathlib import Path |
| |
|
| | import torch |
| | from transformers import AutoModel, PreTrainedModel |
| | from transformers import ModernBertConfig |
| |
|
| | for _logger_name in ["transformers.modeling_utils", "transformers.configuration_utils"]: |
| | logging.getLogger(_logger_name).setLevel(logging.ERROR) |
| |
|
| | from .configuration_hare import HareConfig |
| | from .birwkv7 import BiRWKV7Layer, init_from_attention |
| |
|
| |
|
| | def _find_encoder(model): |
| | for attr in ['encoder', 'model']: |
| | if hasattr(model, attr): |
| | candidate = getattr(model, attr) |
| | if hasattr(candidate, 'layers'): |
| | return candidate |
| | if hasattr(model, 'layers'): |
| | return model |
| | raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}") |
| |
|
| |
|
| | def _perform_surgery(model, replaced_layers, hidden_size, num_heads): |
| | encoder = _find_encoder(model) |
| | for layer_idx_str, info in replaced_layers.items(): |
| | layer_idx = int(layer_idx_str) |
| | layer = encoder.layers[layer_idx] |
| | attn = None |
| | attn_name = None |
| | for name in ['attn', 'attention', 'self_attn', 'self_attention']: |
| | if hasattr(layer, name): |
| | attn = getattr(layer, name) |
| | attn_name = name |
| | break |
| | if attn is None: |
| | continue |
| | birwkv = BiRWKV7Layer(hidden_size, num_heads) |
| | device = next(attn.parameters()).device |
| | dtype = next(attn.parameters()).dtype |
| | birwkv = birwkv.to(device=device, dtype=dtype) |
| | setattr(layer, attn_name, birwkv) |
| |
|
| |
|
| | class HareModel(PreTrainedModel): |
| | config_class = HareConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | base_config = ModernBertConfig( |
| | hidden_size=config.hidden_size, |
| | num_attention_heads=config.num_attention_heads, |
| | num_hidden_layers=config.num_hidden_layers, |
| | intermediate_size=config.intermediate_size, |
| | vocab_size=config.vocab_size, |
| | max_position_embeddings=config.max_position_embeddings, |
| | pad_token_id=config.pad_token_id, |
| | bos_token_id=config.bos_token_id, |
| | eos_token_id=config.eos_token_id, |
| | cls_token_id=getattr(config, 'cls_token_id', config.bos_token_id), |
| | sep_token_id=getattr(config, 'sep_token_id', config.eos_token_id), |
| | global_attn_every_n_layers=getattr(config, 'global_attn_every_n_layers', 3), |
| | local_attention=getattr(config, 'local_attention', 128), |
| | ) |
| | self.inner_model = AutoModel.from_config(base_config) |
| |
|
| | if config.replaced_layers: |
| | _perform_surgery( |
| | self.inner_model, |
| | config.replaced_layers, |
| | config.hidden_size, |
| | config.num_attention_heads, |
| | ) |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| | outputs = self.inner_model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **kwargs, |
| | ) |
| | return outputs |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| | model_dir = Path(pretrained_model_name_or_path) |
| | surgery_meta_path = model_dir / "surgery_meta.json" |
| |
|
| | if not surgery_meta_path.exists(): |
| | from huggingface_hub import hf_hub_download |
| | try: |
| | surgery_meta_path = Path(hf_hub_download( |
| | pretrained_model_name_or_path, "surgery_meta.json")) |
| | model_dir = surgery_meta_path.parent |
| | except Exception: |
| | return super().from_pretrained( |
| | pretrained_model_name_or_path, *args, **kwargs) |
| |
|
| | with open(surgery_meta_path) as f: |
| | meta = json.load(f) |
| |
|
| | config = cls.config_class.from_pretrained(pretrained_model_name_or_path) |
| | config.replaced_layers = meta.get("replaced_layers") |
| | config.surgery_variant = meta.get("variant", "conservative") |
| |
|
| | model = cls(config) |
| |
|
| | weights_path = model_dir / "model.pt" |
| | if not weights_path.exists(): |
| | from huggingface_hub import hf_hub_download |
| | try: |
| | weights_path = Path(hf_hub_download( |
| | pretrained_model_name_or_path, "model.pt")) |
| | except Exception: |
| | pass |
| |
|
| | if weights_path.exists(): |
| | state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) |
| | model.inner_model.load_state_dict(state_dict) |
| |
|
| | return model.float().eval() |
| |
|