Tingxie commited on
Commit
0087cda
·
verified ·
1 Parent(s): 107d140

Update model_finetune.py

Browse files
Files changed (1) hide show
  1. model_finetune.py +121 -83
model_finetune.py CHANGED
@@ -1,83 +1,121 @@
1
-
2
-
3
- from modular_curei import CureiModel, CureiConfig
4
-
5
- import torch
6
- from torch import nn
7
- import torch.nn.functional as F
8
-
9
- class ESA(nn.Module):
10
- def __init__(self, feature_dim, out_dim):
11
- super().__init__()
12
- self.ln_f = nn.LayerNorm(feature_dim)
13
- self.linear = nn.Linear(feature_dim, out_dim)
14
- self.linear1 = nn.Linear(out_dim, out_dim)
15
-
16
- def forward(self, hidden_states,attention_mask):
17
- logits = self.ln_f(hidden_states) # (B, N, C)
18
- cap_embes = self.linear(logits) # Q
19
- features_in = self.linear1(cap_embes) # M
20
- mask = attention_mask.unsqueeze(-1) # (B, N, 1)
21
- features_in = features_in.masked_fill(mask == 0, -1e4) # (B, N, C)
22
- features_k_softmax = nn.Softmax(dim=1)(features_in)
23
- attn = features_k_softmax.masked_fill(mask == 0, 0)
24
- aggr_feature = torch.sum(attn * cap_embes, dim=1) # (B, C)
25
- return aggr_feature
26
-
27
-
28
- class EIMSBERT(nn.Module):
29
- def __init__(self,
30
- config = None,
31
- embed_dim = 256,
32
- temperature = 0.07,
33
- ):
34
- super().__init__()
35
-
36
- self.config= CureiConfig()
37
- self.text_encoder = CureiModel(config=self.config)
38
- self.esa = ESA(self.config.hidden_size, embed_dim)
39
- self.proj = nn.Linear(embed_dim,embed_dim)
40
- self.temperature = temperature
41
-
42
- def info_nce_loss(self, features1, features2):
43
- batch_size = features1.shape[0]
44
- device = features1.device
45
-
46
- features1 = F.normalize(features1, dim=1)
47
- features2 = F.normalize(features2, dim=1)
48
-
49
- similarity_matrix = torch.matmul(features1, features2.T)
50
- labels = torch.arange(batch_size).to(device)
51
-
52
- logits1 = similarity_matrix / self.temperature
53
- loss1 = F.cross_entropy(logits1, labels)
54
-
55
- logits2 = similarity_matrix.T / self.temperature
56
- loss2 = F.cross_entropy(logits2, labels)
57
-
58
- return (loss1 + loss2) / 2
59
-
60
- def forward(self,input_ids,attention_masks,intens_tensors,num_peaks,input_ids_pre,attention_masks_pre,intens_tensors_pre,num_peaks_pre):
61
-
62
-
63
- output = self.text_encoder(input_ids=input_ids,
64
- intensities=intens_tensors,
65
- attention_mask = attention_masks,
66
- return_dict = True,
67
- )
68
- output_pre = self.text_encoder(input_ids=input_ids_pre,
69
- intensities=intens_tensors_pre,
70
- attention_mask = attention_masks_pre,
71
- return_dict = True,
72
- )
73
- output_feats = output.last_hidden_state
74
- output_pre_feats = output_pre.last_hidden_state
75
-
76
- output_aggr_feats = self.esa(output_feats,attention_masks)
77
- output_pre_aggr_feats = self.esa(output_pre_feats,attention_masks_pre)
78
- output_aggr_feats = self.proj(output_aggr_feats)
79
- output_pre_aggr_feats = self.proj(output_pre_aggr_feats)
80
- loss = self.info_nce_loss(output_aggr_feats, output_pre_aggr_feats)
81
-
82
- return output_aggr_feats,output_pre_aggr_feats,loss
83
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from modular_csuep import CsuepModel, CsuepConfig
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from transformers.activations import ACT2FN
8
+
9
+ class MeanPooling(nn.Module):
10
+ def __init__(self, feature_dim, out_dim):
11
+ super().__init__()
12
+ self.ln_f = nn.LayerNorm(feature_dim)
13
+ self.linear = nn.Linear(feature_dim, out_dim)
14
+
15
+ def forward(self, hidden_states, attention_mask):
16
+ # hidden_states: (B, N, C_in)
17
+ logits = self.ln_f(hidden_states)
18
+ features = self.linear(logits) # (B, N, C_out)
19
+
20
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(features.size()).to(features.dtype)
21
+ sum_embeddings = torch.sum(features * input_mask_expanded, dim=1)
22
+ sum_mask = input_mask_expanded.sum(dim=1)
23
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
24
+ mean_embeddings = sum_embeddings / sum_mask
25
+
26
+ return mean_embeddings
27
+
28
+ class MaxPooling(nn.Module):
29
+ def __init__(self, feature_dim, out_dim):
30
+ super().__init__()
31
+ self.ln_f = nn.LayerNorm(feature_dim)
32
+ self.linear = nn.Linear(feature_dim, out_dim)
33
+
34
+ def forward(self, hidden_states, attention_mask):
35
+ logits = self.ln_f(hidden_states)
36
+ features = self.linear(logits) # (B, N, C_out)
37
+ mask = attention_mask.unsqueeze(-1) # (B, N, 1)
38
+ min_value = torch.finfo(features.dtype).min
39
+ features = features.masked_fill(mask == 0, min_value)
40
+ aggr_feature = torch.max(features, dim=1)[0] # (B, C_out)
41
+ return aggr_feature
42
+
43
+ class ESA(nn.Module):
44
+ def __init__(self, feature_dim, out_dim):
45
+ super().__init__()
46
+ self.ln_f = nn.LayerNorm(feature_dim)
47
+ self.linear = nn.Linear(feature_dim, out_dim)
48
+ self.linear1 = nn.Linear(out_dim, out_dim)
49
+
50
+ def forward(self, hidden_states,attention_mask):
51
+ logits = self.ln_f(hidden_states) # (B, N, C)
52
+ cap_embes = self.linear(logits) # Q
53
+ features_in = self.linear1(cap_embes) # M
54
+ mask = attention_mask.unsqueeze(-1) # (B, N, 1)
55
+ features_in = features_in.masked_fill(mask == 0, -1e4) # (B, N, C)
56
+ features_k_softmax = nn.Softmax(dim=1)(features_in)
57
+ attn = features_k_softmax.masked_fill(mask == 0, 0)
58
+ aggr_feature = torch.sum(attn * cap_embes, dim=1) # (B, C)
59
+ return aggr_feature
60
+
61
+
62
+ class CSUEP_finetune(nn.Module):
63
+ def __init__(self,
64
+ config = None,
65
+ embed_dim = 768,
66
+ temperature = 0.07,
67
+ ):
68
+ super().__init__()
69
+
70
+ self.config= CsuepConfig()
71
+ self.text_encoder = CsuepModel(config=self.config)
72
+ #self.esa = ESA(self.config.hidden_size, embed_dim)
73
+ self.pooler = MeanPooling(self.config.hidden_size, embed_dim)
74
+ self.proj = nn.Linear(embed_dim,embed_dim)
75
+ self.temperature = temperature
76
+
77
+ def info_nce_loss(self, features1, features2):
78
+ batch_size = features1.shape[0]
79
+ device = features1.device
80
+
81
+ features1 = F.normalize(features1, dim=1)
82
+ features2 = F.normalize(features2, dim=1)
83
+
84
+ similarity_matrix = torch.matmul(features1, features2.T)
85
+ labels = torch.arange(batch_size).to(device)
86
+
87
+ logits1 = similarity_matrix / self.temperature
88
+ loss1 = F.cross_entropy(logits1, labels)
89
+
90
+ logits2 = similarity_matrix.T / self.temperature
91
+ loss2 = F.cross_entropy(logits2, labels)
92
+
93
+ return (loss1 + loss2) / 2
94
+
95
+ def forward(self,input_ids,attention_masks,intens_tensors,num_peaks,input_ids_pre,attention_masks_pre,intens_tensors_pre,num_peaks_pre):
96
+
97
+
98
+ output = self.text_encoder(input_ids=input_ids,
99
+ intensities=intens_tensors,
100
+ attention_mask = attention_masks,
101
+ return_dict = True,
102
+ )
103
+ output_pre = self.text_encoder(input_ids=input_ids_pre,
104
+ intensities=intens_tensors_pre,
105
+ attention_mask = attention_masks_pre,
106
+ return_dict = True,
107
+ )
108
+ output_feats = output.last_hidden_state
109
+ output_pre_feats = output_pre.last_hidden_state
110
+
111
+ output_aggr_feats = self.pooler(output_feats,attention_masks)
112
+ output_pre_aggr_feats = self.pooler(output_pre_feats,attention_masks_pre)
113
+
114
+ #output_aggr_feats = self.esa(output_feats,attention_masks)
115
+ #output_pre_aggr_feats = self.esa(output_pre_feats,attention_masks_pre)
116
+ output_aggr_feats = self.proj(output_aggr_feats)
117
+ output_pre_aggr_feats = self.proj(output_pre_aggr_feats)
118
+ loss = self.info_nce_loss(output_aggr_feats, output_pre_aggr_feats)
119
+
120
+ return output_aggr_feats,output_pre_aggr_feats,loss
121
+