Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import os.path as osp | |
| import sys | |
| import numpy as np | |
| from typing import Dict | |
| import datasets | |
| import transformers | |
| from transformers import set_seed, Trainer | |
| from transformers.trainer_utils import get_last_checkpoint | |
| from arguments import get_args | |
| from tasks.utils import * | |
| os.environ["WANDB_DISABLED"] = "true" | |
| logger = logging.getLogger(__name__) | |
| def train(trainer, resume_from_checkpoint=None, last_checkpoint=None): | |
| checkpoint = None | |
| if resume_from_checkpoint is not None: | |
| checkpoint = resume_from_checkpoint | |
| elif last_checkpoint is not None: | |
| checkpoint = last_checkpoint | |
| train_result = trainer.train(resume_from_checkpoint=checkpoint) | |
| # trainer.save_model() | |
| metrics = train_result.metrics | |
| trainer.log_metrics("train", metrics) | |
| trainer.save_metrics("train", metrics) | |
| trainer.save_state() | |
| trainer.log_best_metrics() | |
| def evaluate(args, trainer, checkpoint=None): | |
| logger.info("*** Evaluate ***") | |
| if checkpoint is not None: | |
| trainer._load_from_checkpoint(resume_from_checkpoint=checkpoint) | |
| trainer._resume_watermark() | |
| metrics = trainer.evaluate(ignore_keys=["hidden_states", "attentions"]) | |
| score, asr = 0., 0. | |
| if training_args.watermark != "clean": | |
| score, asr = trainer.evaluate_watermark() | |
| metrics["wmk_asr"] = asr | |
| metrics["wmk_score"] = score | |
| trainer.evaluate_clean() | |
| torch.save(trainer.eval_memory, f"{args.output_dir}/exp11_attentions.pth") | |
| trainer.log_metrics("eval", metrics) | |
| path = osp.join(args.output_dir, "exp11_acc_asr.pth") | |
| torch.save(metrics, path) | |
| def predict(trainer, predict_dataset=None): | |
| if predict_dataset is None: | |
| logger.info("No dataset is available for testing") | |
| elif isinstance(predict_dataset, dict): | |
| for dataset_name, d in predict_dataset.items(): | |
| logger.info("*** Predict: %s ***" % dataset_name) | |
| predictions, labels, metrics = trainer.predict(d, metric_key_prefix="predict") | |
| predictions = np.argmax(predictions, axis=2) | |
| trainer.log_metrics("predict", metrics) | |
| trainer.save_metrics("predict", metrics) | |
| else: | |
| logger.info("*** Predict ***") | |
| predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") | |
| predictions = np.argmax(predictions, axis=2) | |
| trainer.log_metrics("predict", metrics) | |
| trainer.save_metrics("predict", metrics) | |
| if __name__ == '__main__': | |
| args = get_args() | |
| p_type = "prefix" if args[0].prefix else "prompt" | |
| output_root = osp.join("checkpoints", f"{args[1].task_name}_{args[1].dataset_name}_{args[0].model_name_or_path}_{args[2].watermark}_{p_type}") | |
| output_dir = osp.join(output_root, f"t{args[2].trigger_num}_p{args[2].poison_rate:0.2f}") | |
| for path in [output_root, output_dir]: | |
| if not osp.exists(path): | |
| try: | |
| os.makedirs(path) | |
| except: | |
| pass | |
| args[0].output_dir = output_dir | |
| args[1].output_dir = output_dir | |
| args[2].output_dir = output_dir | |
| args[3].output_dir = output_dir | |
| torch.save(args, osp.join(output_dir, "args.pt")) | |
| model_args, data_args, training_args, _ = args | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| log_level = training_args.get_process_log_level() | |
| logger.setLevel(log_level) | |
| datasets.utils.logging.set_verbosity(log_level) | |
| transformers.utils.logging.set_verbosity(log_level) | |
| transformers.utils.logging.enable_default_handler() | |
| transformers.utils.logging.enable_explicit_format() | |
| # Log on each process the small summary: | |
| logger.warning( | |
| f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" | |
| + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | |
| ) | |
| if not os.path.isdir("checkpoints") or not os.path.exists("checkpoints"): | |
| os.mkdir("checkpoints") | |
| if data_args.task_name.lower() == "superglue": | |
| assert data_args.dataset_name.lower() in SUPERGLUE_DATASETS | |
| from tasks.superglue.get_trainer import get_trainer | |
| elif data_args.task_name.lower() == "glue": | |
| assert data_args.dataset_name.lower() in GLUE_DATASETS | |
| from tasks.glue.get_trainer import get_trainer | |
| elif data_args.task_name.lower() == "ner": | |
| assert data_args.dataset_name.lower() in NER_DATASETS | |
| from tasks.ner.get_trainer import get_trainer | |
| elif data_args.task_name.lower() == "srl": | |
| assert data_args.dataset_name.lower() in SRL_DATASETS | |
| from tasks.srl.get_trainer import get_trainer | |
| elif data_args.task_name.lower() == "qa": | |
| assert data_args.dataset_name.lower() in QA_DATASETS | |
| from tasks.qa.get_trainer import get_trainer | |
| elif data_args.task_name.lower() == "ag_news": | |
| from tasks.ag_news.get_trainer import get_trainer | |
| elif data_args.task_name.lower() == "imdb": | |
| from tasks.imdb.get_trainer import get_trainer | |
| else: | |
| raise NotImplementedError('Task {} is not implemented. Please choose a task from: {}'.format(data_args.task_name, ", ".join(TASKS))) | |
| set_seed(training_args.seed) | |
| trainer, predict_dataset = get_trainer(args) | |
| last_checkpoint = None | |
| if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: | |
| last_checkpoint = get_last_checkpoint(training_args.output_dir) | |
| if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: | |
| raise ValueError( | |
| f"Output directory ({training_args.output_dir}) already exists and is not empty. " | |
| "Use --overwrite_output_dir to overcome." | |
| ) | |
| elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: | |
| logger.info( | |
| f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " | |
| "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." | |
| ) | |
| if training_args.do_train: | |
| train(trainer, training_args.resume_from_checkpoint, last_checkpoint) | |
| if training_args.do_eval: | |
| if last_checkpoint is None: | |
| last_checkpoint = osp.join(training_args.output_dir, "checkpoint") | |
| print(f"-> last_checkpoint:{last_checkpoint}") | |
| evaluate(training_args, trainer, checkpoint=last_checkpoint) | |
| # if training_args.do_predict: | |
| # predict(trainer, predict_dataset) | |