| import os |
| import os.path as osp |
| import torch |
| from torch.utils.tensorboard import SummaryWriter |
| from torch.utils.data import DataLoader, RandomSampler |
| from torch.cuda.amp import GradScaler, autocast |
| from datetime import datetime |
| from easydict import EasyDict as edict |
| from tqdm import tqdm |
| import pdb |
| import pprint |
| import json |
| import pickle |
| from collections import defaultdict |
| import copy |
| from time import time |
|
|
| from config import cfg |
| from torchlight import initialize_exp, set_seed, get_dump_path |
| from src.data import load_data, load_data_kg, Collator_base, Collator_kg, SeqDataset, KGDataset, Collator_order, load_order_data |
| from src.utils import set_optim, Loss_log, add_special_token, time_trans |
| from src.distributed_utils import init_distributed_mode, dist_pdb, is_main_process, reduce_value, cleanup |
| import torch.distributed as dist |
|
|
| from itertools import cycle |
| from model import BertTokenizer, HWBert, KGEModel, OD_model, KE_model |
| import torch.multiprocessing |
| from torch.nn.parallel import DistributedDataParallel |
|
|
| |
|
|
|
|
| class Runner: |
| def __init__(self, args, writer=None, logger=None, rank=0): |
| self.datapath = edict() |
| self.datapath.log_dir = get_dump_path(args) |
| self.datapath.model_dir = os.path.join(self.datapath.log_dir, 'model') |
| self.rank = rank |
| |
| self.mlm_probability = args.mlm_probability |
| self.args = args |
| self.writer = writer |
| self.logger = logger |
| |
| self.model_list = [] |
| self.model = HWBert(self.args) |
| |
| self.data_init() |
| self.model.cuda() |
| |
| self.od_model, self.ke_model = None, None |
| self.scaler = GradScaler() |
|
|
| |
| if self.args.train_strategy >= 2: |
| self.ke_model = KE_model(self.args) |
| if self.args.train_strategy >= 3: |
| |
| pass |
| if self.args.train_strategy >= 4: |
| self.od_model = OD_model(self.args) |
|
|
| if self.args.model_name not in ['MacBert', 'TeleBert', 'TeleBert2', 'TeleBert3'] and not self.args.from_pretrain: |
| |
| self.model = self._load_model(self.model, self.args.model_name) |
| self.od_model = self._load_model(self.od_model, f"od_{self.args.model_name}") |
| self.ke_model = self._load_model(self.ke_model, f"ke_{self.args.model_name}") |
| |
|
|
| |
| if self.args.only_test: |
| self.dataloader_init(self.seq_test_set) |
| else: |
| |
| if self.args.ernie_stratege > 0: |
| self.args.mask_stratege = 'rand' |
| |
| self.dataloader_init(self.seq_train_set, self.kg_train_set, self.order_train_set) |
| if self.args.dist: |
| |
| self.model_sync() |
| else: |
| self.model_list = [model for model in [self.model, self.od_model, self.ke_model] if model is not None] |
|
|
| self.optim_init(self.args) |
|
|
| def model_sync(self): |
| checkpoint_path = osp.join(self.args.data_path, "tmp", "initial_weights.pt") |
| checkpoint_path_od = osp.join(self.args.data_path, "tmp", "initial_weights_od.pt") |
| checkpoint_path_ke = osp.join(self.args.data_path, "tmp", "initial_weights_ke.pt") |
| if self.rank == 0: |
| torch.save(self.model.state_dict(), checkpoint_path) |
| if self.od_model is not None: |
| torch.save(self.od_model.state_dict(), checkpoint_path_od) |
| if self.ke_model is not None: |
| torch.save(self.ke_model.state_dict(), checkpoint_path_ke) |
| dist.barrier() |
|
|
| |
| |
| self.model = self._model_sync(self.model, checkpoint_path) |
| if self.od_model is not None: |
| self.od_model = self._model_sync(self.od_model, checkpoint_path_od) |
| if self.ke_model is not None: |
| self.ke_model = self._model_sync(self.ke_model, checkpoint_path_ke) |
|
|
| def _model_sync(self, model, checkpoint_path): |
| model.load_state_dict(torch.load(checkpoint_path, map_location=self.args.device)) |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(self.args.device) |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.args.gpu], find_unused_parameters=True) |
| self.model_list.append(model) |
| model = model.module |
| return model |
|
|
| def optim_init(self, opt, total_step=None, accumulation_step=None): |
| step_per_epoch = len(self.train_dataloader) |
| |
| opt.total_steps = int(step_per_epoch * opt.epoch) if total_step is None else int(total_step) |
| opt.warmup_steps = int(opt.total_steps * 0.15) |
|
|
| if self.rank == 0 and total_step is None: |
| self.logger.info(f"warmup_steps: {opt.warmup_steps}") |
| self.logger.info(f"total_steps: {opt.total_steps}") |
| self.logger.info(f"weight_decay: {opt.weight_decay}") |
|
|
| freeze_part = ['bert.encoder.layer.1.', 'bert.encoder.layer.2.', 'bert.encoder.layer.3.', 'bert.encoder.layer.4.'][:self.args.freeze_layer] |
| self.optimizer, self.scheduler = set_optim(opt, self.model_list, freeze_part, accumulation_step) |
|
|
| def data_init(self): |
| |
| |
| self.seq_train_set, self.seq_test_set, self.kg_train_set, self.kg_data = None, None, None, None |
| self.order_train_set, self.order_test_set = None, None |
|
|
| if self.args.train_strategy >= 1 and self.args.train_strategy <= 4: |
| |
| self.seq_train_set, self.seq_test_set, train_test_split = load_data(self.logger, self.args) |
| if self.args.train_strategy >= 2: |
| self.kg_train_set, self.kg_data = load_data_kg(self.logger, self.args) |
| if self.args.train_strategy >= 3: |
| |
| pass |
| if self.args.train_strategy >= 4: |
| self.order_train_set, self.order_test_set, train_test_split = load_order_data(self.logger, self.args) |
|
|
| if self.args.dist and not self.args.only_test: |
| |
| if self.args.train_strategy >= 1 and self.args.train_strategy <= 4: |
| self.seq_train_sampler = torch.utils.data.distributed.DistributedSampler(self.seq_train_set) |
| if self.args.train_strategy >= 2: |
| self.kg_train_sampler = torch.utils.data.distributed.DistributedSampler(self.kg_train_set) |
| if self.args.train_strategy >= 3: |
| |
| pass |
| if self.args.train_strategy >= 4: |
| self.order_train_sampler = torch.utils.data.distributed.DistributedSampler(self.order_train_set) |
|
|
| |
| |
|
|
| |
| model_name = self.args.model_name |
| if self.args.model_name in ['TeleBert', 'TeleBert2', 'TeleBert3']: |
| self.tokenizer = BertTokenizer.from_pretrained(osp.join(self.args.data_root, 'transformer', model_name), do_lower_case=True) |
| else: |
| if not osp.exists(osp.join(self.args.data_root, 'transformer', self.args.model_name)): |
| model_name = 'MacBert' |
| self.tokenizer = BertTokenizer.from_pretrained(osp.join(self.args.data_root, 'transformer', model_name), do_lower_case=True) |
|
|
| |
| self.special_token = None |
| |
| if self.args.add_special_word and not (self.args.only_test and self.args.model_name in ['MacBert', 'TeleBert', 'TeleBert2', 'TeleBert3']): |
| |
| |
| self.tokenizer, special_token, _ = add_special_token(self.tokenizer, model=self.model.encoder, rank=self.rank, cache_path=self.args.specail_emb_path) |
| |
| self.special_token = [token.lower() for token in special_token] |
|
|
| def _dataloader_dist(self, train_set, train_sampler, batch_size, collator): |
| train_dataloader = DataLoader( |
| train_set, |
| sampler=train_sampler, |
| pin_memory=True, |
| num_workers=self.args.workers, |
| persistent_workers=True, |
| drop_last=True, |
| batch_size=batch_size, |
| collate_fn=collator |
| ) |
| return train_dataloader |
|
|
| def _dataloader(self, train_set, batch_size, collator): |
| train_dataloader = DataLoader( |
| train_set, |
| num_workers=self.args.workers, |
| persistent_workers=True, |
| shuffle=(self.args.only_test == 0), |
| drop_last=(self.args.only_test == 0), |
| batch_size=batch_size, |
| collate_fn=collator |
| ) |
| return train_dataloader |
|
|
| def dataloader_init(self, train_set=None, kg_train_set=None, order_train_set=None): |
| bs = self.args.batch_size |
| bs_ke = self.args.batch_size_ke |
| bs_od = self.args.batch_size_od |
| bs_ad = self.args.batch_size_ad |
| |
| if self.args.dist and not self.args.only_test: |
| self.args.workers = min([os.cpu_count(), self.args.batch_size, self.args.workers]) |
| |
| |
|
|
| if train_set is not None: |
| seq_collator = Collator_base(self.args, tokenizer=self.tokenizer, special_token=self.special_token) |
| self.train_dataloader = self._dataloader_dist(train_set, self.seq_train_sampler, bs, seq_collator) |
| if kg_train_set is not None: |
| kg_collator = Collator_kg(self.args, tokenizer=self.tokenizer, data=self.kg_data) |
| self.train_dataloader_kg = self._dataloader_dist(kg_train_set, self.kg_train_sampler, bs_ke, kg_collator) |
| if order_train_set is not None: |
| order_collator = Collator_order(self.args, tokenizer=self.tokenizer) |
| self.train_dataloader_order = self._dataloader_dist(order_train_set, self.order_train_sampler, bs_od, order_collator) |
| else: |
| if train_set is not None: |
| seq_collator = Collator_base(self.args, tokenizer=self.tokenizer, special_token=self.special_token) |
| self.train_dataloader = self._dataloader(train_set, bs, seq_collator) |
| if kg_train_set is not None: |
| kg_collator = Collator_kg(self.args, tokenizer=self.tokenizer, data=self.kg_data) |
| self.train_dataloader_kg = self._dataloader(kg_train_set, bs_ke, kg_collator) |
| if order_train_set is not None: |
| order_collator = Collator_order(self.args, tokenizer=self.tokenizer) |
| self.train_dataloader_order = self._dataloader(order_train_set, bs_od, order_collator) |
|
|
| def dist_step(self, task=0): |
| |
| if self.args.dist: |
| if task == 0: |
| self.seq_train_sampler.set_epoch(self.dist_epoch) |
| if task == 1: |
| self.kg_train_sampler.set_epoch(self.dist_epoch) |
| if task == 2: |
| |
| pass |
| if task == 3: |
| self.order_train_sampler.set_epoch(self.dist_epoch) |
| self.dist_epoch += 1 |
|
|
| def mask_rate_update(self, i): |
| |
| if self.args.mlm_probability_increase == "curve": |
| self.args.mlm_probability += (i + 1) * ((self.args.final_mlm_probability - self.args.mlm_probability) / self.args.epoch) |
| |
| else: |
| assert self.args.mlm_probability_increase == "linear" |
| self.args.mlm_probability += (self.args.final_mlm_probability - self.mlm_probability) / self.args.epoch |
|
|
| if self.rank == 0: |
| self.logger.info(f"Moving Mlm_probability in next epoch to: {self.args.mlm_probability*100}%") |
|
|
| def task_switch(self, training_strategy): |
| |
| if training_strategy == 1 or self.args.train_together: |
| return (0, 0), None |
|
|
| |
| |
|
|
| for i in range(4): |
| for task in range(training_strategy): |
| if self.args.epoch_matrix[task][i] > 0: |
| self.args.epoch_matrix[task][i] -= 1 |
| return (task, i), self.args.epoch_matrix[task][i] + 1 |
|
|
| def run(self): |
| self.loss_log = Loss_log() |
| self.curr_loss = 0. |
| self.lr = self.args.lr |
| self.curr_loss_dic = defaultdict(float) |
| self.curr_kpi_loss_dic = defaultdict(float) |
| self.loss_weight = [1, 1] |
| self.kpi_loss_weight = [1, 1] |
| self.step = 0 |
| |
| self.total_step_sum = 0 |
| task_last = 0 |
| stage_last = 0 |
| self.dist_epoch = 0 |
| |
| |
| |
| with tqdm(total=self.args.epoch) as _tqdm: |
| for i in range(self.args.epoch): |
| |
| (task, stage), task_epoch = self.task_switch(self.args.train_strategy) |
| self.dist_step(task) |
| dataloader = self.task_dataloader_choose(task) |
| |
| if self.args.train_together and self.args.train_strategy > 1: |
| self.dataloader_list = ['#'] |
| |
| for t in range(1, self.args.train_strategy): |
| self.dist_step(t) |
| self.dataloader_list.append(iter(self.task_dataloader_choose(t))) |
|
|
| if task != task_last or stage != stage_last: |
| self.step = 0 |
| if self.rank == 0: |
| print(f"switch to task [{task}] in stage [{stage}]...") |
| if stage != stage_last: |
| |
| self._save_model(stage=f'_stg{stage_last}') |
| |
| |
| if task_epoch is not None: |
| self.optim_init(self.args, total_step=len(dataloader) * task_epoch, accumulation_step=self.args.accumulation_steps_dict[task]) |
| task_last = task |
| stage_last = stage |
|
|
| |
| if task == 0 and self.args.ernie_stratege > 0 and i >= self.args.ernie_stratege: |
| |
| self.args.ernie_stratege = 10000000 |
| if self.rank == 0: |
| self.logger.info("switch to wwm stratege...") |
| self.args.mask_stratege = 'wwm' |
|
|
| if self.args.mlm_probability != self.args.final_mlm_probability: |
| |
| |
| |
| self.mask_rate_update(i) |
| self.dataloader_init(self.seq_train_set, self.kg_train_set, self.order_train_set) |
| |
| |
| self.train(_tqdm, dataloader, task) |
| |
| _tqdm.update(1) |
|
|
| |
| if self.rank == 0: |
| self.logger.info(f"min loss {self.loss_log.get_min_loss()}") |
| |
| if not self.args.only_test and self.args.save_model: |
| self._save_model() |
|
|
| def task_dataloader_choose(self, task): |
| self.model.train() |
| |
| if task == 0: |
| dataloader = self.train_dataloader |
| elif task == 1: |
| self.ke_model.train() |
| dataloader = self.train_dataloader_kg |
| elif task == 2: |
| pass |
| elif task == 3: |
| self.od_model.train() |
| dataloader = self.train_dataloader_order |
| return dataloader |
| |
|
|
| def loss_output(self, batch, task): |
| |
| if task == 0: |
| |
| _output = self.model(batch) |
| loss = _output['loss'] |
| elif task == 1: |
| loss = self.ke_model(batch, self.model) |
| elif task == 2: |
| pass |
| elif task == 3: |
| |
| |
| emb = self.model.cls_embedding(batch[0], tp=self.args.plm_emb_type) |
| loss, loss_dic = self.od_model(emb, batch[1].cuda()) |
| order_score = self.od_model.predict(emb) |
| token_right = self.od_model.right_caculate(order_score, batch[1], threshold=0.5) |
| self.loss_log.update_token(batch[1].shape[0], [token_right]) |
| return loss |
|
|
| def train(self, _tqdm, dataloader, task=0): |
| |
| loss_weight, kpi_loss_weight, kpi_loss_dict, _output = None, None, None, None |
| |
| self.loss_log.acc_init() |
| |
| accumulation_steps = self.args.accumulation_steps_dict[task] |
| torch.cuda.empty_cache() |
|
|
| for batch in dataloader: |
| |
| loss = self.args.mask_loss_scale * self.loss_output(batch, task) |
| |
| if self.args.train_together and self.args.train_strategy > 1: |
| for t in range(1, self.args.train_strategy): |
| try: |
| batch = next(self.dataloader_list[t]) |
| except StopIteration: |
| self.dist_step(t) |
| self.dataloader_list[t] = iter(self.task_dataloader_choose(t)) |
| batch = next(self.dataloader_list[t]) |
| |
| |
| loss += self.loss_output(batch, t) |
| |
| loss = loss / accumulation_steps |
| self.scaler.scale(loss).backward() |
| |
| if self.args.dist: |
| loss = reduce_value(loss, average=True) |
| |
| self.step += 1 |
| self.total_step_sum += 1 |
|
|
| |
| if not self.args.dist or is_main_process(): |
| self.output_statistic(loss, _output) |
| acc_descrip = f"Acc: {self.loss_log.get_token_acc()}" if self.loss_log.get_token_acc() > 0 else "" |
| _tqdm.set_description(f'Train | step [{self.step}/{self.args.total_steps}] {acc_descrip} LR [{self.lr}] Loss {self.loss_log.get_loss():.5f} ') |
| if self.step % self.args.eval_step == 0 and self.step > 0: |
| self.loss_log.update(self.curr_loss) |
| self.update_loss_log() |
| |
| if self.step % accumulation_steps == 0 and self.step > 0: |
| |
| self.scaler.unscale_(self.optimizer) |
| for model in self.model_list: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.clip) |
|
|
| |
| scale = self.scaler.get_scale() |
| self.scaler.step(self.optimizer) |
|
|
| self.scaler.update() |
| skip_lr_sched = (scale > self.scaler.get_scale()) |
| if not skip_lr_sched: |
| |
| self.scheduler.step() |
|
|
| if not self.args.dist or is_main_process(): |
| |
| self.lr = self.scheduler.get_last_lr()[-1] |
| self.writer.add_scalars("lr", {"lr": self.lr}, self.total_step_sum) |
| |
| for model in self.model_list: |
| model.zero_grad(set_to_none=True) |
|
|
| if self.args.dist: |
| torch.cuda.synchronize(self.args.device) |
| return self.curr_loss, self.curr_loss_dic |
|
|
| def output_statistic(self, loss, output): |
| |
| self.curr_loss += loss.item() |
| if output is None: |
| return |
| for key in output['loss_dic'].keys(): |
| self.curr_loss_dic[key] += output['loss_dic'][key] |
| if 'kpi_loss_dict' in output and output['kpi_loss_dict'] is not None: |
| for key in output['kpi_loss_dict'].keys(): |
| self.curr_kpi_loss_dic[key] += output['kpi_loss_dict'][key] |
| if 'loss_weight' in output and output['loss_weight'] is not None: |
| self.loss_weight = output['loss_weight'] |
| |
| if 'kpi_loss_weight' in output and output['kpi_loss_weight'] is not None: |
| self.kpi_loss_weight = output['kpi_loss_weight'] |
|
|
| def update_loss_log(self, task=0): |
| |
| |
| |
| vis_dict = {"train_loss": self.curr_loss} |
| vis_dict.update(self.curr_loss_dic) |
| self.writer.add_scalars("loss", vis_dict, self.total_step_sum) |
| if self.loss_weight is not None: |
| |
| loss_weight_dic = {} |
| if self.args.train_strategy == 1: |
| loss_weight_dic["mask"] = 1 / (self.loss_weight[0]**2) |
| if self.args.use_NumEmb: |
| loss_weight_dic["kpi"] = 1 / (self.loss_weight[1]**2) |
| vis_kpi_dic = {"recover": 1 / (self.kpi_loss_weight[0]**2), "classifier": 1 / (self.kpi_loss_weight[1]**2)} |
| if self.args.contrastive_loss and len(self.kpi_loss_weight) > 2: |
| vis_kpi_dic.update({"contrastive": 1 / (self.kpi_loss_weight[2]**2)}) |
| self.writer.add_scalars("kpi_loss_weight", vis_kpi_dic, self.total_step_sum) |
| self.writer.add_scalars("kpi_loss", self.curr_kpi_loss_dic, self.total_step_sum) |
| self.writer.add_scalars("loss_weight", loss_weight_dic, self.total_step_sum) |
| |
|
|
| |
| self.curr_loss = 0. |
| for key in self.curr_loss_dic: |
| self.curr_loss_dic[key] = 0. |
| if len(self.curr_kpi_loss_dic) > 0: |
| for key in self.curr_kpi_loss_dic: |
| self.curr_kpi_loss_dic[key] = 0. |
|
|
| |
| def eval(self): |
| self.model.eval() |
| torch.cuda.empty_cache() |
|
|
| def mask_test(self, test_log): |
| |
| assert self.args.train_ratio < 1 |
| topk = (1, 100, 500) |
| test_log.acc_init(topk) |
| |
| self.args.only_test = 0 |
| self.dataloader_init(self.seq_test_set) |
| |
| sz_test = len(self.train_dataloader) |
| loss_sum = 0 |
| with tqdm(total=sz_test) as _tqdm: |
| for step, batch in enumerate(self.train_dataloader): |
| |
| with torch.no_grad(): |
| token_num, token_right, loss = self.model.mask_prediction(batch, len(self.tokenizer), topk) |
| test_log.update_token(token_num, token_right) |
| loss_sum += loss |
| |
| _tqdm.update(1) |
| _tqdm.set_description(f'Test | step [{step}/{sz_test}] Top{topk} Token_Acc: {test_log.get_token_acc()}') |
| print(f"perplexity: {loss_sum}") |
| |
| self.args.only_test = 1 |
| |
| print(f"Top{topk} acc is {test_log.get_token_acc()}") |
|
|
| def emb_generate(self, path_gen): |
| assert len(self.args.path_gen) > 0 or path_gen is not None |
| data_path = self.args.data_path |
| if path_gen is None: |
| path_gen = self.args.path_gen |
| with open(osp.join(data_path, 'downstream_task', f'{path_gen}.json'), "r") as fp: |
| data = json.load(fp) |
| print(f"read file {path_gen} done!") |
| test_set = SeqDataset(data) |
| self.dataloader_init(test_set) |
| sz_test = len(self.train_dataloader) |
| all_emb_dic = defaultdict(list) |
| emb_output = {} |
| all_emb_ent = [] |
| |
| tps = ['cls', 'last_avg'] |
| |
| for step, batch in enumerate(self.train_dataloader): |
| for tp in tps: |
| with torch.no_grad(): |
| batch_embedding = self.model.cls_embedding(batch, tp=tp) |
| |
| if tp in self.args.model_name and self.ke_model is not None: |
| batch_embedding_ent = self.ke_model.get_embedding(batch_embedding, is_ent=True) |
| |
| batch_embedding_ent = batch_embedding_ent.cpu() |
| all_emb_ent.append(batch_embedding_ent) |
|
|
| batch_embedding = batch_embedding.cpu() |
| all_emb_dic[tp].append(batch_embedding) |
| |
| |
| torch.cuda.empty_cache() |
| for tp in tps: |
| emb_output[tp] = torch.cat(all_emb_dic[tp]) |
| assert emb_output[tp].shape[0] == len(data) |
| if len(all_emb_ent) > 0: |
| emb_output_ent = torch.cat(all_emb_ent) |
| |
| save_path = osp.join(data_path, 'downstream_task', 'output') |
| os.makedirs(save_path, exist_ok=True) |
| for tp in tps: |
| save_dir = osp.join(save_path, f'{path_gen}_emb_{self.args.model_name.replace("DistributedDataParallel", "")}_{tp}.pt') |
| torch.save(emb_output[tp], save_dir) |
| |
| if len(all_emb_ent) > 0: |
| save_dir = osp.join(save_path, f'{path_gen}_emb_{self.args.model_name.replace("DistributedDataParallel", "")}_ent.pt') |
| torch.save(emb_output_ent, save_dir) |
|
|
| def KGE_test(self): |
| |
| sz_test = len(self.kg_train_set) |
| |
| ent_set = set() |
| rel_set = set() |
| with tqdm(total=sz_test) as _tqdm: |
| _tqdm.set_description('trans entity/relation ID') |
| for batch in self.kg_train_set: |
| ent_set.add(batch[0]) |
| ent_set.add(batch[2]) |
| rel_set.add(batch[1]) |
| _tqdm.update(1) |
| all_ent, all_rel = list(ent_set), list(rel_set) |
| nent, nrel = len(all_ent), len(all_rel) |
| ent_dic, rel_dic = {}, {} |
| for i in range(nent): |
| ent_dic[all_ent[i]] = i |
| for i in range(nrel): |
| rel_dic[all_rel[i]] = i |
| id_format_triple = [] |
| with tqdm(total=sz_test) as _tqdm: |
| _tqdm.set_description('trans triple ID') |
| for triple in self.kg_train_set: |
| id_format_triple.append((ent_dic[triple[0]], rel_dic[triple[1]], ent_dic[triple[2]])) |
| _tqdm.update(1) |
|
|
| |
| |
| ent_dataset = KGDataset(all_ent) |
| rel_dataset = KGDataset(all_rel) |
|
|
| ent_dataloader = DataLoader( |
| ent_dataset, |
| batch_size=self.args.batch_size * 32, |
| num_workers=self.args.workers, |
| persistent_workers=True, |
| shuffle=False |
| ) |
| rel_dataloader = DataLoader( |
| rel_dataset, |
| batch_size=self.args.batch_size * 32, |
| num_workers=self.args.workers, |
| persistent_workers=True, |
| shuffle=False |
| ) |
|
|
| sz_test = len(ent_dataloader) + len(rel_dataloader) |
| with tqdm(total=sz_test) as _tqdm: |
| ent_emb = [] |
| rel_emb = [] |
| step = 0 |
| _tqdm.set_description('get the ent embedding') |
| with torch.no_grad(): |
| for batch in ent_dataloader: |
| batch = self.tokenizer.batch_encode_plus( |
| batch, |
| padding='max_length', |
| max_length=self.args.maxlength, |
| truncation=True, |
| return_tensors="pt", |
| return_token_type_ids=False, |
| return_attention_mask=True, |
| add_special_tokens=False |
| ) |
|
|
| batch_emb = self.model.cls_embedding(batch, tp=self.args.plm_emb_type) |
| batch_emb = self.ke_model.get_embedding(batch_emb, is_ent=True) |
|
|
| ent_emb.append(batch_emb.cpu()) |
| _tqdm.update(1) |
| step += 1 |
| torch.cuda.empty_cache() |
| _tqdm.set_description(f'ENT emb: [{step}/{sz_test}]') |
|
|
| _tqdm.set_description('get the rel embedding') |
| for batch in rel_dataloader: |
| batch = self.tokenizer.batch_encode_plus( |
| batch, |
| padding='max_length', |
| max_length=self.args.maxlength, |
| truncation=True, |
| return_tensors="pt", |
| return_token_type_ids=False, |
| return_attention_mask=True, |
| add_special_tokens=False |
| ) |
| batch_emb = self.model.cls_embedding(batch, tp=self.args.plm_emb_type) |
| batch_emb = self.ke_model.get_embedding(batch_emb, is_ent=False) |
| |
| rel_emb.append(batch_emb.cpu()) |
| _tqdm.update(1) |
| step += 1 |
| torch.cuda.empty_cache() |
| _tqdm.set_description(f'REL emb: [{step}/{sz_test}]') |
|
|
| all_ent_emb = torch.cat(ent_emb).cuda() |
| all_rel_emb = torch.cat(rel_emb).cuda() |
| |
| |
| kge_model_for_test = KGEModel(nentity=len(all_ent), nrelation=len(all_rel), hidden_dim=self.args.ke_dim, |
| gamma=self.args.ke_margin, entity_embedding=all_ent_emb, relation_embedding=all_rel_emb).cuda() |
| if self.args.ke_test_num > 0: |
| test_triples = id_format_triple[:self.args.ke_test_num] |
| else: |
| test_triples = id_format_triple |
| with torch.no_grad(): |
| metrics = kge_model_for_test.test_step(test_triples=test_triples, all_true_triples=id_format_triple, args=self.args, nentity=len(all_ent), nrelation=len(all_rel)) |
| |
| print(f"result:{metrics}") |
|
|
| def OD_test(self): |
| |
| |
| |
| self.od_model.eval() |
| test_log = Loss_log() |
| test_log.acc_init() |
| sz_test = len(self.train_dataloader) |
| all_emb_ent = [] |
| with tqdm(total=sz_test) as _tqdm: |
| for step, batch in enumerate(self.train_dataloader): |
| with torch.no_grad(): |
| emb = self.model.cls_embedding(batch[0], tp=self.args.plm_emb_type) |
| out_emb = self.od_model.encode(emb) |
| emb_cpu = out_emb.cpu() |
| all_emb_ent.append(emb_cpu) |
| order_score = self.od_model.predict(emb) |
| token_right = self.od_model.right_caculate(order_score, batch[1], threshold=self.args.order_threshold) |
| test_log.update_token(batch[1].shape[0], [token_right]) |
| _tqdm.update(1) |
| _tqdm.set_description(f'Test | step [{step}/{sz_test}] Acc: {test_log.get_token_acc()}') |
|
|
| emb_output = torch.cat(all_emb_ent) |
| data_path = self.args.data_path |
| save_path = osp.join(data_path, 'downstream_task', 'output') |
| os.makedirs(save_path, exist_ok=True) |
| save_dir = osp.join(save_path, f'ratio{self.args.train_ratio}_{emb_output.shape[0]}emb_{self.args.model_name.replace("DistributedDataParallel", "")}.pt') |
| torch.save(emb_output, save_dir) |
| print(f"save {emb_output.shape[0]} embeddings done...") |
|
|
| @ torch.no_grad() |
| def test(self, path_gen=None): |
| test_log = Loss_log() |
| self.model.eval() |
| if not (self.args.mask_test or self.args.embed_gen or self.args.ke_test or len(self.args.order_test_name) > 0): |
| return |
| if self.args.mask_test: |
| self.mask_test(test_log) |
| if self.args.embed_gen: |
| self.emb_generate(path_gen) |
| if self.args.ke_test: |
| self.KGE_test() |
| if len(self.args.order_test_name) > 0: |
| runner.OD_test() |
|
|
| def _load_model(self, model, name): |
| if model is None: |
| return None |
| |
| _name = name if name[:3] not in ['od_', 'ke_'] else name[3:] |
| save_path = osp.join(self.args.data_path, 'save', _name) |
| save_name = osp.join(save_path, f'{name}.pkl') |
| if not osp.exists(save_path) or not osp.exists(save_name): |
| return model.cuda() |
| |
| if 'Distribute' in self.args.model_name: |
| model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(os.path.join(save_name), map_location=self.args.device).items()}) |
| else: |
| model.load_state_dict(torch.load(save_name, map_location=self.args.device)) |
| model.cuda() |
| if self.rank == 0: |
| print(f"loading model [{name}.pkl] done!") |
|
|
| return model |
|
|
| def _save_model(self, stage=''): |
| model_name = type(self.model).__name__ |
| |
| save_path = osp.join(self.args.data_path, 'save') |
| os.makedirs(save_path, exist_ok=True) |
| if self.args.train_strategy == 1: |
| save_name = f'{self.args.exp_name}_{self.args.exp_id}_s{self.args.random_seed}{stage}' |
| else: |
| save_name = f'{self.args.exp_name}_{self.args.exp_id}_s{self.args.random_seed}_{self.args.plm_emb_type}{stage}' |
| save_path = osp.join(save_path, save_name) |
| os.makedirs(save_path, exist_ok=True) |
| |
| self._save(self.model, save_path, save_name) |
|
|
| |
| save_name_od = f'od_{save_name}' |
| self._save(self.od_model, save_path, save_name_od) |
| save_name_ke = f'ke_{save_name}' |
| self._save(self.ke_model, save_path, save_name_ke) |
| return save_path |
|
|
| def _save(self, model, save_path, save_name): |
| if model is None: |
| return |
| if self.args.save_model: |
| torch.save(model.state_dict(), osp.join(save_path, f'{save_name}.pkl')) |
| print(f"saving {save_name} done!") |
|
|
| if self.args.save_pretrain and not save_name.startswith('od_') and not save_name.startswith('ke_'): |
| self.tokenizer.save_pretrained(osp.join(self.args.plm_path, f'{save_name}')) |
| self.model.encoder.save_pretrained(osp.join(self.args.plm_path, f'{save_name}')) |
| print(f"saving [pretrained] {save_name} done!") |
|
|
|
|
| if __name__ == '__main__': |
| cfg = cfg() |
| cfg.get_args() |
| cfgs = cfg.update_train_configs() |
| set_seed(cfgs.random_seed) |
| |
| |
| if cfgs.dist and not cfgs.only_test: |
| init_distributed_mode(args=cfgs) |
| |
| |
| else: |
| |
| torch.multiprocessing.set_sharing_strategy('file_system') |
| rank = cfgs.rank |
|
|
| writer, logger = None, None |
| if rank == 0: |
| |
| logger = initialize_exp(cfgs) |
| logger_path = get_dump_path(cfgs) |
| cfgs.time_stamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) |
| comment = f'bath_size={cfgs.batch_size} exp_id={cfgs.exp_id}' |
| if not cfgs.no_tensorboard and not cfgs.only_test: |
| writer = SummaryWriter(log_dir=os.path.join(logger_path, 'tensorboard', cfgs.time_stamp), comment=comment) |
|
|
| cfgs.device = torch.device(cfgs.device) |
|
|
| |
| runner = Runner(cfgs, writer, logger, rank) |
|
|
| if cfgs.only_test: |
| if cfgs.embed_gen: |
| |
| if cfgs.mask_test or cfgs.ke_test: |
| runner.args.embed_gen = 0 |
| runner.test() |
| runner.args.embed_gen = 1 |
| |
| gen_dir = ['yht_serialize_withAttribute', 'yht_serialize_withoutAttr', 'yht_name_serialize', 'zyc_serialize_withAttribute', 'zyc_serialize_withoutAttr', 'zyc_name_serialize', |
| 'yz_serialize_withAttribute', 'yz_serialize_withoutAttr', 'yz_name_serialize', 'yz_serialize_net'] |
| |
| |
| runner.args.mask_test, runner.args.ke_test = 0, 0 |
| for item in gen_dir: |
| runner.test(item) |
| else: |
| runner.test() |
| else: |
| runner.run() |
|
|
| |
| if not cfgs.no_tensorboard and not cfgs.only_test and rank == 0: |
| writer.close() |
| logger.info("done!") |
|
|
| if cfgs.dist and not cfgs.only_test: |
| dist.barrier() |
| dist.destroy_process_group() |
| |
|
|