MAS-AI-0000's picture
Upload 2 files
1b51bf6 verified
raw
history blame
12 kB
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from .text_embedding import TextEmbeddingModel
class Tree():
def __init__(self,path):
self.name = {}
self.childs = {}
self.father = {}
self.dep = {}
self.root = None
self.max_dep = 0
self.subtree = {}
self.grad_fa = {} # the node closest to the root for each leaf
with open(path, 'r') as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
assert len(parts) == 3, "Each line must have exactly three parts"
now,fa,name = parts
now,fa = int(now),int(fa)
if name != 'none':
self.name[now] = name.split(',')
if fa != -1:
self.childs[fa] = self.childs.get(fa, []) + [now]
self.father[now] = fa
else:
self.root = now
self.fa_pos = torch.zeros((len(self.father),len(self.father)),dtype=torch.bool)
self.dfs(self.root)
#max_dep,N,N+K 0/1
self.pos_down2up = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
self.neg_down2up = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
self.pos_up2down = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
self.neg_up2down = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
self.pos_center = torch.zeros((self.max_dep,len(self.name)),dtype=torch.long)
self.mask_center = torch.zeros((self.max_dep,len(self.name),len(self.father)),dtype=torch.bool)
#max_dep,N 0/1
self.mask = torch.zeros((self.max_dep,len(self.name)),dtype=torch.bool)
self.depth = torch.zeros(len(self.name))
self.labels = torch.zeros(len(self.name),dtype=torch.long)
self.vis_leaf()
label_value = list(set(self.grad_fa.values()))
for key, value in self.grad_fa.items():
self.labels[key] = label_value.index(value)
def dfs(self, node, depth=0,grfa=-1):
self.dep[node] = depth
self.max_dep = max(self.max_dep, depth)
if node!=self.root:
self.subtree[node] = torch.zeros(len(self.father),dtype=torch.bool)
self.subtree[node][node] = 1
# if self.fa_pos.get(node) is None:
if self.father[node] != self.root:
self.fa_pos[node] = self.fa_pos[self.father[node]].clone()
self.fa_pos[node][node] = 1
if grfa == -1:
grfa = node
if self.childs.get(node) is None:
self.grad_fa[node] = grfa
for child in self.childs.get(node, []):
self.dfs(child, depth + 1,grfa)
if node!=self.root:
self.subtree[node] = torch.logical_or(self.subtree[node], self.subtree[child])
def gen_leaf_item(self,node):
last_node = -1
leaf_id = node
self.depth[node] = self.dep[node]
while node != self.root:
now_dep=self.dep[node]-1
self.mask[now_dep,leaf_id] = 1
self.pos_center[now_dep,leaf_id] = node
self.mask_center[now_dep,leaf_id] = torch.logical_not(torch.logical_or(self.fa_pos[node],self.subtree[node]))
self.mask_center[now_dep,leaf_id,node] = 1
if last_node == -1:
self.pos_down2up[now_dep,leaf_id] = self.subtree[node]
else:
self.pos_down2up[now_dep,leaf_id]=torch.logical_xor(self.subtree[node],self.subtree[last_node])
self.neg_down2up[now_dep,leaf_id]=torch.logical_not(self.subtree[node])
if self.father[node] == self.root:
self.neg_up2down[now_dep,leaf_id] = torch.logical_not(self.subtree[node])
else:
self.neg_up2down[now_dep,leaf_id] = torch.logical_xor(self.subtree[node],self.subtree[self.father[node]])
self.pos_up2down[now_dep,leaf_id] = self.subtree[node]
last_node = node
node = self.father[node]
def vis_leaf(self):
for node, name in self.name.items():
self.gen_leaf_item(node)
def display(self):
for node, name in self.name.items():
depth = self.dep[node]
print(f"{depth}- {name} {self.father[node]}")
class SimCLR_Tree(nn.Module):
def __init__(self, opt, fabric):
super(SimCLR_Tree, self).__init__()
self.temperature = opt.temperature
self.opt = opt
self.fabric = fabric
adapter_path = getattr(opt, "adapter_path", None)
self.model = TextEmbeddingModel(
opt.model_name,
lora=opt.lora,
use_pooling=opt.pooling,
lora_r=opt.lora_r,
lora_alpha=opt.lora_alpha,
lora_dropout=opt.lora_dropout,
adapter_path=adapter_path,
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tree = Tree(opt.tree_txt)
self.pos_down2up = self.tree.pos_down2up.to(self.device)
self.neg_down2up = self.tree.neg_down2up.to(self.device)
self.pos_up2down = self.tree.pos_up2down.to(self.device)
self.neg_up2down = self.tree.neg_up2down.to(self.device)
self.pos_center = self.tree.pos_center.to(self.device)
self.mask_center = self.tree.mask_center.to(self.device)
self.K = self.pos_down2up.shape[0]
self.mask = self.tree.mask.to(self.device)
self.depth = self.tree.depth.to(self.device)
self.root_labels = self.tree.labels.to(self.device)
self.esp = torch.tensor(1e-6, device=self.device)
self.max_dep = self.tree.max_dep
self.leaf_cnt = len(self.tree.name)
self.names2id = {}
for key, value in self.tree.name.items():
for item in value:
self.names2id[item] = key
self.vitual_center = nn.Parameter(
torch.randn((len(self.tree.father), opt.projection_size), device=self.device),
requires_grad=True,
)
nn.init.xavier_uniform_(self.vitual_center)
self.center_labels = torch.arange(len(self.tree.father), dtype=torch.long, device=self.device)
if adapter_path is not None:
self.load_tree_state(adapter_path)
def get_encoder(self):
return self.model
def save_pretrained(self, save_directory: str, save_tokenizer: bool = True):
os.makedirs(save_directory, exist_ok=True)
self.model.save_pretrained(save_directory, save_tokenizer=save_tokenizer)
torch.save(
{"vitual_center": self.vitual_center.detach().cpu()},
os.path.join(save_directory, "tree_state.pt"),
)
def load_tree_state(self, directory: str):
state_path = os.path.join(directory, "tree_state.pt")
if not os.path.exists(state_path):
return
state = torch.load(state_path, map_location=self.vitual_center.device)
self.vitual_center.data.copy_(state["vitual_center"].to(self.vitual_center.device))
def load_from_directory(self, directory: str, is_trainable: bool = True):
if getattr(self.opt, "lora", False):
self.model.load_adapter(directory, is_trainable=is_trainable)
else:
self.model = TextEmbeddingModel(
directory,
lora=False,
use_pooling=self.opt.pooling,
output_hidden_states=False,
)
self.load_tree_state(directory)
def _compute_logits(self, q,q_labels,k,k_labels,pos_mask,neg_mask):
def cosine_similarity_matrix(q, k):
q_norm = F.normalize(q,dim=-1)
k_norm = F.normalize(k,dim=-1)
cosine_similarity = q_norm@k_norm.T
return cosine_similarity
def gen_label_mask(relation_matrix,q_labels, k_labels):
N1 = q_labels.shape[0]
N2 = k_labels.shape[0]
q_labels_expanded = q_labels.unsqueeze(1).expand(-1, N2) # N1 x N2
k_labels_expanded = k_labels.unsqueeze(0).expand(N1, -1) # N1 x N2
result_matrix = relation_matrix[:,q_labels_expanded, k_labels_expanded]
return result_matrix
logits=cosine_similarity_matrix(q,k)
logits=logits/self.temperature
logits = logits.unsqueeze(0).expand(self.K,-1,-1) #K,N1,N2
pos_mask = gen_label_mask(pos_mask,q_labels, k_labels)
neg_mask = gen_label_mask(neg_mask,q_labels, k_labels) #K,N1,N2
pos_logits = torch.sum(logits*pos_mask,dim=-1)/torch.max(torch.sum(pos_mask,dim=-1),self.esp)#K,N1
pos_logits = pos_logits.unsqueeze(-1)#K,N1,1
neg_logits = logits*neg_mask#K,N1,N2
logits = torch.cat((pos_logits, neg_logits), dim=-1)#K,N1,N2+1
#model:model set
# pos_logits_model = torch.sum(logits*same_model,dim=1)/torch.max(torch.sum(same_model,dim=1),self.esp)# N
# neg_logits_model=logits*torch.logical_not(same_model)# N,N+K
# logits_model=torch.cat((pos_logits_model.unsqueeze(1), neg_logits_model), dim=1)
return logits
def forward(self, encoded_batch, labels):
q = self.model(encoded_batch)
N1 = q.shape[0]
k = q.clone().detach()
k = self.fabric.all_gather(k).view(-1, k.size(1))
k_labels = self.fabric.all_gather(labels).view(-1)
now_depth = self.depth[labels].unsqueeze(0).expand(self.K,-1)
now_mask = self.mask[:,labels]
# leaf_labels = self.root_labels[labels]
k = torch.concat((k,self.vitual_center),dim=0)
k_labels = torch.concat((k_labels,self.center_labels),dim=0)
logits_sample = self._compute_logits(q,labels,k,k_labels,self.pos_down2up,self.neg_down2up)#K,N1,N2+1
gt_sample = torch.zeros(logits_sample.shape[:-1], dtype=torch.long,device=logits_sample.device)
logits_sample = logits_sample.permute(0,2,1)
loss_smaple1 = F.cross_entropy(logits_sample, gt_sample, reduction='none') #K,N1
loss_smaple1 = torch.sum((loss_smaple1/now_depth)*now_mask)/N1*self.max_dep
# out = self.root_classfier(q)
# loss_classfiy = F.cross_entropy(out, leaf_labels)
loss = loss_smaple1
return loss,loss_smaple1
# def forward(self, encoded_batch, labels):
# q = self.model(encoded_batch)
# # N1 = q.shape[0]
# # k = q.clone().detach()
# # k = self.fabric.all_gather(k).view(-1, k.size(1))
# # k_labels = self.fabric.all_gather(labels).view(-1)
# # now_depth = self.depth[labels].unsqueeze(0).expand(self.K,-1)
# # now_mask = self.mask[:,labels]
# leaf_labels = self.root_labels[labels]
# # k = torch.concat((k,self.vitual_center),dim=0)
# # k_labels = torch.concat((k_labels,self.center_labels),dim=0)
# # logits_sample = self._compute_logits(q,labels,k,k_labels,self.pos_down2up,self.neg_down2up)#K,N1,N2+1
# # gt_sample = torch.zeros(logits_sample.shape[:-1], dtype=torch.long,device=logits_sample.device)
# # logits_sample = logits_sample.permute(0,2,1)
# # loss_smaple1 = F.cross_entropy(logits_sample, gt_sample, reduction='none') #K,N1
# # loss_smaple1 = torch.sum((loss_smaple1/now_depth)*now_mask)/N1*self.max_dep
# out = self.root_classfier(q)
# loss_classfiy = F.cross_entropy(out, leaf_labels)
# loss = loss_classfiy
# return loss,loss_classfiy