Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| ''' | |
| @Author : Jiangjie Chen | |
| @Time : 2020/10/15 16:10 | |
| @Contact : jjchen19@fudan.edu.cn | |
| @Description: | |
| ''' | |
| import torch | |
| import random | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| class ClassificationHead(nn.Module): | |
| """Head for sentence-level classification tasks.""" | |
| def __init__(self, hidden_size, num_labels, hidden_dropout_prob=0.2): | |
| super().__init__() | |
| self.dropout = nn.Dropout(hidden_dropout_prob) | |
| self.out_proj = nn.Linear(hidden_size, num_labels, bias=False) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| def temperature_annealing(tau, step): | |
| if tau == 0.: | |
| tau = 10. if step % 5 == 0 else 1. | |
| return tau | |
| def get_label_embeddings(labels, label_embedding): | |
| ''' | |
| :param labels: b x 3 | |
| :param label_embedding: 3 x h' | |
| :return: b x h' | |
| ''' | |
| emb = torch.einsum('oi,bo->bi', label_embedding, labels) | |
| return emb | |
| def soft_logic(y_i, mask, tnorm='product'): | |
| ''' | |
| a^b = ab | |
| avb = 1 - ((1-a)(1-b)) | |
| :param y_i: b x m x 3 | |
| :param mask: b x m | |
| :param tnorm: product or godel or lukasiewicz | |
| :return: [b x 3] | |
| ''' | |
| _sup = y_i[:, :, 2] # b x m | |
| _ref = y_i[:, :, 0] # b x m | |
| _sup = _sup * mask + (1 - mask) # pppp1111 | |
| _ref = _ref * mask # pppp0000 | |
| if tnorm == 'product': | |
| p_sup = torch.exp(torch.log(_sup).sum(1)) | |
| p_ref = 1 - torch.exp(torch.log(1 - _ref).sum(1)) | |
| elif tnorm == 'godel': | |
| p_sup = _sup.min(-1).values | |
| p_ref = _ref.max(-1).values | |
| elif tnorm == 'lukas': | |
| raise NotImplementedError(tnorm) | |
| else: | |
| raise NotImplementedError(tnorm) | |
| p_nei = 1 - p_sup - p_ref | |
| p_sup = torch.max(p_sup, torch.zeros_like(p_sup)) | |
| p_ref = torch.max(p_ref, torch.zeros_like(p_ref)) | |
| p_nei = torch.max(p_nei, torch.zeros_like(p_nei)) | |
| logical_prob = torch.stack([p_ref, p_nei, p_sup], dim=-1) | |
| assert torch.lt(logical_prob, 0).to(torch.int).sum().tolist() == 0, \ | |
| (logical_prob, _sup, _ref) | |
| return logical_prob # b x 3 | |
| def build_pseudo_labels(labels, m_attn): | |
| ''' | |
| :param labels: (b,) | |
| :param m_attn: b x m | |
| :return: b x m x 3 | |
| ''' | |
| mask = torch.gt(m_attn, 1e-16).to(torch.int) | |
| sup_label = torch.tensor(2).to(labels) | |
| nei_label = torch.tensor(1).to(labels) | |
| ref_label = torch.tensor(0).to(labels) | |
| pseudo_labels = [] | |
| for idx, label in enumerate(labels): | |
| mm = mask[idx].sum(0) | |
| if label == 2: # SUPPORTS | |
| pseudo_label = F.one_hot(sup_label.repeat(mask.size(1)), num_classes=3).to(torch.float) # TODO: hyperparam | |
| elif label == 0: # REFUTES | |
| num_samples = magic_proportion(mm) | |
| ids = torch.topk(m_attn[idx], k=num_samples).indices | |
| pseudo_label = [] | |
| for i in range(mask.size(1)): | |
| if i >= mm: | |
| _label = torch.tensor([1/3, 1/3, 1/3]).to(labels) | |
| elif i in ids: | |
| _label = F.one_hot(ref_label, num_classes=3).to(torch.float) | |
| else: | |
| if random.random() > 0.5: | |
| _label = torch.tensor([0., 0., 1.]).to(labels) | |
| else: | |
| _label = torch.tensor([0., 1., 0.]).to(labels) | |
| pseudo_label.append(_label) | |
| pseudo_label = torch.stack(pseudo_label) | |
| else: # NEI | |
| num_samples = magic_proportion(mm) | |
| ids = torch.topk(m_attn[idx], k=num_samples).indices | |
| pseudo_label = sup_label.repeat(mask.size(1)) | |
| pseudo_label[ids] = nei_label | |
| pseudo_label = F.one_hot(pseudo_label, num_classes=3).to(torch.float) # TODO: hyperparam | |
| pseudo_labels.append(pseudo_label) | |
| return torch.stack(pseudo_labels) | |
| def magic_proportion(m, magic_n=5): | |
| # 1~4: 1, 5~m: 2 | |
| return m // magic_n + 1 | |
| def sequence_mask(lengths, max_len=None): | |
| """ | |
| Creates a boolean mask from sequence lengths. | |
| """ | |
| batch_size = lengths.numel() | |
| max_len = max_len or lengths.max() | |
| return (torch.arange(0, max_len, device=lengths.device) | |
| .type_as(lengths) | |
| .repeat(batch_size, 1) | |
| .lt(lengths.unsqueeze(1))) | |
| def collapse_w_mask(inputs, mask): | |
| ''' | |
| :param inputs: b x L x h | |
| :param mask: b x L | |
| :return: b x h | |
| ''' | |
| hidden = inputs.size(-1) | |
| output = inputs * mask.unsqueeze(-1).repeat((1, 1, hidden)) # b x L x h | |
| output = output.sum(-2) | |
| output /= (mask.sum(-1) + 1e-6).unsqueeze(-1).repeat((1, hidden)) # b x h | |
| return output | |
| def parse_ce_outputs(ce_seq_output, ce_lengths): | |
| ''' | |
| :param qa_seq_output: b x L1 x h | |
| :param qa_lengths: e.g. [0,1,1,0,2,2,0,0] (b x L2) | |
| :return: | |
| c_output: b x h | |
| e_output: b x h | |
| ''' | |
| if ce_lengths.max() == 0: | |
| b, L1, h = ce_seq_output.size() | |
| return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda() | |
| masks = [] | |
| for mask_id in range(1, ce_lengths.max() + 1): | |
| _m = torch.ones_like(ce_lengths) * mask_id | |
| mask = _m.eq(ce_lengths).to(torch.int) | |
| masks.append(mask) | |
| c_output = collapse_w_mask(ce_seq_output, masks[0]) | |
| e_output = torch.stack([collapse_w_mask(ce_seq_output, m) | |
| for m in masks[1:]]).mean(0) | |
| return c_output, e_output | |
| def parse_qa_outputs(qa_seq_output, qa_lengths, k): | |
| ''' | |
| :param qa_seq_output: b x L2 x h | |
| :param qa_lengths: e.g. [0,1,1,0,2,2,0,3,0,4,0,5,0,0,0,0] (b x L2) | |
| :return: | |
| q_output: b x h | |
| a_output: b x h | |
| k_cand_output: k x b x h | |
| ''' | |
| b, L2, h = qa_seq_output.size() | |
| if qa_lengths.max() == 0: | |
| return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda(), \ | |
| torch.zeros([k, b, h]).cuda() | |
| masks = [] | |
| for mask_id in range(1, qa_lengths.max() + 1): | |
| _m = torch.ones_like(qa_lengths) * mask_id | |
| mask = _m.eq(qa_lengths).to(torch.int) | |
| masks.append(mask) | |
| q_output = collapse_w_mask(qa_seq_output, masks[0]) | |
| a_output = collapse_w_mask(qa_seq_output, masks[1]) | |
| k_cand_output = [collapse_w_mask(qa_seq_output, m) | |
| for m in masks[2:2 + k]] | |
| for i in range(k - len(k_cand_output)): | |
| k_cand_output.append(torch.zeros([b, h]).cuda()) | |
| k_cand_output = torch.stack(k_cand_output, dim=0) | |
| return q_output, a_output, k_cand_output | |
| def attention_mask_to_mask(attention_mask): | |
| ''' | |
| :param attention_mask: b x m x L | |
| :return: b x m | |
| ''' | |
| mask = torch.gt(attention_mask.sum(-1), 0).to(torch.int).sum(-1) # (b,) | |
| mask = sequence_mask(mask, max_len=attention_mask.size(1)).to(torch.int) # (b, m) | |
| return mask | |
| if __name__ == "__main__": | |
| y = torch.tensor([[[0.3,0.5,0.2],[0.1,0.4,0.5]]]) | |
| mask = torch.tensor([1,1]) | |
| s = soft_logic(y, mask) | |
| print(s) |