Khriis commited on
Commit
38ce35c
·
verified ·
1 Parent(s): 7e05c76

Upload cross_scorer_model.py

Browse files
Files changed (1) hide show
  1. cross_scorer_model.py +161 -0
cross_scorer_model.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pad_sequence
4
+ from transformers import BertTokenizer, BertModel
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
9
+ from transformers import BertForMaskedLM
10
+
11
+ import torch.nn.functional as F
12
+
13
+ import spacy
14
+ import transformers
15
+ import torch.nn as nn
16
+
17
+
18
+ class CrossScorerCrossEncoder(nn.Module):
19
+
20
+ def __init__(self, transformer):
21
+
22
+ super(CrossScorerCrossEncoder, self).__init__()
23
+
24
+ self.cross_encoder = transformer
25
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+
27
+
28
+ # Binary Head
29
+ self.l1 = torch.nn.Linear(768, 512)
30
+ self.relu = torch.nn.ELU()
31
+ self.l2 = torch.nn.Linear(512,1)
32
+
33
+ self.encoder_type = "cross"
34
+
35
+
36
+ def score_forward(
37
+ self,
38
+ input_ids=None,
39
+ attention_mask=None,
40
+ token_type_ids=None,
41
+ position_ids=None,
42
+ head_mask=None,
43
+ inputs_embeds=None,
44
+ encoder_hidden_states=None,
45
+ encoder_attention_mask=None,
46
+ labels=None,
47
+ output_attentions=None,
48
+ output_hidden_states=None,
49
+ return_dict=None,
50
+ return_attentions=False
51
+ ):
52
+
53
+
54
+ output = self.cross_encoder(
55
+ input_ids=input_ids,
56
+ attention_mask=attention_mask,
57
+ token_type_ids=token_type_ids,
58
+ position_ids=position_ids,
59
+ head_mask=head_mask,
60
+ inputs_embeds=inputs_embeds,
61
+ encoder_hidden_states=encoder_hidden_states,
62
+ encoder_attention_mask=encoder_attention_mask,
63
+ output_attentions=output_attentions,
64
+ output_hidden_states=output_hidden_states,
65
+ return_dict=return_dict,
66
+ )
67
+ pair_reps = output.last_hidden_state[:,0,:]
68
+ score = self.l2(self.relu(self.l1(pair_reps)))
69
+
70
+ if output_attentions and return_attentions:
71
+ return score.sigmoid().squeeze(), output.attentions
72
+
73
+ return score
74
+
75
+
76
+ def cl_loss(self, pair_scores, labels):
77
+ BSZ = pair_scores.size(0)
78
+ BSZ = int(BSZ/(4))
79
+
80
+ pair_scores= list(pair_scores.tensor_split(BSZ, dim=0) )
81
+ pair_scores = torch.stack(pair_scores)
82
+
83
+
84
+ gap_1_loss_fct = nn.MarginRankingLoss(margin=0.5)
85
+ gap_2_loss_fct = nn.MarginRankingLoss(margin=1.0)
86
+
87
+ mq_scores = pair_scores[:,1] # 1
88
+ lq_scores = pair_scores[:,2:-1] # 2
89
+
90
+
91
+ hq_scores = pair_scores[:,0]
92
+
93
+ hq_mq_loss = gap_1_loss_fct(
94
+ hq_scores.flatten(),
95
+ mq_scores.flatten(),
96
+ torch.ones(mq_scores.flatten().size()).to(self.device))
97
+ mq_lq_loss = gap_1_loss_fct(
98
+ mq_scores.repeat(1,lq_scores.size(-1)).flatten(),
99
+ lq_scores.flatten(),
100
+ torch.ones(lq_scores.flatten().size()).to(self.device))
101
+ hq_lq_loss = gap_2_loss_fct(
102
+ hq_scores.repeat(1,lq_scores.size(-1)).flatten(),
103
+ lq_scores.flatten(),
104
+ torch.ones(lq_scores.flatten().size()).to(self.device))
105
+
106
+ mismatch_scores = pair_scores[:,-1]
107
+ hq_mismatch_loss = gap_2_loss_fct(
108
+ hq_scores.flatten(),
109
+ mismatch_scores.flatten(),
110
+ torch.ones(mismatch_scores.flatten().size()).to(self.device))
111
+ mq_mismatch_loss = gap_1_loss_fct(
112
+ mq_scores.flatten(),
113
+ mismatch_scores.flatten(),
114
+ torch.ones(mismatch_scores.flatten().size()).to(self.device))
115
+ mismatch_loss = hq_mismatch_loss + mq_mismatch_loss
116
+
117
+ loss = hq_mq_loss + mq_lq_loss + hq_lq_loss + mismatch_loss
118
+ return loss
119
+
120
+ def forward(
121
+ self,
122
+ input_ids=None,
123
+ attention_mask=None,
124
+ token_type_ids=None,
125
+ position_ids=None,
126
+ head_mask=None,
127
+ inputs_embeds=None,
128
+ encoder_hidden_states=None,
129
+ encoder_attention_mask=None,
130
+ labels=None,
131
+ output_attentions=None,
132
+ output_hidden_states=None,
133
+ return_dict=None,
134
+ random = False
135
+ ):
136
+
137
+ pair_scores = self.score_forward(
138
+ input_ids=input_ids,
139
+ attention_mask=attention_mask,
140
+ token_type_ids=token_type_ids,
141
+ position_ids=position_ids,
142
+ head_mask=head_mask,
143
+ inputs_embeds=inputs_embeds,
144
+ encoder_hidden_states=encoder_hidden_states,
145
+ encoder_attention_mask=encoder_attention_mask,
146
+ labels=labels,
147
+ output_attentions=output_attentions,
148
+ output_hidden_states=output_hidden_states,
149
+ return_dict=return_dict,
150
+ ).sigmoid().squeeze()
151
+
152
+
153
+
154
+ cl_loss = self.cl_loss(pair_scores, labels)
155
+
156
+ loss = cl_loss
157
+ return SequenceClassifierOutput(
158
+ loss=loss,
159
+ logits=pair_scores,
160
+ )
161
+