deepspeed / scripts /tools /test_dataset_loading.py
xingzhikb's picture
init
002bd9b
import sys
sys.path.append(".")
import logging
import os
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from src.arguments import global_setup
from transformers.trainer_utils import get_last_checkpoint
from transformers import set_seed
import tqdm
from src.train import (
prepare_datasets,
prepare_data_transform,
SCASeq2SeqTrainer,
prepare_processor,
prepare_collate_fn,
)
from src.arguments import global_setup
import dotenv
logger = logging.getLogger(__name__)
@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf")
def main(args: DictConfig) -> None:
# NOTE(xiaoke): follow https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification.py
logger.info(OmegaConf.to_yaml(args))
args, training_args, model_args = global_setup(args)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
logger.warning(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"There is no checkpoint in the directory. Or we can resume from `resume_from_checkpoint`."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(args.training.seed)
# Initialize our dataset and prepare it
train_dataset, eval_dataset = prepare_datasets(args)
# NOTE(xiaoke): load sas_key from .env for huggingface model downloading.
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.")
use_auth_token = os.getenv("USE_AUTH_TOKEN", False)
processor = prepare_processor(model_args, use_auth_token)
train_dataset, eval_dataset = prepare_data_transform(
training_args, model_args, train_dataset, eval_dataset, processor
)
collate_fn = prepare_collate_fn(training_args, model_args, processor)
# Load the accuracy metric from the datasets package
# metric = evaluate.load("accuracy")
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
# def compute_metrics(p):
# """Computes accuracy on a batch of predictions"""
# return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
compute_metrics = training_args.compute_metrics
if compute_metrics is not True:
# NOTE: compute_metrics = None triggers the default `prediction_loss_only=True`
# NOTE: compute_metrics should be a function, but we define the function in the trainer, so we use bool here to indicate the usage.
compute_metrics = None
# config = AutoConfig.from_pretrained(
# model_args.config_name or model_args.model_name_or_path,
# num_labels=len(labels),
# label2id=label2id,
# id2label=id2label,
# finetuning_task="image-classification",
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )
# model = AutoModelForImageClassification.from_pretrained(
# model_args.model_name_or_path,
# from_tf=bool(".ckpt" in model_args.model_name_or_path),
# config=config,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
# )
# image_processor = AutoImageProcessor.from_pretrained(
# model_args.image_processor_name or model_args.model_name_or_path,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )
# model = prepare_model(model_args, use_auth_token)
# prepare_model_trainable_parameters(model, args)
# # Initalize our trainer
# custom_callbacks = [LoggerCallback(), EvalLossCallback()]
# if args.wandb.log is True:
# custom_callbacks.append(CustomWandbCallBack(args))
# if training_args.evaluate_before_train:
# custom_callbacks.append(EvaluateFirstStepCallback())
model = torch.nn.Linear(10, 2)
trainer = SCASeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval or training_args.do_train else None,
compute_metrics=compute_metrics,
data_collator=collate_fn,
tokenizer=processor.tokenizer,
callbacks=None,
)
# Training
if training_args.do_train:
train_dataloader = trainer.get_train_dataloader()
run_and_print_dataloader(train_dataloader)
# Evaluation or Inference
if training_args.do_eval or training_args.do_inference:
for eval_dataset_k, eval_dataset_v in eval_dataset.items():
eval_dataloader = trainer.get_eval_dataloader(eval_dataset_v)
run_and_print_dataloader(eval_dataloader)
def run_and_print_dataloader(dataloader):
pbar = tqdm.tqdm(dataloader)
for batch in pbar:
batch_str = ""
for k, v in batch.items():
if v is None:
batch_str += f"{k}: None\n"
elif isinstance(v, torch.Tensor):
batch_str += f"{k}: {v.shape}\n"
elif isinstance(v, list):
batch_str += f"{k}: {len(v)}\n"
else:
batch_str += f"{k}: {v}\n"
pbar.write(batch_str)
if __name__ == "__main__":
main()