annnli commited on
Commit
da73d19
·
verified ·
1 Parent(s): 4991f82

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_roberta_cl.py +389 -46
modeling_roberta_cl.py CHANGED
@@ -2,10 +2,36 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import torch.distributed as dist
 
5
 
6
  import transformers
7
- from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaClassificationHead
8
- from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class ResidualBlock(nn.Module):
11
  def __init__(self, dim):
@@ -72,6 +98,70 @@ class RobertaClassificationHeadForEmbedding(RobertaClassificationHead):
72
  # x = self.dropout(x)
73
  # x = self.out_proj(x)
74
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def cl_init(cls, config):
77
  """
@@ -104,6 +194,8 @@ def cl_forward(cls,
104
  output_attentions=None,
105
  output_hidden_states=None,
106
  return_dict=None,
 
 
107
  latter_sentiment_spoof_mask=None,
108
  ):
109
  return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
@@ -112,29 +204,97 @@ def cl_forward(cls,
112
  # original + cls.model_args.num_paraphrased + cls.model_args.num_negative
113
  num_sent = input_ids.size(1)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # Flatten input for encoding
116
  input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
117
  attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
118
  if token_type_ids is not None:
119
  token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
120
 
121
- # Get raw embeddings
122
- outputs = cls.roberta(
123
- input_ids,
124
- attention_mask=attention_mask,
125
- token_type_ids=token_type_ids,
126
- position_ids=position_ids,
127
- head_mask=head_mask,
128
- inputs_embeds=inputs_embeds,
129
- output_attentions=output_attentions,
130
- output_hidden_states=False,
131
- return_dict=True,
132
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Pooling
135
- sequence_output = outputs[0] # (bs*num_sent, seq_len, hidden)
136
- pooler_output = cls.classifier(sequence_output) # (bs*num_sent, hidden)
137
- pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # Mapping
140
  pooler_output = cls.map(pooler_output) # (bs, num_sent, hidden_states)
@@ -150,6 +310,11 @@ def cl_forward(cls,
150
  # Gather all embeddings if using distributed training
151
  if dist.is_initialized() and cls.training:
152
  raise NotImplementedError
 
 
 
 
 
153
 
154
  # get sign value before calculating similarity
155
  original = torch.tanh(original * 1000)
@@ -160,21 +325,61 @@ def cl_forward(cls,
160
  for cname, n in zip(spoofing_cnames, negative_list):
161
  negative_dict[cname] = n
162
 
163
- # Calculate triplet loss
164
- loss_triplet = 0
165
- for i in range(batch_size):
166
- for j in range(cls.model_args.num_paraphrased):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  for cname in spoofing_cnames:
168
  if cname == 'latter_sentiment_spoof_0' and latter_sentiment_spoof_mask[i] == 0:
169
  continue
170
- ori = original[i]
171
- pos = paraphrase_list[j][i]
172
- neg = negative_dict[cname][i]
173
- loss_triplet += F.relu(cls.sim(ori, neg) * cls.model_args.temp - cls.sim(ori, pos) * cls.model_args.temp + cls.model_args.margin)
174
- loss_triplet /= (batch_size * cls.model_args.num_paraphrased * len(spoofing_cnames))
 
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  # Calculate loss for uniform perturbation and unbiased token preference
177
  def sign_loss(x):
 
178
  row = torch.abs(torch.mean(torch.mean(x, dim=0)))
179
  col = torch.abs(torch.mean(torch.mean(x, dim=1)))
180
  return (row + col)/2
@@ -185,6 +390,8 @@ def cl_forward(cls,
185
  loss_3_list = [cls.sim(original, p).unsqueeze(1) for p in paraphrase_list] # [(bs, 1)] * num_paraphrased
186
  loss_3_tensor = torch.cat(loss_3_list, dim=1) # (bs, num_paraphrased)
187
  loss_3 = loss_3_tensor.mean() * cls.model_args.temp
 
 
188
 
189
  # calculate loss_sent: similarity between original and sentiment spoofed text
190
  negative_sample_loss = {}
@@ -202,7 +409,14 @@ def cl_forward(cls,
202
  ori_ori_cos_removed = remove_diagonal_elements(ori_ori_cos) # (bs, bs-1)
203
  loss_5 = ori_ori_cos_removed.mean() * cls.model_args.temp
204
 
205
- loss = loss_gr + loss_triplet
 
 
 
 
 
 
 
206
 
207
  result = {
208
  'loss': loss,
@@ -217,7 +431,10 @@ def cl_forward(cls,
217
  key = f"sim_{cname.replace('_spoof_0', '')}"
218
  result[key] = l
219
 
220
- result['loss_tl'] = loss_triplet
 
 
 
221
 
222
  if not return_dict:
223
  raise NotImplementedError
@@ -238,23 +455,60 @@ def sentemb_forward(
238
  output_attentions=None,
239
  output_hidden_states=None,
240
  return_dict=None,
 
 
241
  ):
242
 
243
  return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
244
 
245
- outputs = cls.roberta(
246
- input_ids,
247
- attention_mask=attention_mask,
248
- token_type_ids=token_type_ids,
249
- position_ids=position_ids,
250
- head_mask=head_mask,
251
- inputs_embeds=inputs_embeds,
252
- output_attentions=output_attentions,
253
- output_hidden_states=False,
254
- return_dict=True,
255
- )
256
- sequence_output = outputs[0]
257
- pooler_output = cls.classifier(sequence_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  # Mapping
260
  mapping_output = cls.map(pooler_output)
@@ -276,18 +530,103 @@ class RobertaForCL(RobertaForSequenceClassification):
276
 
277
  def __init__(self, config, *model_args, **model_kargs):
278
  super().__init__(config)
279
- self.model_args = model_kargs.get("model_args", None)
280
 
281
  self.classifier = RobertaClassificationHeadForEmbedding(config)
282
 
283
- if self.model_args:
284
- cl_init(self, config)
285
 
286
  self.map = SemanticModel(input_dim=768)
287
-
 
 
 
 
 
 
 
 
288
  # Initialize weights and apply final processing
289
  self.post_init()
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  def forward(self,
292
  input_ids=None,
293
  attention_mask=None,
@@ -300,6 +639,8 @@ class RobertaForCL(RobertaForSequenceClassification):
300
  output_hidden_states=None,
301
  return_dict=None,
302
  sent_emb=False,
 
 
303
  latter_sentiment_spoof_mask=None,
304
  ):
305
  if sent_emb:
@@ -327,6 +668,8 @@ class RobertaForCL(RobertaForSequenceClassification):
327
  output_attentions=output_attentions,
328
  output_hidden_states=output_hidden_states,
329
  return_dict=return_dict,
 
 
330
  latter_sentiment_spoof_mask=latter_sentiment_spoof_mask,
331
  )
332
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import torch.distributed as dist
5
+ from torch import Tensor
6
 
7
  import transformers
8
+ from transformers import RobertaTokenizer
9
+ from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaClassificationHead, RobertaPreTrainedModel, RobertaModel, RobertaLMHead
10
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel, Qwen2Model
11
+ from transformers.activations import gelu
12
+ from transformers.file_utils import (
13
+ add_code_sample_docstrings,
14
+ add_start_docstrings,
15
+ add_start_docstrings_to_model_forward,
16
+ replace_return_docstrings,
17
+ )
18
+ from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
19
+
20
+ class MLPLayer(nn.Module):
21
+ """
22
+ Head for getting sentence representations over RoBERTa/BERT's CLS representation.
23
+ """
24
+
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
28
+ self.activation = nn.Tanh()
29
+
30
+ def forward(self, features, **kwargs):
31
+ x = self.dense(features)
32
+ x = self.activation(x)
33
+
34
+ return x
35
 
36
  class ResidualBlock(nn.Module):
37
  def __init__(self, dim):
 
98
  # x = self.dropout(x)
99
  # x = self.out_proj(x)
100
  return x
101
+
102
+
103
+ class QueryHead(nn.Module):
104
+ def __init__(self, hidden_size):
105
+ super(QueryHead, self).__init__()
106
+ # Learnable query vector
107
+ self.query = nn.Parameter(torch.randn(hidden_size))
108
+
109
+ def forward(self, hidden_states, attention_mask=None):
110
+ """
111
+ Args:
112
+ hidden_states: Tensor of shape (batch_size, seq_length, hidden_size)
113
+ attention_mask: Tensor of shape (batch_size, seq_length) with 1 for real tokens and 0 for padding tokens.
114
+ Returns:
115
+ sequence_embedding: Tensor of shape (batch_size, hidden_size)
116
+ """
117
+ # Compute raw attention scores
118
+ attention_scores = torch.matmul(hidden_states, self.query) # (batch_size, seq_length)
119
+
120
+ # Apply attention mask (set padding positions to large negative value before softmax)
121
+ if attention_mask is not None:
122
+ attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e4)
123
+
124
+ # Normalize attention scores
125
+ attention_weights = F.softmax(attention_scores, dim=1) # (batch_size, seq_length)
126
+
127
+ # Aggregate hidden states
128
+ sequence_embedding = torch.matmul(attention_weights.unsqueeze(1), hidden_states).squeeze(1) # (batch_size, hidden_size)
129
+
130
+ return sequence_embedding
131
+
132
+
133
+ class AttentionPooling(nn.Module):
134
+ def __init__(self, hidden_dim):
135
+ super().__init__()
136
+ self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) # Key matrix W_K
137
+ self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) # Value matrix W_V
138
+ self.query = nn.Parameter(torch.randn(hidden_dim)) # Learnable query vector
139
+
140
+ def forward(self, x, attention_mask=None):
141
+ """
142
+ Args:
143
+ x: Tensor of shape (B, L, H), the last hidden layer output.
144
+ attention_mask: Tensor of shape (B, L) with 1 for real tokens and 0 for padding tokens.
145
+ Returns:
146
+ pooled_output: Tensor of shape (B, H), the pooled sequence embedding.
147
+ """
148
+ K = self.key_proj(x) # (B, L, H)
149
+ V = self.value_proj(x) # (B, L, H)
150
+
151
+ # Compute attention scores
152
+ attn_scores = torch.matmul(K, self.query) / (K.shape[-1] ** 0.5) # (B, L)
153
+
154
+ # Apply attention mask (set padding tokens to large negative value)
155
+ if attention_mask is not None:
156
+ attn_scores = attn_scores.masked_fill(attention_mask == 0, -1e4)
157
+
158
+ attn_weights = F.softmax(attn_scores, dim=1) # (B, L)
159
+
160
+ # Weighted sum of values
161
+ pooled_output = torch.matmul(attn_weights.unsqueeze(1), V).squeeze(1) # (B, H)
162
+ # pooled_output = torch.sum(attn_weights.unsqueeze(-1) * V, dim=1) # (B, H)
163
+
164
+ return pooled_output
165
 
166
  def cl_init(cls, config):
167
  """
 
194
  output_attentions=None,
195
  output_hidden_states=None,
196
  return_dict=None,
197
+ mlm_input_ids=None,
198
+ mlm_labels=None,
199
  latter_sentiment_spoof_mask=None,
200
  ):
201
  return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
 
204
  # original + cls.model_args.num_paraphrased + cls.model_args.num_negative
205
  num_sent = input_ids.size(1)
206
 
207
+ # # input_ids: (bs, num_sent, len)
208
+ # # random downsample one paraphrased sentence from sentences index in [1, cls.model_args.num_paraphrased-1]
209
+ # # randomly generate one index from [1, cls.model_args.num_paraphrased-1]
210
+ # # exclude tensor [:, index, :] from input_ids
211
+ # paraphrased_idx = torch.randint(1, cls.model_args.num_paraphrased, (batch_size,))
212
+ # mask = torch.ones_like(input_ids, dtype=torch.bool)
213
+ # for i in range(batch_size):
214
+ # mask[i, paraphrased_idx[i], :] = False
215
+ # input_ids = input_ids[mask].view(batch_size, num_sent - 1, -1)
216
+ # attention_mask = attention_mask[mask].view(batch_size, num_sent - 1, -1)
217
+ # num_paraphrased = cls.model_args.num_paraphrased - 1
218
+ # num_sent -= 1
219
+ # if token_type_ids is not None:
220
+ # token_type_ids = token_type_ids[mask].view(batch_size, num_sent - 1, -1)
221
+
222
+ mlm_outputs = None
223
  # Flatten input for encoding
224
  input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
225
  attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
226
  if token_type_ids is not None:
227
  token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
228
 
229
+ if 'roberta' in cls.model_args.model_name_or_path:
230
+ # Get raw embeddings
231
+ outputs = cls.roberta(
232
+ input_ids,
233
+ attention_mask=attention_mask,
234
+ token_type_ids=token_type_ids,
235
+ position_ids=position_ids,
236
+ head_mask=head_mask,
237
+ inputs_embeds=inputs_embeds,
238
+ output_attentions=output_attentions,
239
+ output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
240
+ return_dict=True,
241
+ )
242
+
243
+ # MLM auxiliary objective
244
+ if mlm_input_ids is not None:
245
+ mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1)))
246
+ mlm_outputs = cls.roberta(
247
+ mlm_input_ids,
248
+ attention_mask=attention_mask,
249
+ token_type_ids=token_type_ids,
250
+ position_ids=position_ids,
251
+ head_mask=head_mask,
252
+ inputs_embeds=inputs_embeds,
253
+ output_attentions=output_attentions,
254
+ output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
255
+ return_dict=True,
256
+ )
257
 
258
+ # Pooling
259
+ sequence_output = outputs[0] # (bs*num_sent, seq_len, hidden)
260
+ pooler_output = cls.classifier(sequence_output) # (bs*num_sent, hidden)
261
+ pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
262
+
263
+ elif 'qwen2' in cls.model_args.model_name_or_path.lower():
264
+ def last_token_pool(last_hidden_states: Tensor,
265
+ attention_mask: Tensor) -> Tensor:
266
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
267
+ if left_padding:
268
+ return last_hidden_states[:, -1]
269
+ else:
270
+ sequence_lengths = attention_mask.sum(dim=1) - 1
271
+ batch_size = last_hidden_states.shape[0]
272
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
273
+
274
+ outputs = cls.model(
275
+ input_ids,
276
+ attention_mask=attention_mask,
277
+ token_type_ids=token_type_ids,
278
+ position_ids=position_ids,
279
+ head_mask=head_mask,
280
+ inputs_embeds=inputs_embeds,
281
+ output_attentions=output_attentions,
282
+ output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
283
+ return_dict=True,
284
+ )
285
+
286
+ if cls.model_args.pooler_type in ['query', 'attention']:
287
+ pooler_output = cls.pool(outputs.last_hidden_state, attention_mask)
288
+ elif cls.model_args.pooler_type == 'last':
289
+ pooler_output = last_token_pool(outputs.last_hidden_state, attention_mask)
290
+ else:
291
+ raise NotImplementedError
292
+ # normalize embeddings
293
+ pooler_output = F.normalize(pooler_output, p=2, dim=1)
294
+
295
+ pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden_states)
296
+ else:
297
+ raise NotImplementedError
298
 
299
  # Mapping
300
  pooler_output = cls.map(pooler_output) # (bs, num_sent, hidden_states)
 
310
  # Gather all embeddings if using distributed training
311
  if dist.is_initialized() and cls.training:
312
  raise NotImplementedError
313
+
314
+ # straight-through estimate sign function
315
+ def sign_ste(x):
316
+ x_nogradient = x.detach()
317
+ return x + x.sign() - x_nogradient
318
 
319
  # get sign value before calculating similarity
320
  original = torch.tanh(original * 1000)
 
325
  for cname, n in zip(spoofing_cnames, negative_list):
326
  negative_dict[cname] = n
327
 
328
+ # z1 = sign_ste(z1)
329
+ # z2_list = [sign_ste(z2) for z2 in z2_list]
330
+ # z3_list = [sign_ste(z3) for z3 in z3_list]
331
+
332
+ # Compute contrastive loss
333
+ if cls.model_args.cl_weight != 0:
334
+ negative_weight = cls.model_args.hard_negative_weight
335
+ ori_ori_cos = cls.sim(original.unsqueeze(1), original.unsqueeze(0)) # (bs, bs)
336
+ ori_ori_cos_removed = remove_diagonal_elements(ori_ori_cos) # (bs, bs-1)
337
+ ori_para_cos_list = [cls.sim(original, p).unsqueeze(1) for p in paraphrase_list] # [(bs, 1)] * num_paraphrased
338
+ ori_neg_cos_list = [cls.sim(original, n).unsqueeze(1) for n in negative_list] # [(bs,1)] * num_negative
339
+ ori_neg_cos_dict = {}
340
+ for cname, n in zip(spoofing_cnames, ori_neg_cos_list):
341
+ ori_neg_cos_dict[cname] = n
342
+
343
+ loss_cl = 0
344
+ for i in range(batch_size):
345
+ ori = ori_ori_cos_removed[i].sum()
346
+ neg = 0
347
  for cname in spoofing_cnames:
348
  if cname == 'latter_sentiment_spoof_0' and latter_sentiment_spoof_mask[i] == 0:
349
  continue
350
+ neg += ori_neg_cos_dict[cname][i]
351
+ for j in range(cls.model_args.num_paraphrased):
352
+ pos = ori_para_cos_list[j][i]
353
+ denominator = ori + pos + negative_weight * neg
354
+ fraction = pos / (ori + pos + negative_weight * neg)
355
+ loss_cl -= torch.log(fraction)
356
+ loss_cl /= (batch_size * cls.model_args.num_paraphrased)
357
 
358
+ # Calculate triplet loss
359
+ if cls.model_args.tl_weight != 0:
360
+ loss_triplet = 0
361
+ for i in range(batch_size):
362
+ for j in range(cls.model_args.num_paraphrased):
363
+ for cname in spoofing_cnames:
364
+ if cname == 'latter_sentiment_spoof_0' and latter_sentiment_spoof_mask[i] == 0:
365
+ continue
366
+ ori = original[i]
367
+ pos = paraphrase_list[j][i]
368
+ neg = negative_dict[cname][i]
369
+ loss_triplet += F.relu(cls.sim(ori, neg) * cls.model_args.temp - cls.sim(ori, pos) * cls.model_args.temp + cls.model_args.margin)
370
+ loss_triplet /= (batch_size * cls.model_args.num_paraphrased * len(spoofing_cnames))
371
+
372
+ # Calculate loss for MLM
373
+ if mlm_outputs is not None and mlm_labels is not None:
374
+ raise NotImplementedError
375
+ # mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
376
+ # prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)
377
+ # masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))
378
+ # loss_cl = loss_cl + cls.model_args.mlm_weight * masked_lm_loss
379
+
380
  # Calculate loss for uniform perturbation and unbiased token preference
381
  def sign_loss(x):
382
+ # smooth_sign = sign_ste(x)
383
  row = torch.abs(torch.mean(torch.mean(x, dim=0)))
384
  col = torch.abs(torch.mean(torch.mean(x, dim=1)))
385
  return (row + col)/2
 
390
  loss_3_list = [cls.sim(original, p).unsqueeze(1) for p in paraphrase_list] # [(bs, 1)] * num_paraphrased
391
  loss_3_tensor = torch.cat(loss_3_list, dim=1) # (bs, num_paraphrased)
392
  loss_3 = loss_3_tensor.mean() * cls.model_args.temp
393
+ # debug:
394
+ # loss_3 = loss_3[valid_for_loss3.bool()]
395
 
396
  # calculate loss_sent: similarity between original and sentiment spoofed text
397
  negative_sample_loss = {}
 
409
  ori_ori_cos_removed = remove_diagonal_elements(ori_ori_cos) # (bs, bs-1)
410
  loss_5 = ori_ori_cos_removed.mean() * cls.model_args.temp
411
 
412
+ if cls.model_args.cl_weight != 0 and cls.model_args.tl_weight != 0:
413
+ loss = loss_gr + cls.model_args.cl_weight * loss_cl + cls.model_args.tl_weight * loss_triplet
414
+ elif cls.model_args.cl_weight != 0 and cls.model_args.tl_weight == 0:
415
+ loss = loss_gr + cls.model_args.cl_weight * loss_cl
416
+ elif cls.model_args.cl_weight == 0 and cls.model_args.tl_weight != 0:
417
+ loss = loss_gr + cls.model_args.tl_weight * loss_triplet
418
+ else:
419
+ raise ValueError("Both contrastive loss and triplet loss weights are zero.")
420
 
421
  result = {
422
  'loss': loss,
 
431
  key = f"sim_{cname.replace('_spoof_0', '')}"
432
  result[key] = l
433
 
434
+ if cls.model_args.cl_weight != 0:
435
+ result['loss_cl'] = loss_cl
436
+ if cls.model_args.tl_weight != 0:
437
+ result['loss_tl'] = loss_triplet
438
 
439
  if not return_dict:
440
  raise NotImplementedError
 
455
  output_attentions=None,
456
  output_hidden_states=None,
457
  return_dict=None,
458
+ lambda_1=1.0,
459
+ lambda_2=1.0,
460
  ):
461
 
462
  return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
463
 
464
+ if 'roberta' in cls.model_args.model_name_or_path:
465
+ outputs = cls.roberta(
466
+ input_ids,
467
+ attention_mask=attention_mask,
468
+ token_type_ids=token_type_ids,
469
+ position_ids=position_ids,
470
+ head_mask=head_mask,
471
+ inputs_embeds=inputs_embeds,
472
+ output_attentions=output_attentions,
473
+ output_hidden_states=False,
474
+ return_dict=True,
475
+ )
476
+ sequence_output = outputs[0]
477
+ pooler_output = cls.classifier(sequence_output)
478
+ elif 'qwen2' in cls.model_args.model_name_or_path.lower():
479
+ def last_token_pool(last_hidden_states: Tensor,
480
+ attention_mask: Tensor) -> Tensor:
481
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
482
+ if left_padding:
483
+ return last_hidden_states[:, -1]
484
+ else:
485
+ sequence_lengths = attention_mask.sum(dim=1) - 1
486
+ batch_size = last_hidden_states.shape[0]
487
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
488
+
489
+ outputs = cls.model(
490
+ input_ids,
491
+ attention_mask=attention_mask,
492
+ token_type_ids=token_type_ids,
493
+ position_ids=position_ids,
494
+ head_mask=head_mask,
495
+ inputs_embeds=inputs_embeds,
496
+ output_attentions=output_attentions,
497
+ output_hidden_states=True,
498
+ return_dict=True,
499
+ )
500
+
501
+ if cls.model_args.pooler_type in ['query', 'attention']:
502
+ pooler_output = cls.pool(outputs.last_hidden_state, attention_mask)
503
+ elif cls.model_args.pooler_type == 'last':
504
+ pooler_output = last_token_pool(outputs.last_hidden_state, attention_mask)
505
+ else:
506
+ raise NotImplementedError
507
+ # normalize embeddings
508
+ pooler_output = F.normalize(pooler_output, p=2, dim=1)
509
+ else:
510
+ raise NotImplementedError
511
+
512
 
513
  # Mapping
514
  mapping_output = cls.map(pooler_output)
 
530
 
531
  def __init__(self, config, *model_args, **model_kargs):
532
  super().__init__(config)
533
+ self.model_args = model_kargs["model_args"]
534
 
535
  self.classifier = RobertaClassificationHeadForEmbedding(config)
536
 
537
+ if self.model_args.do_mlm:
538
+ self.lm_head = RobertaLMHead(config)
539
 
540
  self.map = SemanticModel(input_dim=768)
541
+ cl_init(self, config)
542
+
543
+ if self.model_args.freeze_base:
544
+ # Freeze RoBERTa encoder parameters
545
+ for param in self.roberta.parameters():
546
+ param.requires_grad = False
547
+ for param in self.classifier.parameters():
548
+ param.requires_grad = False
549
+
550
  # Initialize weights and apply final processing
551
  self.post_init()
552
 
553
+ def initialize_mlp_weights(self, pretrained_model_state_dict):
554
+ """
555
+ Initialize MLP weights using the pretrained classifier's weights.
556
+ """
557
+ self.mlp.dense.weight.data = pretrained_model_state_dict.classifier.dense.weight.data.clone()
558
+ self.mlp.dense.bias.data = pretrained_model_state_dict.classifier.dense.bias.data.clone()
559
+
560
+ def forward(self,
561
+ input_ids=None,
562
+ attention_mask=None,
563
+ token_type_ids=None,
564
+ position_ids=None,
565
+ head_mask=None,
566
+ inputs_embeds=None,
567
+ labels=None,
568
+ output_attentions=None,
569
+ output_hidden_states=None,
570
+ return_dict=None,
571
+ sent_emb=False,
572
+ mlm_input_ids=None,
573
+ mlm_labels=None,
574
+ latter_sentiment_spoof_mask=None,
575
+ ):
576
+ if sent_emb:
577
+ return sentemb_forward(self,
578
+ input_ids=input_ids,
579
+ attention_mask=attention_mask,
580
+ token_type_ids=token_type_ids,
581
+ position_ids=position_ids,
582
+ head_mask=head_mask,
583
+ inputs_embeds=inputs_embeds,
584
+ labels=labels,
585
+ output_attentions=output_attentions,
586
+ output_hidden_states=output_hidden_states,
587
+ return_dict=return_dict,
588
+ )
589
+ else:
590
+ return cl_forward(self,
591
+ input_ids=input_ids,
592
+ attention_mask=attention_mask,
593
+ token_type_ids=token_type_ids,
594
+ position_ids=position_ids,
595
+ head_mask=head_mask,
596
+ inputs_embeds=inputs_embeds,
597
+ labels=labels,
598
+ output_attentions=output_attentions,
599
+ output_hidden_states=output_hidden_states,
600
+ return_dict=return_dict,
601
+ mlm_input_ids=mlm_input_ids,
602
+ mlm_labels=mlm_labels,
603
+ latter_sentiment_spoof_mask=latter_sentiment_spoof_mask,
604
+ )
605
+
606
+ class Qwen2ForCL(Qwen2PreTrainedModel):
607
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
608
+
609
+ def __init__(self, config, *model_args, **model_kargs):
610
+ super().__init__(config)
611
+ self.model_args = model_kargs["model_args"]
612
+ self.model = Qwen2Model(config)
613
+
614
+ if self.model_args.pooler_type == 'query':
615
+ self.pool = QueryHead(config.hidden_size)
616
+ elif self.model_args.pooler_type == 'attention':
617
+ self.pool = AttentionPooling(config.hidden_size)
618
+
619
+ # if self.model_args.do_mlm:
620
+ # self.lm_head = RobertaLMHead(config)
621
+
622
+ cl_init(self, config)
623
+ self.map = SemanticModel(input_dim=1536)
624
+
625
+ if self.model_args.freeze_base:
626
+ # Freeze Qwen parameters
627
+ for param in self.model.parameters():
628
+ param.requires_grad = False
629
+
630
  def forward(self,
631
  input_ids=None,
632
  attention_mask=None,
 
639
  output_hidden_states=None,
640
  return_dict=None,
641
  sent_emb=False,
642
+ mlm_input_ids=None,
643
+ mlm_labels=None,
644
  latter_sentiment_spoof_mask=None,
645
  ):
646
  if sent_emb:
 
668
  output_attentions=output_attentions,
669
  output_hidden_states=output_hidden_states,
670
  return_dict=return_dict,
671
+ mlm_input_ids=mlm_input_ids,
672
+ mlm_labels=mlm_labels,
673
  latter_sentiment_spoof_mask=latter_sentiment_spoof_mask,
674
  )
675