| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Finetuning 🤗 Transformers model for object detection with Accelerate.""" |
| |
|
| | import argparse |
| | import json |
| | import logging |
| | import math |
| | import os |
| | from collections.abc import Mapping |
| | from functools import partial |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import albumentations as A |
| | import datasets |
| | import numpy as np |
| | import torch |
| | from accelerate import Accelerator |
| | from accelerate.logging import get_logger |
| | from accelerate.utils import set_seed |
| | from datasets import load_dataset |
| | from huggingface_hub import HfApi |
| | from torch.utils.data import DataLoader |
| | from torchmetrics.detection.mean_ap import MeanAveragePrecision |
| | from tqdm.auto import tqdm |
| |
|
| | import transformers |
| | from transformers import ( |
| | AutoConfig, |
| | AutoImageProcessor, |
| | AutoModelForObjectDetection, |
| | SchedulerType, |
| | get_scheduler, |
| | ) |
| | from transformers.image_processing_utils import BatchFeature |
| | from transformers.image_transforms import center_to_corners_format |
| | from transformers.utils import check_min_version |
| | from transformers.utils.versions import require_version |
| |
|
| |
|
| | |
| | check_min_version("4.57.0.dev0") |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = get_logger(__name__) |
| |
|
| | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") |
| |
|
| |
|
| | |
| | def format_image_annotations_as_coco( |
| | image_id: str, categories: list[int], areas: list[float], bboxes: list[tuple[float]] |
| | ) -> dict: |
| | """Format one set of image annotations to the COCO format |
| | |
| | Args: |
| | image_id (str): image id. e.g. "0001" |
| | categories (list[int]): list of categories/class labels corresponding to provided bounding boxes |
| | areas (list[float]): list of corresponding areas to provided bounding boxes |
| | bboxes (list[tuple[float]]): list of bounding boxes provided in COCO format |
| | ([center_x, center_y, width, height] in absolute coordinates) |
| | |
| | Returns: |
| | dict: { |
| | "image_id": image id, |
| | "annotations": list of formatted annotations |
| | } |
| | """ |
| | annotations = [] |
| | for category, area, bbox in zip(categories, areas, bboxes): |
| | formatted_annotation = { |
| | "image_id": image_id, |
| | "category_id": category, |
| | "iscrowd": 0, |
| | "area": area, |
| | "bbox": list(bbox), |
| | } |
| | annotations.append(formatted_annotation) |
| |
|
| | return { |
| | "image_id": image_id, |
| | "annotations": annotations, |
| | } |
| |
|
| |
|
| | |
| | def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: |
| | """ |
| | Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1] |
| | to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates. |
| | |
| | Args: |
| | boxes (torch.Tensor): Bounding boxes in YOLO format |
| | image_size (tuple[int, int]): Image size in format (height, width) |
| | |
| | Returns: |
| | torch.Tensor: Bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max) |
| | """ |
| | |
| | boxes = center_to_corners_format(boxes) |
| |
|
| | |
| | height, width = image_size |
| | boxes = boxes * torch.tensor([[width, height, width, height]]) |
| |
|
| | return boxes |
| |
|
| |
|
| | |
| | def augment_and_transform_batch( |
| | examples: Mapping[str, Any], |
| | transform: A.Compose, |
| | image_processor: AutoImageProcessor, |
| | return_pixel_mask: bool = False, |
| | ) -> BatchFeature: |
| | """Apply augmentations and format annotations in COCO format for object detection task""" |
| |
|
| | images = [] |
| | annotations = [] |
| | for image_id, image, objects in zip(examples["image_id"], examples["image"], examples["objects"]): |
| | image = np.array(image.convert("RGB")) |
| |
|
| | |
| | output = transform(image=image, bboxes=objects["bbox"], category=objects["category"]) |
| | images.append(output["image"]) |
| |
|
| | |
| | formatted_annotations = format_image_annotations_as_coco( |
| | image_id, output["category"], objects["area"], output["bboxes"] |
| | ) |
| | annotations.append(formatted_annotations) |
| |
|
| | |
| | result = image_processor(images=images, annotations=annotations, return_tensors="pt") |
| |
|
| | if not return_pixel_mask: |
| | result.pop("pixel_mask", None) |
| |
|
| | return result |
| |
|
| |
|
| | |
| | def collate_fn(batch: list[BatchFeature]) -> Mapping[str, torch.Tensor | list[Any]]: |
| | data = {} |
| | data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch]) |
| | data["labels"] = [x["labels"] for x in batch] |
| | if "pixel_mask" in batch[0]: |
| | data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch]) |
| | return data |
| |
|
| |
|
| | def nested_to_cpu(objects): |
| | """Move nested tesnors in objects to CPU if they are on GPU""" |
| | if isinstance(objects, torch.Tensor): |
| | return objects.cpu() |
| | elif isinstance(objects, Mapping): |
| | return type(objects)({k: nested_to_cpu(v) for k, v in objects.items()}) |
| | elif isinstance(objects, (list, tuple)): |
| | return type(objects)([nested_to_cpu(v) for v in objects]) |
| | elif isinstance(objects, (np.ndarray, str, int, float, bool)): |
| | return objects |
| | raise ValueError(f"Unsupported type {type(objects)}") |
| |
|
| |
|
| | def evaluation_loop( |
| | model: torch.nn.Module, |
| | image_processor: AutoImageProcessor, |
| | accelerator: Accelerator, |
| | dataloader: DataLoader, |
| | id2label: Mapping[int, str], |
| | ) -> dict: |
| | model.eval() |
| | metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True) |
| |
|
| | for step, batch in enumerate(tqdm(dataloader, disable=not accelerator.is_local_main_process)): |
| | with torch.no_grad(): |
| | outputs = model(**batch) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | image_size = torch.stack([example["orig_size"] for example in batch["labels"]], dim=0) |
| | predictions = image_processor.post_process_object_detection(outputs, threshold=0.0, target_sizes=image_size) |
| | predictions = nested_to_cpu(predictions) |
| |
|
| | |
| | |
| | target = [] |
| | for label in batch["labels"]: |
| | label = nested_to_cpu(label) |
| | boxes = convert_bbox_yolo_to_pascal(label["boxes"], label["orig_size"]) |
| | labels = label["class_labels"] |
| | target.append({"boxes": boxes, "labels": labels}) |
| |
|
| | metric.update(predictions, target) |
| |
|
| | metrics = metric.compute() |
| |
|
| | |
| | classes = metrics.pop("classes") |
| | map_per_class = metrics.pop("map_per_class") |
| | mar_100_per_class = metrics.pop("mar_100_per_class") |
| | for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class): |
| | class_name = id2label[class_id.item()] |
| | metrics[f"map_{class_name}"] = class_map |
| | metrics[f"mar_100_{class_name}"] = class_mar |
| |
|
| | |
| | metrics = {k: round(v.item(), 4) for k, v in metrics.items()} |
| |
|
| | return metrics |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Finetune a transformers model for object detection task") |
| | parser.add_argument( |
| | "--model_name_or_path", |
| | type=str, |
| | help="Path to a pretrained model or model identifier from huggingface.co/models.", |
| | default="facebook/detr-resnet-50", |
| | ) |
| | parser.add_argument( |
| | "--dataset_name", |
| | type=str, |
| | help="Name of the dataset on the hub.", |
| | default="cppe-5", |
| | ) |
| | parser.add_argument( |
| | "--train_val_split", |
| | type=float, |
| | default=0.15, |
| | help="Fraction of the dataset to be used for validation.", |
| | ) |
| | parser.add_argument( |
| | "--ignore_mismatched_sizes", |
| | action="store_true", |
| | help="Ignore mismatched sizes between the model and the dataset.", |
| | ) |
| | parser.add_argument( |
| | "--image_square_size", |
| | type=int, |
| | default=1333, |
| | help="Image longest size will be resized to this value, then image will be padded to square.", |
| | ) |
| | parser.add_argument( |
| | "--use_fast", |
| | type=bool, |
| | default=True, |
| | help="Use a fast torchvision-base image processor if it is supported for a given model.", |
| | ) |
| | parser.add_argument( |
| | "--cache_dir", |
| | type=str, |
| | help="Path to a folder in which the model and dataset will be cached.", |
| | ) |
| | parser.add_argument( |
| | "--per_device_train_batch_size", |
| | type=int, |
| | default=8, |
| | help="Batch size (per device) for the training dataloader.", |
| | ) |
| | parser.add_argument( |
| | "--per_device_eval_batch_size", |
| | type=int, |
| | default=8, |
| | help="Batch size (per device) for the evaluation dataloader.", |
| | ) |
| | parser.add_argument( |
| | "--dataloader_num_workers", |
| | type=int, |
| | default=4, |
| | help="Number of workers to use for the dataloaders.", |
| | ) |
| | parser.add_argument( |
| | "--learning_rate", |
| | type=float, |
| | default=5e-5, |
| | help="Initial learning rate (after the potential warmup period) to use.", |
| | ) |
| | parser.add_argument( |
| | "--adam_beta1", |
| | type=float, |
| | default=0.9, |
| | help="Beta1 for AdamW optimizer", |
| | ) |
| | parser.add_argument( |
| | "--adam_beta2", |
| | type=float, |
| | default=0.999, |
| | help="Beta2 for AdamW optimizer", |
| | ) |
| | parser.add_argument( |
| | "--adam_epsilon", |
| | type=float, |
| | default=1e-8, |
| | help="Epsilon for AdamW optimizer", |
| | ) |
| | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") |
| | parser.add_argument( |
| | "--max_train_steps", |
| | type=int, |
| | default=None, |
| | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| | ) |
| | parser.add_argument( |
| | "--gradient_accumulation_steps", |
| | type=int, |
| | default=1, |
| | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| | ) |
| | parser.add_argument( |
| | "--lr_scheduler_type", |
| | type=SchedulerType, |
| | default="linear", |
| | help="The scheduler type to use.", |
| | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], |
| | ) |
| | parser.add_argument( |
| | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." |
| | ) |
| | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") |
| | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") |
| | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") |
| | parser.add_argument( |
| | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." |
| | ) |
| | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") |
| | parser.add_argument( |
| | "--trust_remote_code", |
| | action="store_true", |
| | help=( |
| | "Whether to trust the execution of code from datasets/models defined on the Hub." |
| | " This option should only be set to `True` for repositories you trust and in which you have read the" |
| | " code, as it will execute code present on the Hub on your local machine." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--checkpointing_steps", |
| | type=str, |
| | default=None, |
| | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", |
| | ) |
| | parser.add_argument( |
| | "--resume_from_checkpoint", |
| | type=str, |
| | default=None, |
| | help="If the training should continue from a checkpoint folder.", |
| | ) |
| | parser.add_argument( |
| | "--with_tracking", |
| | required=False, |
| | action="store_true", |
| | help="Whether to enable experiment trackers for logging.", |
| | ) |
| | parser.add_argument( |
| | "--report_to", |
| | type=str, |
| | default="all", |
| | help=( |
| | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' |
| | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. ' |
| | "Only applicable when `--with_tracking` is passed." |
| | ), |
| | ) |
| | args = parser.parse_args() |
| |
|
| | |
| | if args.push_to_hub or args.with_tracking: |
| | if args.output_dir is None: |
| | raise ValueError( |
| | "Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified." |
| | ) |
| |
|
| | if args.output_dir is not None: |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | return args |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| |
|
| | |
| | |
| | |
| | accelerator_log_kwargs = {} |
| |
|
| | if args.with_tracking: |
| | accelerator_log_kwargs["log_with"] = args.report_to |
| | accelerator_log_kwargs["project_dir"] = args.output_dir |
| |
|
| | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) |
| |
|
| | logger.info(accelerator.state, main_process_only=False) |
| | if accelerator.is_local_main_process: |
| | datasets.utils.logging.set_verbosity_warning() |
| | transformers.utils.logging.set_verbosity_info() |
| | else: |
| | datasets.utils.logging.set_verbosity_error() |
| | transformers.utils.logging.set_verbosity_error() |
| |
|
| | |
| | |
| | if args.seed is not None: |
| | set_seed(args.seed, device_specific=True) |
| |
|
| | |
| | if accelerator.is_main_process: |
| | if args.push_to_hub: |
| | |
| | repo_name = args.hub_model_id |
| | if repo_name is None: |
| | repo_name = Path(args.output_dir).absolute().name |
| | |
| | api = HfApi() |
| | repo_id = api.create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id |
| |
|
| | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: |
| | if "step_*" not in gitignore: |
| | gitignore.write("step_*\n") |
| | if "epoch_*" not in gitignore: |
| | gitignore.write("epoch_*\n") |
| | elif args.output_dir is not None: |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | accelerator.wait_for_everyone() |
| |
|
| | |
| | |
| | |
| | dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir, trust_remote_code=args.trust_remote_code) |
| |
|
| | |
| | args.train_val_split = None if "validation" in dataset else args.train_val_split |
| | if isinstance(args.train_val_split, float) and args.train_val_split > 0.0: |
| | split = dataset["train"].train_test_split(args.train_val_split, seed=args.seed) |
| | dataset["train"] = split["train"] |
| | dataset["validation"] = split["test"] |
| |
|
| | |
| | if isinstance(dataset["train"].features["objects"], dict): |
| | categories = dataset["train"].features["objects"]["category"].feature.names |
| | else: |
| | categories = dataset["train"].features["objects"].feature["category"].names |
| | id2label = dict(enumerate(categories)) |
| | label2id = {v: k for k, v in id2label.items()} |
| |
|
| | |
| | |
| | |
| |
|
| | common_pretrained_args = { |
| | "cache_dir": args.cache_dir, |
| | "token": args.hub_token, |
| | "trust_remote_code": args.trust_remote_code, |
| | } |
| | config = AutoConfig.from_pretrained( |
| | args.model_name_or_path, label2id=label2id, id2label=id2label, **common_pretrained_args |
| | ) |
| | model = AutoModelForObjectDetection.from_pretrained( |
| | args.model_name_or_path, |
| | config=config, |
| | ignore_mismatched_sizes=args.ignore_mismatched_sizes, |
| | **common_pretrained_args, |
| | ) |
| | image_processor = AutoImageProcessor.from_pretrained( |
| | args.model_name_or_path, |
| | do_resize=True, |
| | size={"max_height": args.image_square_size, "max_width": args.image_square_size}, |
| | do_pad=True, |
| | pad_size={"height": args.image_square_size, "width": args.image_square_size}, |
| | use_fast=args.use_fast, |
| | **common_pretrained_args, |
| | ) |
| |
|
| | |
| | |
| | |
| | max_size = args.image_square_size |
| | train_augment_and_transform = A.Compose( |
| | [ |
| | A.Compose( |
| | [ |
| | A.SmallestMaxSize(max_size=max_size, p=1.0), |
| | A.RandomSizedBBoxSafeCrop(height=max_size, width=max_size, p=1.0), |
| | ], |
| | p=0.2, |
| | ), |
| | A.OneOf( |
| | [ |
| | A.Blur(blur_limit=7, p=0.5), |
| | A.MotionBlur(blur_limit=7, p=0.5), |
| | A.Defocus(radius=(1, 5), alias_blur=(0.1, 0.25), p=0.1), |
| | ], |
| | p=0.1, |
| | ), |
| | A.Perspective(p=0.1), |
| | A.HorizontalFlip(p=0.5), |
| | A.RandomBrightnessContrast(p=0.5), |
| | A.HueSaturationValue(p=0.1), |
| | ], |
| | bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True, min_area=25), |
| | ) |
| | validation_transform = A.Compose( |
| | [A.NoOp()], |
| | bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True), |
| | ) |
| |
|
| | |
| | train_transform_batch = partial( |
| | augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor |
| | ) |
| | validation_transform_batch = partial( |
| | augment_and_transform_batch, transform=validation_transform, image_processor=image_processor |
| | ) |
| |
|
| | with accelerator.main_process_first(): |
| | train_dataset = dataset["train"].with_transform(train_transform_batch) |
| | valid_dataset = dataset["validation"].with_transform(validation_transform_batch) |
| | test_dataset = dataset["test"].with_transform(validation_transform_batch) |
| |
|
| | dataloader_common_args = { |
| | "num_workers": args.dataloader_num_workers, |
| | "collate_fn": collate_fn, |
| | } |
| | train_dataloader = DataLoader( |
| | train_dataset, shuffle=True, batch_size=args.per_device_train_batch_size, **dataloader_common_args |
| | ) |
| | valid_dataloader = DataLoader( |
| | valid_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args |
| | ) |
| | test_dataloader = DataLoader( |
| | test_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | optimizer = torch.optim.AdamW( |
| | list(model.parameters()), |
| | lr=args.learning_rate, |
| | betas=[args.adam_beta1, args.adam_beta2], |
| | eps=args.adam_epsilon, |
| | ) |
| |
|
| | |
| | checkpointing_steps = args.checkpointing_steps |
| | if checkpointing_steps is not None and checkpointing_steps.isdigit(): |
| | checkpointing_steps = int(checkpointing_steps) |
| |
|
| | |
| | overrode_max_train_steps = False |
| | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| | if args.max_train_steps is None: |
| | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| | overrode_max_train_steps = True |
| |
|
| | lr_scheduler = get_scheduler( |
| | name=args.lr_scheduler_type, |
| | optimizer=optimizer, |
| | num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, |
| | num_training_steps=args.max_train_steps |
| | if overrode_max_train_steps |
| | else args.max_train_steps * accelerator.num_processes, |
| | ) |
| |
|
| | |
| | model, optimizer, train_dataloader, valid_dataloader, test_dataloader, lr_scheduler = accelerator.prepare( |
| | model, optimizer, train_dataloader, valid_dataloader, test_dataloader, lr_scheduler |
| | ) |
| |
|
| | |
| | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| | if overrode_max_train_steps: |
| | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| | |
| | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| |
|
| | |
| | |
| | if args.with_tracking: |
| | experiment_config = vars(args) |
| | |
| | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value |
| | accelerator.init_trackers("object_detection_no_trainer", experiment_config) |
| |
|
| | |
| | |
| | |
| |
|
| | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| |
|
| | logger.info("***** Running training *****") |
| | logger.info(f" Num examples = {len(train_dataset)}") |
| | logger.info(f" Num Epochs = {args.num_train_epochs}") |
| | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") |
| | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
| | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| | logger.info(f" Total optimization steps = {args.max_train_steps}") |
| |
|
| | |
| | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) |
| | completed_steps = 0 |
| | starting_epoch = 0 |
| |
|
| | |
| | if args.resume_from_checkpoint: |
| | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": |
| | checkpoint_path = args.resume_from_checkpoint |
| | path = os.path.basename(args.resume_from_checkpoint) |
| | else: |
| | |
| | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] |
| | dirs.sort(key=os.path.getctime) |
| | path = dirs[-1] |
| | checkpoint_path = path |
| | path = os.path.basename(checkpoint_path) |
| |
|
| | accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") |
| | accelerator.load_state(checkpoint_path) |
| | |
| | training_difference = os.path.splitext(path)[0] |
| |
|
| | if "epoch" in training_difference: |
| | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 |
| | resume_step = None |
| | completed_steps = starting_epoch * num_update_steps_per_epoch |
| | else: |
| | |
| | resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps |
| | starting_epoch = resume_step // len(train_dataloader) |
| | completed_steps = resume_step // args.gradient_accumulation_steps |
| | resume_step -= starting_epoch * len(train_dataloader) |
| |
|
| | |
| | progress_bar.update(completed_steps) |
| |
|
| | for epoch in range(starting_epoch, args.num_train_epochs): |
| | model.train() |
| | if args.with_tracking: |
| | total_loss = 0 |
| | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: |
| | |
| | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) |
| | else: |
| | active_dataloader = train_dataloader |
| |
|
| | for step, batch in enumerate(active_dataloader): |
| | with accelerator.accumulate(model): |
| | outputs = model(**batch) |
| | loss = outputs.loss |
| | |
| | if args.with_tracking: |
| | total_loss += loss.detach().float() |
| | accelerator.backward(loss) |
| | optimizer.step() |
| | lr_scheduler.step() |
| | optimizer.zero_grad() |
| |
|
| | |
| | if accelerator.sync_gradients: |
| | progress_bar.update(1) |
| | completed_steps += 1 |
| |
|
| | if isinstance(checkpointing_steps, int): |
| | if completed_steps % checkpointing_steps == 0 and accelerator.sync_gradients: |
| | output_dir = f"step_{completed_steps}" |
| | if args.output_dir is not None: |
| | output_dir = os.path.join(args.output_dir, output_dir) |
| | accelerator.save_state(output_dir) |
| |
|
| | if args.push_to_hub and epoch < args.num_train_epochs - 1: |
| | accelerator.wait_for_everyone() |
| | unwrapped_model = accelerator.unwrap_model(model) |
| | unwrapped_model.save_pretrained( |
| | args.output_dir, |
| | is_main_process=accelerator.is_main_process, |
| | save_function=accelerator.save, |
| | ) |
| | if accelerator.is_main_process: |
| | image_processor.save_pretrained(args.output_dir) |
| | api.upload_folder( |
| | commit_message=f"Training in progress epoch {epoch}", |
| | folder_path=args.output_dir, |
| | repo_id=repo_id, |
| | repo_type="model", |
| | token=args.hub_token, |
| | ) |
| |
|
| | if completed_steps >= args.max_train_steps: |
| | break |
| |
|
| | logger.info("***** Running evaluation *****") |
| | metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label) |
| |
|
| | logger.info(f"epoch {epoch}: {metrics}") |
| |
|
| | if args.with_tracking: |
| | accelerator.log( |
| | { |
| | "train_loss": total_loss.item() / len(train_dataloader), |
| | **metrics, |
| | "epoch": epoch, |
| | "step": completed_steps, |
| | }, |
| | step=completed_steps, |
| | ) |
| |
|
| | if args.push_to_hub and epoch < args.num_train_epochs - 1: |
| | accelerator.wait_for_everyone() |
| | unwrapped_model = accelerator.unwrap_model(model) |
| | unwrapped_model.save_pretrained( |
| | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save |
| | ) |
| | if accelerator.is_main_process: |
| | image_processor.save_pretrained(args.output_dir) |
| | api.upload_folder( |
| | commit_message=f"Training in progress epoch {epoch}", |
| | folder_path=args.output_dir, |
| | repo_id=repo_id, |
| | repo_type="model", |
| | token=args.hub_token, |
| | ) |
| |
|
| | if args.checkpointing_steps == "epoch": |
| | output_dir = f"epoch_{epoch}" |
| | if args.output_dir is not None: |
| | output_dir = os.path.join(args.output_dir, output_dir) |
| | accelerator.save_state(output_dir) |
| |
|
| | |
| | |
| | |
| |
|
| | logger.info("***** Running evaluation on test dataset *****") |
| | metrics = evaluation_loop(model, image_processor, accelerator, test_dataloader, id2label) |
| | metrics = {f"test_{k}": v for k, v in metrics.items()} |
| |
|
| | logger.info(f"Test metrics: {metrics}") |
| |
|
| | if args.output_dir is not None: |
| | accelerator.wait_for_everyone() |
| | unwrapped_model = accelerator.unwrap_model(model) |
| | unwrapped_model.save_pretrained( |
| | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save |
| | ) |
| | if accelerator.is_main_process: |
| | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: |
| | json.dump(metrics, f, indent=2) |
| |
|
| | image_processor.save_pretrained(args.output_dir) |
| |
|
| | if args.push_to_hub: |
| | api.upload_folder( |
| | commit_message="End of training", |
| | folder_path=args.output_dir, |
| | repo_id=repo_id, |
| | repo_type="model", |
| | token=args.hub_token, |
| | ignore_patterns=["epoch_*"], |
| | ) |
| |
|
| | accelerator.wait_for_everyone() |
| | accelerator.end_training() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|