MAS-AI-0000 commited on
Commit
1b51bf6
·
verified ·
1 Parent(s): 2d84a53

Upload 2 files

Browse files
detree/model/simclr.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .text_embedding import TextEmbeddingModel
6
+
7
+ class Tree():
8
+ def __init__(self,path):
9
+ self.name = {}
10
+ self.childs = {}
11
+ self.father = {}
12
+ self.dep = {}
13
+ self.root = None
14
+ self.max_dep = 0
15
+ self.subtree = {}
16
+ self.grad_fa = {} # the node closest to the root for each leaf
17
+ with open(path, 'r') as f:
18
+ lines = f.readlines()
19
+ for line in lines:
20
+ parts = line.strip().split()
21
+ assert len(parts) == 3, "Each line must have exactly three parts"
22
+
23
+ now,fa,name = parts
24
+ now,fa = int(now),int(fa)
25
+ if name != 'none':
26
+ self.name[now] = name.split(',')
27
+ if fa != -1:
28
+ self.childs[fa] = self.childs.get(fa, []) + [now]
29
+ self.father[now] = fa
30
+ else:
31
+ self.root = now
32
+ self.fa_pos = torch.zeros((len(self.father),len(self.father)),dtype=torch.bool)
33
+
34
+ self.dfs(self.root)
35
+ #max_dep,N,N+K 0/1
36
+ self.pos_down2up = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
37
+ self.neg_down2up = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
38
+
39
+ self.pos_up2down = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
40
+ self.neg_up2down = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
41
+
42
+ self.pos_center = torch.zeros((self.max_dep,len(self.name)),dtype=torch.long)
43
+ self.mask_center = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
44
+
45
+ #max_dep,N 0/1
46
+ self.mask = torch.zeros((self.max_dep,len(self.name)),dtype=torch.bool)
47
+ self.depth = torch.zeros(len(self.name))
48
+ self.labels = torch.zeros(len(self.name),dtype=torch.long)
49
+ self.vis_leaf()
50
+ label_value = list(set(self.grad_fa.values()))
51
+ for key, value in self.grad_fa.items():
52
+ self.labels[key] = label_value.index(value)
53
+
54
+ def dfs(self, node, depth=0,grfa=-1):
55
+ self.dep[node] = depth
56
+ self.max_dep = max(self.max_dep, depth)
57
+ if node!=self.root:
58
+ self.subtree[node] = torch.zeros(len(self.father),dtype=torch.bool)
59
+ self.subtree[node][node] = 1
60
+
61
+ # if self.fa_pos.get(node) is None:
62
+ if self.father[node] != self.root:
63
+ self.fa_pos[node] = self.fa_pos[self.father[node]].clone()
64
+ self.fa_pos[node][node] = 1
65
+ if grfa == -1:
66
+ grfa = node
67
+ if self.childs.get(node) is None:
68
+ self.grad_fa[node] = grfa
69
+ for child in self.childs.get(node, []):
70
+ self.dfs(child, depth + 1,grfa)
71
+ if node!=self.root:
72
+ self.subtree[node] = torch.logical_or(self.subtree[node], self.subtree[child])
73
+
74
+ def gen_leaf_item(self,node):
75
+ last_node = -1
76
+ leaf_id = node
77
+ self.depth[node] = self.dep[node]
78
+ while node != self.root:
79
+ now_dep=self.dep[node]-1
80
+ self.mask[now_dep,leaf_id] = 1
81
+ self.pos_center[now_dep,leaf_id] = node
82
+ self.mask_center[now_dep,leaf_id] = torch.logical_not(torch.logical_or(self.fa_pos[node],self.subtree[node]))
83
+ self.mask_center[now_dep,leaf_id,node] = 1
84
+ if last_node == -1:
85
+ self.pos_down2up[now_dep,leaf_id] = self.subtree[node]
86
+ else:
87
+ self.pos_down2up[now_dep,leaf_id]=torch.logical_xor(self.subtree[node],self.subtree[last_node])
88
+ self.neg_down2up[now_dep,leaf_id]=torch.logical_not(self.subtree[node])
89
+
90
+ if self.father[node] == self.root:
91
+ self.neg_up2down[now_dep,leaf_id] = torch.logical_not(self.subtree[node])
92
+ else:
93
+ self.neg_up2down[now_dep,leaf_id] = torch.logical_xor(self.subtree[node],self.subtree[self.father[node]])
94
+ self.pos_up2down[now_dep,leaf_id] = self.subtree[node]
95
+
96
+ last_node = node
97
+ node = self.father[node]
98
+
99
+ def vis_leaf(self):
100
+ for node, name in self.name.items():
101
+ self.gen_leaf_item(node)
102
+
103
+
104
+ def display(self):
105
+ for node, name in self.name.items():
106
+ depth = self.dep[node]
107
+ print(f"{depth}- {name} {self.father[node]}")
108
+
109
+ class SimCLR_Tree(nn.Module):
110
+ def __init__(self, opt, fabric):
111
+ super(SimCLR_Tree, self).__init__()
112
+
113
+ self.temperature = opt.temperature
114
+ self.opt = opt
115
+ self.fabric = fabric
116
+
117
+ adapter_path = getattr(opt, "adapter_path", None)
118
+ self.model = TextEmbeddingModel(
119
+ opt.model_name,
120
+ lora=opt.lora,
121
+ use_pooling=opt.pooling,
122
+ lora_r=opt.lora_r,
123
+ lora_alpha=opt.lora_alpha,
124
+ lora_dropout=opt.lora_dropout,
125
+ adapter_path=adapter_path,
126
+ )
127
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ self.tree = Tree(opt.tree_txt)
129
+
130
+ self.pos_down2up = self.tree.pos_down2up.to(self.device)
131
+ self.neg_down2up = self.tree.neg_down2up.to(self.device)
132
+ self.pos_up2down = self.tree.pos_up2down.to(self.device)
133
+ self.neg_up2down = self.tree.neg_up2down.to(self.device)
134
+ self.pos_center = self.tree.pos_center.to(self.device)
135
+ self.mask_center = self.tree.mask_center.to(self.device)
136
+
137
+ self.K = self.pos_down2up.shape[0]
138
+
139
+ self.mask = self.tree.mask.to(self.device)
140
+ self.depth = self.tree.depth.to(self.device)
141
+ self.root_labels = self.tree.labels.to(self.device)
142
+ self.esp = torch.tensor(1e-6, device=self.device)
143
+ self.max_dep = self.tree.max_dep
144
+ self.leaf_cnt = len(self.tree.name)
145
+
146
+ self.names2id = {}
147
+ for key, value in self.tree.name.items():
148
+ for item in value:
149
+ self.names2id[item] = key
150
+
151
+ self.vitual_center = nn.Parameter(
152
+ torch.randn((len(self.tree.father), opt.projection_size), device=self.device),
153
+ requires_grad=True,
154
+ )
155
+ nn.init.xavier_uniform_(self.vitual_center)
156
+ self.center_labels = torch.arange(len(self.tree.father), dtype=torch.long, device=self.device)
157
+ if adapter_path is not None:
158
+ self.load_tree_state(adapter_path)
159
+
160
+
161
+ def get_encoder(self):
162
+ return self.model
163
+
164
+ def save_pretrained(self, save_directory: str, save_tokenizer: bool = True):
165
+ os.makedirs(save_directory, exist_ok=True)
166
+ self.model.save_pretrained(save_directory, save_tokenizer=save_tokenizer)
167
+ torch.save(
168
+ {"vitual_center": self.vitual_center.detach().cpu()},
169
+ os.path.join(save_directory, "tree_state.pt"),
170
+ )
171
+
172
+ def load_tree_state(self, directory: str):
173
+ state_path = os.path.join(directory, "tree_state.pt")
174
+ if not os.path.exists(state_path):
175
+ return
176
+ state = torch.load(state_path, map_location=self.vitual_center.device)
177
+ self.vitual_center.data.copy_(state["vitual_center"].to(self.vitual_center.device))
178
+
179
+ def load_from_directory(self, directory: str, is_trainable: bool = True):
180
+ if getattr(self.opt, "lora", False):
181
+ self.model.load_adapter(directory, is_trainable=is_trainable)
182
+ else:
183
+ self.model = TextEmbeddingModel(
184
+ directory,
185
+ lora=False,
186
+ use_pooling=self.opt.pooling,
187
+ output_hidden_states=False,
188
+ )
189
+ self.load_tree_state(directory)
190
+
191
+ def _compute_logits(self, q,q_labels,k,k_labels,pos_mask,neg_mask):
192
+ def cosine_similarity_matrix(q, k):
193
+ q_norm = F.normalize(q,dim=-1)
194
+ k_norm = F.normalize(k,dim=-1)
195
+ cosine_similarity = q_norm@k_norm.T
196
+
197
+ return cosine_similarity
198
+
199
+ def gen_label_mask(relation_matrix,q_labels, k_labels):
200
+
201
+ N1 = q_labels.shape[0]
202
+ N2 = k_labels.shape[0]
203
+
204
+ q_labels_expanded = q_labels.unsqueeze(1).expand(-1, N2) # N1 x N2
205
+ k_labels_expanded = k_labels.unsqueeze(0).expand(N1, -1) # N1 x N2
206
+
207
+ result_matrix = relation_matrix[:,q_labels_expanded, k_labels_expanded]
208
+
209
+ return result_matrix
210
+
211
+ logits=cosine_similarity_matrix(q,k)
212
+ logits=logits/self.temperature
213
+ logits = logits.unsqueeze(0).expand(self.K,-1,-1) #K,N1,N2
214
+
215
+ pos_mask = gen_label_mask(pos_mask,q_labels, k_labels)
216
+ neg_mask = gen_label_mask(neg_mask,q_labels, k_labels) #K,N1,N2
217
+
218
+ pos_logits = torch.sum(logits*pos_mask,dim=-1)/torch.max(torch.sum(pos_mask,dim=-1),self.esp)#K,N1
219
+ pos_logits = pos_logits.unsqueeze(-1)#K,N1,1
220
+ neg_logits = logits*neg_mask#K,N1,N2
221
+
222
+ logits = torch.cat((pos_logits, neg_logits), dim=-1)#K,N1,N2+1
223
+
224
+ #model:model set
225
+ # pos_logits_model = torch.sum(logits*same_model,dim=1)/torch.max(torch.sum(same_model,dim=1),self.esp)# N
226
+ # neg_logits_model=logits*torch.logical_not(same_model)# N,N+K
227
+ # logits_model=torch.cat((pos_logits_model.unsqueeze(1), neg_logits_model), dim=1)
228
+
229
+ return logits
230
+
231
+ def forward(self, encoded_batch, labels):
232
+ q = self.model(encoded_batch)
233
+ N1 = q.shape[0]
234
+ k = q.clone().detach()
235
+ k = self.fabric.all_gather(k).view(-1, k.size(1))
236
+ k_labels = self.fabric.all_gather(labels).view(-1)
237
+
238
+ now_depth = self.depth[labels].unsqueeze(0).expand(self.K,-1)
239
+ now_mask = self.mask[:,labels]
240
+ # leaf_labels = self.root_labels[labels]
241
+
242
+ k = torch.concat((k,self.vitual_center),dim=0)
243
+ k_labels = torch.concat((k_labels,self.center_labels),dim=0)
244
+
245
+ logits_sample = self._compute_logits(q,labels,k,k_labels,self.pos_down2up,self.neg_down2up)#K,N1,N2+1
246
+ gt_sample = torch.zeros(logits_sample.shape[:-1], dtype=torch.long,device=logits_sample.device)
247
+ logits_sample = logits_sample.permute(0,2,1)
248
+ loss_smaple1 = F.cross_entropy(logits_sample, gt_sample, reduction='none') #K,N1
249
+ loss_smaple1 = torch.sum((loss_smaple1/now_depth)*now_mask)/N1*self.max_dep
250
+
251
+ # out = self.root_classfier(q)
252
+ # loss_classfiy = F.cross_entropy(out, leaf_labels)
253
+
254
+ loss = loss_smaple1
255
+
256
+ return loss,loss_smaple1
257
+
258
+ # def forward(self, encoded_batch, labels):
259
+ # q = self.model(encoded_batch)
260
+ # # N1 = q.shape[0]
261
+ # # k = q.clone().detach()
262
+ # # k = self.fabric.all_gather(k).view(-1, k.size(1))
263
+ # # k_labels = self.fabric.all_gather(labels).view(-1)
264
+
265
+ # # now_depth = self.depth[labels].unsqueeze(0).expand(self.K,-1)
266
+ # # now_mask = self.mask[:,labels]
267
+ # leaf_labels = self.root_labels[labels]
268
+
269
+ # # k = torch.concat((k,self.vitual_center),dim=0)
270
+ # # k_labels = torch.concat((k_labels,self.center_labels),dim=0)
271
+
272
+ # # logits_sample = self._compute_logits(q,labels,k,k_labels,self.pos_down2up,self.neg_down2up)#K,N1,N2+1
273
+ # # gt_sample = torch.zeros(logits_sample.shape[:-1], dtype=torch.long,device=logits_sample.device)
274
+ # # logits_sample = logits_sample.permute(0,2,1)
275
+ # # loss_smaple1 = F.cross_entropy(logits_sample, gt_sample, reduction='none') #K,N1
276
+ # # loss_smaple1 = torch.sum((loss_smaple1/now_depth)*now_mask)/N1*self.max_dep
277
+
278
+ # out = self.root_classfier(q)
279
+ # loss_classfiy = F.cross_entropy(out, leaf_labels)
280
+
281
+ # loss = loss_classfiy
282
+
283
+ # return loss,loss_classfiy
284
+
detree/model/text_embedding.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from peft import LoraConfig, TaskType, PeftModel, get_peft_model
6
+
7
+
8
+ class TextEmbeddingModel(nn.Module):
9
+ """Wrapper around a Hugging Face model with optional LoRA adapters."""
10
+
11
+ def __init__(
12
+ self,
13
+ model_name,
14
+ output_hidden_states=False,
15
+ lora=False,
16
+ infer=False,
17
+ use_pooling="average",
18
+ lora_r=128,
19
+ lora_alpha=256,
20
+ lora_dropout=0,
21
+ adapter_path=None,
22
+ ):
23
+ super(TextEmbeddingModel, self).__init__()
24
+ self.model_name = model_name
25
+ self.use_pooling = use_pooling
26
+ self.lora = lora
27
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+
29
+ model_kwargs = {"trust_remote_code": True}
30
+ if output_hidden_states:
31
+ model_kwargs["output_hidden_states"] = True
32
+ self.model = AutoModel.from_pretrained(model_name, **model_kwargs)
33
+
34
+ if self.lora:
35
+ peft_config = LoraConfig(
36
+ peft_type=TaskType.FEATURE_EXTRACTION,
37
+ inference_mode=infer,
38
+ r=lora_r,
39
+ lora_alpha=lora_alpha,
40
+ lora_dropout=lora_dropout,
41
+ )
42
+ self.model = get_peft_model(self.model, peft_config)
43
+ if adapter_path is not None:
44
+ self.load_adapter(adapter_path, is_trainable=not infer)
45
+ else:
46
+ self.model.print_trainable_parameters()
47
+ elif adapter_path is not None:
48
+ self.model = AutoModel.from_pretrained(adapter_path, **model_kwargs)
49
+
50
+ def pooling(self, model_output, attention_mask, hidden_states=False):
51
+ if hidden_states:
52
+ if self.use_pooling == "average":
53
+ model_output.masked_fill(~attention_mask[None, ..., None].bool(), 0.0)
54
+ emb = model_output.sum(dim=2) / attention_mask.sum(dim=1)[..., None]
55
+ elif self.use_pooling == "max":
56
+ emb = model_output.masked_fill(~attention_mask[None, ..., None].bool(), float("-inf"))
57
+ emb, _ = emb.max(dim=2)
58
+ elif self.use_pooling == "cls":
59
+ emb = model_output[:, :, 0]
60
+ else:
61
+ raise ValueError("Pooling method not supported")
62
+ emb = emb.permute(1, 0, 2)
63
+ else:
64
+ if self.use_pooling == "average":
65
+ model_output.masked_fill(~attention_mask[..., None].bool(), 0.0)
66
+ emb = model_output.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
67
+ elif self.use_pooling == "max":
68
+ emb = model_output.masked_fill(~attention_mask[..., None].bool(), float("-inf"))
69
+ emb, _ = emb.max(dim=1)
70
+ elif self.use_pooling == "cls":
71
+ emb = model_output[:, 0]
72
+ else:
73
+ raise ValueError("Pooling method not supported")
74
+ return emb
75
+
76
+ def forward(self, encoded_batch, hidden_states=False, retrun_all_emb=False):
77
+ if "t5" in self.model_name.lower():
78
+ input_ids = encoded_batch['input_ids']
79
+ decoder_input_ids = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
80
+ model_output = self.model(**encoded_batch,
81
+ decoder_input_ids=decoder_input_ids)
82
+ else:
83
+ model_output = self.model(**encoded_batch)
84
+
85
+
86
+ if isinstance(model_output, tuple):
87
+ model_output = model_output[0]
88
+ if isinstance(model_output, dict):
89
+ if hidden_states:
90
+ model_output = model_output["hidden_states"]
91
+ model_output = torch.stack(model_output, dim=0)
92
+ else:
93
+ model_output = model_output["last_hidden_state"]
94
+
95
+ emb = self.pooling(model_output, encoded_batch['attention_mask'], hidden_states)
96
+ if retrun_all_emb:
97
+ return emb, model_output
98
+ return emb
99
+
100
+ def save_pretrained(self, save_directory: str, save_tokenizer: bool = True):
101
+ os.makedirs(save_directory, exist_ok=True)
102
+ if isinstance(self.model, PeftModel):
103
+ self.model.save_pretrained(save_directory)
104
+ else:
105
+ self.model.save_pretrained(save_directory)
106
+ if save_tokenizer:
107
+ self.tokenizer.save_pretrained(save_directory)
108
+
109
+ def load_adapter(self, adapter_path: str, is_trainable: bool = False):
110
+ if not self.lora or not isinstance(self.model, PeftModel):
111
+ raise ValueError("LoRA is not enabled for this model instance.")
112
+ self.model = PeftModel.from_pretrained(
113
+ self.model.base_model,
114
+ adapter_path,
115
+ is_trainable=is_trainable,
116
+ )
117
+ self.model.print_trainable_parameters()
118
+
119
+ def merge_and_unload(self):
120
+ if not isinstance(self.model, PeftModel):
121
+ raise ValueError("The current model does not contain a LoRA adapter to merge.")
122
+ merged_model = self.model.merge_and_unload()
123
+ return merged_model
124
+
125
+
126
+ class ClassificationHead(nn.Module):
127
+ """Head for sentence-level classification tasks."""
128
+
129
+ def __init__(self, hidden_size,num_labels):
130
+ super().__init__()
131
+ self.dense = nn.Linear(hidden_size, hidden_size)
132
+ self.out_proj = nn.Linear(hidden_size, num_labels)
133
+
134
+ def forward(self, x):
135
+ x = self.dense(x)
136
+ x = torch.tanh(x)
137
+ x = self.out_proj(x)
138
+ return x
139
+
140
+ class TextClassificationModel(nn.Module):
141
+ def __init__(self, opt,dim=2):
142
+ super(TextClassificationModel, self).__init__()
143
+ self.model = TextEmbeddingModel(opt.model_name,lora=True,use_pooling=opt.pooling,\
144
+ lora_r=opt.lora_r,lora_alpha=opt.lora_alpha,infer=True)
145
+ self.root_classfier = nn.Linear(opt.embedding_dim, dim)
146
+
147
+ def forward(self, encoded_batch):
148
+ q = self.model(encoded_batch)
149
+ out = self.root_classfier(q)
150
+ return out
151
+
152
+