Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .text_embedding import TextEmbeddingModel | |
| class Tree(): | |
| def __init__(self,path): | |
| self.name = {} | |
| self.childs = {} | |
| self.father = {} | |
| self.dep = {} | |
| self.root = None | |
| self.max_dep = 0 | |
| self.subtree = {} | |
| self.grad_fa = {} # the node closest to the root for each leaf | |
| with open(path, 'r') as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| parts = line.strip().split() | |
| assert len(parts) == 3, "Each line must have exactly three parts" | |
| now,fa,name = parts | |
| now,fa = int(now),int(fa) | |
| if name != 'none': | |
| self.name[now] = name.split(',') | |
| if fa != -1: | |
| self.childs[fa] = self.childs.get(fa, []) + [now] | |
| self.father[now] = fa | |
| else: | |
| self.root = now | |
| self.fa_pos = torch.zeros((len(self.father),len(self.father)),dtype=torch.bool) | |
| self.dfs(self.root) | |
| #max_dep,N,N+K 0/1 | |
| self.pos_down2up = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool) | |
| self.neg_down2up = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool) | |
| self.pos_up2down = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool) | |
| self.neg_up2down = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool) | |
| self.pos_center = torch.zeros((self.max_dep,len(self.name)),dtype=torch.long) | |
| self.mask_center = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool) | |
| #max_dep,N 0/1 | |
| self.mask = torch.zeros((self.max_dep,len(self.name)),dtype=torch.bool) | |
| self.depth = torch.zeros(len(self.name)) | |
| self.labels = torch.zeros(len(self.name),dtype=torch.long) | |
| self.vis_leaf() | |
| label_value = list(set(self.grad_fa.values())) | |
| for key, value in self.grad_fa.items(): | |
| self.labels[key] = label_value.index(value) | |
| def dfs(self, node, depth=0,grfa=-1): | |
| self.dep[node] = depth | |
| self.max_dep = max(self.max_dep, depth) | |
| if node!=self.root: | |
| self.subtree[node] = torch.zeros(len(self.father),dtype=torch.bool) | |
| self.subtree[node][node] = 1 | |
| # if self.fa_pos.get(node) is None: | |
| if self.father[node] != self.root: | |
| self.fa_pos[node] = self.fa_pos[self.father[node]].clone() | |
| self.fa_pos[node][node] = 1 | |
| if grfa == -1: | |
| grfa = node | |
| if self.childs.get(node) is None: | |
| self.grad_fa[node] = grfa | |
| for child in self.childs.get(node, []): | |
| self.dfs(child, depth + 1,grfa) | |
| if node!=self.root: | |
| self.subtree[node] = torch.logical_or(self.subtree[node], self.subtree[child]) | |
| def gen_leaf_item(self,node): | |
| last_node = -1 | |
| leaf_id = node | |
| self.depth[node] = self.dep[node] | |
| while node != self.root: | |
| now_dep=self.dep[node]-1 | |
| self.mask[now_dep,leaf_id] = 1 | |
| self.pos_center[now_dep,leaf_id] = node | |
| self.mask_center[now_dep,leaf_id] = torch.logical_not(torch.logical_or(self.fa_pos[node],self.subtree[node])) | |
| self.mask_center[now_dep,leaf_id,node] = 1 | |
| if last_node == -1: | |
| self.pos_down2up[now_dep,leaf_id] = self.subtree[node] | |
| else: | |
| self.pos_down2up[now_dep,leaf_id]=torch.logical_xor(self.subtree[node],self.subtree[last_node]) | |
| self.neg_down2up[now_dep,leaf_id]=torch.logical_not(self.subtree[node]) | |
| if self.father[node] == self.root: | |
| self.neg_up2down[now_dep,leaf_id] = torch.logical_not(self.subtree[node]) | |
| else: | |
| self.neg_up2down[now_dep,leaf_id] = torch.logical_xor(self.subtree[node],self.subtree[self.father[node]]) | |
| self.pos_up2down[now_dep,leaf_id] = self.subtree[node] | |
| last_node = node | |
| node = self.father[node] | |
| def vis_leaf(self): | |
| for node, name in self.name.items(): | |
| self.gen_leaf_item(node) | |
| def display(self): | |
| for node, name in self.name.items(): | |
| depth = self.dep[node] | |
| print(f"{depth}- {name} {self.father[node]}") | |
| class SimCLR_Tree(nn.Module): | |
| def __init__(self, opt, fabric): | |
| super(SimCLR_Tree, self).__init__() | |
| self.temperature = opt.temperature | |
| self.opt = opt | |
| self.fabric = fabric | |
| adapter_path = getattr(opt, "adapter_path", None) | |
| self.model = TextEmbeddingModel( | |
| opt.model_name, | |
| lora=opt.lora, | |
| use_pooling=opt.pooling, | |
| lora_r=opt.lora_r, | |
| lora_alpha=opt.lora_alpha, | |
| lora_dropout=opt.lora_dropout, | |
| adapter_path=adapter_path, | |
| ) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.tree = Tree(opt.tree_txt) | |
| self.pos_down2up = self.tree.pos_down2up.to(self.device) | |
| self.neg_down2up = self.tree.neg_down2up.to(self.device) | |
| self.pos_up2down = self.tree.pos_up2down.to(self.device) | |
| self.neg_up2down = self.tree.neg_up2down.to(self.device) | |
| self.pos_center = self.tree.pos_center.to(self.device) | |
| self.mask_center = self.tree.mask_center.to(self.device) | |
| self.K = self.pos_down2up.shape[0] | |
| self.mask = self.tree.mask.to(self.device) | |
| self.depth = self.tree.depth.to(self.device) | |
| self.root_labels = self.tree.labels.to(self.device) | |
| self.esp = torch.tensor(1e-6, device=self.device) | |
| self.max_dep = self.tree.max_dep | |
| self.leaf_cnt = len(self.tree.name) | |
| self.names2id = {} | |
| for key, value in self.tree.name.items(): | |
| for item in value: | |
| self.names2id[item] = key | |
| self.vitual_center = nn.Parameter( | |
| torch.randn((len(self.tree.father), opt.projection_size), device=self.device), | |
| requires_grad=True, | |
| ) | |
| nn.init.xavier_uniform_(self.vitual_center) | |
| self.center_labels = torch.arange(len(self.tree.father), dtype=torch.long, device=self.device) | |
| if adapter_path is not None: | |
| self.load_tree_state(adapter_path) | |
| def get_encoder(self): | |
| return self.model | |
| def save_pretrained(self, save_directory: str, save_tokenizer: bool = True): | |
| os.makedirs(save_directory, exist_ok=True) | |
| self.model.save_pretrained(save_directory, save_tokenizer=save_tokenizer) | |
| torch.save( | |
| {"vitual_center": self.vitual_center.detach().cpu()}, | |
| os.path.join(save_directory, "tree_state.pt"), | |
| ) | |
| def load_tree_state(self, directory: str): | |
| state_path = os.path.join(directory, "tree_state.pt") | |
| if not os.path.exists(state_path): | |
| return | |
| state = torch.load(state_path, map_location=self.vitual_center.device) | |
| self.vitual_center.data.copy_(state["vitual_center"].to(self.vitual_center.device)) | |
| def load_from_directory(self, directory: str, is_trainable: bool = True): | |
| if getattr(self.opt, "lora", False): | |
| self.model.load_adapter(directory, is_trainable=is_trainable) | |
| else: | |
| self.model = TextEmbeddingModel( | |
| directory, | |
| lora=False, | |
| use_pooling=self.opt.pooling, | |
| output_hidden_states=False, | |
| ) | |
| self.load_tree_state(directory) | |
| def _compute_logits(self, q,q_labels,k,k_labels,pos_mask,neg_mask): | |
| def cosine_similarity_matrix(q, k): | |
| q_norm = F.normalize(q,dim=-1) | |
| k_norm = F.normalize(k,dim=-1) | |
| cosine_similarity = q_norm@k_norm.T | |
| return cosine_similarity | |
| def gen_label_mask(relation_matrix,q_labels, k_labels): | |
| N1 = q_labels.shape[0] | |
| N2 = k_labels.shape[0] | |
| q_labels_expanded = q_labels.unsqueeze(1).expand(-1, N2) # N1 x N2 | |
| k_labels_expanded = k_labels.unsqueeze(0).expand(N1, -1) # N1 x N2 | |
| result_matrix = relation_matrix[:,q_labels_expanded, k_labels_expanded] | |
| return result_matrix | |
| logits=cosine_similarity_matrix(q,k) | |
| logits=logits/self.temperature | |
| logits = logits.unsqueeze(0).expand(self.K,-1,-1) #K,N1,N2 | |
| pos_mask = gen_label_mask(pos_mask,q_labels, k_labels) | |
| neg_mask = gen_label_mask(neg_mask,q_labels, k_labels) #K,N1,N2 | |
| pos_logits = torch.sum(logits*pos_mask,dim=-1)/torch.max(torch.sum(pos_mask,dim=-1),self.esp)#K,N1 | |
| pos_logits = pos_logits.unsqueeze(-1)#K,N1,1 | |
| neg_logits = logits*neg_mask#K,N1,N2 | |
| logits = torch.cat((pos_logits, neg_logits), dim=-1)#K,N1,N2+1 | |
| #model:model set | |
| # pos_logits_model = torch.sum(logits*same_model,dim=1)/torch.max(torch.sum(same_model,dim=1),self.esp)# N | |
| # neg_logits_model=logits*torch.logical_not(same_model)# N,N+K | |
| # logits_model=torch.cat((pos_logits_model.unsqueeze(1), neg_logits_model), dim=1) | |
| return logits | |
| def forward(self, encoded_batch, labels): | |
| q = self.model(encoded_batch) | |
| N1 = q.shape[0] | |
| k = q.clone().detach() | |
| k = self.fabric.all_gather(k).view(-1, k.size(1)) | |
| k_labels = self.fabric.all_gather(labels).view(-1) | |
| now_depth = self.depth[labels].unsqueeze(0).expand(self.K,-1) | |
| now_mask = self.mask[:,labels] | |
| # leaf_labels = self.root_labels[labels] | |
| k = torch.concat((k,self.vitual_center),dim=0) | |
| k_labels = torch.concat((k_labels,self.center_labels),dim=0) | |
| logits_sample = self._compute_logits(q,labels,k,k_labels,self.pos_down2up,self.neg_down2up)#K,N1,N2+1 | |
| gt_sample = torch.zeros(logits_sample.shape[:-1], dtype=torch.long,device=logits_sample.device) | |
| logits_sample = logits_sample.permute(0,2,1) | |
| loss_smaple1 = F.cross_entropy(logits_sample, gt_sample, reduction='none') #K,N1 | |
| loss_smaple1 = torch.sum((loss_smaple1/now_depth)*now_mask)/N1*self.max_dep | |
| # out = self.root_classfier(q) | |
| # loss_classfiy = F.cross_entropy(out, leaf_labels) | |
| loss = loss_smaple1 | |
| return loss,loss_smaple1 | |
| # def forward(self, encoded_batch, labels): | |
| # q = self.model(encoded_batch) | |
| # # N1 = q.shape[0] | |
| # # k = q.clone().detach() | |
| # # k = self.fabric.all_gather(k).view(-1, k.size(1)) | |
| # # k_labels = self.fabric.all_gather(labels).view(-1) | |
| # # now_depth = self.depth[labels].unsqueeze(0).expand(self.K,-1) | |
| # # now_mask = self.mask[:,labels] | |
| # leaf_labels = self.root_labels[labels] | |
| # # k = torch.concat((k,self.vitual_center),dim=0) | |
| # # k_labels = torch.concat((k_labels,self.center_labels),dim=0) | |
| # # logits_sample = self._compute_logits(q,labels,k,k_labels,self.pos_down2up,self.neg_down2up)#K,N1,N2+1 | |
| # # gt_sample = torch.zeros(logits_sample.shape[:-1], dtype=torch.long,device=logits_sample.device) | |
| # # logits_sample = logits_sample.permute(0,2,1) | |
| # # loss_smaple1 = F.cross_entropy(logits_sample, gt_sample, reduction='none') #K,N1 | |
| # # loss_smaple1 = torch.sum((loss_smaple1/now_depth)*now_mask)/N1*self.max_dep | |
| # out = self.root_classfier(q) | |
| # loss_classfiy = F.cross_entropy(out, leaf_labels) | |
| # loss = loss_classfiy | |
| # return loss,loss_classfiy | |