WilliamCHN's picture
Upload standalone Inference of LightXML package
f0bda3b verified
import tqdm
import time
import cProfile
import numpy as np
from contextlib import contextmanager
try:
from apex import amp
except Exception: # pragma: no cover - fallback when apex is unavailable
class _DummyAmp:
def initialize(self, model, optimizer, opt_level="O1"):
return model, optimizer
@contextmanager
def scale_loss(self, loss, optimizer):
yield loss
amp = _DummyAmp()
import torch
from torch import nn
from transformers import BertTokenizer, BertConfig, BertModel
from transformers import RobertaModel, RobertaConfig, RobertaTokenizer
from transformers import XLNetTokenizer, XLNetModel, XLNetConfig
from transformers import AutoConfig, AutoModel, AutoTokenizer
from tokenizers import BertWordPieceTokenizer
from transformers import RobertaTokenizerFast
def get_bert(bert_name):
if 'roberta' in bert_name:
print('load roberta-base')
model_config = RobertaConfig.from_pretrained('roberta-base')
model_config.output_hidden_states = True
bert = RobertaModel.from_pretrained('roberta-base', config=model_config)
elif 'xlnet' in bert_name:
print('load xlnet-base-cased')
model_config = XLNetConfig.from_pretrained('xlnet-base-cased')
model_config.output_hidden_states = True
bert = XLNetModel.from_pretrained('xlnet-base-cased', config=model_config)
else:
if bert_name in ['bert-base', 'bert-base-uncased']:
print('load bert-base-uncased')
model_config = BertConfig.from_pretrained('bert-base-uncased')
model_config.output_hidden_states = True
bert = BertModel.from_pretrained('bert-base-uncased', config=model_config)
else:
print(f'load {bert_name}')
model_config = AutoConfig.from_pretrained(bert_name)
model_config.output_hidden_states = True
bert = AutoModel.from_pretrained(bert_name, config=model_config)
return bert
class LightXML(nn.Module):
def __init__(self, n_labels, group_y=None, bert='bert-base', feature_layers=5, dropout=0.5, update_count=1,
candidates_topk=10,
use_swa=True, swa_warmup_epoch=10, swa_update_step=200, hidden_dim=300):
super(LightXML, self).__init__()
self.use_swa = use_swa
self.swa_warmup_epoch = swa_warmup_epoch
self.swa_update_step = swa_update_step
self.swa_state = {}
self.update_count = update_count
self.candidates_topk = candidates_topk
print('swa', self.use_swa, self.swa_warmup_epoch, self.swa_update_step, self.swa_state)
print('update_count', self.update_count)
self.bert_name, self.bert = bert, get_bert(bert)
self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout)
self.group_y = group_y
if self.group_y is not None:
self.group_y_labels = group_y.shape[0]
print('hidden dim:', hidden_dim)
print('label goup numbers:', self.group_y_labels)
self.l0 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, self.group_y_labels)
# hidden bottle layer
self.l1 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, hidden_dim)
self.embed = nn.Embedding(n_labels, hidden_dim)
nn.init.xavier_uniform_(self.embed.weight)
else:
self.l0 = nn.Linear(self.feature_layers*self.bert.config.hidden_size, n_labels)
def get_candidates(self, group_logits, group_gd=None):
logits = torch.sigmoid(group_logits.detach())
if group_gd is not None:
logits += group_gd
k = min(self.candidates_topk, logits.shape[1])
if k <= 0:
raise ValueError("No group labels available for candidate selection")
scores, indices = torch.topk(logits, k=k)
scores, indices = scores.cpu().detach().numpy(), indices.cpu().detach().numpy()
candidates, candidates_scores = [], []
for index, score in zip(indices, scores):
candidates.append(self.group_y[index])
candidates_scores.append([np.full(c.shape, s) for c, s in zip(candidates[-1], score)])
candidates[-1] = np.concatenate(candidates[-1])
candidates_scores[-1] = np.concatenate(candidates_scores[-1])
max_candidates = max([i.shape[0] for i in candidates])
candidates = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates])
candidates_scores = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates_scores])
return indices, candidates, candidates_scores
def forward(self, input_ids, attention_mask, token_type_ids,
labels=None, group_labels=None, candidates=None):
is_training = labels is not None
outs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)[-1]
out = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers+1)], dim=-1)
out = self.drop_out(out)
group_logits = self.l0(out)
if self.group_y is None:
logits = group_logits
if is_training:
loss_fn = torch.nn.BCEWithLogitsLoss()
loss = loss_fn(logits, labels)
return logits, loss
else:
return logits
if is_training:
l = labels.to(dtype=torch.bool)
target_candidates = torch.masked_select(candidates, l).detach().cpu()
target_candidates_num = l.sum(dim=1).detach().cpu()
groups, candidates, group_candidates_scores = self.get_candidates(group_logits,
group_gd=group_labels if is_training else None)
if is_training:
bs = 0
new_labels = []
for i, n in enumerate(target_candidates_num.numpy()):
be = bs + n
c = set(target_candidates[bs: be].numpy())
c2 = candidates[i]
new_labels.append(torch.tensor([1.0 if i in c else 0.0 for i in c2 ]))
if len(c) != new_labels[-1].sum():
s_c2 = set(c2)
for cc in list(c):
if cc in s_c2:
continue
for j in range(new_labels[-1].shape[0]):
if new_labels[-1][j].item() != 1:
c2[j] = cc
new_labels[-1][j] = 1.0
break
bs = be
labels = torch.stack(new_labels).cuda()
candidates, group_candidates_scores = torch.LongTensor(candidates).cuda(), torch.Tensor(group_candidates_scores).cuda()
emb = self.l1(out)
embed_weights = self.embed(candidates) # N, sampled_size, H
emb = emb.unsqueeze(-1)
logits = torch.bmm(embed_weights, emb).squeeze(-1)
if is_training:
loss_fn = torch.nn.BCEWithLogitsLoss()
loss = loss_fn(logits, labels) + loss_fn(group_logits, group_labels)
return logits, loss
else:
candidates_scores = torch.sigmoid(logits)
candidates_scores = candidates_scores * group_candidates_scores
return group_logits, candidates, candidates_scores
def get_accuracy(self, candidates, logits, labels):
if candidates is not None:
candidates = candidates.detach().cpu()
k = min(10, logits.shape[1])
if k <= 0:
return 0, 0, 0, 0
scores, indices = torch.topk(logits.detach().cpu(), k=k)
acc1, acc3, acc5, total = 0, 0, 0, 0
for i, l in enumerate(labels):
if candidates is not None and l.shape[0] == candidates.shape[1]:
positive_idx = np.nonzero(l)[0]
target_labels = set(candidates[i][positive_idx].numpy())
else:
target_labels = set(np.nonzero(l)[0])
if candidates is not None:
pred_labels = candidates[i][indices[i]].numpy()
else:
pred_labels = indices[i, :5].numpy()
acc1 += len(set([pred_labels[0]]) & target_labels)
acc3 += len(set(pred_labels[:3]) & target_labels)
acc5 += len(set(pred_labels[:5]) & target_labels)
total += 1
return total, acc1, acc3, acc5
def one_epoch(self, epoch, dataloader, optimizer,
mode='train', eval_loader=None, eval_step=20000, log=None, log_interval=50, use_tqdm=False):
total_steps = len(dataloader)
bar = tqdm.tqdm(total=total_steps) if use_tqdm else None
p1, p3, p5 = 0, 0, 0
g_p1, g_p3, g_p5 = 0, 0, 0
total, acc1, acc3, acc5 = 0, 0, 0, 0
g_acc1, g_acc3, g_acc5 = 0, 0, 0
train_loss = 0
if mode == 'train':
self.train()
self.zero_grad()
else:
self.eval()
if self.use_swa and epoch == self.swa_warmup_epoch and mode == 'train':
self.swa_init()
if self.use_swa and mode == 'eval':
self.swa_swap_params()
pred_scores, pred_labels = [], []
if bar:
bar.set_description(f'{mode}-{epoch}')
device = next(self.parameters()).device
def _gpu_mem():
if device.type != "cuda":
return "gpu_mem=NA"
alloc = torch.cuda.memory_allocated(device) // (1024 * 1024)
total = torch.cuda.get_device_properties(device).total_memory // (1024 * 1024)
return f"gpu_mem={alloc}MB/{total}MB"
with torch.set_grad_enabled(mode == 'train'):
for step, data in enumerate(dataloader):
batch = tuple(t for t in data)
inputs = {'input_ids': batch[0].to(device),
'attention_mask': batch[1].to(device),
'token_type_ids': batch[2].to(device)}
if mode == 'train':
inputs['labels'] = batch[3].to(device)
if self.group_y is not None:
inputs['group_labels'] = batch[4].to(device)
inputs['candidates'] = batch[5].to(device)
outputs = self(**inputs)
if bar:
bar.update(1)
if mode == 'train':
loss = outputs[1]
loss_group = outputs[2]
loss_rank = outputs[3]
loss_group = outputs[2]
loss_rank = outputs[3]
loss_group = outputs[2]
loss_rank = outputs[3]
loss_group = outputs[2]
loss_rank = outputs[3]
loss /= self.update_count
train_loss += loss.item()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if (step + 1) % self.update_count == 0:
optimizer.step()
self.zero_grad()
if step % eval_step == 0 and eval_loader is not None and step != 0:
results = self.one_epoch(epoch, eval_loader, optimizer, mode='eval')
p1, p3, p5 = results[3:6]
g_p1, g_p3, g_p5 = results[:3]
if self.group_y is not None:
log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}'
f' {g_p1:.4f}, {g_p3:.4f}, {g_p5:.4f}')
else:
log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}')
if self.use_swa and step % self.swa_update_step == 0:
self.swa_step()
if bar:
bar.set_postfix(loss=loss.item())
if log_interval and step % log_interval == 0:
print(f"[train] epoch={epoch} step={step+1}/{len(dataloader)} "
f"loss={loss.item():.6f} loss_group={loss_group.item():.6f} "
f"loss_rank={loss_rank.item():.6f} {_gpu_mem()}")
elif self.group_y is None:
logits = outputs
if mode == 'eval':
labels = batch[3]
_total, _acc1, _acc3, _acc5 = self.get_accuracy(None, logits, labels.cpu().numpy())
total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5
p1 = acc1 / total
p3 = acc3 / total / 3
p5 = acc5 / total / 5
bar.set_postfix(p1=p1, p3=p3, p5=p5)
elif mode == 'test':
pred_scores.append(logits.detach().cpu())
else:
group_logits, candidates, logits = outputs
if mode == 'eval':
labels = batch[3]
group_labels = batch[4]
_total, _acc1, _acc3, _acc5 = self.get_accuracy(candidates, logits, labels.cpu().numpy())
total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5
p1 = acc1 / total
p3 = acc3 / total / 3
p5 = acc5 / total / 5
_, _g_acc1, _g_acc3, _g_acc5 = self.get_accuracy(None, group_logits, group_labels.cpu().numpy())
g_acc1 += _g_acc1; g_acc3 += _g_acc3; g_acc5 += _g_acc5
g_p1 = g_acc1 / total
g_p3 = g_acc3 / total / 3
g_p5 = g_acc5 / total / 5
if bar:
bar.set_postfix(p1=p1, p3=p3, p5=p5, g_p1=g_p1, g_p3=g_p3, g_p5=g_p5)
if log_interval and step % log_interval == 0:
print(f"[eval] epoch={epoch} step={step+1}/{len(dataloader)} "
f"g_p1={g_p1:.4f} g_p3={g_p3:.4f} g_p5={g_p5:.4f} "
f"p1={p1:.4f} p3={p3:.4f} p5={p5:.4f} {_gpu_mem()}")
elif mode == 'test':
k = min(100, logits.shape[1])
_scores, _indices = torch.topk(logits.detach().cpu(), k=k)
_labels = torch.stack([candidates[i][_indices[i]] for i in range(_indices.shape[0])], dim=0)
pred_scores.append(_scores.cpu())
pred_labels.append(_labels.cpu())
if mode == 'train' and total_steps > 0 and total_steps % self.update_count != 0:
optimizer.step()
self.zero_grad()
if self.use_swa and mode == 'eval':
self.swa_swap_params()
if bar:
bar.close()
if mode == 'eval':
return g_p1, g_p3, g_p5, p1, p3, p5
elif mode == 'test':
return torch.cat(pred_scores, dim=0).numpy(), torch.cat(pred_labels, dim=0).numpy() if len(pred_labels) != 0 else None
elif mode == 'train':
return train_loss
def save_model(self, path):
self.swa_swap_params()
torch.save(self.state_dict(), path)
self.swa_swap_params()
def swa_init(self):
self.swa_state = {'models_num': 1}
for n, p in self.named_parameters():
self.swa_state[n] = p.data.cpu().clone().detach()
def swa_step(self):
if 'models_num' not in self.swa_state:
return
self.swa_state['models_num'] += 1
beta = 1.0 / self.swa_state['models_num']
with torch.no_grad():
for n, p in self.named_parameters():
self.swa_state[n].mul_(1.0 - beta).add_(beta, p.data.cpu())
def swa_swap_params(self):
if 'models_num' not in self.swa_state:
return
for n, p in self.named_parameters():
self.swa_state[n], p.data = self.swa_state[n].cpu(), p.data.cpu()
self.swa_state[n], p.data = p.data.cpu(), self.swa_state[n].cuda()
def get_fast_tokenizer(self):
if 'roberta' in self.bert_name:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', do_lower_case=True)
elif 'xlnet' in self.bert_name:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
elif self.bert_name in ['bert-base', 'bert-base-uncased']:
tokenizer = BertWordPieceTokenizer(
"data/.bert-base-uncased-vocab.txt",
lowercase=True)
else:
tokenizer = AutoTokenizer.from_pretrained(self.bert_name, use_fast=True)
return tokenizer
def get_tokenizer(self):
if 'roberta' in self.bert_name:
print('load roberta-base tokenizer')
try:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True, local_files_only=True)
except Exception:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
elif 'xlnet' in self.bert_name:
print('load xlnet-base-cased tokenizer')
try:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', local_files_only=True)
except Exception:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
elif self.bert_name in ['bert-base', 'bert-base-uncased']:
print('load bert-base-uncased tokenizer')
try:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, local_files_only=True)
except Exception:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
elif self.bert_name == 'bert-base-chinese':
print('load bert-base-chinese tokenizer')
try:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True)
except Exception:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
else:
print(f'load {self.bert_name} tokenizer')
try:
tokenizer = AutoTokenizer.from_pretrained(self.bert_name, local_files_only=True)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(self.bert_name)
return tokenizer
class LightXMLMultiZone(nn.Module):
def __init__(self, n_labels, bert='bert-base-chinese', feature_layers=5, dropout=0.5,
update_count=1, fusion='concat', hidden_dim=300):
super(LightXMLMultiZone, self).__init__()
self.update_count = update_count
self.bert_name, self.bert = bert, get_bert(bert)
self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout)
self.fusion = fusion
self.hidden_size = self.bert.config.hidden_size * self.feature_layers
if self.fusion == 'concat':
self.fusion_proj = nn.Linear(self.hidden_size * 6, self.hidden_size)
elif self.fusion == 'weighted':
self.fusion_weight = nn.Parameter(torch.ones(6))
elif self.fusion == 'attention':
self.attn = nn.Linear(self.hidden_size, 1)
else:
raise ValueError(f"Unsupported fusion: {self.fusion}")
self.l0 = nn.Linear(self.hidden_size, n_labels)
def encode_zones(self, input_ids, attention_mask, token_type_ids):
# input shape: (B, Z, L)
bsz, zones, seqlen = input_ids.shape
flat_ids = input_ids.view(bsz * zones, seqlen)
flat_mask = attention_mask.view(bsz * zones, seqlen)
flat_type = token_type_ids.view(bsz * zones, seqlen)
outs = self.bert(
flat_ids,
attention_mask=flat_mask,
token_type_ids=flat_type
)[-1]
cls = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers + 1)], dim=-1)
cls = cls.view(bsz, zones, -1)
return cls
def fuse(self, zone_reps, return_weights=False):
if self.fusion == 'concat':
fused = zone_reps.reshape(zone_reps.shape[0], -1)
fused = self.fusion_proj(fused)
if return_weights:
return fused, None
return fused
if self.fusion == 'weighted':
weights = torch.softmax(self.fusion_weight, dim=0)
fused = (zone_reps * weights.view(1, -1, 1)).sum(dim=1)
if return_weights:
return fused, weights
return fused
if self.fusion == 'attention':
attn_scores = self.attn(zone_reps).squeeze(-1)
attn_weights = torch.softmax(attn_scores, dim=1)
fused = (zone_reps * attn_weights.unsqueeze(-1)).sum(dim=1)
if return_weights:
return fused, attn_weights
return fused
raise ValueError(f"Unsupported fusion: {self.fusion}")
class LightXMLMultiZoneGroupY(nn.Module):
def __init__(self, n_labels, group_y, bert='bert-base-chinese', feature_layers=5, dropout=0.5,
update_count=1, fusion='concat', hidden_dim=300, candidates_topk=20,
teacher_force_lambda=0.3):
super(LightXMLMultiZoneGroupY, self).__init__()
self.use_swa = True
self.swa_warmup_epoch = 10
self.swa_update_step = 200
self.swa_state = {}
self.update_count = update_count
self.candidates_topk = candidates_topk
self.teacher_force_lambda = teacher_force_lambda
self.bert_name, self.bert = bert, get_bert(bert)
self.feature_layers, self.drop_out = feature_layers, nn.Dropout(dropout)
self.fusion = fusion
self.group_y = group_y
self.hidden_size = self.bert.config.hidden_size * self.feature_layers
if self.fusion == 'concat':
self.fusion_proj = nn.Linear(self.hidden_size * 6, self.hidden_size)
elif self.fusion == 'weighted':
self.fusion_weight = nn.Parameter(torch.ones(6))
elif self.fusion == 'attention':
self.attn = nn.Linear(self.hidden_size, 1)
else:
raise ValueError(f"Unsupported fusion: {self.fusion}")
self.group_y_labels = group_y.shape[0]
self.l0 = nn.Linear(self.hidden_size, self.group_y_labels)
self.l1 = nn.Linear(self.hidden_size, hidden_dim)
self.embed = nn.Embedding(n_labels, hidden_dim)
nn.init.xavier_uniform_(self.embed.weight)
def encode_zones(self, input_ids, attention_mask, token_type_ids):
bsz, zones, seqlen = input_ids.shape
flat_ids = input_ids.view(bsz * zones, seqlen)
flat_mask = attention_mask.view(bsz * zones, seqlen)
flat_type = token_type_ids.view(bsz * zones, seqlen)
outs = self.bert(
flat_ids,
attention_mask=flat_mask,
token_type_ids=flat_type
)[-1]
cls = torch.cat([outs[-i][:, 0] for i in range(1, self.feature_layers + 1)], dim=-1)
return cls.view(bsz, zones, -1)
def fuse(self, zone_reps):
if self.fusion == 'concat':
fused = zone_reps.reshape(zone_reps.shape[0], -1)
return self.fusion_proj(fused)
if self.fusion == 'weighted':
weights = torch.softmax(self.fusion_weight, dim=0)
return (zone_reps * weights.view(1, -1, 1)).sum(dim=1)
if self.fusion == 'attention':
attn_scores = self.attn(zone_reps).squeeze(-1)
attn_weights = torch.softmax(attn_scores, dim=1)
return (zone_reps * attn_weights.unsqueeze(-1)).sum(dim=1)
raise ValueError(f"Unsupported fusion: {self.fusion}")
def get_candidates(self, group_logits, group_gd=None):
logits = torch.sigmoid(group_logits.detach())
if group_gd is not None and self.teacher_force_lambda > 0:
logits += self.teacher_force_lambda * group_gd
k = min(self.candidates_topk, logits.shape[1])
if k <= 0:
raise ValueError("No group labels available for candidate selection")
scores, indices = torch.topk(logits, k=k)
scores, indices = scores.cpu().detach().numpy(), indices.cpu().detach().numpy()
candidates, candidates_scores = [], []
for index, score in zip(indices, scores):
candidates.append(self.group_y[index])
candidates_scores.append([np.full(c.shape, s) for c, s in zip(candidates[-1], score)])
candidates[-1] = np.concatenate(candidates[-1])
candidates_scores[-1] = np.concatenate(candidates_scores[-1])
max_candidates = max([i.shape[0] for i in candidates])
candidates = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates])
candidates_scores = np.stack([np.pad(i, (0, max_candidates - i.shape[0]), mode='edge') for i in candidates_scores])
return indices, candidates, candidates_scores
def forward(self, input_ids, attention_mask, token_type_ids,
labels=None, group_labels=None, candidates=None):
is_training = labels is not None
zone_reps = self.encode_zones(input_ids, attention_mask, token_type_ids)
out = self.drop_out(self.fuse(zone_reps))
group_logits = self.l0(out)
if is_training:
l = labels.to(dtype=torch.bool)
target_candidates = torch.masked_select(candidates, l).detach().cpu()
target_candidates_num = l.sum(dim=1).detach().cpu()
groups, candidates, group_candidates_scores = self.get_candidates(
group_logits, group_gd=group_labels if is_training else None
)
if is_training:
bs = 0
new_labels = []
for i, n in enumerate(target_candidates_num.numpy()):
be = bs + n
c = set(target_candidates[bs: be].numpy())
c2 = candidates[i]
new_labels.append(torch.tensor([1.0 if i in c else 0.0 for i in c2]))
if len(c) != new_labels[-1].sum():
s_c2 = set(c2)
for cc in list(c):
if cc in s_c2:
continue
for j in range(new_labels[-1].shape[0]):
if new_labels[-1][j].item() != 1:
c2[j] = cc
new_labels[-1][j] = 1.0
break
bs = be
labels = torch.stack(new_labels).to(input_ids.device)
candidates = torch.LongTensor(candidates).to(input_ids.device)
group_candidates_scores = torch.Tensor(group_candidates_scores).to(input_ids.device)
emb = self.l1(out)
embed_weights = self.embed(candidates)
emb = emb.unsqueeze(-1)
logits = torch.bmm(embed_weights, emb).squeeze(-1)
if is_training:
loss_fn = torch.nn.BCEWithLogitsLoss()
loss_rank = loss_fn(logits, labels)
loss_group = loss_fn(group_logits, group_labels)
loss = loss_rank + loss_group
return logits, loss, loss_group.detach(), loss_rank.detach()
else:
candidates_scores = torch.sigmoid(logits)
candidates_scores = candidates_scores * group_candidates_scores
return group_logits, candidates, candidates_scores
def get_accuracy(self, candidates, logits, labels):
if candidates is not None:
candidates = candidates.detach().cpu()
k = min(10, logits.shape[1])
if k <= 0:
return 0, 0, 0, 0
scores, indices = torch.topk(logits.detach().cpu(), k=k)
acc1, acc3, acc5, total = 0, 0, 0, 0
for i, l in enumerate(labels):
if candidates is not None and l.shape[0] == candidates.shape[1]:
positive_idx = np.nonzero(l)[0]
target_labels = set(candidates[i][positive_idx].numpy())
else:
target_labels = set(np.nonzero(l)[0])
if candidates is not None:
pred_labels = candidates[i][indices[i]].numpy()
else:
pred_labels = indices[i, :5].numpy()
acc1 += len(set([pred_labels[0]]) & target_labels)
acc3 += len(set(pred_labels[:3]) & target_labels)
acc5 += len(set(pred_labels[:5]) & target_labels)
total += 1
return total, acc1, acc3, acc5
def one_epoch(self, epoch, dataloader, optimizer,
mode='train', eval_loader=None, eval_step=20000, log=None, log_interval=50, use_tqdm=False):
total_steps = len(dataloader)
bar = tqdm.tqdm(total=total_steps) if use_tqdm else None
p1, p3, p5 = 0, 0, 0
g_p1, g_p3, g_p5 = 0, 0, 0
total, acc1, acc3, acc5 = 0, 0, 0, 0
g_acc1, g_acc3, g_acc5 = 0, 0, 0
cand_hit, cand_total = 0, 0
train_loss = 0
if mode == 'train':
self.train()
self.zero_grad()
else:
self.eval()
if self.use_swa and epoch == self.swa_warmup_epoch and mode == 'train':
self.swa_init()
if self.use_swa and mode == 'eval':
self.swa_swap_params()
pred_scores, pred_labels = [], []
if bar:
bar.set_description(f'{mode}-{epoch}')
device = next(self.parameters()).device
def _gpu_mem():
if device.type != "cuda":
return "gpu_mem=NA"
alloc = torch.cuda.memory_allocated(device) // (1024 * 1024)
total = torch.cuda.get_device_properties(device).total_memory // (1024 * 1024)
return f"gpu_mem={alloc}MB/{total}MB"
with torch.set_grad_enabled(mode == 'train'):
for step, data in enumerate(dataloader):
batch = tuple(t for t in data)
inputs = {'input_ids': batch[0].to(device),
'attention_mask': batch[1].to(device),
'token_type_ids': batch[2].to(device)}
if mode == 'train':
inputs['labels'] = batch[3].to(device)
inputs['group_labels'] = batch[4].to(device)
inputs['candidates'] = batch[5].to(device)
outputs = self(**inputs)
if bar:
bar.update(1)
if mode == 'train':
loss = outputs[1]
loss_group = outputs[2]
loss_rank = outputs[3]
loss /= self.update_count
train_loss += loss.item()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if (step + 1) % self.update_count == 0:
optimizer.step()
self.zero_grad()
if step % eval_step == 0 and eval_loader is not None and step != 0:
results = self.one_epoch(epoch, eval_loader, optimizer, mode='eval')
p1, p3, p5 = results[3:6]
g_p1, g_p3, g_p5 = results[:3]
log.log(f'{epoch:>2} {step:>6}: {p1:.4f}, {p3:.4f}, {p5:.4f}'
f' {g_p1:.4f}, {g_p3:.4f}, {g_p5:.4f}')
if self.use_swa and step % self.swa_update_step == 0:
self.swa_step()
if bar:
bar.set_postfix(loss=loss.item())
if log_interval and step % log_interval == 0:
print(f"[train] epoch={epoch} step={step+1}/{len(dataloader)} "
f"loss={loss.item():.6f} loss_group={loss_group.item():.6f} "
f"loss_rank={loss_rank.item():.6f} {_gpu_mem()}")
else:
group_logits, candidates, logits = outputs
if mode == 'eval':
labels = batch[3]
group_labels = batch[4]
_total, _acc1, _acc3, _acc5 = self.get_accuracy(candidates, logits, labels.cpu().numpy())
total += _total; acc1 += _acc1; acc3 += _acc3; acc5 += _acc5
p1 = acc1 / total
p3 = acc3 / total / 3
p5 = acc5 / total / 5
cand_np = candidates.detach().cpu().numpy()
labels_np = labels.cpu().numpy()
for i, l in enumerate(labels_np):
pos = np.nonzero(l)[0]
if pos.shape[0] == 0:
continue
cand_set = set(cand_np[i].tolist())
cand_hit += sum(1 for t in pos if t in cand_set)
cand_total += pos.shape[0]
candidate_recall = cand_hit / cand_total if cand_total > 0 else 0.0
_, _g_acc1, _g_acc3, _g_acc5 = self.get_accuracy(None, group_logits, group_labels.cpu().numpy())
g_acc1 += _g_acc1; g_acc3 += _g_acc3; g_acc5 += _g_acc5
g_p1 = g_acc1 / total
g_p3 = g_acc3 / total / 3
g_p5 = g_acc5 / total / 5
if bar:
bar.set_postfix(p1=p1, p3=p3, p5=p5, g_p1=g_p1, g_p3=g_p3, g_p5=g_p5,
cand_recall=candidate_recall)
if log_interval and step % log_interval == 0:
print(f"[eval] epoch={epoch} step={step+1}/{len(dataloader)} "
f"g_p1={g_p1:.4f} g_p3={g_p3:.4f} g_p5={g_p5:.4f} "
f"p1={p1:.4f} p3={p3:.4f} p5={p5:.4f} "
f"cand_recall={candidate_recall:.4f} {_gpu_mem()}")
elif mode == 'test':
k = min(100, logits.shape[1])
_scores, _indices = torch.topk(logits.detach().cpu(), k=k)
_labels = torch.stack([candidates[i][_indices[i]] for i in range(_indices.shape[0])], dim=0)
pred_scores.append(_scores.cpu())
pred_labels.append(_labels.cpu())
if mode == 'train' and total_steps > 0 and total_steps % self.update_count != 0:
optimizer.step()
self.zero_grad()
if self.use_swa and mode == 'eval':
self.swa_swap_params()
if bar:
bar.close()
if mode == 'eval':
candidate_recall = cand_hit / cand_total if cand_total > 0 else 0.0
return g_p1, g_p3, g_p5, p1, p3, p5, candidate_recall
elif mode == 'test':
return torch.cat(pred_scores, dim=0).numpy(), torch.cat(pred_labels, dim=0).numpy() if len(pred_labels) != 0 else None
elif mode == 'train':
return train_loss
def swa_init(self):
self.swa_state = {'models_num': 1}
for n, p in self.named_parameters():
self.swa_state[n] = p.data.detach().cpu().clone()
def swa_step(self):
if 'models_num' not in self.swa_state:
return
self.swa_state['models_num'] += 1
beta = 1.0 / self.swa_state['models_num']
with torch.no_grad():
for n, p in self.named_parameters():
self.swa_state[n].mul_(1.0 - beta).add_(beta, p.data.detach().cpu())
def swa_swap_params(self):
if 'models_num' not in self.swa_state:
return
for n, p in self.named_parameters():
buf = self.swa_state[n]
self.swa_state[n] = p.data.detach().cpu()
p.data = buf.to(p.data.device)
def get_tokenizer(self):
if 'roberta' in self.bert_name:
try:
return RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True, local_files_only=True)
except Exception:
return RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
if 'xlnet' in self.bert_name:
try:
return XLNetTokenizer.from_pretrained('xlnet-base-cased', local_files_only=True)
except Exception:
return XLNetTokenizer.from_pretrained('xlnet-base-cased')
if self.bert_name in ['bert-base', 'bert-base-uncased']:
try:
return BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, local_files_only=True)
except Exception:
return BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
if self.bert_name == 'bert-base-chinese':
try:
return BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True)
except Exception:
return BertTokenizer.from_pretrained('bert-base-chinese')
try:
return AutoTokenizer.from_pretrained(self.bert_name, local_files_only=True)
except Exception:
return AutoTokenizer.from_pretrained(self.bert_name)