import tqdm import time import cProfile import numpy as np from contextlib import contextmanager try: from apex import amp except Exception: # pragma: no cover - fallback when apex is unavailable class _DummyAmp: def initialize(self, model, optimizer, opt_level="O1"): return model, optimizer @contextmanager def scale_loss(self, loss, optimizer): yield loss amp = _DummyAmp() import torch from torch import nn from transformers import BertTokenizer, BertConfig, BertModel from transformers import RobertaModel, RobertaConfig, RobertaTokenizer from transformers import XLNetTokenizer, XLNetModel, XLNetConfig from transformers import AutoConfig, AutoModel, AutoTokenizer from tokenizers import BertWordPieceTokenizer from transformers import RobertaTokenizerFast def get_bert(bert_name): if 'roberta' in bert_name: print('load roberta-base') model_config = RobertaConfig.from_pretrained('roberta-base') model_config.output_hidden_states = True bert = RobertaModel.from_pretrained('roberta-base', config=model_config) elif 'xlnet' in bert_name: print('load xlnet-base-cased') model_config = XLNetConfig.from_pretrained('xlnet-base-cased') model_config.output_hidden_states = True bert = XLNetModel.from_pretrained('xlnet-base-cased', config=model_config) else: if bert_name in ['bert-base', 'bert-base-uncased']: print('load bert-base-uncased') model_config = BertConfig.from_pretrained('bert-base-uncased') model_config.output_hidden_states = True bert = BertModel.from_pretrained('bert-base-uncased', config=model_config) else: print(f'load {bert_name}') model_config = AutoConfig.from_pretrained(bert_name) model_config.output_hidden_states = True bert = AutoModel.from_pretrained(bert_name, config=model_config) return bert class LightXML(nn.Module): def __init__(self, n_labels, group_y=None, bert='bert-base', feature_layers=5, dropout=0.5, update_count=1, candidates_topk=10, use_swa=True, swa_warmup_epoch=10, swa_update_step=200, hidden_dim=300): super(LightXML, self).__init__() self.use_swa = use_swa self.swa_warmup_epoch = swa_warmup_epoch self.swa_update_step = swa_update_step self.swa_state = {} self.update_count = update_count self.candidates_topk = candidates_topk print('swa', self.use_swa, self.swa_warmup_epoch, self.swa_update_step, self.swa_state) print('update_count', self.update_count) self.bert_name, self.bert = bert, get_bert(bert) self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout) self.group_y = group_y if self.group_y is not None: self.group_y_labels = group_y.shape[0] print('hidden dim:', hidden_dim) print('label goup numbers:', self.group_y_labels) self.l0 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, self.group_y_labels) # hidden bottle layer self.l1 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, hidden_dim) self.embed = nn.Embedding(n_labels, hidden_dim) nn.init.xavier_uniform_(self.embed.weight) else: self.l0 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, n_labels) def get_candidates(self, group_logits, group_gd=None): logits = torch.sigmoid(group_logits.detach()) if group_gd is not None: logits += group_gd k = min(self.candidates_topk, logits.shape[1]) if k <= 0: raise ValueError("No group labels available for candidate selection") scores, indices = torch.topk(logits, k=k) scores, indices = scores.cpu().detach().numpy(), indices.cpu().detach().numpy() candidates, candidates_scores = [], [] for index, score in zip(indices, scores): candidates.append(self.group_y[index]) candidates_scores.append([np.full(c.shape, s) for c, s in zip(candidates[-1], score)]) candidates[-1] = np.concatenate(candidates[-1]) candidates_scores[-1] = np.concatenate(candidates_scores[-1]) max_candidates = max([i.shape[0] for i in candidates]) candidates = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates]) candidates_scores = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates_scores]) return indices, candidates, candidates_scores def forward(self, input_ids, attention_mask, token_type_ids, labels=None, group_labels=None, candidates=None): is_training = labels is not None outs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids )[-1] out = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers+1)], dim=-1) out = self.drop_out(out) group_logits = self.l0(out) if self.group_y is None: logits = group_logits if is_training: loss_fn = torch.nn.BCEWithLogitsLoss() loss = loss_fn(logits, labels) return logits, loss else: return logits if is_training: l = labels.to(dtype=torch.bool) target_candidates = torch.masked_select(candidates, l).detach().cpu() target_candidates_num = l.sum(dim=1).detach().cpu() groups, candidates, group_candidates_scores = self.get_candidates(group_logits, group_gd=group_labels if is_training else None) if is_training: bs = 0 new_labels = [] for i, n in enumerate(target_candidates_num.numpy()): be = bs + n c = set(target_candidates[bs: be].numpy()) c2 = candidates[i] new_labels.append(torch.tensor([1.0 if i in c else 0.0 for i in c2 ])) if len(c) != new_labels[-1].sum(): s_c2 = set(c2) for cc in list(c): if cc in s_c2: continue for j in range(new_labels[-1].shape[0]): if new_labels[-1][j].item() != 1: c2[j] = cc new_labels[-1][j] = 1.0 break bs = be labels = torch.stack(new_labels).cuda() candidates, group_candidates_scores = torch.LongTensor(candidates).cuda(), torch.Tensor(group_candidates_scores).cuda() emb = self.l1(out) embed_weights = self.embed(candidates) # N, sampled_size, H emb = emb.unsqueeze(-1) logits = torch.bmm(embed_weights, emb).squeeze(-1) if is_training: loss_fn = torch.nn.BCEWithLogitsLoss() loss = loss_fn(logits, labels) + loss_fn(group_logits, group_labels) return logits, loss else: candidates_scores = torch.sigmoid(logits) candidates_scores = candidates_scores * group_candidates_scores return group_logits, candidates, candidates_scores def get_accuracy(self, candidates, logits, labels): if candidates is not None: candidates = candidates.detach().cpu() k = min(10, logits.shape[1]) if k <= 0: return 0, 0, 0, 0 scores, indices = torch.topk(logits.detach().cpu(), k=k) acc1, acc3, acc5, total = 0, 0, 0, 0 for i, l in enumerate(labels): if candidates is not None and l.shape[0] == candidates.shape[1]: positive_idx = np.nonzero(l)[0] target_labels = set(candidates[i][positive_idx].numpy()) else: target_labels = set(np.nonzero(l)[0]) if candidates is not None: pred_labels = candidates[i][indices[i]].numpy() else: pred_labels = indices[i, :5].numpy() acc1 += len(set([pred_labels[0]]) & target_labels) acc3 += len(set(pred_labels[:3]) & target_labels) acc5 += len(set(pred_labels[:5]) & target_labels) total += 1 return total, acc1, acc3, acc5 def one_epoch(self, epoch, dataloader, optimizer, mode='train', eval_loader=None, eval_step=20000, log=None, log_interval=50, use_tqdm=False): total_steps = len(dataloader) bar = tqdm.tqdm(total=total_steps) if use_tqdm else None p1, p3, p5 = 0, 0, 0 g_p1, g_p3, g_p5 = 0, 0, 0 total, acc1, acc3, acc5 = 0, 0, 0, 0 g_acc1, g_acc3, g_acc5 = 0, 0, 0 train_loss = 0 if mode == 'train': self.train() self.zero_grad() else: self.eval() if self.use_swa and epoch == self.swa_warmup_epoch and mode == 'train': self.swa_init() if self.use_swa and mode == 'eval': self.swa_swap_params() pred_scores, pred_labels = [], [] if bar: bar.set_description(f'{mode}-{epoch}') device = next(self.parameters()).device def _gpu_mem(): if device.type != "cuda": return "gpu_mem=NA" alloc = torch.cuda.memory_allocated(device) // (1024 * 1024) total = torch.cuda.get_device_properties(device).total_memory // (1024 * 1024) return f"gpu_mem={alloc}MB/{total}MB" with torch.set_grad_enabled(mode == 'train'): for step, data in enumerate(dataloader): batch = tuple(t for t in data) inputs = {'input_ids': batch[0].to(device), 'attention_mask': batch[1].to(device), 'token_type_ids': batch[2].to(device)} if mode == 'train': inputs['labels'] = batch[3].to(device) if self.group_y is not None: inputs['group_labels'] = batch[4].to(device) inputs['candidates'] = batch[5].to(device) outputs = self(**inputs) if bar: bar.update(1) if mode == 'train': loss = outputs[1] loss_group = outputs[2] loss_rank = outputs[3] loss_group = outputs[2] loss_rank = outputs[3] loss_group = outputs[2] loss_rank = outputs[3] loss_group = outputs[2] loss_rank = outputs[3] loss /= self.update_count train_loss += loss.item() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if (step + 1) % self.update_count == 0: optimizer.step() self.zero_grad() if step % eval_step == 0 and eval_loader is not None and step != 0: results = self.one_epoch(epoch, eval_loader, optimizer, mode='eval') p1, p3, p5 = results[3:6] g_p1, g_p3, g_p5 = results[:3] if self.group_y is not None: log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}' f' {g_p1:.4f}, {g_p3:.4f}, {g_p5:.4f}') else: log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}') if self.use_swa and step % self.swa_update_step == 0: self.swa_step() if bar: bar.set_postfix(loss=loss.item()) if log_interval and step % log_interval == 0: print(f"[train] epoch={epoch} step={step+1}/{len(dataloader)} " f"loss={loss.item():.6f} loss_group={loss_group.item():.6f} " f"loss_rank={loss_rank.item():.6f} {_gpu_mem()}") elif self.group_y is None: logits = outputs if mode == 'eval': labels = batch[3] _total, _acc1, _acc3, _acc5 = self.get_accuracy(None, logits, labels.cpu().numpy()) total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5 p1 = acc1 / total p3 = acc3 / total / 3 p5 = acc5 / total / 5 bar.set_postfix(p1=p1, p3=p3, p5=p5) elif mode == 'test': pred_scores.append(logits.detach().cpu()) else: group_logits, candidates, logits = outputs if mode == 'eval': labels = batch[3] group_labels = batch[4] _total, _acc1, _acc3, _acc5 = self.get_accuracy(candidates, logits, labels.cpu().numpy()) total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5 p1 = acc1 / total p3 = acc3 / total / 3 p5 = acc5 / total / 5 _, _g_acc1, _g_acc3, _g_acc5 = self.get_accuracy(None, group_logits, group_labels.cpu().numpy()) g_acc1 += _g_acc1; g_acc3 += _g_acc3; g_acc5 += _g_acc5 g_p1 = g_acc1 / total g_p3 = g_acc3 / total / 3 g_p5 = g_acc5 / total / 5 if bar: bar.set_postfix(p1=p1, p3=p3, p5=p5, g_p1=g_p1, g_p3=g_p3, g_p5=g_p5) if log_interval and step % log_interval == 0: print(f"[eval] epoch={epoch} step={step+1}/{len(dataloader)} " f"g_p1={g_p1:.4f} g_p3={g_p3:.4f} g_p5={g_p5:.4f} " f"p1={p1:.4f} p3={p3:.4f} p5={p5:.4f} {_gpu_mem()}") elif mode == 'test': k = min(100, logits.shape[1]) _scores, _indices = torch.topk(logits.detach().cpu(), k=k) _labels = torch.stack([candidates[i][_indices[i]] for i in range(_indices.shape[0])], dim=0) pred_scores.append(_scores.cpu()) pred_labels.append(_labels.cpu()) if mode == 'train' and total_steps > 0 and total_steps % self.update_count != 0: optimizer.step() self.zero_grad() if self.use_swa and mode == 'eval': self.swa_swap_params() if bar: bar.close() if mode == 'eval': return g_p1, g_p3, g_p5, p1, p3, p5 elif mode == 'test': return torch.cat(pred_scores, dim=0).numpy(), torch.cat(pred_labels, dim=0).numpy() if len(pred_labels) != 0 else None elif mode == 'train': return train_loss def save_model(self, path): self.swa_swap_params() torch.save(self.state_dict(), path) self.swa_swap_params() def swa_init(self): self.swa_state = {'models_num': 1} for n, p in self.named_parameters(): self.swa_state[n] = p.data.cpu().clone().detach() def swa_step(self): if 'models_num' not in self.swa_state: return self.swa_state['models_num'] += 1 beta = 1.0 / self.swa_state['models_num'] with torch.no_grad(): for n, p in self.named_parameters(): self.swa_state[n].mul_(1.0 - beta).add_(beta, p.data.cpu()) def swa_swap_params(self): if 'models_num' not in self.swa_state: return for n, p in self.named_parameters(): self.swa_state[n], p.data = self.swa_state[n].cpu(), p.data.cpu() self.swa_state[n], p.data = p.data.cpu(), self.swa_state[n].cuda() def get_fast_tokenizer(self): if 'roberta' in self.bert_name: tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', do_lower_case=True) elif 'xlnet' in self.bert_name: tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') elif self.bert_name in ['bert-base', 'bert-base-uncased']: tokenizer = BertWordPieceTokenizer( "data/.bert-base-uncased-vocab.txt", lowercase=True) else: tokenizer = AutoTokenizer.from_pretrained(self.bert_name, use_fast=True) return tokenizer def get_tokenizer(self): if 'roberta' in self.bert_name: print('load roberta-base tokenizer') try: tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True, local_files_only=True) except Exception: tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True) elif 'xlnet' in self.bert_name: print('load xlnet-base-cased tokenizer') try: tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', local_files_only=True) except Exception: tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') elif self.bert_name in ['bert-base', 'bert-base-uncased']: print('load bert-base-uncased tokenizer') try: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, local_files_only=True) except Exception: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) elif self.bert_name == 'bert-base-chinese': print('load bert-base-chinese tokenizer') try: tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True) except Exception: tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') else: print(f'load {self.bert_name} tokenizer') try: tokenizer = AutoTokenizer.from_pretrained(self.bert_name, local_files_only=True) except Exception: tokenizer = AutoTokenizer.from_pretrained(self.bert_name) return tokenizer class LightXMLMultiZone(nn.Module): def __init__(self, n_labels, bert='bert-base-chinese', feature_layers=5, dropout=0.5, update_count=1, fusion='concat', hidden_dim=300): super(LightXMLMultiZone, self).__init__() self.update_count = update_count self.bert_name, self.bert = bert, get_bert(bert) self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout) self.fusion = fusion self.hidden_size = self.bert.config.hidden_size * self.feature_layers if self.fusion == 'concat': self.fusion_proj = nn.Linear(self.hidden_size * 6, self.hidden_size) elif self.fusion == 'weighted': self.fusion_weight = nn.Parameter(torch.ones(6)) elif self.fusion == 'attention': self.attn = nn.Linear(self.hidden_size, 1) else: raise ValueError(f"Unsupported fusion: {self.fusion}") self.l0 = nn.Linear(self.hidden_size, n_labels) def encode_zones(self, input_ids, attention_mask, token_type_ids): # input shape: (B, Z, L) bsz, zones, seqlen = input_ids.shape flat_ids = input_ids.view(bsz * zones, seqlen) flat_mask = attention_mask.view(bsz * zones, seqlen) flat_type = token_type_ids.view(bsz * zones, seqlen) outs = self.bert( flat_ids, attention_mask=flat_mask, token_type_ids=flat_type )[-1] cls = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers + 1)], dim=-1) cls = cls.view(bsz, zones, -1) return cls def fuse(self, zone_reps, return_weights=False): if self.fusion == 'concat': fused = zone_reps.reshape(zone_reps.shape[0], -1) fused = self.fusion_proj(fused) if return_weights: return fused, None return fused if self.fusion == 'weighted': weights = torch.softmax(self.fusion_weight, dim=0) fused = (zone_reps * weights.view(1, -1, 1)).sum(dim=1) if return_weights: return fused, weights return fused if self.fusion == 'attention': attn_scores = self.attn(zone_reps).squeeze(-1) attn_weights = torch.softmax(attn_scores, dim=1) fused = (zone_reps * attn_weights.unsqueeze(-1)).sum(dim=1) if return_weights: return fused, attn_weights return fused raise ValueError(f"Unsupported fusion: {self.fusion}") class LightXMLMultiZoneGroupY(nn.Module): def __init__(self, n_labels, group_y, bert='bert-base-chinese', feature_layers=5, dropout=0.5, update_count=1, fusion='concat', hidden_dim=300, candidates_topk=20, teacher_force_lambda=0.3): super(LightXMLMultiZoneGroupY, self).__init__() self.use_swa = True self.swa_warmup_epoch = 10 self.swa_update_step = 200 self.swa_state = {} self.update_count = update_count self.candidates_topk = candidates_topk self.teacher_force_lambda = teacher_force_lambda self.bert_name, self.bert = bert, get_bert(bert) self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout) self.fusion = fusion self.group_y = group_y self.hidden_size = self.bert.config.hidden_size * self.feature_layers if self.fusion == 'concat': self.fusion_proj = nn.Linear(self.hidden_size * 6, self.hidden_size) elif self.fusion == 'weighted': self.fusion_weight = nn.Parameter(torch.ones(6)) elif self.fusion == 'attention': self.attn = nn.Linear(self.hidden_size, 1) else: raise ValueError(f"Unsupported fusion: {self.fusion}") self.group_y_labels = group_y.shape[0] self.l0 = nn.Linear(self.hidden_size, self.group_y_labels) self.l1 = nn.Linear(self.hidden_size, hidden_dim) self.embed = nn.Embedding(n_labels, hidden_dim) nn.init.xavier_uniform_(self.embed.weight) def encode_zones(self, input_ids, attention_mask, token_type_ids): bsz, zones, seqlen = input_ids.shape flat_ids = input_ids.view(bsz * zones, seqlen) flat_mask = attention_mask.view(bsz * zones, seqlen) flat_type = token_type_ids.view(bsz * zones, seqlen) outs = self.bert( flat_ids, attention_mask=flat_mask, token_type_ids=flat_type )[-1] cls = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers + 1)], dim=-1) return cls.view(bsz, zones, -1) def fuse(self, zone_reps): if self.fusion == 'concat': fused = zone_reps.reshape(zone_reps.shape[0], -1) return self.fusion_proj(fused) if self.fusion == 'weighted': weights = torch.softmax(self.fusion_weight, dim=0) return (zone_reps * weights.view(1, -1, 1)).sum(dim=1) if self.fusion == 'attention': attn_scores = self.attn(zone_reps).squeeze(-1) attn_weights = torch.softmax(attn_scores, dim=1) return (zone_reps * attn_weights.unsqueeze(-1)).sum(dim=1) raise ValueError(f"Unsupported fusion: {self.fusion}") def get_candidates(self, group_logits, group_gd=None): logits = torch.sigmoid(group_logits.detach()) if group_gd is not None and self.teacher_force_lambda > 0: logits += self.teacher_force_lambda * group_gd k = min(self.candidates_topk, logits.shape[1]) if k <= 0: raise ValueError("No group labels available for candidate selection") scores, indices = torch.topk(logits, k=k) scores, indices = scores.cpu().detach().numpy(), indices.cpu().detach().numpy() candidates, candidates_scores = [], [] for index, score in zip(indices, scores): candidates.append(self.group_y[index]) candidates_scores.append([np.full(c.shape, s) for c, s in zip(candidates[-1], score)]) candidates[-1] = np.concatenate(candidates[-1]) candidates_scores[-1] = np.concatenate(candidates_scores[-1]) max_candidates = max([i.shape[0] for i in candidates]) candidates = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates]) candidates_scores = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates_scores]) return indices, candidates, candidates_scores def forward(self, input_ids, attention_mask, token_type_ids, labels=None, group_labels=None, candidates=None): is_training = labels is not None zone_reps = self.encode_zones(input_ids, attention_mask, token_type_ids) out = self.drop_out(self.fuse(zone_reps)) group_logits = self.l0(out) if is_training: l = labels.to(dtype=torch.bool) target_candidates = torch.masked_select(candidates, l).detach().cpu() target_candidates_num = l.sum(dim=1).detach().cpu() groups, candidates, group_candidates_scores = self.get_candidates( group_logits, group_gd=group_labels if is_training else None ) if is_training: bs = 0 new_labels = [] for i, n in enumerate(target_candidates_num.numpy()): be = bs + n c = set(target_candidates[bs: be].numpy()) c2 = candidates[i] new_labels.append(torch.tensor([1.0 if i in c else 0.0 for i in c2])) if len(c) != new_labels[-1].sum(): s_c2 = set(c2) for cc in list(c): if cc in s_c2: continue for j in range(new_labels[-1].shape[0]): if new_labels[-1][j].item() != 1: c2[j] = cc new_labels[-1][j] = 1.0 break bs = be labels = torch.stack(new_labels).to(input_ids.device) candidates = torch.LongTensor(candidates).to(input_ids.device) group_candidates_scores = torch.Tensor(group_candidates_scores).to(input_ids.device) emb = self.l1(out) embed_weights = self.embed(candidates) emb = emb.unsqueeze(-1) logits = torch.bmm(embed_weights, emb).squeeze(-1) if is_training: loss_fn = torch.nn.BCEWithLogitsLoss() loss_rank = loss_fn(logits, labels) loss_group = loss_fn(group_logits, group_labels) loss = loss_rank + loss_group return logits, loss, loss_group.detach(), loss_rank.detach() else: candidates_scores = torch.sigmoid(logits) candidates_scores = candidates_scores * group_candidates_scores return group_logits, candidates, candidates_scores def get_accuracy(self, candidates, logits, labels): if candidates is not None: candidates = candidates.detach().cpu() k = min(10, logits.shape[1]) if k <= 0: return 0, 0, 0, 0 scores, indices = torch.topk(logits.detach().cpu(), k=k) acc1, acc3, acc5, total = 0, 0, 0, 0 for i, l in enumerate(labels): if candidates is not None and l.shape[0] == candidates.shape[1]: positive_idx = np.nonzero(l)[0] target_labels = set(candidates[i][positive_idx].numpy()) else: target_labels = set(np.nonzero(l)[0]) if candidates is not None: pred_labels = candidates[i][indices[i]].numpy() else: pred_labels = indices[i, :5].numpy() acc1 += len(set([pred_labels[0]]) & target_labels) acc3 += len(set(pred_labels[:3]) & target_labels) acc5 += len(set(pred_labels[:5]) & target_labels) total += 1 return total, acc1, acc3, acc5 def one_epoch(self, epoch, dataloader, optimizer, mode='train', eval_loader=None, eval_step=20000, log=None, log_interval=50, use_tqdm=False): total_steps = len(dataloader) bar = tqdm.tqdm(total=total_steps) if use_tqdm else None p1, p3, p5 = 0, 0, 0 g_p1, g_p3, g_p5 = 0, 0, 0 total, acc1, acc3, acc5 = 0, 0, 0, 0 g_acc1, g_acc3, g_acc5 = 0, 0, 0 cand_hit, cand_total = 0, 0 train_loss = 0 if mode == 'train': self.train() self.zero_grad() else: self.eval() if self.use_swa and epoch == self.swa_warmup_epoch and mode == 'train': self.swa_init() if self.use_swa and mode == 'eval': self.swa_swap_params() pred_scores, pred_labels = [], [] if bar: bar.set_description(f'{mode}-{epoch}') device = next(self.parameters()).device def _gpu_mem(): if device.type != "cuda": return "gpu_mem=NA" alloc = torch.cuda.memory_allocated(device) // (1024 * 1024) total = torch.cuda.get_device_properties(device).total_memory // (1024 * 1024) return f"gpu_mem={alloc}MB/{total}MB" with torch.set_grad_enabled(mode == 'train'): for step, data in enumerate(dataloader): batch = tuple(t for t in data) inputs = {'input_ids': batch[0].to(device), 'attention_mask': batch[1].to(device), 'token_type_ids': batch[2].to(device)} if mode == 'train': inputs['labels'] = batch[3].to(device) inputs['group_labels'] = batch[4].to(device) inputs['candidates'] = batch[5].to(device) outputs = self(**inputs) if bar: bar.update(1) if mode == 'train': loss = outputs[1] loss_group = outputs[2] loss_rank = outputs[3] loss /= self.update_count train_loss += loss.item() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if (step + 1) % self.update_count == 0: optimizer.step() self.zero_grad() if step % eval_step == 0 and eval_loader is not None and step != 0: results = self.one_epoch(epoch, eval_loader, optimizer, mode='eval') p1, p3, p5 = results[3:6] g_p1, g_p3, g_p5 = results[:3] log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}' f' {g_p1:.4f}, {g_p3:.4f}, {g_p5:.4f}') if self.use_swa and step % self.swa_update_step == 0: self.swa_step() if bar: bar.set_postfix(loss=loss.item()) if log_interval and step % log_interval == 0: print(f"[train] epoch={epoch} step={step+1}/{len(dataloader)} " f"loss={loss.item():.6f} loss_group={loss_group.item():.6f} " f"loss_rank={loss_rank.item():.6f} {_gpu_mem()}") else: group_logits, candidates, logits = outputs if mode == 'eval': labels = batch[3] group_labels = batch[4] _total, _acc1, _acc3, _acc5 = self.get_accuracy(candidates, logits, labels.cpu().numpy()) total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5 p1 = acc1 / total p3 = acc3 / total / 3 p5 = acc5 / total / 5 cand_np = candidates.detach().cpu().numpy() labels_np = labels.cpu().numpy() for i, l in enumerate(labels_np): pos = np.nonzero(l)[0] if pos.shape[0] == 0: continue cand_set = set(cand_np[i].tolist()) cand_hit += sum(1 for t in pos if t in cand_set) cand_total += pos.shape[0] candidate_recall = cand_hit / cand_total if cand_total > 0 else 0.0 _, _g_acc1, _g_acc3, _g_acc5 = self.get_accuracy(None, group_logits, group_labels.cpu().numpy()) g_acc1 += _g_acc1; g_acc3 += _g_acc3; g_acc5 += _g_acc5 g_p1 = g_acc1 / total g_p3 = g_acc3 / total / 3 g_p5 = g_acc5 / total / 5 if bar: bar.set_postfix(p1=p1, p3=p3, p5=p5, g_p1=g_p1, g_p3=g_p3, g_p5=g_p5, cand_recall=candidate_recall) if log_interval and step % log_interval == 0: print(f"[eval] epoch={epoch} step={step+1}/{len(dataloader)} " f"g_p1={g_p1:.4f} g_p3={g_p3:.4f} g_p5={g_p5:.4f} " f"p1={p1:.4f} p3={p3:.4f} p5={p5:.4f} " f"cand_recall={candidate_recall:.4f} {_gpu_mem()}") elif mode == 'test': k = min(100, logits.shape[1]) _scores, _indices = torch.topk(logits.detach().cpu(), k=k) _labels = torch.stack([candidates[i][_indices[i]] for i in range(_indices.shape[0])], dim=0) pred_scores.append(_scores.cpu()) pred_labels.append(_labels.cpu()) if mode == 'train' and total_steps > 0 and total_steps % self.update_count != 0: optimizer.step() self.zero_grad() if self.use_swa and mode == 'eval': self.swa_swap_params() if bar: bar.close() if mode == 'eval': candidate_recall = cand_hit / cand_total if cand_total > 0 else 0.0 return g_p1, g_p3, g_p5, p1, p3, p5, candidate_recall elif mode == 'test': return torch.cat(pred_scores, dim=0).numpy(), torch.cat(pred_labels, dim=0).numpy() if len(pred_labels) != 0 else None elif mode == 'train': return train_loss def swa_init(self): self.swa_state = {'models_num': 1} for n, p in self.named_parameters(): self.swa_state[n] = p.data.detach().cpu().clone() def swa_step(self): if 'models_num' not in self.swa_state: return self.swa_state['models_num'] += 1 beta = 1.0 / self.swa_state['models_num'] with torch.no_grad(): for n, p in self.named_parameters(): self.swa_state[n].mul_(1.0 - beta).add_(beta, p.data.detach().cpu()) def swa_swap_params(self): if 'models_num' not in self.swa_state: return for n, p in self.named_parameters(): buf = self.swa_state[n] self.swa_state[n] = p.data.detach().cpu() p.data = buf.to(p.data.device) def get_tokenizer(self): if 'roberta' in self.bert_name: try: return RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True, local_files_only=True) except Exception: return RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True) if 'xlnet' in self.bert_name: try: return XLNetTokenizer.from_pretrained('xlnet-base-cased', local_files_only=True) except Exception: return XLNetTokenizer.from_pretrained('xlnet-base-cased') if self.bert_name in ['bert-base', 'bert-base-uncased']: try: return BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, local_files_only=True) except Exception: return BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) if self.bert_name == 'bert-base-chinese': try: return BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True) except Exception: return BertTokenizer.from_pretrained('bert-base-chinese') try: return AutoTokenizer.from_pretrained(self.bert_name, local_files_only=True) except Exception: return AutoTokenizer.from_pretrained(self.bert_name)