| from fengshen.data.task_dataloader.task_datasets import LCSTSDataModel |
| from transformers import T5Tokenizer, MT5ForConditionalGeneration |
| from transformers.optimization import get_linear_schedule_with_warmup |
| from pytorch_lightning import Trainer, loggers |
| from pytorch_lightning.callbacks import ModelCheckpoint |
| from transformers import AutoTokenizer |
| import pytorch_lightning as pl |
| import json |
| import argparse |
| import torch |
| import os |
| import sys |
| sys.path.append('./') |
|
|
| |
|
|
|
|
| def test(): |
| tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") |
| article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." |
| summary = "Weiter Verhandlung in Syrien." |
| article = "日前,方舟子发文直指林志颖旗下爱碧丽推销假保健品,引起哗然。调查发现,爱碧丽没有自己的生产加工厂。 \ |
| 其胶原蛋白饮品无核心研发,全部代工生产。号称有“逆生长”功效的爱碧丽“梦幻奇迹限量组”售价>高达1080元,实际成本仅为每瓶4元!" |
| summary = "林志颖公司疑涉虚假营销无厂房无研发" |
| inputs = tokenizer(article, rturn_tensors="pt") |
| tt = tokenizer.encode_plus(summary, max_length=64, |
| padding='max_length', truncation='longest_first') |
| print('tt:', tt) |
| print('inputs:', inputs) |
| with tokenizer.as_target_tokenizer(): |
| labels = tokenizer(summary, return_tensors="pt") |
| print('labels:', labels) |
| print('origin labels:', tokenizer.decode(labels['input_ids'][0])) |
|
|
| model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small") |
| |
| |
|
|
| |
| model.eval() |
| generated_ids = model.generate( |
| input_ids=inputs['input_ids'], |
| attention_mask=inputs['attention_mask'], |
| max_length=150, |
| num_beams=2, |
| repetition_penalty=2.5, |
| length_penalty=1.0, |
| early_stopping=True |
| ) |
| preds = [tokenizer.decode(g, skip_special_tokens=True, |
| clean_up_tokenization_spaces=True) for g in generated_ids] |
| print(preds) |
|
|
|
|
| class MT5FinetuneSummaryModelCheckpoint: |
| @staticmethod |
| def add_argparse_args(parent_args): |
| parser = parent_args.add_argument_group('BaseModel') |
|
|
| parser.add_argument('--monitor', default='train_loss', type=str) |
| parser.add_argument('--mode', default='min', type=str) |
| parser.add_argument('--dirpath', default='./ckpt/', type=str) |
| parser.add_argument( |
| '--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str) |
| parser.add_argument('--save_last', action='store_true', default=True) |
| parser.add_argument('--save_top_k', default=3, type=float) |
| parser.add_argument('--every_n_train_steps', default=100, type=float) |
| parser.add_argument('--save_weights_only', default=True, type=bool) |
|
|
| return parent_args |
|
|
| def __init__(self, args): |
| self.callbacks = ModelCheckpoint(monitor=args.monitor, |
| save_top_k=args.save_top_k, |
| mode=args.mode, |
| every_n_train_steps=args.every_n_train_steps, |
| save_weights_only=args.save_weights_only, |
| dirpath=args.dirpath, |
| filename=args.filename, |
| save_last=args.save_last) |
|
|
|
|
| class MT5FinetuneSummary(pl.LightningModule): |
|
|
| @staticmethod |
| def add_model_specific_args(parent_args): |
| parser = parent_args.add_argument_group('BaseModel') |
| parser.add_argument('--learning_rate', default=1e-4, type=float) |
| parser.add_argument('--weight_decay', default=0.1, type=float) |
| parser.add_argument('--warmup', default=0.01, type=float) |
| return parent_args |
|
|
| def __init__(self, args, num_data): |
| super().__init__() |
| self.args = args |
| self.num_data = num_data |
| print('num_data:', num_data) |
| self.model = MT5ForConditionalGeneration.from_pretrained(args.pretrained_model_path) |
|
|
| def setup(self, stage) -> None: |
| if stage == 'fit': |
| num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0 |
| self.total_step = int(self.trainer.max_epochs * self.num_data / |
| (max(1, num_gpus) * self.trainer.accumulate_grad_batches)) |
| print('Total training step:', self.total_step) |
|
|
| def training_step(self, batch, batch_idx): |
| output = self.model(input_ids=batch['input_ids'], |
| attention_mask=batch['attention_mask'], labels=batch['labels']) |
| |
| |
| self.log('train_loss', output.loss) |
| return output.loss |
|
|
| def comput_metrix(self, logits, labels): |
| y_pred = torch.argmax(logits, dim=-1) |
| y_pred = y_pred.view(size=(-1,)) |
| y_true = labels.view(size=(-1,)).float() |
| corr = torch.eq(y_pred, y_true) |
| acc = torch.sum(corr.float())/labels.size()[0] |
| return acc |
|
|
| def validation_step(self, batch, batch_idx): |
| output = self.model(input_ids=batch['input_ids'], |
| attention_mask=batch['attention_mask'], labels=batch['labels']) |
| |
| |
| self.log('val_loss', output.loss) |
| |
|
|
| def predict_step(self, batch, batch_idx): |
| text = batch['text'] |
| summary = batch['summary'] |
| generated_ids = self.model.generate( |
| input_ids=batch['input_ids'], |
| attention_mask=batch['attention_mask'], |
| max_length=self.args.max_dec_length |
| ) |
| return {"pred": generated_ids, "text": text, "summary": summary} |
|
|
| def configure_optimizers(self): |
| no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
| paras = list( |
| filter(lambda p: p[1].requires_grad, self.named_parameters())) |
| paras = [{ |
| 'params': |
| [p for n, p in paras if not any(nd in n for nd in no_decay)], |
| 'weight_decay': self.args.weight_decay |
| }, { |
| 'params': [p for n, p in paras if any(nd in n for nd in no_decay)], |
| 'weight_decay': 0.0 |
| }] |
| optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate) |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, int(self.total_step * self.args.warmup), |
| self.total_step) |
|
|
| return [{ |
| 'optimizer': optimizer, |
| 'lr_scheduler': { |
| 'scheduler': scheduler, |
| 'interval': 'step', |
| 'frequency': 1 |
| } |
| }] |
|
|
|
|
| def save_test(data, args, data_model): |
| tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path) |
| with open(os.path.join(args.output_save_path), 'w', encoding='utf-8') as f: |
| for _, batch in enumerate(data): |
| texts = batch['text'] |
| summarys = batch['summary'] |
| preds = batch['pred'] |
| for idx, pred_ids in enumerate(preds): |
| text = texts[idx] |
| summary = summarys[idx] |
| tmp_result = dict() |
| preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| for g in pred_ids] |
| tmp_result['summary'] = ''.join(preds) |
| tmp_result['label'] = summary |
| tmp_result['origin_text'] = text |
| json_data = json.dumps(tmp_result, ensure_ascii=False) |
| f.write(json_data+'\n') |
| print('save the result to '+args.output_save_path) |
|
|
|
|
| def main(): |
| total_parser = argparse.ArgumentParser("Summary Task") |
| total_parser.add_argument('--do_eval_only', action='store_true', default=False) |
| total_parser.add_argument('--pretrained_model_path', default='google/mt5-small', type=str) |
| total_parser.add_argument('--output_save_path', default='./predict.json', type=str) |
| |
| total_parser = LCSTSDataModel.add_data_specific_args(total_parser) |
| |
| total_parser = Trainer.add_argparse_args(total_parser) |
| total_parser = MT5FinetuneSummaryModelCheckpoint.add_argparse_args(total_parser) |
| total_parser = MT5FinetuneSummary.add_model_specific_args(total_parser) |
| |
| args = total_parser.parse_args() |
|
|
| data_model = LCSTSDataModel(args) |
| if not args.do_eval_only: |
| model = MT5FinetuneSummary(args, len(data_model.train_dataloader())) |
| checkpoint_callback = MT5FinetuneSummaryModelCheckpoint(args).callbacks |
| logger = loggers.TensorBoardLogger(save_dir=os.path.join( |
| args.default_root_dir, 'log/'), name='mt5_summary') |
| trainer = Trainer.from_argparse_args(args, |
| logger=logger, |
| callbacks=[checkpoint_callback] |
| ) |
| trainer.fit(model, data_model) |
| else: |
| trainer = Trainer.from_argparse_args(args) |
| model = MT5FinetuneSummary.load_from_checkpoint( |
| args.resume_from_checkpoint, args=args, num_data=len(data_model.predict_dataloader())) |
| result = trainer.predict(model, data_model) |
| if torch.distributed.get_rank() == 0: |
| save_test(result, args, data_model) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
| |
|
|
| ''' |
| python examples/mt5_summary.py --gpus=1 --test_data=test_public.jsonl |
| --default_root_dir=/cognitive_comp/ganruyi/fengshen/mt5_summary/eval |
| --do_eval_only |
| --resume_from_checkpoint=/cognitive_comp/ganruyi/fengshen/mt5_summary/ckpt/model-epoch=01-train_loss=1.9166.ckpt |
| --strategy=ddp |
| ''' |
|
|