| import tqdm |
| import time |
| import cProfile |
| import numpy as np |
| from contextlib import contextmanager |
|
|
| try: |
| from apex import amp |
| except Exception: |
| class _DummyAmp: |
| def initialize(self, model, optimizer, opt_level="O1"): |
| return model, optimizer |
|
|
| @contextmanager |
| def scale_loss(self, loss, optimizer): |
| yield loss |
|
|
| amp = _DummyAmp() |
|
|
| import torch
|
| from torch import nn
|
|
|
| from transformers import BertTokenizer, BertConfig, BertModel |
| from transformers import RobertaModel, RobertaConfig, RobertaTokenizer |
| from transformers import XLNetTokenizer, XLNetModel, XLNetConfig |
| from transformers import AutoConfig, AutoModel, AutoTokenizer |
|
|
| from tokenizers import BertWordPieceTokenizer
|
| from transformers import RobertaTokenizerFast
|
|
|
| def get_bert(bert_name): |
| if 'roberta' in bert_name: |
| print('load roberta-base') |
| model_config = RobertaConfig.from_pretrained('roberta-base') |
| model_config.output_hidden_states = True |
| bert = RobertaModel.from_pretrained('roberta-base', config=model_config) |
| elif 'xlnet' in bert_name: |
| print('load xlnet-base-cased') |
| model_config = XLNetConfig.from_pretrained('xlnet-base-cased') |
| model_config.output_hidden_states = True |
| bert = XLNetModel.from_pretrained('xlnet-base-cased', config=model_config) |
| else: |
| if bert_name in ['bert-base', 'bert-base-uncased']: |
| print('load bert-base-uncased') |
| model_config = BertConfig.from_pretrained('bert-base-uncased') |
| model_config.output_hidden_states = True |
| bert = BertModel.from_pretrained('bert-base-uncased', config=model_config) |
| else: |
| print(f'load {bert_name}') |
| model_config = AutoConfig.from_pretrained(bert_name) |
| model_config.output_hidden_states = True |
| bert = AutoModel.from_pretrained(bert_name, config=model_config) |
| return bert |
|
|
| class LightXML(nn.Module):
|
| def __init__(self, n_labels, group_y=None, bert='bert-base', feature_layers=5, dropout=0.5, update_count=1,
|
| candidates_topk=10,
|
| use_swa=True, swa_warmup_epoch=10, swa_update_step=200, hidden_dim=300):
|
| super(LightXML, self).__init__()
|
|
|
| self.use_swa = use_swa
|
| self.swa_warmup_epoch = swa_warmup_epoch
|
| self.swa_update_step = swa_update_step
|
| self.swa_state = {}
|
|
|
| self.update_count = update_count
|
|
|
| self.candidates_topk = candidates_topk
|
|
|
| print('swa', self.use_swa, self.swa_warmup_epoch, self.swa_update_step, self.swa_state)
|
| print('update_count', self.update_count)
|
|
|
| self.bert_name, self.bert = bert, get_bert(bert)
|
| self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout)
|
|
|
| self.group_y = group_y
|
| if self.group_y is not None:
|
| self.group_y_labels = group_y.shape[0]
|
| print('hidden dim:', hidden_dim)
|
| print('label goup numbers:', self.group_y_labels)
|
|
|
| self.l0 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, self.group_y_labels)
|
|
|
| self.l1 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, hidden_dim)
|
| self.embed = nn.Embedding(n_labels, hidden_dim)
|
| nn.init.xavier_uniform_(self.embed.weight)
|
| else:
|
| self.l0 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, n_labels)
|
|
|
| def get_candidates(self, group_logits, group_gd=None):
|
| logits = torch.sigmoid(group_logits.detach())
|
| if group_gd is not None:
|
| logits += group_gd
|
| k = min(self.candidates_topk, logits.shape[1]) |
| if k <= 0: |
| raise ValueError("No group labels available for candidate selection") |
| scores, indices = torch.topk(logits, k=k) |
| scores, indices = scores.cpu().detach().numpy(), indices.cpu().detach().numpy()
|
| candidates, candidates_scores = [], []
|
| for index, score in zip(indices, scores):
|
| candidates.append(self.group_y[index])
|
| candidates_scores.append([np.full(c.shape, s) for c, s in zip(candidates[-1], score)])
|
| candidates[-1] = np.concatenate(candidates[-1])
|
| candidates_scores[-1] = np.concatenate(candidates_scores[-1])
|
| max_candidates = max([i.shape[0] for i in candidates])
|
| candidates = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates])
|
| candidates_scores = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates_scores])
|
| return indices, candidates, candidates_scores
|
|
|
| def forward(self, input_ids, attention_mask, token_type_ids, |
| labels=None, group_labels=None, candidates=None): |
| is_training = labels is not None
|
|
|
| outs = self.bert(
|
| input_ids,
|
| attention_mask=attention_mask,
|
| token_type_ids=token_type_ids
|
| )[-1]
|
|
|
| out = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers+1)], dim=-1)
|
| out = self.drop_out(out)
|
| group_logits = self.l0(out)
|
| if self.group_y is None:
|
| logits = group_logits
|
| if is_training:
|
| loss_fn = torch.nn.BCEWithLogitsLoss()
|
| loss = loss_fn(logits, labels)
|
| return logits, loss
|
| else:
|
| return logits
|
|
|
| if is_training:
|
| l = labels.to(dtype=torch.bool)
|
| target_candidates = torch.masked_select(candidates, l).detach().cpu()
|
| target_candidates_num = l.sum(dim=1).detach().cpu()
|
| groups, candidates, group_candidates_scores = self.get_candidates(group_logits,
|
| group_gd=group_labels if is_training else None)
|
| if is_training:
|
| bs = 0
|
| new_labels = []
|
| for i, n in enumerate(target_candidates_num.numpy()):
|
| be = bs + n
|
| c = set(target_candidates[bs: be].numpy())
|
| c2 = candidates[i]
|
| new_labels.append(torch.tensor([1.0 if i in c else 0.0 for i in c2 ]))
|
| if len(c) != new_labels[-1].sum():
|
| s_c2 = set(c2)
|
| for cc in list(c):
|
| if cc in s_c2:
|
| continue
|
| for j in range(new_labels[-1].shape[0]):
|
| if new_labels[-1][j].item() != 1:
|
| c2[j] = cc
|
| new_labels[-1][j] = 1.0
|
| break
|
| bs = be
|
| labels = torch.stack(new_labels).cuda()
|
| candidates, group_candidates_scores = torch.LongTensor(candidates).cuda(), torch.Tensor(group_candidates_scores).cuda()
|
|
|
| emb = self.l1(out)
|
| embed_weights = self.embed(candidates)
|
| emb = emb.unsqueeze(-1)
|
| logits = torch.bmm(embed_weights, emb).squeeze(-1)
|
|
|
| if is_training:
|
| loss_fn = torch.nn.BCEWithLogitsLoss()
|
| loss = loss_fn(logits, labels) + loss_fn(group_logits, group_labels)
|
| return logits, loss
|
| else:
|
| candidates_scores = torch.sigmoid(logits)
|
| candidates_scores = candidates_scores * group_candidates_scores |
| return group_logits, candidates, candidates_scores |
|
|
| def get_accuracy(self, candidates, logits, labels): |
| if candidates is not None: |
| candidates = candidates.detach().cpu() |
| k = min(10, logits.shape[1]) |
| if k <= 0: |
| return 0, 0, 0, 0 |
| scores, indices = torch.topk(logits.detach().cpu(), k=k) |
|
|
| acc1, acc3, acc5, total = 0, 0, 0, 0 |
| for i, l in enumerate(labels): |
| if candidates is not None and l.shape[0] == candidates.shape[1]: |
| positive_idx = np.nonzero(l)[0] |
| target_labels = set(candidates[i][positive_idx].numpy()) |
| else: |
| target_labels = set(np.nonzero(l)[0]) |
|
|
| if candidates is not None: |
| pred_labels = candidates[i][indices[i]].numpy() |
| else: |
| pred_labels = indices[i, :5].numpy() |
|
|
| acc1 += len(set([pred_labels[0]]) & target_labels) |
| acc3 += len(set(pred_labels[:3]) & target_labels) |
| acc5 += len(set(pred_labels[:5]) & target_labels) |
| total += 1 |
|
|
| return total, acc1, acc3, acc5 |
|
|
| def one_epoch(self, epoch, dataloader, optimizer, |
| mode='train', eval_loader=None, eval_step=20000, log=None, log_interval=50, use_tqdm=False): |
|
|
| total_steps = len(dataloader) |
| bar = tqdm.tqdm(total=total_steps) if use_tqdm else None |
| p1, p3, p5 = 0, 0, 0 |
| g_p1, g_p3, g_p5 = 0, 0, 0 |
| total, acc1, acc3, acc5 = 0, 0, 0, 0 |
| g_acc1, g_acc3, g_acc5 = 0, 0, 0 |
| train_loss = 0 |
|
|
| if mode == 'train': |
| self.train() |
| self.zero_grad() |
| else: |
| self.eval() |
|
|
| if self.use_swa and epoch == self.swa_warmup_epoch and mode == 'train': |
| self.swa_init() |
|
|
| if self.use_swa and mode == 'eval': |
| self.swa_swap_params() |
|
|
| pred_scores, pred_labels = [], [] |
| if bar: |
| bar.set_description(f'{mode}-{epoch}') |
|
|
| device = next(self.parameters()).device |
|
|
| def _gpu_mem(): |
| if device.type != "cuda": |
| return "gpu_mem=NA" |
| alloc = torch.cuda.memory_allocated(device) // (1024 * 1024) |
| total = torch.cuda.get_device_properties(device).total_memory // (1024 * 1024) |
| return f"gpu_mem={alloc}MB/{total}MB" |
|
|
| with torch.set_grad_enabled(mode == 'train'): |
| for step, data in enumerate(dataloader): |
| batch = tuple(t for t in data) |
| inputs = {'input_ids': batch[0].to(device), |
| 'attention_mask': batch[1].to(device), |
| 'token_type_ids': batch[2].to(device)} |
| if mode == 'train': |
| inputs['labels'] = batch[3].to(device) |
| if self.group_y is not None: |
| inputs['group_labels'] = batch[4].to(device) |
| inputs['candidates'] = batch[5].to(device) |
|
|
| outputs = self(**inputs) |
| if bar: |
| bar.update(1) |
|
|
| if mode == 'train': |
| loss = outputs[1] |
| loss_group = outputs[2] |
| loss_rank = outputs[3] |
| loss_group = outputs[2] |
| loss_rank = outputs[3] |
| loss_group = outputs[2] |
| loss_rank = outputs[3] |
| loss_group = outputs[2] |
| loss_rank = outputs[3] |
| loss /= self.update_count |
| train_loss += loss.item() |
|
|
| with amp.scale_loss(loss, optimizer) as scaled_loss: |
| scaled_loss.backward() |
|
|
| if (step + 1) % self.update_count == 0: |
| optimizer.step() |
| self.zero_grad() |
|
|
| if step % eval_step == 0 and eval_loader is not None and step != 0: |
| results = self.one_epoch(epoch, eval_loader, optimizer, mode='eval') |
| p1, p3, p5 = results[3:6] |
| g_p1, g_p3, g_p5 = results[:3] |
| if self.group_y is not None: |
| log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}' |
| f' {g_p1:.4f}, {g_p3:.4f}, {g_p5:.4f}') |
| else: |
| log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}') |
|
|
| if self.use_swa and step % self.swa_update_step == 0: |
| self.swa_step() |
|
|
| if bar: |
| bar.set_postfix(loss=loss.item()) |
| if log_interval and step % log_interval == 0: |
| print(f"[train] epoch={epoch} step={step+1}/{len(dataloader)} " |
| f"loss={loss.item():.6f} loss_group={loss_group.item():.6f} " |
| f"loss_rank={loss_rank.item():.6f} {_gpu_mem()}") |
| elif self.group_y is None: |
| logits = outputs |
| if mode == 'eval': |
| labels = batch[3] |
| _total, _acc1, _acc3, _acc5 = self.get_accuracy(None, logits, labels.cpu().numpy()) |
| total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5 |
| p1 = acc1 / total |
| p3 = acc3 / total / 3 |
| p5 = acc5 / total / 5 |
| bar.set_postfix(p1=p1, p3=p3, p5=p5) |
| elif mode == 'test': |
| pred_scores.append(logits.detach().cpu()) |
| else: |
| group_logits, candidates, logits = outputs |
|
|
| if mode == 'eval': |
| labels = batch[3] |
| group_labels = batch[4] |
|
|
| _total, _acc1, _acc3, _acc5 = self.get_accuracy(candidates, logits, labels.cpu().numpy()) |
| total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5 |
| p1 = acc1 / total |
| p3 = acc3 / total / 3 |
| p5 = acc5 / total / 5 |
|
|
| _, _g_acc1, _g_acc3, _g_acc5 = self.get_accuracy(None, group_logits, group_labels.cpu().numpy()) |
| g_acc1 += _g_acc1; g_acc3 += _g_acc3; g_acc5 += _g_acc5 |
| g_p1 = g_acc1 / total |
| g_p3 = g_acc3 / total / 3 |
| g_p5 = g_acc5 / total / 5 |
| if bar: |
| bar.set_postfix(p1=p1, p3=p3, p5=p5, g_p1=g_p1, g_p3=g_p3, g_p5=g_p5) |
| if log_interval and step % log_interval == 0: |
| print(f"[eval] epoch={epoch} step={step+1}/{len(dataloader)} " |
| f"g_p1={g_p1:.4f} g_p3={g_p3:.4f} g_p5={g_p5:.4f} " |
| f"p1={p1:.4f} p3={p3:.4f} p5={p5:.4f} {_gpu_mem()}") |
| elif mode == 'test': |
| k = min(100, logits.shape[1]) |
| _scores, _indices = torch.topk(logits.detach().cpu(), k=k) |
| _labels = torch.stack([candidates[i][_indices[i]] for i in range(_indices.shape[0])], dim=0) |
| pred_scores.append(_scores.cpu()) |
| pred_labels.append(_labels.cpu()) |
|
|
| if mode == 'train' and total_steps > 0 and total_steps % self.update_count != 0: |
| optimizer.step() |
| self.zero_grad() |
|
|
| if self.use_swa and mode == 'eval': |
| self.swa_swap_params() |
| if bar: |
| bar.close() |
|
|
| if mode == 'eval': |
| return g_p1, g_p3, g_p5, p1, p3, p5 |
| elif mode == 'test': |
| return torch.cat(pred_scores, dim=0).numpy(), torch.cat(pred_labels, dim=0).numpy() if len(pred_labels) != 0 else None |
| elif mode == 'train': |
| return train_loss |
|
|
| def save_model(self, path):
|
| self.swa_swap_params()
|
| torch.save(self.state_dict(), path)
|
| self.swa_swap_params()
|
|
|
| def swa_init(self):
|
| self.swa_state = {'models_num': 1}
|
| for n, p in self.named_parameters():
|
| self.swa_state[n] = p.data.cpu().clone().detach()
|
|
|
| def swa_step(self):
|
| if 'models_num' not in self.swa_state:
|
| return
|
| self.swa_state['models_num'] += 1
|
| beta = 1.0 / self.swa_state['models_num']
|
| with torch.no_grad():
|
| for n, p in self.named_parameters():
|
| self.swa_state[n].mul_(1.0 - beta).add_(beta, p.data.cpu())
|
|
|
| def swa_swap_params(self):
|
| if 'models_num' not in self.swa_state:
|
| return
|
| for n, p in self.named_parameters():
|
| self.swa_state[n], p.data = self.swa_state[n].cpu(), p.data.cpu()
|
| self.swa_state[n], p.data = p.data.cpu(), self.swa_state[n].cuda()
|
|
|
| def get_fast_tokenizer(self): |
| if 'roberta' in self.bert_name: |
| tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', do_lower_case=True) |
| elif 'xlnet' in self.bert_name: |
| tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') |
| elif self.bert_name in ['bert-base', 'bert-base-uncased']: |
| tokenizer = BertWordPieceTokenizer( |
| "data/.bert-base-uncased-vocab.txt", |
| lowercase=True) |
| else: |
| tokenizer = AutoTokenizer.from_pretrained(self.bert_name, use_fast=True) |
| return tokenizer |
|
|
| def get_tokenizer(self): |
| if 'roberta' in self.bert_name: |
| print('load roberta-base tokenizer') |
| try: |
| tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True, local_files_only=True) |
| except Exception: |
| tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True) |
| elif 'xlnet' in self.bert_name: |
| print('load xlnet-base-cased tokenizer') |
| try: |
| tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', local_files_only=True) |
| except Exception: |
| tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') |
| elif self.bert_name in ['bert-base', 'bert-base-uncased']: |
| print('load bert-base-uncased tokenizer') |
| try: |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, local_files_only=True) |
| except Exception: |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) |
| elif self.bert_name == 'bert-base-chinese': |
| print('load bert-base-chinese tokenizer') |
| try: |
| tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True) |
| except Exception: |
| tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') |
| else: |
| print(f'load {self.bert_name} tokenizer') |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(self.bert_name, local_files_only=True) |
| except Exception: |
| tokenizer = AutoTokenizer.from_pretrained(self.bert_name) |
| return tokenizer |
|
|
|
|
| class LightXMLMultiZone(nn.Module): |
| def __init__(self, n_labels, bert='bert-base-chinese', feature_layers=5, dropout=0.5, |
| update_count=1, fusion='concat', hidden_dim=300): |
| super(LightXMLMultiZone, self).__init__() |
|
|
| self.update_count = update_count |
| self.bert_name, self.bert = bert, get_bert(bert) |
| self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout) |
| self.fusion = fusion |
|
|
| self.hidden_size = self.bert.config.hidden_size * self.feature_layers |
| if self.fusion == 'concat': |
| self.fusion_proj = nn.Linear(self.hidden_size * 6, self.hidden_size) |
| elif self.fusion == 'weighted': |
| self.fusion_weight = nn.Parameter(torch.ones(6)) |
| elif self.fusion == 'attention': |
| self.attn = nn.Linear(self.hidden_size, 1) |
| else: |
| raise ValueError(f"Unsupported fusion: {self.fusion}") |
|
|
| self.l0 = nn.Linear(self.hidden_size, n_labels) |
|
|
| def encode_zones(self, input_ids, attention_mask, token_type_ids): |
| |
| bsz, zones, seqlen = input_ids.shape |
| flat_ids = input_ids.view(bsz * zones, seqlen) |
| flat_mask = attention_mask.view(bsz * zones, seqlen) |
| flat_type = token_type_ids.view(bsz * zones, seqlen) |
|
|
| outs = self.bert( |
| flat_ids, |
| attention_mask=flat_mask, |
| token_type_ids=flat_type |
| )[-1] |
|
|
| cls = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers + 1)], dim=-1) |
| cls = cls.view(bsz, zones, -1) |
| return cls |
|
|
| def fuse(self, zone_reps, return_weights=False): |
| if self.fusion == 'concat': |
| fused = zone_reps.reshape(zone_reps.shape[0], -1) |
| fused = self.fusion_proj(fused) |
| if return_weights: |
| return fused, None |
| return fused |
| if self.fusion == 'weighted': |
| weights = torch.softmax(self.fusion_weight, dim=0) |
| fused = (zone_reps * weights.view(1, -1, 1)).sum(dim=1) |
| if return_weights: |
| return fused, weights |
| return fused |
| if self.fusion == 'attention': |
| attn_scores = self.attn(zone_reps).squeeze(-1) |
| attn_weights = torch.softmax(attn_scores, dim=1) |
| fused = (zone_reps * attn_weights.unsqueeze(-1)).sum(dim=1) |
| if return_weights: |
| return fused, attn_weights |
| return fused |
| raise ValueError(f"Unsupported fusion: {self.fusion}") |
|
|
|
|
| class LightXMLMultiZoneGroupY(nn.Module): |
| def __init__(self, n_labels, group_y, bert='bert-base-chinese', feature_layers=5, dropout=0.5, |
| update_count=1, fusion='concat', hidden_dim=300, candidates_topk=20, |
| teacher_force_lambda=0.3): |
| super(LightXMLMultiZoneGroupY, self).__init__() |
|
|
| self.use_swa = True |
| self.swa_warmup_epoch = 10 |
| self.swa_update_step = 200 |
| self.swa_state = {} |
|
|
| self.update_count = update_count |
| self.candidates_topk = candidates_topk |
| self.teacher_force_lambda = teacher_force_lambda |
| self.bert_name, self.bert = bert, get_bert(bert) |
| self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout) |
| self.fusion = fusion |
| self.group_y = group_y |
|
|
| self.hidden_size = self.bert.config.hidden_size * self.feature_layers |
| if self.fusion == 'concat': |
| self.fusion_proj = nn.Linear(self.hidden_size * 6, self.hidden_size) |
| elif self.fusion == 'weighted': |
| self.fusion_weight = nn.Parameter(torch.ones(6)) |
| elif self.fusion == 'attention': |
| self.attn = nn.Linear(self.hidden_size, 1) |
| else: |
| raise ValueError(f"Unsupported fusion: {self.fusion}") |
|
|
| self.group_y_labels = group_y.shape[0] |
| self.l0 = nn.Linear(self.hidden_size, self.group_y_labels) |
| self.l1 = nn.Linear(self.hidden_size, hidden_dim) |
| self.embed = nn.Embedding(n_labels, hidden_dim) |
| nn.init.xavier_uniform_(self.embed.weight) |
|
|
| def encode_zones(self, input_ids, attention_mask, token_type_ids): |
| bsz, zones, seqlen = input_ids.shape |
| flat_ids = input_ids.view(bsz * zones, seqlen) |
| flat_mask = attention_mask.view(bsz * zones, seqlen) |
| flat_type = token_type_ids.view(bsz * zones, seqlen) |
| outs = self.bert( |
| flat_ids, |
| attention_mask=flat_mask, |
| token_type_ids=flat_type |
| )[-1] |
| cls = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers + 1)], dim=-1) |
| return cls.view(bsz, zones, -1) |
|
|
| def fuse(self, zone_reps): |
| if self.fusion == 'concat': |
| fused = zone_reps.reshape(zone_reps.shape[0], -1) |
| return self.fusion_proj(fused) |
| if self.fusion == 'weighted': |
| weights = torch.softmax(self.fusion_weight, dim=0) |
| return (zone_reps * weights.view(1, -1, 1)).sum(dim=1) |
| if self.fusion == 'attention': |
| attn_scores = self.attn(zone_reps).squeeze(-1) |
| attn_weights = torch.softmax(attn_scores, dim=1) |
| return (zone_reps * attn_weights.unsqueeze(-1)).sum(dim=1) |
| raise ValueError(f"Unsupported fusion: {self.fusion}") |
|
|
| def get_candidates(self, group_logits, group_gd=None): |
| logits = torch.sigmoid(group_logits.detach()) |
| if group_gd is not None and self.teacher_force_lambda > 0: |
| logits += self.teacher_force_lambda * group_gd |
| k = min(self.candidates_topk, logits.shape[1]) |
| if k <= 0: |
| raise ValueError("No group labels available for candidate selection") |
| scores, indices = torch.topk(logits, k=k) |
| scores, indices = scores.cpu().detach().numpy(), indices.cpu().detach().numpy() |
| candidates, candidates_scores = [], [] |
| for index, score in zip(indices, scores): |
| candidates.append(self.group_y[index]) |
| candidates_scores.append([np.full(c.shape, s) for c, s in zip(candidates[-1], score)]) |
| candidates[-1] = np.concatenate(candidates[-1]) |
| candidates_scores[-1] = np.concatenate(candidates_scores[-1]) |
| max_candidates = max([i.shape[0] for i in candidates]) |
| candidates = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates]) |
| candidates_scores = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates_scores]) |
| return indices, candidates, candidates_scores |
|
|
| def forward(self, input_ids, attention_mask, token_type_ids, |
| labels=None, group_labels=None, candidates=None): |
| is_training = labels is not None |
| zone_reps = self.encode_zones(input_ids, attention_mask, token_type_ids) |
| out = self.drop_out(self.fuse(zone_reps)) |
| group_logits = self.l0(out) |
|
|
| if is_training: |
| l = labels.to(dtype=torch.bool) |
| target_candidates = torch.masked_select(candidates, l).detach().cpu() |
| target_candidates_num = l.sum(dim=1).detach().cpu() |
| groups, candidates, group_candidates_scores = self.get_candidates( |
| group_logits, group_gd=group_labels if is_training else None |
| ) |
| if is_training: |
| bs = 0 |
| new_labels = [] |
| for i, n in enumerate(target_candidates_num.numpy()): |
| be = bs + n |
| c = set(target_candidates[bs: be].numpy()) |
| c2 = candidates[i] |
| new_labels.append(torch.tensor([1.0 if i in c else 0.0 for i in c2])) |
| if len(c) != new_labels[-1].sum(): |
| s_c2 = set(c2) |
| for cc in list(c): |
| if cc in s_c2: |
| continue |
| for j in range(new_labels[-1].shape[0]): |
| if new_labels[-1][j].item() != 1: |
| c2[j] = cc |
| new_labels[-1][j] = 1.0 |
| break |
| bs = be |
| labels = torch.stack(new_labels).to(input_ids.device) |
|
|
| candidates = torch.LongTensor(candidates).to(input_ids.device) |
| group_candidates_scores = torch.Tensor(group_candidates_scores).to(input_ids.device) |
|
|
| emb = self.l1(out) |
| embed_weights = self.embed(candidates) |
| emb = emb.unsqueeze(-1) |
| logits = torch.bmm(embed_weights, emb).squeeze(-1) |
|
|
| if is_training: |
| loss_fn = torch.nn.BCEWithLogitsLoss() |
| loss_rank = loss_fn(logits, labels) |
| loss_group = loss_fn(group_logits, group_labels) |
| loss = loss_rank + loss_group |
| return logits, loss, loss_group.detach(), loss_rank.detach() |
| else: |
| candidates_scores = torch.sigmoid(logits) |
| candidates_scores = candidates_scores * group_candidates_scores |
| return group_logits, candidates, candidates_scores |
|
|
| def get_accuracy(self, candidates, logits, labels): |
| if candidates is not None: |
| candidates = candidates.detach().cpu() |
| k = min(10, logits.shape[1]) |
| if k <= 0: |
| return 0, 0, 0, 0 |
| scores, indices = torch.topk(logits.detach().cpu(), k=k) |
|
|
| acc1, acc3, acc5, total = 0, 0, 0, 0 |
| for i, l in enumerate(labels): |
| if candidates is not None and l.shape[0] == candidates.shape[1]: |
| positive_idx = np.nonzero(l)[0] |
| target_labels = set(candidates[i][positive_idx].numpy()) |
| else: |
| target_labels = set(np.nonzero(l)[0]) |
|
|
| if candidates is not None: |
| pred_labels = candidates[i][indices[i]].numpy() |
| else: |
| pred_labels = indices[i, :5].numpy() |
|
|
| acc1 += len(set([pred_labels[0]]) & target_labels) |
| acc3 += len(set(pred_labels[:3]) & target_labels) |
| acc5 += len(set(pred_labels[:5]) & target_labels) |
| total += 1 |
|
|
| return total, acc1, acc3, acc5 |
|
|
| def one_epoch(self, epoch, dataloader, optimizer, |
| mode='train', eval_loader=None, eval_step=20000, log=None, log_interval=50, use_tqdm=False): |
|
|
| total_steps = len(dataloader) |
| bar = tqdm.tqdm(total=total_steps) if use_tqdm else None |
| p1, p3, p5 = 0, 0, 0 |
| g_p1, g_p3, g_p5 = 0, 0, 0 |
| total, acc1, acc3, acc5 = 0, 0, 0, 0 |
| g_acc1, g_acc3, g_acc5 = 0, 0, 0 |
| cand_hit, cand_total = 0, 0 |
| train_loss = 0 |
|
|
| if mode == 'train': |
| self.train() |
| self.zero_grad() |
| else: |
| self.eval() |
|
|
| if self.use_swa and epoch == self.swa_warmup_epoch and mode == 'train': |
| self.swa_init() |
|
|
| if self.use_swa and mode == 'eval': |
| self.swa_swap_params() |
|
|
| pred_scores, pred_labels = [], [] |
| if bar: |
| bar.set_description(f'{mode}-{epoch}') |
|
|
| device = next(self.parameters()).device |
|
|
| def _gpu_mem(): |
| if device.type != "cuda": |
| return "gpu_mem=NA" |
| alloc = torch.cuda.memory_allocated(device) // (1024 * 1024) |
| total = torch.cuda.get_device_properties(device).total_memory // (1024 * 1024) |
| return f"gpu_mem={alloc}MB/{total}MB" |
|
|
| with torch.set_grad_enabled(mode == 'train'): |
| for step, data in enumerate(dataloader): |
| batch = tuple(t for t in data) |
| inputs = {'input_ids': batch[0].to(device), |
| 'attention_mask': batch[1].to(device), |
| 'token_type_ids': batch[2].to(device)} |
| if mode == 'train': |
| inputs['labels'] = batch[3].to(device) |
| inputs['group_labels'] = batch[4].to(device) |
| inputs['candidates'] = batch[5].to(device) |
|
|
| outputs = self(**inputs) |
| if bar: |
| bar.update(1) |
|
|
| if mode == 'train': |
| loss = outputs[1] |
| loss_group = outputs[2] |
| loss_rank = outputs[3] |
| loss /= self.update_count |
| train_loss += loss.item() |
|
|
| with amp.scale_loss(loss, optimizer) as scaled_loss: |
| scaled_loss.backward() |
|
|
| if (step + 1) % self.update_count == 0: |
| optimizer.step() |
| self.zero_grad() |
|
|
| if step % eval_step == 0 and eval_loader is not None and step != 0: |
| results = self.one_epoch(epoch, eval_loader, optimizer, mode='eval') |
| p1, p3, p5 = results[3:6] |
| g_p1, g_p3, g_p5 = results[:3] |
| log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}' |
| f' {g_p1:.4f}, {g_p3:.4f}, {g_p5:.4f}') |
|
|
| if self.use_swa and step % self.swa_update_step == 0: |
| self.swa_step() |
|
|
| if bar: |
| bar.set_postfix(loss=loss.item()) |
| if log_interval and step % log_interval == 0: |
| print(f"[train] epoch={epoch} step={step+1}/{len(dataloader)} " |
| f"loss={loss.item():.6f} loss_group={loss_group.item():.6f} " |
| f"loss_rank={loss_rank.item():.6f} {_gpu_mem()}") |
| else: |
| group_logits, candidates, logits = outputs |
|
|
| if mode == 'eval': |
| labels = batch[3] |
| group_labels = batch[4] |
|
|
| _total, _acc1, _acc3, _acc5 = self.get_accuracy(candidates, logits, labels.cpu().numpy()) |
| total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5 |
| p1 = acc1 / total |
| p3 = acc3 / total / 3 |
| p5 = acc5 / total / 5 |
|
|
| cand_np = candidates.detach().cpu().numpy() |
| labels_np = labels.cpu().numpy() |
| for i, l in enumerate(labels_np): |
| pos = np.nonzero(l)[0] |
| if pos.shape[0] == 0: |
| continue |
| cand_set = set(cand_np[i].tolist()) |
| cand_hit += sum(1 for t in pos if t in cand_set) |
| cand_total += pos.shape[0] |
| candidate_recall = cand_hit / cand_total if cand_total > 0 else 0.0 |
|
|
| _, _g_acc1, _g_acc3, _g_acc5 = self.get_accuracy(None, group_logits, group_labels.cpu().numpy()) |
| g_acc1 += _g_acc1; g_acc3 += _g_acc3; g_acc5 += _g_acc5 |
| g_p1 = g_acc1 / total |
| g_p3 = g_acc3 / total / 3 |
| g_p5 = g_acc5 / total / 5 |
| if bar: |
| bar.set_postfix(p1=p1, p3=p3, p5=p5, g_p1=g_p1, g_p3=g_p3, g_p5=g_p5, |
| cand_recall=candidate_recall) |
| if log_interval and step % log_interval == 0: |
| print(f"[eval] epoch={epoch} step={step+1}/{len(dataloader)} " |
| f"g_p1={g_p1:.4f} g_p3={g_p3:.4f} g_p5={g_p5:.4f} " |
| f"p1={p1:.4f} p3={p3:.4f} p5={p5:.4f} " |
| f"cand_recall={candidate_recall:.4f} {_gpu_mem()}") |
| elif mode == 'test': |
| k = min(100, logits.shape[1]) |
| _scores, _indices = torch.topk(logits.detach().cpu(), k=k) |
| _labels = torch.stack([candidates[i][_indices[i]] for i in range(_indices.shape[0])], dim=0) |
| pred_scores.append(_scores.cpu()) |
| pred_labels.append(_labels.cpu()) |
|
|
| if mode == 'train' and total_steps > 0 and total_steps % self.update_count != 0: |
| optimizer.step() |
| self.zero_grad() |
|
|
| if self.use_swa and mode == 'eval': |
| self.swa_swap_params() |
| if bar: |
| bar.close() |
|
|
| if mode == 'eval': |
| candidate_recall = cand_hit / cand_total if cand_total > 0 else 0.0 |
| return g_p1, g_p3, g_p5, p1, p3, p5, candidate_recall |
| elif mode == 'test': |
| return torch.cat(pred_scores, dim=0).numpy(), torch.cat(pred_labels, dim=0).numpy() if len(pred_labels) != 0 else None |
| elif mode == 'train': |
| return train_loss |
|
|
| def swa_init(self): |
| self.swa_state = {'models_num': 1} |
| for n, p in self.named_parameters(): |
| self.swa_state[n] = p.data.detach().cpu().clone() |
|
|
| def swa_step(self): |
| if 'models_num' not in self.swa_state: |
| return |
| self.swa_state['models_num'] += 1 |
| beta = 1.0 / self.swa_state['models_num'] |
| with torch.no_grad(): |
| for n, p in self.named_parameters(): |
| self.swa_state[n].mul_(1.0 - beta).add_(beta, p.data.detach().cpu()) |
|
|
| def swa_swap_params(self): |
| if 'models_num' not in self.swa_state: |
| return |
| for n, p in self.named_parameters(): |
| buf = self.swa_state[n] |
| self.swa_state[n] = p.data.detach().cpu() |
| p.data = buf.to(p.data.device) |
|
|
| def get_tokenizer(self): |
| if 'roberta' in self.bert_name: |
| try: |
| return RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True, local_files_only=True) |
| except Exception: |
| return RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True) |
| if 'xlnet' in self.bert_name: |
| try: |
| return XLNetTokenizer.from_pretrained('xlnet-base-cased', local_files_only=True) |
| except Exception: |
| return XLNetTokenizer.from_pretrained('xlnet-base-cased') |
| if self.bert_name in ['bert-base', 'bert-base-uncased']: |
| try: |
| return BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, local_files_only=True) |
| except Exception: |
| return BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) |
| if self.bert_name == 'bert-base-chinese': |
| try: |
| return BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True) |
| except Exception: |
| return BertTokenizer.from_pretrained('bert-base-chinese') |
| try: |
| return AutoTokenizer.from_pretrained(self.bert_name, local_files_only=True) |
| except Exception: |
| return AutoTokenizer.from_pretrained(self.bert_name) |
|
|