Schrieffer2sy commited on
Commit
a641af6
·
1 Parent(s): be1a05c

add remote code

Browse files
README.md CHANGED
@@ -15,10 +15,59 @@ tags:
15
 
16
  + **Paper**: [Interpretable Reward Model via Sparse Autoencoder](https://arxiv.org/abs/2508.08746)
17
 
18
- + **Model**: [schrieffer/SARM-4B](https://huggingface.co/schrieffer/SARM-4B)
19
 
20
  + Finetuned from model: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
21
 
22
  + **Code Repository:** [https://github.com/schrieffer-z/sarm](https://github.com/schrieffer-z/sarm)
23
 
24
- + **Demo:** [Try SARM Demo in Huggingface Space](https://huggingface.co/spaces/Schrieffer/SARM-Demo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  + **Paper**: [Interpretable Reward Model via Sparse Autoencoder](https://arxiv.org/abs/2508.08746)
17
 
18
+ + **Model**: [schrieffer/SARM-4B](https://huggingface.co/schrieffer/Llama-SARM-4B)
19
 
20
  + Finetuned from model: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
21
 
22
  + **Code Repository:** [https://github.com/schrieffer-z/sarm](https://github.com/schrieffer-z/sarm)
23
 
24
+ + **Demo:** [Try SARM Demo in Huggingface Space](https://huggingface.co/spaces/Schrieffer/SARM-Demo)
25
+
26
+
27
+ ```python
28
+ import torch
29
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
30
+
31
+
32
+ def get_reward_score(prompt: str, response: str) -> float:
33
+ """
34
+ Receives a prompt and a response, and returns the reward score calculated by the SARM model.
35
+ """
36
+ # Use the same chat template as used during model training.
37
+ messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
38
+ # The model will handle moving inputs to the correct device automatically.
39
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
40
+
41
+ with torch.no_grad():
42
+ score = model(input_ids).logits.item()
43
+
44
+ return round(score, 4)
45
+
46
+
47
+ device = "cuda"
48
+ path = "Schrieffer/Llama-SARM-4B"
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(path)
51
+ model = AutoModelForSequenceClassification.from_pretrained(
52
+ path,
53
+ device_map=device,
54
+ trust_remote_code=True,
55
+ torch_dtype=torch.bfloat16
56
+ )
57
+
58
+ examples=[
59
+ ["What is the capital of France?", "The capital of France is Paris."],
60
+ ["What is the capital of France?", "Berlin is a large city in Germany."],
61
+ ["Write a short poem about the moon.", "Silver orb in velvet night, / Casting shadows, soft and light. / Silent watcher, distant, bright, / Guiding dreams till morning's light."],
62
+ ["Write a short poem about the moon.", "The moon is a rock."]
63
+ ],
64
+
65
+ for example in examples:
66
+ print("=".center("example"))
67
+ print("Question:\n"+example[0])
68
+ print("Answer:\n"+example[1])
69
+ print("Score:", get_reward_score(example),)
70
+
71
+ with torch.no_grad():
72
+ output = model(input_ids)
73
+ preference_score = output.score.cpu().float()
config.json CHANGED
@@ -4,12 +4,16 @@
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
 
 
 
7
  "bos_token_id": 128000,
8
  "eos_token_id": [
9
  128001,
10
  128008,
11
  128009
12
  ],
 
13
  "hidden_act": "silu",
14
  "hidden_size": 4096,
15
  "id2label": {
@@ -37,9 +41,17 @@
37
  "rope_type": "llama3"
38
  },
39
  "rope_theta": 500000.0,
 
 
 
 
 
 
 
 
40
  "tie_word_embeddings": false,
41
  "torch_dtype": "bfloat16",
42
- "transformers_version": "4.43.4",
43
  "use_cache": false,
44
  "vocab_size": 128257
45
  }
 
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModelForSequenceClassification": "modeling_sarm_llama.LlamaSARM"
9
+ },
10
  "bos_token_id": 128000,
11
  "eos_token_id": [
12
  128001,
13
  128008,
14
  128009
15
  ],
16
+ "head_dim": 128,
17
  "hidden_act": "silu",
18
  "hidden_size": 4096,
19
  "id2label": {
 
41
  "rope_type": "llama3"
42
  },
43
  "rope_theta": 500000.0,
44
+ "sarm_param": {
45
+ "sae_k": 192,
46
+ "sae_latent_size": 65536,
47
+ "sae_source_layer": 16,
48
+ "sae_use_sequence_level": false,
49
+ "sarm_train_mode": 1,
50
+ "sarm_use_topk": true
51
+ },
52
  "tie_word_embeddings": false,
53
  "torch_dtype": "bfloat16",
54
+ "transformers_version": "4.51.0",
55
  "use_cache": false,
56
  "vocab_size": 128257
57
  }
modeling_sarm_gemma2.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch import nn
4
+ from typing import List, Optional, Union, Tuple
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
+ from transformers.models.gemma2.modeling_gemma2 import (
7
+ Gemma2PreTrainedModel,
8
+ Gemma2DecoderLayer,
9
+ Gemma2RMSNorm
10
+ )
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPast,
13
+ SequenceClassifierOutputWithPast
14
+ )
15
+ from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
16
+ from transformers.cache_utils import Cache
17
+ from transformers.utils import logging
18
+
19
+ # Local
20
+ from sae import TopkSAE, pre_process, Normalized_MSE_loss, Masked_Normalized_MSE_loss
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ #==========================================================================================================================================================================
25
+ #==========================================================================================================================================================================
26
+
27
+ def get_last_assistant_masks(input_ids):
28
+ i=len(input_ids)-4
29
+ while i >= 0:
30
+ if input_ids[i:i+4] == [128006, 78191, 128007, 271]:
31
+ pos = i + 4
32
+ break
33
+ i -= 1
34
+
35
+ assistant_masks = []
36
+ for i in range(len(input_ids)):
37
+ if i < pos:
38
+ assistant_masks.append(0)
39
+ else:
40
+ assistant_masks.append(1)
41
+
42
+ assert input_ids[-1]==128009
43
+ return assistant_masks
44
+
45
+ def Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
46
+ return (((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)).mean()
47
+
48
+ def Masked_Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
49
+ mask = mask.to(torch.bfloat16)
50
+ loss = ((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)
51
+ assert loss.shape==mask.shape
52
+ seq_loss = (mask * loss).sum(-1) / (mask.sum(-1))
53
+ return seq_loss.mean()
54
+
55
+ def pre_process(hidden_stats: torch.Tensor, eps: float = 1e-6) -> tuple:
56
+ '''
57
+ :param hidden_stats: Hidden states (shape: [batch, max_length, hidden_size]).
58
+ :param eps: Epsilon value for numerical stability.
59
+ '''
60
+ mean = hidden_stats.mean(dim=-1, keepdim=True)
61
+ std = hidden_stats.std(dim=-1, keepdim=True)
62
+ x = (hidden_stats - mean) / (std + eps)
63
+ return x, mean, std
64
+
65
+ class TopkSAE(nn.Module):
66
+ '''
67
+ TopK Sparse Autoencoder Implements:
68
+ z = TopK(encoder(x - pre_bias) + latent_bias)
69
+ x_hat = decoder(z) + pre_bias
70
+ '''
71
+ def __init__(
72
+ self, hidden_size: int, latent_size: int, k: int
73
+ ) -> None:
74
+ '''
75
+ :param hidden_size: Dimensionality of the input residual stream activation.
76
+ :param latent_size: Number of latent units.
77
+ :param k: Number of activated latents.
78
+ '''
79
+
80
+ # 'sae_pre_bias', 'sae_latent_bias', 'sae_encoder.weight', 'sae_decoder.weight'
81
+
82
+ assert k <= latent_size, f'k should be less than or equal to {latent_size}'
83
+ super(TopkSAE, self).__init__()
84
+ self.pre_bias = nn.Parameter(torch.zeros(hidden_size))
85
+ self.latent_bias = nn.Parameter(torch.zeros(latent_size))
86
+ self.encoder = nn.Linear(hidden_size, latent_size, bias=False)
87
+ self.decoder = nn.Linear(latent_size, hidden_size, bias=False)
88
+
89
+ self.k = k
90
+ self.latent_size = latent_size
91
+ self.hidden_size = hidden_size
92
+
93
+ # "tied" init
94
+ # self.decoder.weight.data = self.encoder.weight.data.T.clone()
95
+
96
+ def pre_acts(self, x: torch.Tensor) -> torch.Tensor:
97
+ x = x - self.pre_bias
98
+ return self.encoder(x) + self.latent_bias
99
+
100
+ def get_latents(self, pre_acts: torch.Tensor) -> torch.Tensor:
101
+ topk = torch.topk(pre_acts, self.k, dim=-1)
102
+ latents = torch.zeros_like(pre_acts)
103
+ latents.scatter_(-1, topk.indices, topk.values)
104
+ return latents
105
+
106
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
107
+ pre_acts = self.pre_acts(x)
108
+ latents = self.get_latents(pre_acts)
109
+ return latents
110
+
111
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
112
+ return self.decoder(latents) + self.pre_bias
113
+
114
+ def forward(self, x: torch.Tensor) -> tuple:
115
+ '''
116
+ :param x: Input residual stream activation (shape: [batch_size, max_length, hidden_size]).
117
+ :return: latents (shape: [batch_size, max_length, latent_size]).
118
+ x_hat (shape: [batch_size, max_length, hidden_size]).
119
+ '''
120
+ latents = self.encode(x)
121
+ x_hat = self.decode(latents)
122
+ return latents, x_hat
123
+
124
+
125
+ #==========================================================================================================================================================================
126
+ #==========================================================================================================================================================================
127
+
128
+
129
+ class MyGemma2Model(Gemma2PreTrainedModel):
130
+ """
131
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`]
132
+
133
+ Args:
134
+ config: Gemma2Config
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ config: Gemma2Config,
140
+ ):
141
+ sae_source_layer = config.sarm_param.get("sae_source_layer", config.num_hidden_layers/2)
142
+
143
+
144
+ super().__init__(config)
145
+ self.padding_idx = config.pad_token_id
146
+ self.vocab_size = config.vocab_size
147
+
148
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
149
+ self.layers = nn.ModuleList(
150
+ [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(sae_source_layer)]
151
+ )
152
+ self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
153
+ self.gradient_checkpointing = False
154
+
155
+ # Initialize weights and apply final processing
156
+ self.post_init()
157
+
158
+ def get_input_embeddings(self):
159
+ return self.embed_tokens
160
+
161
+ def set_input_embeddings(self, value):
162
+ self.embed_tokens = value
163
+
164
+ def forward(
165
+ self,
166
+ input_ids: torch.LongTensor = None,
167
+ attention_mask: Optional[torch.Tensor] = None,
168
+ position_ids: Optional[torch.LongTensor] = None,
169
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
170
+ inputs_embeds: Optional[torch.FloatTensor] = None,
171
+ use_cache: Optional[bool] = None,
172
+ output_attentions: Optional[bool] = None,
173
+ output_hidden_states: Optional[bool] = None,
174
+ return_dict: Optional[bool] = None,
175
+ cache_position: Optional[torch.LongTensor] = None,
176
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
177
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
178
+ output_hidden_states = (
179
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
180
+ )
181
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
182
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
183
+
184
+ if (input_ids is None) ^ (inputs_embeds is not None):
185
+ raise ValueError(
186
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
187
+ )
188
+
189
+ if self.gradient_checkpointing and self.training and use_cache:
190
+ logger.warning_once(
191
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
192
+ )
193
+ use_cache = False
194
+
195
+ if inputs_embeds is None:
196
+ inputs_embeds = self.embed_tokens(input_ids)
197
+
198
+ if cache_position is None:
199
+ cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
200
+
201
+ if position_ids is None:
202
+ position_ids = cache_position.unsqueeze(0)
203
+
204
+ causal_mask = self._update_causal_mask(
205
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
206
+ )
207
+
208
+ # embed positions
209
+ hidden_states = inputs_embeds
210
+
211
+ # normalized
212
+ # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
213
+ # See https://github.com/huggingface/transformers/pull/29402
214
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
215
+ hidden_states = hidden_states * normalizer
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_self_attns = () if output_attentions else None
219
+
220
+ for decoder_layer in self.layers:
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ layer_outputs = self._gradient_checkpointing_func(
226
+ decoder_layer.__call__,
227
+ hidden_states,
228
+ causal_mask,
229
+ position_ids,
230
+ past_key_values,
231
+ output_attentions,
232
+ use_cache,
233
+ cache_position,
234
+ )
235
+ else:
236
+ layer_outputs = decoder_layer(
237
+ hidden_states,
238
+ attention_mask=causal_mask,
239
+ position_ids=position_ids,
240
+ past_key_value=past_key_values,
241
+ output_attentions=output_attentions,
242
+ use_cache=use_cache,
243
+ cache_position=cache_position,
244
+ )
245
+
246
+ hidden_states = layer_outputs[0]
247
+
248
+ if output_attentions:
249
+ all_self_attns += (layer_outputs[1],)
250
+
251
+ # hidden_states = self.norm(hidden_states)
252
+
253
+ # add hidden states from the last decoder layer
254
+ if output_hidden_states:
255
+ all_hidden_states += (hidden_states,)
256
+
257
+ next_cache = past_key_values if use_cache else None
258
+
259
+ if not return_dict:
260
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
261
+ return BaseModelOutputWithPast(
262
+ last_hidden_state=hidden_states,
263
+ past_key_values=next_cache,
264
+ hidden_states=all_hidden_states,
265
+ attentions=all_self_attns,
266
+ )
267
+
268
+ def _update_causal_mask(
269
+ self,
270
+ attention_mask: torch.Tensor,
271
+ input_tensor: torch.Tensor,
272
+ cache_position: torch.Tensor,
273
+ past_key_values: Cache,
274
+ output_attentions: bool,
275
+ ):
276
+ if self.config._attn_implementation == "flash_attention_2":
277
+ if attention_mask is not None and 0.0 in attention_mask:
278
+ return attention_mask
279
+ return None
280
+
281
+ dtype, device = input_tensor.dtype, input_tensor.device
282
+ min_dtype = torch.finfo(dtype).min
283
+ sequence_length = input_tensor.shape[1]
284
+ if past_key_values is not None:
285
+ target_length = past_key_values.get_max_length()
286
+ else:
287
+ target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
288
+
289
+ if attention_mask is not None and attention_mask.dim() == 4:
290
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
291
+ if attention_mask.max() != 0:
292
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
293
+ causal_mask = attention_mask
294
+ else:
295
+ causal_mask = torch.full(
296
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
297
+ )
298
+ if sequence_length != 1:
299
+ causal_mask = torch.triu(causal_mask, diagonal=1)
300
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
301
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
302
+ if attention_mask is not None:
303
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
304
+ mask_length = attention_mask.shape[-1]
305
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
306
+ padding_mask = padding_mask == 0
307
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
308
+ padding_mask, min_dtype
309
+ )
310
+ return causal_mask
311
+
312
+
313
+
314
+
315
+ #==========================================================================================================================================================================
316
+ #==========================================================================================================================================================================
317
+
318
+
319
+ class Gemma2SARM(Gemma2PreTrainedModel):
320
+ def __init__(
321
+ self, config, sae_hidden_state_source_layer, sae_latent_size, sae_k,
322
+ sae_use_sequence_level=False,
323
+ sarm_use_topk=False,
324
+ sarm_train_mode=1
325
+ ):
326
+ super().__init__(config)
327
+ self.num_labels = config.num_labels
328
+ self.model = MyGemma2Model(config)
329
+
330
+ self.score = nn.Linear(config.sarm_param['sae_latent_size'], self.num_labels, bias=False)
331
+ self.sae = TopkSAE(hidden_size=self.model.config.hidden_size, latent_size=config.sarm_param['sae_latent_size'], k=config.sarm_param['sae_k'])
332
+
333
+ self.sae_use_sequence_level = config.sarm_param['sae_use_sequence_level']
334
+ self.sarm_use_topk = config.sarm_param['sarm_use_topk']
335
+ self.sarm_train_mode = config.sarm_param['sarm_use_topk']
336
+
337
+ if self.sarm_train_mode==1:
338
+ for p in self.sae.parameters():
339
+ p.requires_grad_(False)
340
+
341
+ # Initialize weights and apply final processing
342
+ self.post_init()
343
+
344
+ def get_input_embeddings(self):
345
+ return self.model.embed_tokens
346
+
347
+ def set_input_embeddings(self, value):
348
+ self.model.embed_tokens = value
349
+
350
+ def forward(
351
+ self,
352
+ input_ids: torch.LongTensor = None,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ assistant_masks: Optional[torch.Tensor] = None,
355
+ position_ids: Optional[torch.LongTensor] = None,
356
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
357
+ inputs_embeds: Optional[torch.FloatTensor] = None,
358
+ labels: Optional[torch.LongTensor] = None,
359
+ use_cache: Optional[bool] = None,
360
+ output_attentions: Optional[bool] = None,
361
+ output_hidden_states: Optional[bool] = None,
362
+ return_dict: Optional[bool] = None,
363
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
364
+ r"""
365
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
366
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
367
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
368
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
369
+ """
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ transformer_outputs = self.model(
373
+ input_ids,
374
+ attention_mask=attention_mask,
375
+ position_ids=position_ids,
376
+ past_key_values=past_key_values,
377
+ inputs_embeds=inputs_embeds,
378
+ use_cache=use_cache,
379
+ output_attentions=output_attentions,
380
+ output_hidden_states=output_hidden_states,
381
+ return_dict=return_dict,
382
+ )
383
+ hidden_states = transformer_outputs[0]
384
+
385
+ h, _, _ = pre_process(hidden_states)
386
+ sae_features = self.sae.pre_acts(h)
387
+ if self.sarm_use_topk:
388
+ sae_features = self.sae.get_latents(sae_features)
389
+
390
+ logits = self.score(sae_features)
391
+
392
+ if input_ids is not None:
393
+ batch_size = input_ids.shape[0]
394
+ else:
395
+ batch_size = inputs_embeds.shape[0]
396
+
397
+ if self.config.pad_token_id is None and batch_size != 1:
398
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
399
+ if self.config.pad_token_id is None:
400
+ sequence_lengths = -1
401
+ else:
402
+ if input_ids is not None:
403
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
404
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
405
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
406
+ sequence_lengths = sequence_lengths.to(logits.device)
407
+ else:
408
+ sequence_lengths = -1
409
+
410
+ # ensure last_token is <|eot_id|>
411
+ assert ((input_ids[torch.arange(batch_size, device=logits.device), sequence_lengths]!=torch.ones(batch_size, device=logits.device)*128009).sum() == 0).item()
412
+
413
+ # joint training
414
+ rec_loss = None
415
+ if self.sarm_train_mode==2:
416
+ if not self.sarm_use_topk:
417
+ sae_features_t = self.sae.get_latents(sae_features)
418
+ h_hat = self.sae.decode(sae_features_t)
419
+ rec_loss = Masked_Normalized_MSE_loss(h, h_hat, assistant_masks)
420
+ elif self.sarm_train_mode==3 and not self.sae_use_sequence_level:
421
+ h_d = h.detach()
422
+ _, h_hat = self.sae(h_d)
423
+ rec_loss = Masked_Normalized_MSE_loss(h_d, h_hat, assistant_masks)
424
+ elif self.sarm_train_mode==3 and self.sae_use_sequence_level:
425
+ h_d = h.detach()
426
+ sequence_lengths_t = sequence_lengths.view(-1,1,1)
427
+ last_token_mask = torch.zeros([h_d.shape[0] ,1 ,h_d.shape[1]], device=h_d.device)
428
+ last_token_mask.scatter_(-1, sequence_lengths_t, torch.ones_like(sequence_lengths_t, dtype=last_token_mask.dtype))
429
+
430
+ # h_d -> (bs, seq_len, d), last_token_mask -> (bs, 1, seq_len)
431
+ h_d = torch.matmul(last_token_mask.to(h_d.dtype), h_d)
432
+
433
+ _, h_hat = self.sae(h_d)
434
+ rec_loss = Normalized_MSE_loss(h_d, h_hat)
435
+
436
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
437
+
438
+ loss = None
439
+ if labels is not None:
440
+ labels = labels.to(logits.device)
441
+ if self.config.problem_type is None:
442
+ if self.num_labels == 1:
443
+ self.config.problem_type = "regression"
444
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
445
+ self.config.problem_type = "single_label_classification"
446
+ else:
447
+ self.config.problem_type = "multi_label_classification"
448
+
449
+ if self.config.problem_type == "regression":
450
+ loss_fct = MSELoss()
451
+ if self.num_labels == 1:
452
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
453
+ else:
454
+ loss = loss_fct(pooled_logits, labels)
455
+ elif self.config.problem_type == "single_label_classification":
456
+ loss_fct = CrossEntropyLoss()
457
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
458
+ elif self.config.problem_type == "multi_label_classification":
459
+ loss_fct = BCEWithLogitsLoss()
460
+ loss = loss_fct(pooled_logits, labels)
461
+
462
+ if rec_loss is not None:
463
+ loss = rec_loss
464
+
465
+ if not return_dict:
466
+ output = (pooled_logits,) + transformer_outputs[1:]
467
+ return ((loss,) + output) if loss is not None else output
468
+
469
+ return SequenceClassifierOutputWithPast(
470
+ loss=loss,
471
+ logits=pooled_logits,
472
+ past_key_values=transformer_outputs.past_key_values,
473
+ hidden_states=transformer_outputs.hidden_states,
474
+ attentions=transformer_outputs.attentions,
475
+ )
modeling_sarm_llama.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import List, Optional, Union, Tuple
5
+ from transformers import LlamaConfig
6
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
7
+ from transformers.utils import logging
8
+ from transformers.modeling_outputs import (
9
+ SequenceClassifierOutputWithPast,
10
+ BaseModelOutputWithPast
11
+ )
12
+ from transformers.models.llama.modeling_llama import (
13
+ LlamaDecoderLayer,
14
+ LlamaRMSNorm,
15
+ LlamaRotaryEmbedding,
16
+ LlamaPreTrainedModel
17
+ )
18
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ #==========================================================================================================================================================================
24
+ #==========================================================================================================================================================================
25
+ def get_last_assistant_masks(input_ids):
26
+ i=len(input_ids)-4
27
+ while i >= 0:
28
+ if input_ids[i:i+4] == [128006, 78191, 128007, 271]:
29
+ pos = i + 4
30
+ break
31
+ i -= 1
32
+
33
+ assistant_masks = []
34
+ for i in range(len(input_ids)):
35
+ if i < pos:
36
+ assistant_masks.append(0)
37
+ else:
38
+ assistant_masks.append(1)
39
+
40
+ assert input_ids[-1]==128009
41
+ return assistant_masks
42
+
43
+ def Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
44
+ return (((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)).mean()
45
+
46
+ def Masked_Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
47
+ mask = mask.to(torch.bfloat16)
48
+ loss = ((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)
49
+ assert loss.shape==mask.shape
50
+ seq_loss = (mask * loss).sum(-1) / (mask.sum(-1))
51
+ return seq_loss.mean()
52
+
53
+ def pre_process(hidden_stats: torch.Tensor, eps: float = 1e-6) -> tuple:
54
+ '''
55
+ :param hidden_stats: Hidden states (shape: [batch, max_length, hidden_size]).
56
+ :param eps: Epsilon value for numerical stability.
57
+ '''
58
+ mean = hidden_stats.mean(dim=-1, keepdim=True)
59
+ std = hidden_stats.std(dim=-1, keepdim=True)
60
+ x = (hidden_stats - mean) / (std + eps)
61
+ return x, mean, std
62
+
63
+ class TopkSAE(nn.Module):
64
+ '''
65
+ TopK Sparse Autoencoder Implements:
66
+ z = TopK(encoder(x - pre_bias) + latent_bias)
67
+ x_hat = decoder(z) + pre_bias
68
+ '''
69
+ def __init__(
70
+ self, hidden_size: int, latent_size: int, k: int
71
+ ) -> None:
72
+ '''
73
+ :param hidden_size: Dimensionality of the input residual stream activation.
74
+ :param latent_size: Number of latent units.
75
+ :param k: Number of activated latents.
76
+ '''
77
+
78
+ # 'sae_pre_bias', 'sae_latent_bias', 'sae_encoder.weight', 'sae_decoder.weight'
79
+
80
+ assert k <= latent_size, f'k should be less than or equal to {latent_size}'
81
+ super(TopkSAE, self).__init__()
82
+ self.pre_bias = nn.Parameter(torch.zeros(hidden_size))
83
+ self.latent_bias = nn.Parameter(torch.zeros(latent_size))
84
+ self.encoder = nn.Linear(hidden_size, latent_size, bias=False)
85
+ self.decoder = nn.Linear(latent_size, hidden_size, bias=False)
86
+
87
+ self.k = k
88
+ self.latent_size = latent_size
89
+ self.hidden_size = hidden_size
90
+
91
+ # "tied" init
92
+ # self.decoder.weight.data = self.encoder.weight.data.T.clone()
93
+
94
+ def pre_acts(self, x: torch.Tensor) -> torch.Tensor:
95
+ x = x - self.pre_bias
96
+ return self.encoder(x) + self.latent_bias
97
+
98
+ def get_latents(self, pre_acts: torch.Tensor) -> torch.Tensor:
99
+ topk = torch.topk(pre_acts, self.k, dim=-1)
100
+ latents = torch.zeros_like(pre_acts)
101
+ latents.scatter_(-1, topk.indices, topk.values)
102
+ return latents
103
+
104
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
105
+ pre_acts = self.pre_acts(x)
106
+ latents = self.get_latents(pre_acts)
107
+ return latents
108
+
109
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
110
+ return self.decoder(latents) + self.pre_bias
111
+
112
+ def forward(self, x: torch.Tensor) -> tuple:
113
+ '''
114
+ :param x: Input residual stream activation (shape: [batch_size, max_length, hidden_size]).
115
+ :return: latents (shape: [batch_size, max_length, latent_size]).
116
+ x_hat (shape: [batch_size, max_length, hidden_size]).
117
+ '''
118
+ latents = self.encode(x)
119
+ x_hat = self.decode(latents)
120
+ return latents, x_hat
121
+
122
+
123
+ #==========================================================================================================================================================================
124
+ #==========================================================================================================================================================================
125
+ class MyLlamaModel(LlamaPreTrainedModel):
126
+ def __init__(
127
+ self,
128
+ config: LlamaConfig,
129
+ ):
130
+ sae_source_layer = config.sarm_param.get("sae_source_layer", config.num_hidden_layers/2)
131
+
132
+ super().__init__(config)
133
+ self.padding_idx = config.pad_token_id
134
+ self.vocab_size = config.vocab_size
135
+
136
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
137
+ self.layers = nn.ModuleList(
138
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(sae_source_layer)]
139
+ )
140
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
141
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
142
+ self.gradient_checkpointing = False
143
+ if getattr(config, "pretraining_tp", 1) != 1:
144
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
145
+
146
+ # Initialize weights and apply final processing
147
+ self.post_init()
148
+
149
+ def get_input_embeddings(self):
150
+ return self.embed_tokens
151
+
152
+ def set_input_embeddings(self, value):
153
+ self.embed_tokens = value
154
+
155
+ def forward(
156
+ self,
157
+ input_ids: torch.LongTensor = None,
158
+ attention_mask: Optional[torch.Tensor] = None,
159
+ position_ids: Optional[torch.LongTensor] = None,
160
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
161
+ inputs_embeds: Optional[torch.FloatTensor] = None,
162
+ use_cache: Optional[bool] = None,
163
+ output_attentions: Optional[bool] = None,
164
+ output_hidden_states: Optional[bool] = None,
165
+ return_dict: Optional[bool] = None,
166
+ cache_position: Optional[torch.LongTensor] = None,
167
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
168
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
169
+ output_hidden_states = (
170
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
171
+ )
172
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
173
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
174
+
175
+ if (input_ids is None) ^ (inputs_embeds is not None):
176
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
177
+
178
+ if self.gradient_checkpointing and self.training and use_cache:
179
+ logger.warning_once(
180
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
181
+ )
182
+ use_cache = False
183
+
184
+ if inputs_embeds is None:
185
+ inputs_embeds = self.embed_tokens(input_ids)
186
+
187
+ # kept for BC (non `Cache` `past_key_values` inputs)
188
+ return_legacy_cache = False
189
+ if use_cache and not isinstance(past_key_values, Cache):
190
+ return_legacy_cache = True
191
+ if past_key_values is None:
192
+ past_key_values = DynamicCache()
193
+ else:
194
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
195
+ logger.warning_once(
196
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
197
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
198
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
199
+ )
200
+
201
+ if cache_position is None:
202
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
203
+ cache_position = torch.arange(
204
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
205
+ )
206
+ if position_ids is None:
207
+ position_ids = cache_position.unsqueeze(0)
208
+
209
+ causal_mask = self._update_causal_mask(
210
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
211
+ )
212
+ hidden_states = inputs_embeds
213
+
214
+ # create position embeddings to be shared across the decoder layers
215
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
216
+
217
+ # decoder layers
218
+ all_hidden_states = () if output_hidden_states else None
219
+ all_self_attns = () if output_attentions else None
220
+ next_decoder_cache = None
221
+
222
+
223
+ for decoder_layer in self.layers:
224
+ if output_hidden_states:
225
+ all_hidden_states += (hidden_states,)
226
+
227
+ if self.gradient_checkpointing and self.training:
228
+ layer_outputs = self._gradient_checkpointing_func(
229
+ decoder_layer.__call__,
230
+ hidden_states,
231
+ causal_mask,
232
+ position_ids,
233
+ past_key_values,
234
+ output_attentions,
235
+ use_cache,
236
+ cache_position,
237
+ position_embeddings,
238
+ )
239
+ else:
240
+ layer_outputs = decoder_layer(
241
+ hidden_states,
242
+ attention_mask=causal_mask,
243
+ position_ids=position_ids,
244
+ past_key_value=past_key_values,
245
+ output_attentions=output_attentions,
246
+ use_cache=use_cache,
247
+ cache_position=cache_position,
248
+ position_embeddings=position_embeddings,
249
+ )
250
+
251
+ hidden_states = layer_outputs[0]
252
+
253
+ if use_cache:
254
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
255
+
256
+ if output_attentions:
257
+ all_self_attns += (layer_outputs[1],)
258
+
259
+ # hidden_states = self.norm(hidden_states)
260
+
261
+ # add hidden states from the last decoder layer
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ next_cache = next_decoder_cache if use_cache else None
266
+ if return_legacy_cache:
267
+ next_cache = next_cache.to_legacy_cache()
268
+
269
+ if not return_dict:
270
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
271
+ return BaseModelOutputWithPast(
272
+ last_hidden_state=hidden_states,
273
+ past_key_values=next_cache,
274
+ hidden_states=all_hidden_states,
275
+ attentions=all_self_attns,
276
+ )
277
+
278
+ def _update_causal_mask(
279
+ self,
280
+ attention_mask: torch.Tensor,
281
+ input_tensor: torch.Tensor,
282
+ cache_position: torch.Tensor,
283
+ past_key_values: Cache,
284
+ output_attentions: bool,
285
+ ):
286
+ if self.config._attn_implementation == "flash_attention_2":
287
+ if attention_mask is not None and 0.0 in attention_mask:
288
+ return attention_mask
289
+ return None
290
+
291
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
292
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
293
+ # to infer the attention mask.
294
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
295
+ using_static_cache = isinstance(past_key_values, StaticCache)
296
+
297
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
298
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
299
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
300
+ attention_mask,
301
+ inputs_embeds=input_tensor,
302
+ past_key_values_length=past_seen_tokens,
303
+ is_training=self.training,
304
+ ):
305
+ return None
306
+
307
+ dtype, device = input_tensor.dtype, input_tensor.device
308
+ sequence_length = input_tensor.shape[1]
309
+ if using_static_cache:
310
+ target_length = past_key_values.get_max_cache_shape()
311
+ else:
312
+ target_length = (
313
+ attention_mask.shape[-1]
314
+ if isinstance(attention_mask, torch.Tensor)
315
+ else past_seen_tokens + sequence_length + 1
316
+ )
317
+
318
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
319
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
320
+ attention_mask,
321
+ sequence_length=sequence_length,
322
+ target_length=target_length,
323
+ dtype=dtype,
324
+ device=device,
325
+ cache_position=cache_position,
326
+ batch_size=input_tensor.shape[0],
327
+ )
328
+
329
+ if (
330
+ self.config._attn_implementation == "sdpa"
331
+ and attention_mask is not None
332
+ and attention_mask.device.type == "cuda"
333
+ and not output_attentions
334
+ ):
335
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
336
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
337
+ # Details: https://github.com/pytorch/pytorch/issues/110213
338
+ min_dtype = torch.finfo(dtype).min
339
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
340
+
341
+ return causal_mask
342
+
343
+ @staticmethod
344
+ def _prepare_4d_causal_attention_mask_with_cache_position(
345
+ attention_mask: torch.Tensor,
346
+ sequence_length: int,
347
+ target_length: int,
348
+ dtype: torch.dtype,
349
+ device: torch.device,
350
+ cache_position: torch.Tensor,
351
+ batch_size: int,
352
+ **kwargs,
353
+ ):
354
+ """
355
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
356
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
357
+
358
+ Args:
359
+ attention_mask (`torch.Tensor`):
360
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
361
+ `(batch_size, 1, query_length, key_value_length)`.
362
+ sequence_length (`int`):
363
+ The sequence length being processed.
364
+ target_length (`int`):
365
+ The target length: when generating with static cache, the mask should be as long as the static cache,
366
+ to account for the 0 padding, the part of the cache that is not filled yet.
367
+ dtype (`torch.dtype`):
368
+ The dtype to use for the 4D attention mask.
369
+ device (`torch.device`):
370
+ The device to plcae the 4D attention mask on.
371
+ cache_position (`torch.Tensor`):
372
+ Indices depicting the position of the input sequence tokens in the sequence.
373
+ batch_size (`torch.Tensor`):
374
+ Batch size.
375
+ """
376
+ if attention_mask is not None and attention_mask.dim() == 4:
377
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
378
+ causal_mask = attention_mask
379
+ else:
380
+ min_dtype = torch.finfo(dtype).min
381
+ causal_mask = torch.full(
382
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
383
+ )
384
+ if sequence_length != 1:
385
+ causal_mask = torch.triu(causal_mask, diagonal=1)
386
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
387
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
388
+ if attention_mask is not None:
389
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
390
+ mask_length = attention_mask.shape[-1]
391
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
392
+ padding_mask = padding_mask == 0
393
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
394
+ padding_mask, min_dtype
395
+ )
396
+
397
+ return causal_mask
398
+
399
+
400
+
401
+
402
+ #==========================================================================================================================================================================
403
+ #============================================ 从LlamaForSequenceClassification为原型,修改为SAE4RM的形式 =============================================
404
+ #==========================================================================================================================================================================
405
+
406
+
407
+ class LlamaSARM(LlamaPreTrainedModel):
408
+ def __init__(
409
+ self, config
410
+ ):
411
+ super().__init__(config)
412
+ self.num_labels = config.num_labels
413
+ self.model = MyLlamaModel(config)
414
+
415
+ self.score = nn.Linear(config.sarm_param['sae_latent_size'], self.num_labels, bias=False)
416
+ self.sae = TopkSAE(hidden_size=self.model.config.hidden_size, latent_size=config.sarm_param['sae_latent_size'], k=config.sarm_param['sae_k'])
417
+
418
+ self.sae_use_sequence_level = config.sarm_param['sae_use_sequence_level']
419
+ self.sarm_use_topk = config.sarm_param['sarm_use_topk']
420
+ self.sarm_train_mode = config.sarm_param['sarm_use_topk']
421
+
422
+ if self.sarm_train_mode==0:
423
+ for p in self.model.parameters():
424
+ p.requires_grad_(False)
425
+ if self.sarm_train_mode==0 or self.sarm_train_mode==1:
426
+ for p in self.sae.parameters():
427
+ p.requires_grad_(False)
428
+
429
+ # Initialize weights and apply final processing
430
+ self.post_init()
431
+
432
+
433
+ def get_input_embeddings(self):
434
+ return self.model.embed_tokens
435
+
436
+ def set_input_embeddings(self, value):
437
+ self.model.embed_tokens = value
438
+
439
+
440
+ def forward(
441
+ self,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
+ attention_mask: Optional[torch.Tensor] = None,
444
+ assistant_masks: Optional[torch.Tensor] = None,
445
+ position_ids: Optional[torch.LongTensor] = None,
446
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
447
+ inputs_embeds: Optional[torch.FloatTensor] = None,
448
+ labels: Optional[torch.LongTensor] = None,
449
+ use_cache: Optional[bool] = None,
450
+ output_attentions: Optional[bool] = None,
451
+ output_hidden_states: Optional[bool] = None,
452
+ return_dict: Optional[bool] = None,
453
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
454
+ r"""
455
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
456
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
457
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
458
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
459
+ """
460
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
461
+
462
+ transformer_outputs = self.model(
463
+ input_ids,
464
+ attention_mask=attention_mask,
465
+ position_ids=position_ids,
466
+ past_key_values=past_key_values,
467
+ inputs_embeds=inputs_embeds,
468
+ use_cache=use_cache,
469
+ output_attentions=output_attentions,
470
+ output_hidden_states=output_hidden_states,
471
+ return_dict=return_dict,
472
+ )
473
+ hidden_states = transformer_outputs[0]
474
+
475
+
476
+ h, _, _ = pre_process(hidden_states)
477
+ sae_features = self.sae.pre_acts(h)
478
+ if self.sarm_use_topk:
479
+ sae_features = self.sae.get_latents(sae_features)
480
+
481
+
482
+ logits = self.score(sae_features)
483
+
484
+ if input_ids is not None:
485
+ batch_size = input_ids.shape[0]
486
+ else:
487
+ batch_size = inputs_embeds.shape[0]
488
+
489
+ if self.config.pad_token_id is None and batch_size != 1:
490
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
491
+ if self.config.pad_token_id is None:
492
+ sequence_lengths = -1
493
+ else:
494
+ if input_ids is not None:
495
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
496
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
497
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
498
+ sequence_lengths = sequence_lengths.to(logits.device)
499
+ else:
500
+ sequence_lengths = -1
501
+ # ensure last_token is <|eot_id|>
502
+ assert ((input_ids[torch.arange(batch_size, device=logits.device), sequence_lengths]!=torch.ones(batch_size, device=logits.device)*128009).sum() == 0).item()
503
+
504
+ # joint training
505
+ rec_loss = None
506
+ if self.sarm_train_mode==2:
507
+ if not self.sarm_use_topk:
508
+ sae_features_t = self.sae.get_latents(sae_features)
509
+ h_hat = self.sae.decode(sae_features_t)
510
+ rec_loss = Masked_Normalized_MSE_loss(h, h_hat, assistant_masks)
511
+ elif self.sarm_train_mode==3 and not self.sae_use_sequence_level:
512
+ h_d = h.detach()
513
+ _, h_hat = self.sae(h_d)
514
+ rec_loss = Masked_Normalized_MSE_loss(h_d, h_hat, assistant_masks)
515
+ elif self.sarm_train_mode==3 and self.sae_use_sequence_level:
516
+ h_d = h.detach()
517
+ sequence_lengths_t = sequence_lengths.view(-1,1,1)
518
+ last_token_mask = torch.zeros([h_d.shape[0] ,1 ,h_d.shape[1]], device=h_d.device)
519
+ last_token_mask.scatter_(-1, sequence_lengths_t, torch.ones_like(sequence_lengths_t, dtype=last_token_mask.dtype))
520
+
521
+ # h_d -> (bs, seq_len, d), last_token_mask -> (bs, 1, seq_len)
522
+ h_d = torch.matmul(last_token_mask.to(h_d.dtype), h_d)
523
+
524
+ _, h_hat = self.sae(h_d)
525
+ rec_loss = Normalized_MSE_loss(h_d, h_hat)
526
+
527
+
528
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
529
+
530
+
531
+ loss = None
532
+ if labels is not None:
533
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
534
+ if rec_loss is not None:
535
+ loss = rec_loss
536
+
537
+ if not return_dict:
538
+ output = (pooled_logits,) + transformer_outputs[1:]
539
+ return ((loss,) + output) if loss is not None else output
540
+
541
+ return SequenceClassifierOutputWithPast(
542
+ loss=loss,
543
+ logits=pooled_logits,
544
+ past_key_values=transformer_outputs.past_key_values,
545
+ hidden_states=transformer_outputs.hidden_states,
546
+ attentions=transformer_outputs.attentions,
547
+ )
tokenizer_config.json CHANGED
@@ -2061,11 +2061,12 @@
2061
  "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n",
2062
  "clean_up_tokenization_spaces": true,
2063
  "eos_token": "<|eot_id|>",
 
2064
  "model_input_names": [
2065
  "input_ids",
2066
  "attention_mask"
2067
  ],
2068
  "model_max_length": 4096,
2069
  "pad_token": "[PAD]",
2070
- "tokenizer_class": "PreTrainedTokenizerFast"
2071
  }
 
2061
  "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n",
2062
  "clean_up_tokenization_spaces": true,
2063
  "eos_token": "<|eot_id|>",
2064
+ "extra_special_tokens": {},
2065
  "model_input_names": [
2066
  "input_ids",
2067
  "attention_mask"
2068
  ],
2069
  "model_max_length": 4096,
2070
  "pad_token": "[PAD]",
2071
+ "tokenizer_class": "PreTrainedTokenizer"
2072
  }