|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Self-training for sequence classification.""" |
|
|
|
|
|
import argparse |
|
|
import dataclasses |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import shutil |
|
|
from typing import List, Optional |
|
|
|
|
|
import datasets |
|
|
from accelerate import Accelerator |
|
|
from datasets import load_dataset |
|
|
from finetuning import finetune |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
import transformers |
|
|
from transformers import AutoConfig, set_seed |
|
|
from transformers.trainer_utils import IntervalStrategy |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_BIN_FILE = "pytorch_model.bin" |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class STModelArguments: |
|
|
"""Arguments pertaining to which config/tokenizer/model we are going to fine-tune from.""" |
|
|
|
|
|
model_name_or_path: str = dataclasses.field( |
|
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} |
|
|
) |
|
|
cache_dir: Optional[str] = dataclasses.field( |
|
|
default=None, |
|
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co."}, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class STDataArguments: |
|
|
"""Arguments pertaining to what data we are going to input our model for training and evaluation.""" |
|
|
|
|
|
train_file: str = dataclasses.field(metadata={"help": "A csv or a json file containing the training data."}) |
|
|
infer_file: str = dataclasses.field(metadata={"help": "A csv or a json file containing the data to predict on."}) |
|
|
eval_file: Optional[str] = dataclasses.field( |
|
|
default=None, metadata={"help": "A csv or a json file containing the validation data."} |
|
|
) |
|
|
task_name: Optional[str] = dataclasses.field( |
|
|
default=None, |
|
|
metadata={"help": "The name of the task to train on."}, |
|
|
) |
|
|
label_list: Optional[List[str]] = dataclasses.field( |
|
|
default=None, metadata={"help": "The list of labels for the task."} |
|
|
) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class STTrainingArguments: |
|
|
"""Training arguments pertaining to the training loop itself.""" |
|
|
|
|
|
output_dir: str = dataclasses.field( |
|
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."} |
|
|
) |
|
|
eval_metric: Optional[str] = dataclasses.field( |
|
|
default="accuracy", metadata={"help": "The evaluation metric used for the task."} |
|
|
) |
|
|
evaluation_strategy: Optional[str] = dataclasses.field( |
|
|
default="no", |
|
|
metadata={ |
|
|
"help": 'The evaluation strategy to adopt during training. Possible values are: ["no", "step", "epoch]' |
|
|
}, |
|
|
) |
|
|
early_stopping_patience: Optional[int] = dataclasses.field( |
|
|
default=10, |
|
|
metadata={"help": "Number of evaluation calls with no improvement after which training will be stopped."}, |
|
|
) |
|
|
early_stopping_threshold: Optional[float] = dataclasses.field( |
|
|
default=0.0, |
|
|
metadata={ |
|
|
"help": "How much the specified evaluation metric must improve to satisfy early stopping conditions." |
|
|
}, |
|
|
) |
|
|
do_filter_by_confidence: Optional[bool] = dataclasses.field( |
|
|
default=False, |
|
|
metadata={"help": "Whether to filter the pseudo-labeled data based on the confidence score."}, |
|
|
) |
|
|
do_filter_by_val_performance: Optional[bool] = dataclasses.field( |
|
|
default=False, |
|
|
metadata={"help": "Whether to filter the pseudo-labeled data based on the validation performance."}, |
|
|
) |
|
|
finetune_on_labeled_data: Optional[bool] = dataclasses.field( |
|
|
default=False, |
|
|
metadata={"help": "Whether to fine-tune on labeled data after pseudo training."}, |
|
|
) |
|
|
confidence_threshold: Optional[float] = dataclasses.field( |
|
|
default=0.0, |
|
|
metadata={"help": "Confidence threshold for pseudo-labeled data filtering."}, |
|
|
) |
|
|
max_selftrain_iterations: Optional[int] = dataclasses.field( |
|
|
default=100, |
|
|
metadata={"help": "Number of evaluation calls with no improvement after which training will be stopped."}, |
|
|
) |
|
|
seed: Optional[int] = dataclasses.field( |
|
|
default=None, |
|
|
metadata={"help": "Random seed for initialization."}, |
|
|
) |
|
|
|
|
|
|
|
|
def create_pseudo_labeled_data(args, infer_input, infer_output, eval_result, id2label, next_data_dir): |
|
|
"""Create pseudeo labeled data for the next self-training iteration.""" |
|
|
|
|
|
dataset = datasets.concatenate_datasets([infer_input, infer_output], axis=1) |
|
|
|
|
|
if args.do_filter_by_confidence: |
|
|
dataset = dataset.filter(lambda example: example["probability"] > args.confidence_threshold) |
|
|
|
|
|
if args.do_filter_by_val_performance: |
|
|
assert eval_result >= 0.0 and eval_result <= 1.0 |
|
|
num_selected_rows = int(eval_result * len(dataset)) |
|
|
print(num_selected_rows) |
|
|
dataset = dataset.sort("probability", reverse=True) |
|
|
dataset = dataset.select(range(num_selected_rows)) |
|
|
|
|
|
dataset = dataset.remove_columns(["label", "probability"]) |
|
|
dataset = dataset.rename_column("prediction", "label") |
|
|
dataset = dataset.map(lambda example: {"label": id2label[example["label"]]}) |
|
|
dataset = dataset.shuffle(seed=args.seed) |
|
|
|
|
|
pseudo_labeled_data_file = os.path.join(next_data_dir, f"train_pseudo.{args.data_file_extension}") |
|
|
if args.data_file_extension == "csv": |
|
|
dataset.to_csv(pseudo_labeled_data_file, index=False) |
|
|
else: |
|
|
dataset.to_json(pseudo_labeled_data_file) |
|
|
|
|
|
|
|
|
def selftrain(model_name_or_path, train_file, infer_file, output_dir, **kwargs): |
|
|
"""Self-training a pre-trained model on a downstream task. |
|
|
|
|
|
Args: |
|
|
model_name_or_path: Path to pretrained model or model identifier from |
|
|
huggingface.co/models. |
|
|
train_file: A csv or a json file containing the training data. |
|
|
infer_file: A csv or a json file containing the data to predict on. |
|
|
output_dir: The output directory where the model predictions and checkpoints |
|
|
will be written. |
|
|
**kwargs: Dictionary of key/value pairs with which to update the |
|
|
configuration object after loading. The values in kwargs of any keys which |
|
|
are configuration attributes will be used to override the loaded values. |
|
|
""" |
|
|
|
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger.info(accelerator.state) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) |
|
|
|
|
|
if accelerator.is_local_main_process: |
|
|
datasets.utils.logging.set_verbosity_warning() |
|
|
transformers.utils.logging.set_verbosity_info() |
|
|
else: |
|
|
datasets.utils.logging.set_verbosity_error() |
|
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
|
|
model_args = STModelArguments(model_name_or_path=model_name_or_path) |
|
|
data_args = STDataArguments(train_file=train_file, infer_file=infer_file) |
|
|
training_args = STTrainingArguments(output_dir=output_dir) |
|
|
args = argparse.Namespace() |
|
|
|
|
|
for arg_class in (model_args, data_args, training_args): |
|
|
for key, value in vars(arg_class).items(): |
|
|
setattr(args, key, value) |
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if hasattr(args, key): |
|
|
setattr(args, key, value) |
|
|
|
|
|
|
|
|
data_files = {} |
|
|
args.data_file_extension = None |
|
|
|
|
|
|
|
|
assert args.train_file is not None |
|
|
assert args.infer_file is not None |
|
|
data_files["train"] = args.train_file |
|
|
data_files["infer"] = args.infer_file |
|
|
|
|
|
if args.evaluation_strategy != IntervalStrategy.NO.value: |
|
|
assert args.eval_file is not None |
|
|
data_files["eval"] = args.eval_file |
|
|
|
|
|
for key in data_files: |
|
|
extension = data_files[key].split(".")[-1] |
|
|
assert extension in ["csv", "json"], f"`{key}_file` should be a csv or a json file." |
|
|
if args.data_file_extension is None: |
|
|
args.data_file_extension = extension |
|
|
else: |
|
|
assert extension == args.data_file_extension, f"`{key}_file` should be a {args.data_file_extension} file`." |
|
|
|
|
|
assert ( |
|
|
args.eval_metric in datasets.list_metrics() |
|
|
), f"{args.eval_metric} not in the list of supported metrics {datasets.list_metrics()}." |
|
|
|
|
|
|
|
|
if args.seed is not None: |
|
|
set_seed(args.seed) |
|
|
|
|
|
logger.info("Creating the initial data directory for self-training...") |
|
|
data_dir_format = f"{args.output_dir}/self-train_iter-{{}}".format |
|
|
initial_data_dir = data_dir_format(0) |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
if args.output_dir is not None: |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
os.makedirs(initial_data_dir, exist_ok=True) |
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
best_iteration = None |
|
|
best_eval_result = None |
|
|
early_stopping_patience_counter = 0 |
|
|
should_training_stop = False |
|
|
|
|
|
progress_bar = tqdm(range(args.max_selftrain_iterations), disable=not accelerator.is_local_main_process) |
|
|
|
|
|
|
|
|
for iteration in range(0, int(args.max_selftrain_iterations)): |
|
|
current_data_dir = data_dir_format(iteration) |
|
|
assert os.path.exists(current_data_dir) |
|
|
|
|
|
|
|
|
|
|
|
current_output_dir = os.path.join(current_data_dir, "stage-1") |
|
|
arguments_dict = { |
|
|
"accelerator": accelerator, |
|
|
"model_name_or_path": args.model_name_or_path, |
|
|
"cache_dir": args.cache_dir, |
|
|
"do_train": True, |
|
|
"train_file": data_files["train"] if iteration == 0 else data_files["train_pseudo"], |
|
|
"do_eval": True if args.eval_file is not None else False, |
|
|
"eval_file": data_files["eval"], |
|
|
"do_predict": True, |
|
|
"infer_file": data_files["infer"], |
|
|
"task_name": args.task_name, |
|
|
"label_list": args.label_list, |
|
|
"output_dir": current_output_dir, |
|
|
"eval_metric": args.eval_metric, |
|
|
"evaluation_strategy": args.evaluation_strategy, |
|
|
"early_stopping_patience": args.early_stopping_patience, |
|
|
"early_stopping_threshold": args.early_stopping_threshold, |
|
|
"seed": args.seed, |
|
|
} |
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if key not in arguments_dict and not hasattr(training_args, key): |
|
|
arguments_dict.update({key: value}) |
|
|
|
|
|
model_bin_file_path = os.path.join(current_output_dir, "best-checkpoint", MODEL_BIN_FILE) |
|
|
if os.path.exists(model_bin_file_path): |
|
|
logger.info( |
|
|
"Found existing model checkpoint at %s. Skipping self-training: iteration: %d, stage: 1.", |
|
|
model_bin_file_path, |
|
|
iteration, |
|
|
) |
|
|
else: |
|
|
logger.info("***** Running self-training: iteration: %d, stage: 1 *****", iteration) |
|
|
finetune(**arguments_dict) |
|
|
accelerator.wait_for_everyone() |
|
|
assert os.path.exists(model_bin_file_path) |
|
|
logger.info("Self-training job completed: iteration: %d, stage: 1.", iteration) |
|
|
|
|
|
if iteration > 0 and args.finetune_on_labeled_data: |
|
|
|
|
|
model_path = os.path.join(current_output_dir, "best-checkpoint") |
|
|
current_output_dir = os.path.join(current_data_dir, "stage-2") |
|
|
|
|
|
arguments_dict["model_name_or_path"] = model_path |
|
|
arguments_dict["train_file"] = data_files["train"] |
|
|
arguments_dict["output_dir"] = current_output_dir |
|
|
|
|
|
model_bin_file_path = os.path.join(current_output_dir, "best-checkpoint", MODEL_BIN_FILE) |
|
|
if os.path.exists(model_bin_file_path): |
|
|
logger.info( |
|
|
"Found existing model checkpoint at %s. Skipping self-training: iteration: %d, stage: 2.", |
|
|
model_bin_file_path, |
|
|
iteration, |
|
|
) |
|
|
else: |
|
|
logger.info("***** Running self-training: iteration: %d, stage: 2 *****", iteration) |
|
|
finetune(**arguments_dict) |
|
|
accelerator.wait_for_everyone() |
|
|
assert os.path.exists(model_bin_file_path) |
|
|
logger.info("Self-training job completed: iteration: %d, stage: 2.", iteration) |
|
|
|
|
|
new_iteration = iteration |
|
|
next_data_dir = data_dir_format(iteration + 1) |
|
|
|
|
|
config = AutoConfig.from_pretrained(os.path.join(current_output_dir, "best-checkpoint")) |
|
|
id2label = config.id2label |
|
|
eval_results_file = os.path.join(current_output_dir, "eval_results_best-checkpoint.json") |
|
|
test_results_file = os.path.join(current_output_dir, "test_results_best-checkpoint.json") |
|
|
assert os.path.exists(eval_results_file) |
|
|
|
|
|
with open(eval_results_file, "r") as f: |
|
|
eval_result = float(json.load(f)[args.eval_metric]) |
|
|
infer_output_file = os.path.join(current_output_dir, "infer_output_best-checkpoint.csv") |
|
|
assert os.path.exists(infer_output_file) |
|
|
|
|
|
infer_input = load_dataset(args.data_file_extension, data_files={"data": data_files["infer"]})["data"] |
|
|
infer_output = load_dataset("csv", data_files={"data": infer_output_file})["data"] |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
os.makedirs(next_data_dir, exist_ok=True) |
|
|
shutil.copy(eval_results_file, os.path.join(output_dir, f"eval_results_iter-{iteration}.json")) |
|
|
if os.path.exists(test_results_file): |
|
|
shutil.copy(eval_results_file, os.path.join(output_dir, f"test_results_iter-{iteration}.json")) |
|
|
create_pseudo_labeled_data(args, infer_input, infer_output, eval_result, id2label, next_data_dir) |
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
data_files["train_pseudo"] = os.path.join(next_data_dir, f"train_pseudo.{args.data_file_extension}") |
|
|
|
|
|
if args.evaluation_strategy != IntervalStrategy.NO.value: |
|
|
new_eval_result = eval_result |
|
|
|
|
|
if best_iteration is None: |
|
|
best_iteration = new_iteration |
|
|
best_eval_result = new_eval_result |
|
|
else: |
|
|
if new_eval_result - best_eval_result > args.early_stopping_threshold: |
|
|
best_iteration = new_iteration |
|
|
best_eval_result = new_eval_result |
|
|
early_stopping_patience_counter = 0 |
|
|
else: |
|
|
if new_eval_result == best_eval_result: |
|
|
best_iteration = new_iteration |
|
|
best_eval_result = new_eval_result |
|
|
early_stopping_patience_counter += 1 |
|
|
|
|
|
if early_stopping_patience_counter >= args.early_stopping_patience: |
|
|
should_training_stop = True |
|
|
|
|
|
progress_bar.update(1) |
|
|
|
|
|
if should_training_stop: |
|
|
break |
|
|
|
|
|
if best_iteration is not None: |
|
|
|
|
|
logger.info("Best iteration: %d", best_iteration) |
|
|
logger.info("Best evaluation result: %s = %f", args.eval_metric, best_eval_result) |
|
|
accelerator.wait_for_everyone() |
|
|
if accelerator.is_main_process: |
|
|
shutil.copy( |
|
|
os.path.join(output_dir, f"eval_results_iter-{iteration}.json"), |
|
|
os.path.join(output_dir, "eval_results_best-iteration.json"), |
|
|
) |
|
|
else: |
|
|
|
|
|
logger.info("Best iteration: %d", args.max_selftrain_iterations - 1) |
|
|
logger.info("Best evaluation result: %s = %f", args.eval_metric, eval_result) |
|
|
accelerator.wait_for_everyone() |
|
|
if accelerator.is_main_process: |
|
|
shutil.copy( |
|
|
os.path.join(output_dir, f"eval_results_iter-{args.max_selftrain_iterations - 1}.json"), |
|
|
os.path.join(output_dir, "eval_results_best-iteration.json"), |
|
|
) |
|
|
|