zhangshuyi.0109 commited on
Commit
56cebd7
·
1 Parent(s): 6e02de0

update citation & evaluation

Browse files
Files changed (2) hide show
  1. README.md +40 -13
  2. modeling_sarm_gemma2.py +0 -475
README.md CHANGED
@@ -1,5 +1,7 @@
1
  ---
2
- license: apache-2.0
 
 
3
  tags:
4
  - reward-model
5
  - rlhf
@@ -11,26 +13,33 @@ pipeline_tag: reinforcement-learning
11
 
12
  # SARM: Interpretable Reward Model via Sparse Autoencoder
13
 
14
- This repository provides the official implementation and model weights for the AAAI 26 Oral Paper, 'Interpretable Reward Model via Sparse Autoencoder'.
15
 
16
- + **Authors** (\* indicates equal contribution)
 
 
 
 
17
 
18
- Shuyi Zhang\*, Wei Shi\*, Sihang Li\*, Jiayi Liao, Tao Liang, Hengxing Cai, Xiang Wang
19
 
20
  + **Paper**: [Interpretable Reward Model via Sparse Autoencoder](https://arxiv.org/abs/2508.08746)
21
 
22
- + **Model**: [Schrieffer/Llama-SARM-4B](https://huggingface.co/Schrieffer/Llama-SARM-4B)
23
-
24
- + Finetuned from model: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
25
-
26
  + **Code Repository:** [https://github.com/schrieffer-z/sarm](https://github.com/schrieffer-z/sarm)
27
 
28
  + **Demo:** [Try SARM Demo in Huggingface Space](https://huggingface.co/spaces/Schrieffer/SARM-Demo)
29
 
30
- ## Reward Bench V2 evaluation
31
- \[Official results in progress\]
32
-
33
- ## SARM inference demo
 
 
 
 
 
 
 
34
  ```python
35
 
36
  import torch
@@ -74,4 +83,22 @@ for example in examples:
74
  "+example[0])
75
  print("Answer:
76
  "+example[1])
77
- print("Score:", get_reward_score(model, example[0],example[1]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: llama3.1
3
+ base_model:
4
+ - meta-llama/Llama-3.1-8B-Instruct
5
  tags:
6
  - reward-model
7
  - rlhf
 
13
 
14
  # SARM: Interpretable Reward Model via Sparse Autoencoder
15
 
16
+ This repository contains the model weights of the AAAI 2026 Oral Paper "*Interpretable Reward Model via Sparse Autoencoder*".
17
 
18
+ ## 🔥 News
19
+ - [2025/11/8] Our paper has been accepted as an oral presentation at AAAI 2026. 🎉
20
+ - [2025/12/11] Llama-SARM-4B is ranked 18th on the [Reward Bench 2](https://huggingface.co/spaces/allenai/reward-bench) leaderboard, above GPT-4.1, Skywork-Reward-Llama-3.1-8B, and Claude-Sonnet-4!🎉
21
+ ## 🔗 Links
22
+ + **Authors**
23
 
24
+ Shuyi Zhang\*, Wei Shi\*, Sihang Li\*, Jiayi Liao, Tao Liang, Hengxing Cai, Xiang Wang
25
 
26
  + **Paper**: [Interpretable Reward Model via Sparse Autoencoder](https://arxiv.org/abs/2508.08746)
27
 
 
 
 
 
28
  + **Code Repository:** [https://github.com/schrieffer-z/sarm](https://github.com/schrieffer-z/sarm)
29
 
30
  + **Demo:** [Try SARM Demo in Huggingface Space](https://huggingface.co/spaces/Schrieffer/SARM-Demo)
31
 
32
+ ## 📊 Evaluation
33
+ Llama-SARM-4B shows competitive performance, even with a much smaller parameter size.
34
+ ### Reward Bench 2
35
+ | Rank | Model | Model Type | Score | Factuality | Precise IF | Math | Safety | Focus | Ties |
36
+ | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
37
+ | 18 | [**Schrieffer/Llama-SARM-4B**](https://huggingface.co/Schrieffer/Llama-SARM-4B) | Seq. Classifier | 73.79 | 68.74 | 42.81 | 64.48 | 91.78 | 95.56 | 79.39 |
38
+ | 22 | [openai/gpt-4.1-2025-04-14](https://huggingface.co/openai/gpt-4.1-2025-04-14) | Generative | 72.32 | 82.89 | 39.74 | 65.21 | 87.26 | 73.38 | 85.42 |
39
+ | 24 | [Skywork/Skywork-Reward-Llama-3.1-8B-v0.2](https://huggingface.co/Skywork/Skywork-Reward-Llama-3.1-8B-v0.2) | Seq. Classifier | 71.75 | 69.68 | 40.63 | 60.11 | 94.22 | 94.14 | 71.69 |
40
+ | 25 | [anthropic/claude-sonnet-4-20250514](https://huggingface.co/anthropic/claude-sonnet-4-20250514) | Generative | 71.17 | 76.12 | 35.94 | 70.49 | 89.09 | 75.96 | 79.39 |
41
+
42
+ ## SARM Inference Demo
43
  ```python
44
 
45
  import torch
 
83
  "+example[0])
84
  print("Answer:
85
  "+example[1])
86
+ print("Score:", get_reward_score(model, example[0],example[1]))
87
+ ```
88
+
89
+ ## 📧 Contact
90
+
91
+ If you have any questions, please feel free to reach us at `shuyizhang@mail.ustc.edu.cn`.
92
+
93
+ ## 📚 Citation
94
+
95
+ If you find our work useful, please cite it as follows.
96
+
97
+ ```bibtex
98
+ @article{zhang2025interpretable,
99
+ title={Interpretable Reward Model via Sparse Autoencoder},
100
+ author={Zhang, Shuyi and Shi, Wei and Li, Sihang and Liao, Jiayi and Liang, Tao and Cai, Hengxing and Wang, Xiang},
101
+ journal={arXiv preprint arXiv:2508.08746},
102
+ year={2025}
103
+ }
104
+ ```
modeling_sarm_gemma2.py DELETED
@@ -1,475 +0,0 @@
1
- import torch
2
-
3
- from torch import nn
4
- from typing import List, Optional, Union, Tuple
5
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
- from transformers.models.gemma2.modeling_gemma2 import (
7
- Gemma2PreTrainedModel,
8
- Gemma2DecoderLayer,
9
- Gemma2RMSNorm
10
- )
11
- from transformers.modeling_outputs import (
12
- BaseModelOutputWithPast,
13
- SequenceClassifierOutputWithPast
14
- )
15
- from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
16
- from transformers.cache_utils import Cache
17
- from transformers.utils import logging
18
-
19
- # Local
20
- from sae import TopkSAE, pre_process, Normalized_MSE_loss, Masked_Normalized_MSE_loss
21
-
22
- logger = logging.get_logger(__name__)
23
-
24
- #==========================================================================================================================================================================
25
- #==========================================================================================================================================================================
26
-
27
- def get_last_assistant_masks(input_ids):
28
- i=len(input_ids)-4
29
- while i >= 0:
30
- if input_ids[i:i+4] == [128006, 78191, 128007, 271]:
31
- pos = i + 4
32
- break
33
- i -= 1
34
-
35
- assistant_masks = []
36
- for i in range(len(input_ids)):
37
- if i < pos:
38
- assistant_masks.append(0)
39
- else:
40
- assistant_masks.append(1)
41
-
42
- assert input_ids[-1]==128009
43
- return assistant_masks
44
-
45
- def Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
46
- return (((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)).mean()
47
-
48
- def Masked_Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
49
- mask = mask.to(torch.bfloat16)
50
- loss = ((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)
51
- assert loss.shape==mask.shape
52
- seq_loss = (mask * loss).sum(-1) / (mask.sum(-1))
53
- return seq_loss.mean()
54
-
55
- def pre_process(hidden_stats: torch.Tensor, eps: float = 1e-6) -> tuple:
56
- '''
57
- :param hidden_stats: Hidden states (shape: [batch, max_length, hidden_size]).
58
- :param eps: Epsilon value for numerical stability.
59
- '''
60
- mean = hidden_stats.mean(dim=-1, keepdim=True)
61
- std = hidden_stats.std(dim=-1, keepdim=True)
62
- x = (hidden_stats - mean) / (std + eps)
63
- return x, mean, std
64
-
65
- class TopkSAE(nn.Module):
66
- '''
67
- TopK Sparse Autoencoder Implements:
68
- z = TopK(encoder(x - pre_bias) + latent_bias)
69
- x_hat = decoder(z) + pre_bias
70
- '''
71
- def __init__(
72
- self, hidden_size: int, latent_size: int, k: int
73
- ) -> None:
74
- '''
75
- :param hidden_size: Dimensionality of the input residual stream activation.
76
- :param latent_size: Number of latent units.
77
- :param k: Number of activated latents.
78
- '''
79
-
80
- # 'sae_pre_bias', 'sae_latent_bias', 'sae_encoder.weight', 'sae_decoder.weight'
81
-
82
- assert k <= latent_size, f'k should be less than or equal to {latent_size}'
83
- super(TopkSAE, self).__init__()
84
- self.pre_bias = nn.Parameter(torch.zeros(hidden_size))
85
- self.latent_bias = nn.Parameter(torch.zeros(latent_size))
86
- self.encoder = nn.Linear(hidden_size, latent_size, bias=False)
87
- self.decoder = nn.Linear(latent_size, hidden_size, bias=False)
88
-
89
- self.k = k
90
- self.latent_size = latent_size
91
- self.hidden_size = hidden_size
92
-
93
- # "tied" init
94
- # self.decoder.weight.data = self.encoder.weight.data.T.clone()
95
-
96
- def pre_acts(self, x: torch.Tensor) -> torch.Tensor:
97
- x = x - self.pre_bias
98
- return self.encoder(x) + self.latent_bias
99
-
100
- def get_latents(self, pre_acts: torch.Tensor) -> torch.Tensor:
101
- topk = torch.topk(pre_acts, self.k, dim=-1)
102
- latents = torch.zeros_like(pre_acts)
103
- latents.scatter_(-1, topk.indices, topk.values)
104
- return latents
105
-
106
- def encode(self, x: torch.Tensor) -> torch.Tensor:
107
- pre_acts = self.pre_acts(x)
108
- latents = self.get_latents(pre_acts)
109
- return latents
110
-
111
- def decode(self, latents: torch.Tensor) -> torch.Tensor:
112
- return self.decoder(latents) + self.pre_bias
113
-
114
- def forward(self, x: torch.Tensor) -> tuple:
115
- '''
116
- :param x: Input residual stream activation (shape: [batch_size, max_length, hidden_size]).
117
- :return: latents (shape: [batch_size, max_length, latent_size]).
118
- x_hat (shape: [batch_size, max_length, hidden_size]).
119
- '''
120
- latents = self.encode(x)
121
- x_hat = self.decode(latents)
122
- return latents, x_hat
123
-
124
-
125
- #==========================================================================================================================================================================
126
- #==========================================================================================================================================================================
127
-
128
-
129
- class MyGemma2Model(Gemma2PreTrainedModel):
130
- """
131
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`]
132
-
133
- Args:
134
- config: Gemma2Config
135
- """
136
-
137
- def __init__(
138
- self,
139
- config: Gemma2Config,
140
- ):
141
- sae_source_layer = config.sarm_param.get("sae_source_layer", config.num_hidden_layers/2)
142
-
143
-
144
- super().__init__(config)
145
- self.padding_idx = config.pad_token_id
146
- self.vocab_size = config.vocab_size
147
-
148
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
149
- self.layers = nn.ModuleList(
150
- [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(sae_source_layer)]
151
- )
152
- self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
153
- self.gradient_checkpointing = False
154
-
155
- # Initialize weights and apply final processing
156
- self.post_init()
157
-
158
- def get_input_embeddings(self):
159
- return self.embed_tokens
160
-
161
- def set_input_embeddings(self, value):
162
- self.embed_tokens = value
163
-
164
- def forward(
165
- self,
166
- input_ids: torch.LongTensor = None,
167
- attention_mask: Optional[torch.Tensor] = None,
168
- position_ids: Optional[torch.LongTensor] = None,
169
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
170
- inputs_embeds: Optional[torch.FloatTensor] = None,
171
- use_cache: Optional[bool] = None,
172
- output_attentions: Optional[bool] = None,
173
- output_hidden_states: Optional[bool] = None,
174
- return_dict: Optional[bool] = None,
175
- cache_position: Optional[torch.LongTensor] = None,
176
- ) -> Union[Tuple, BaseModelOutputWithPast]:
177
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
178
- output_hidden_states = (
179
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
180
- )
181
- use_cache = use_cache if use_cache is not None else self.config.use_cache
182
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
183
-
184
- if (input_ids is None) ^ (inputs_embeds is not None):
185
- raise ValueError(
186
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
187
- )
188
-
189
- if self.gradient_checkpointing and self.training and use_cache:
190
- logger.warning_once(
191
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
192
- )
193
- use_cache = False
194
-
195
- if inputs_embeds is None:
196
- inputs_embeds = self.embed_tokens(input_ids)
197
-
198
- if cache_position is None:
199
- cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
200
-
201
- if position_ids is None:
202
- position_ids = cache_position.unsqueeze(0)
203
-
204
- causal_mask = self._update_causal_mask(
205
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
206
- )
207
-
208
- # embed positions
209
- hidden_states = inputs_embeds
210
-
211
- # normalized
212
- # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
213
- # See https://github.com/huggingface/transformers/pull/29402
214
- normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
215
- hidden_states = hidden_states * normalizer
216
-
217
- all_hidden_states = () if output_hidden_states else None
218
- all_self_attns = () if output_attentions else None
219
-
220
- for decoder_layer in self.layers:
221
- if output_hidden_states:
222
- all_hidden_states += (hidden_states,)
223
-
224
- if self.gradient_checkpointing and self.training:
225
- layer_outputs = self._gradient_checkpointing_func(
226
- decoder_layer.__call__,
227
- hidden_states,
228
- causal_mask,
229
- position_ids,
230
- past_key_values,
231
- output_attentions,
232
- use_cache,
233
- cache_position,
234
- )
235
- else:
236
- layer_outputs = decoder_layer(
237
- hidden_states,
238
- attention_mask=causal_mask,
239
- position_ids=position_ids,
240
- past_key_value=past_key_values,
241
- output_attentions=output_attentions,
242
- use_cache=use_cache,
243
- cache_position=cache_position,
244
- )
245
-
246
- hidden_states = layer_outputs[0]
247
-
248
- if output_attentions:
249
- all_self_attns += (layer_outputs[1],)
250
-
251
- # hidden_states = self.norm(hidden_states)
252
-
253
- # add hidden states from the last decoder layer
254
- if output_hidden_states:
255
- all_hidden_states += (hidden_states,)
256
-
257
- next_cache = past_key_values if use_cache else None
258
-
259
- if not return_dict:
260
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
261
- return BaseModelOutputWithPast(
262
- last_hidden_state=hidden_states,
263
- past_key_values=next_cache,
264
- hidden_states=all_hidden_states,
265
- attentions=all_self_attns,
266
- )
267
-
268
- def _update_causal_mask(
269
- self,
270
- attention_mask: torch.Tensor,
271
- input_tensor: torch.Tensor,
272
- cache_position: torch.Tensor,
273
- past_key_values: Cache,
274
- output_attentions: bool,
275
- ):
276
- if self.config._attn_implementation == "flash_attention_2":
277
- if attention_mask is not None and 0.0 in attention_mask:
278
- return attention_mask
279
- return None
280
-
281
- dtype, device = input_tensor.dtype, input_tensor.device
282
- min_dtype = torch.finfo(dtype).min
283
- sequence_length = input_tensor.shape[1]
284
- if past_key_values is not None:
285
- target_length = past_key_values.get_max_length()
286
- else:
287
- target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
288
-
289
- if attention_mask is not None and attention_mask.dim() == 4:
290
- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
291
- if attention_mask.max() != 0:
292
- raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
293
- causal_mask = attention_mask
294
- else:
295
- causal_mask = torch.full(
296
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
297
- )
298
- if sequence_length != 1:
299
- causal_mask = torch.triu(causal_mask, diagonal=1)
300
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
301
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
302
- if attention_mask is not None:
303
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
304
- mask_length = attention_mask.shape[-1]
305
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
306
- padding_mask = padding_mask == 0
307
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
308
- padding_mask, min_dtype
309
- )
310
- return causal_mask
311
-
312
-
313
-
314
-
315
- #==========================================================================================================================================================================
316
- #==========================================================================================================================================================================
317
-
318
-
319
- class Gemma2SARM(Gemma2PreTrainedModel):
320
- def __init__(
321
- self, config, sae_hidden_state_source_layer, sae_latent_size, sae_k,
322
- sae_use_sequence_level=False,
323
- sarm_use_topk=False,
324
- sarm_train_mode=1
325
- ):
326
- super().__init__(config)
327
- self.num_labels = config.num_labels
328
- self.model = MyGemma2Model(config)
329
-
330
- self.score = nn.Linear(config.sarm_param['sae_latent_size'], self.num_labels, bias=False)
331
- self.sae = TopkSAE(hidden_size=self.model.config.hidden_size, latent_size=config.sarm_param['sae_latent_size'], k=config.sarm_param['sae_k'])
332
-
333
- self.sae_use_sequence_level = config.sarm_param['sae_use_sequence_level']
334
- self.sarm_use_topk = config.sarm_param['sarm_use_topk']
335
- self.sarm_train_mode = config.sarm_param['sarm_use_topk']
336
-
337
- if self.sarm_train_mode==1:
338
- for p in self.sae.parameters():
339
- p.requires_grad_(False)
340
-
341
- # Initialize weights and apply final processing
342
- self.post_init()
343
-
344
- def get_input_embeddings(self):
345
- return self.model.embed_tokens
346
-
347
- def set_input_embeddings(self, value):
348
- self.model.embed_tokens = value
349
-
350
- def forward(
351
- self,
352
- input_ids: torch.LongTensor = None,
353
- attention_mask: Optional[torch.Tensor] = None,
354
- assistant_masks: Optional[torch.Tensor] = None,
355
- position_ids: Optional[torch.LongTensor] = None,
356
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
357
- inputs_embeds: Optional[torch.FloatTensor] = None,
358
- labels: Optional[torch.LongTensor] = None,
359
- use_cache: Optional[bool] = None,
360
- output_attentions: Optional[bool] = None,
361
- output_hidden_states: Optional[bool] = None,
362
- return_dict: Optional[bool] = None,
363
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
364
- r"""
365
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
366
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
367
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
368
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
369
- """
370
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
-
372
- transformer_outputs = self.model(
373
- input_ids,
374
- attention_mask=attention_mask,
375
- position_ids=position_ids,
376
- past_key_values=past_key_values,
377
- inputs_embeds=inputs_embeds,
378
- use_cache=use_cache,
379
- output_attentions=output_attentions,
380
- output_hidden_states=output_hidden_states,
381
- return_dict=return_dict,
382
- )
383
- hidden_states = transformer_outputs[0]
384
-
385
- h, _, _ = pre_process(hidden_states)
386
- sae_features = self.sae.pre_acts(h)
387
- if self.sarm_use_topk:
388
- sae_features = self.sae.get_latents(sae_features)
389
-
390
- logits = self.score(sae_features)
391
-
392
- if input_ids is not None:
393
- batch_size = input_ids.shape[0]
394
- else:
395
- batch_size = inputs_embeds.shape[0]
396
-
397
- if self.config.pad_token_id is None and batch_size != 1:
398
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
399
- if self.config.pad_token_id is None:
400
- sequence_lengths = -1
401
- else:
402
- if input_ids is not None:
403
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
404
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
405
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
406
- sequence_lengths = sequence_lengths.to(logits.device)
407
- else:
408
- sequence_lengths = -1
409
-
410
- # ensure last_token is <|eot_id|>
411
- assert ((input_ids[torch.arange(batch_size, device=logits.device), sequence_lengths]!=torch.ones(batch_size, device=logits.device)*128009).sum() == 0).item()
412
-
413
- # joint training
414
- rec_loss = None
415
- if self.sarm_train_mode==2:
416
- if not self.sarm_use_topk:
417
- sae_features_t = self.sae.get_latents(sae_features)
418
- h_hat = self.sae.decode(sae_features_t)
419
- rec_loss = Masked_Normalized_MSE_loss(h, h_hat, assistant_masks)
420
- elif self.sarm_train_mode==3 and not self.sae_use_sequence_level:
421
- h_d = h.detach()
422
- _, h_hat = self.sae(h_d)
423
- rec_loss = Masked_Normalized_MSE_loss(h_d, h_hat, assistant_masks)
424
- elif self.sarm_train_mode==3 and self.sae_use_sequence_level:
425
- h_d = h.detach()
426
- sequence_lengths_t = sequence_lengths.view(-1,1,1)
427
- last_token_mask = torch.zeros([h_d.shape[0] ,1 ,h_d.shape[1]], device=h_d.device)
428
- last_token_mask.scatter_(-1, sequence_lengths_t, torch.ones_like(sequence_lengths_t, dtype=last_token_mask.dtype))
429
-
430
- # h_d -> (bs, seq_len, d), last_token_mask -> (bs, 1, seq_len)
431
- h_d = torch.matmul(last_token_mask.to(h_d.dtype), h_d)
432
-
433
- _, h_hat = self.sae(h_d)
434
- rec_loss = Normalized_MSE_loss(h_d, h_hat)
435
-
436
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
437
-
438
- loss = None
439
- if labels is not None:
440
- labels = labels.to(logits.device)
441
- if self.config.problem_type is None:
442
- if self.num_labels == 1:
443
- self.config.problem_type = "regression"
444
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
445
- self.config.problem_type = "single_label_classification"
446
- else:
447
- self.config.problem_type = "multi_label_classification"
448
-
449
- if self.config.problem_type == "regression":
450
- loss_fct = MSELoss()
451
- if self.num_labels == 1:
452
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
453
- else:
454
- loss = loss_fct(pooled_logits, labels)
455
- elif self.config.problem_type == "single_label_classification":
456
- loss_fct = CrossEntropyLoss()
457
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
458
- elif self.config.problem_type == "multi_label_classification":
459
- loss_fct = BCEWithLogitsLoss()
460
- loss = loss_fct(pooled_logits, labels)
461
-
462
- if rec_loss is not None:
463
- loss = rec_loss
464
-
465
- if not return_dict:
466
- output = (pooled_logits,) + transformer_outputs[1:]
467
- return ((loss,) + output) if loss is not None else output
468
-
469
- return SequenceClassifierOutputWithPast(
470
- loss=loss,
471
- logits=pooled_logits,
472
- past_key_values=transformer_outputs.past_key_values,
473
- hidden_states=transformer_outputs.hidden_states,
474
- attentions=transformer_outputs.attentions,
475
- )