SixOpen commited on
Commit
0f92ffc
·
verified ·
1 Parent(s): 503adbb

Update modeling_hare.py

Browse files
Files changed (1) hide show
  1. modeling_hare.py +113 -98
modeling_hare.py CHANGED
@@ -1,98 +1,113 @@
1
- import json
2
- from pathlib import Path
3
-
4
- import torch
5
- from transformers import AutoModel, AutoConfig, PreTrainedModel
6
- from transformers.modeling_outputs import BaseModelOutput
7
-
8
- from .configuration_hare import HareConfig
9
- from .birwkv7 import BiRWKV7Layer, init_from_attention
10
-
11
-
12
- def _find_encoder(model):
13
- for attr in ['encoder', 'model']:
14
- if hasattr(model, attr):
15
- candidate = getattr(model, attr)
16
- if hasattr(candidate, 'layers'):
17
- return candidate
18
- if hasattr(model, 'layers'):
19
- return model
20
- raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
21
-
22
-
23
- def _perform_surgery(model, replaced_layers, hidden_size, num_heads):
24
- encoder = _find_encoder(model)
25
- for layer_idx_str, info in replaced_layers.items():
26
- layer_idx = int(layer_idx_str)
27
- layer = encoder.layers[layer_idx]
28
- attn = None
29
- attn_name = None
30
- for name in ['attn', 'attention', 'self_attn', 'self_attention']:
31
- if hasattr(layer, name):
32
- attn = getattr(layer, name)
33
- attn_name = name
34
- break
35
- if attn is None:
36
- continue
37
- birwkv = BiRWKV7Layer(hidden_size, num_heads)
38
- device = next(attn.parameters()).device
39
- dtype = next(attn.parameters()).dtype
40
- birwkv = birwkv.to(device=device, dtype=dtype)
41
- setattr(layer, attn_name, birwkv)
42
-
43
-
44
- class HareModel(PreTrainedModel):
45
- config_class = HareConfig
46
-
47
- def __init__(self, config):
48
- super().__init__(config)
49
- base_config = AutoConfig.from_pretrained(
50
- "answerdotai/ModernBERT-base",
51
- hidden_size=config.hidden_size,
52
- num_attention_heads=config.num_attention_heads,
53
- num_hidden_layers=config.num_hidden_layers,
54
- intermediate_size=config.intermediate_size,
55
- vocab_size=config.vocab_size,
56
- max_position_embeddings=config.max_position_embeddings,
57
- )
58
- self.inner_model = AutoModel.from_config(base_config)
59
-
60
- if config.replaced_layers:
61
- _perform_surgery(
62
- self.inner_model,
63
- config.replaced_layers,
64
- config.hidden_size,
65
- config.num_attention_heads,
66
- )
67
-
68
- def forward(self, input_ids=None, attention_mask=None, **kwargs):
69
- outputs = self.inner_model(
70
- input_ids=input_ids,
71
- attention_mask=attention_mask,
72
- **kwargs,
73
- )
74
- return outputs
75
-
76
- @classmethod
77
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
78
- model_dir = Path(pretrained_model_name_or_path)
79
- surgery_meta_path = model_dir / "surgery_meta.json"
80
-
81
- if surgery_meta_path.exists():
82
- with open(surgery_meta_path) as f:
83
- meta = json.load(f)
84
-
85
- config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
86
- config.replaced_layers = meta.get("replaced_layers")
87
- config.surgery_variant = meta.get("variant", "conservative")
88
-
89
- model = cls(config)
90
-
91
- weights_path = model_dir / "model.pt"
92
- if weights_path.exists():
93
- state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
94
- model.inner_model.load_state_dict(state_dict)
95
-
96
- return model.float().eval()
97
-
98
- return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from transformers import AutoModel, AutoConfig, PreTrainedModel
6
+ from transformers.modeling_outputs import BaseModelOutput
7
+
8
+ from .configuration_hare import HareConfig
9
+ from .birwkv7 import BiRWKV7Layer, init_from_attention
10
+
11
+
12
+ def _find_encoder(model):
13
+ for attr in ['encoder', 'model']:
14
+ if hasattr(model, attr):
15
+ candidate = getattr(model, attr)
16
+ if hasattr(candidate, 'layers'):
17
+ return candidate
18
+ if hasattr(model, 'layers'):
19
+ return model
20
+ raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}")
21
+
22
+
23
+ def _perform_surgery(model, replaced_layers, hidden_size, num_heads):
24
+ encoder = _find_encoder(model)
25
+ for layer_idx_str, info in replaced_layers.items():
26
+ layer_idx = int(layer_idx_str)
27
+ layer = encoder.layers[layer_idx]
28
+ attn = None
29
+ attn_name = None
30
+ for name in ['attn', 'attention', 'self_attn', 'self_attention']:
31
+ if hasattr(layer, name):
32
+ attn = getattr(layer, name)
33
+ attn_name = name
34
+ break
35
+ if attn is None:
36
+ continue
37
+ birwkv = BiRWKV7Layer(hidden_size, num_heads)
38
+ device = next(attn.parameters()).device
39
+ dtype = next(attn.parameters()).dtype
40
+ birwkv = birwkv.to(device=device, dtype=dtype)
41
+ setattr(layer, attn_name, birwkv)
42
+
43
+
44
+ class HareModel(PreTrainedModel):
45
+ config_class = HareConfig
46
+
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+ base_config = AutoConfig.from_pretrained(
50
+ "answerdotai/ModernBERT-base",
51
+ hidden_size=config.hidden_size,
52
+ num_attention_heads=config.num_attention_heads,
53
+ num_hidden_layers=config.num_hidden_layers,
54
+ intermediate_size=config.intermediate_size,
55
+ vocab_size=config.vocab_size,
56
+ max_position_embeddings=config.max_position_embeddings,
57
+ )
58
+ self.inner_model = AutoModel.from_config(base_config)
59
+
60
+ if config.replaced_layers:
61
+ _perform_surgery(
62
+ self.inner_model,
63
+ config.replaced_layers,
64
+ config.hidden_size,
65
+ config.num_attention_heads,
66
+ )
67
+
68
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
69
+ outputs = self.inner_model(
70
+ input_ids=input_ids,
71
+ attention_mask=attention_mask,
72
+ **kwargs,
73
+ )
74
+ return outputs
75
+
76
+ @classmethod
77
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
78
+ model_dir = Path(pretrained_model_name_or_path)
79
+ surgery_meta_path = model_dir / "surgery_meta.json"
80
+
81
+ if not surgery_meta_path.exists():
82
+ from huggingface_hub import hf_hub_download
83
+ try:
84
+ surgery_meta_path = Path(hf_hub_download(
85
+ pretrained_model_name_or_path, "surgery_meta.json"))
86
+ model_dir = surgery_meta_path.parent
87
+ except Exception:
88
+ return super().from_pretrained(
89
+ pretrained_model_name_or_path, *args, **kwargs)
90
+
91
+ with open(surgery_meta_path) as f:
92
+ meta = json.load(f)
93
+
94
+ config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
95
+ config.replaced_layers = meta.get("replaced_layers")
96
+ config.surgery_variant = meta.get("variant", "conservative")
97
+
98
+ model = cls(config)
99
+
100
+ weights_path = model_dir / "model.pt"
101
+ if not weights_path.exists():
102
+ from huggingface_hub import hf_hub_download
103
+ try:
104
+ weights_path = Path(hf_hub_download(
105
+ pretrained_model_name_or_path, "model.pt"))
106
+ except Exception:
107
+ pass
108
+
109
+ if weights_path.exists():
110
+ state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
111
+ model.inner_model.load_state_dict(state_dict)
112
+
113
+ return model.float().eval()