File size: 4,572 Bytes
0f92ffc
7d6b779
0f92ffc
 
 
7d6b779
 
 
 
 
0f92ffc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d6b779
0f92ffc
 
 
 
 
 
7d6b779
 
 
 
 
 
 
0f92ffc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import logging
from pathlib import Path

import torch
from transformers import AutoModel, PreTrainedModel
from transformers import ModernBertConfig

for _logger_name in ["transformers.modeling_utils", "transformers.configuration_utils"]:
    logging.getLogger(_logger_name).setLevel(logging.ERROR)

from .configuration_hare import HareConfig
from .birwkv7 import BiRWKV7Layer, init_from_attention


def _find_encoder(model):
    for attr in ['encoder', 'model']:
        if hasattr(model, attr):
            candidate = getattr(model, attr)
            if hasattr(candidate, 'layers'):
                return candidate
    if hasattr(model, 'layers'):
        return model
    raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")


def _perform_surgery(model, replaced_layers, hidden_size, num_heads):
    encoder = _find_encoder(model)
    for layer_idx_str, info in replaced_layers.items():
        layer_idx = int(layer_idx_str)
        layer = encoder.layers[layer_idx]
        attn = None
        attn_name = None
        for name in ['attn', 'attention', 'self_attn', 'self_attention']:
            if hasattr(layer, name):
                attn = getattr(layer, name)
                attn_name = name
                break
        if attn is None:
            continue
        birwkv = BiRWKV7Layer(hidden_size, num_heads)
        device = next(attn.parameters()).device
        dtype = next(attn.parameters()).dtype
        birwkv = birwkv.to(device=device, dtype=dtype)
        setattr(layer, attn_name, birwkv)


class HareModel(PreTrainedModel):
    config_class = HareConfig

    def __init__(self, config):
        super().__init__(config)
        base_config = ModernBertConfig(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            num_hidden_layers=config.num_hidden_layers,
            intermediate_size=config.intermediate_size,
            vocab_size=config.vocab_size,
            max_position_embeddings=config.max_position_embeddings,
            pad_token_id=config.pad_token_id,
            bos_token_id=config.bos_token_id,
            eos_token_id=config.eos_token_id,
            cls_token_id=getattr(config, 'cls_token_id', config.bos_token_id),
            sep_token_id=getattr(config, 'sep_token_id', config.eos_token_id),
            global_attn_every_n_layers=getattr(config, 'global_attn_every_n_layers', 3),
            local_attention=getattr(config, 'local_attention', 128),
        )
        self.inner_model = AutoModel.from_config(base_config)

        if config.replaced_layers:
            _perform_surgery(
                self.inner_model,
                config.replaced_layers,
                config.hidden_size,
                config.num_attention_heads,
            )

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        outputs = self.inner_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )
        return outputs

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        model_dir = Path(pretrained_model_name_or_path)
        surgery_meta_path = model_dir / "surgery_meta.json"

        if not surgery_meta_path.exists():
            from huggingface_hub import hf_hub_download
            try:
                surgery_meta_path = Path(hf_hub_download(
                    pretrained_model_name_or_path, "surgery_meta.json"))
                model_dir = surgery_meta_path.parent
            except Exception:
                return super().from_pretrained(
                    pretrained_model_name_or_path, *args, **kwargs)

        with open(surgery_meta_path) as f:
            meta = json.load(f)

        config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
        config.replaced_layers = meta.get("replaced_layers")
        config.surgery_variant = meta.get("variant", "conservative")

        model = cls(config)

        weights_path = model_dir / "model.pt"
        if not weights_path.exists():
            from huggingface_hub import hf_hub_download
            try:
                weights_path = Path(hf_hub_download(
                    pretrained_model_name_or_path, "model.pt"))
            except Exception:
                pass

        if weights_path.exists():
            state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
            model.inner_model.load_state_dict(state_dict)

        return model.float().eval()