#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # /// script # dependencies = [ # "transformers @ git+https://github.com/huggingface/transformers.git", # "albumentations >= 1.4.16", # "timm", # "datasets", # "torchmetrics", # "pycocotools", # ] # /// """Finetuning 🤗 Transformers model for instance segmentation with Accelerate 🚀.""" import argparse import json import logging import math import os import sys from collections.abc import Mapping from functools import partial from pathlib import Path from typing import Any import albumentations as A import datasets import numpy as np import torch from accelerate import Accelerator from accelerate.utils import set_seed from datasets import load_dataset from huggingface_hub import HfApi from torch.utils.data import DataLoader from torchmetrics.detection.mean_ap import MeanAveragePrecision from tqdm import tqdm import transformers from transformers import ( AutoImageProcessor, AutoModelForUniversalSegmentation, SchedulerType, get_scheduler, ) from transformers.image_processing_utils import BatchFeature from transformers.utils import check_min_version from transformers.utils.versions import require_version logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.57.0.dev0") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") def parse_args(): parser = argparse.ArgumentParser(description="Finetune a transformers model for instance segmentation task") parser.add_argument( "--model_name_or_path", type=str, help="Path to a pretrained model or model identifier from huggingface.co/models.", default="facebook/mask2former-swin-tiny-coco-instance", ) parser.add_argument( "--dataset_name", type=str, help="Name of the dataset on the hub.", default="qubvel-hf/ade20k-mini", ) parser.add_argument( "--trust_remote_code", action="store_true", help=( "Whether to trust the execution of code from datasets/models defined on the Hub." " This option should only be set to `True` for repositories you trust and in which you have read the" " code, as it will execute code present on the Hub on your local machine." ), ) parser.add_argument( "--image_height", type=int, default=384, help="The height of the images to feed the model.", ) parser.add_argument( "--image_width", type=int, default=384, help="The width of the images to feed the model.", ) parser.add_argument( "--do_reduce_labels", action="store_true", help="Whether to reduce the number of labels by removing the background class.", ) parser.add_argument( "--cache_dir", type=str, help="Path to a folder in which the model and dataset will be cached.", ) parser.add_argument( "--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.", ) parser.add_argument( "--dataloader_num_workers", type=int, default=4, help="Number of workers to use for the dataloaders.", ) parser.add_argument( "--learning_rate", type=float, default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="Beta1 for AdamW optimizer", ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="Beta2 for AdamW optimizer", ) parser.add_argument( "--adam_epsilon", type=float, default=1e-8, help="Epsilon for AdamW optimizer", ) parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--lr_scheduler_type", type=SchedulerType, default="linear", help="The scheduler type to use.", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], ) parser.add_argument( "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument( "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." ) parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") parser.add_argument( "--checkpointing_steps", type=str, default=None, help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help="If the training should continue from a checkpoint folder.", ) parser.add_argument( "--with_tracking", required=False, action="store_true", help="Whether to enable experiment trackers for logging.", ) parser.add_argument( "--report_to", type=str, default="all", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. ' "Only applicable when `--with_tracking` is passed." ), ) args = parser.parse_args() # Sanity checks if args.push_to_hub or args.with_tracking: if args.output_dir is None: raise ValueError( "Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified." ) if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) return args def augment_and_transform_batch( examples: Mapping[str, Any], transform: A.Compose, image_processor: AutoImageProcessor ) -> BatchFeature: batch = { "pixel_values": [], "mask_labels": [], "class_labels": [], } for pil_image, pil_annotation in zip(examples["image"], examples["annotation"]): image = np.array(pil_image) semantic_and_instance_masks = np.array(pil_annotation)[..., :2] # Apply augmentations output = transform(image=image, mask=semantic_and_instance_masks) aug_image = output["image"] aug_semantic_and_instance_masks = output["mask"] aug_instance_mask = aug_semantic_and_instance_masks[..., 1] # Create mapping from instance id to semantic id unique_semantic_id_instance_id_pairs = np.unique(aug_semantic_and_instance_masks.reshape(-1, 2), axis=0) instance_id_to_semantic_id = { instance_id: semantic_id for semantic_id, instance_id in unique_semantic_id_instance_id_pairs } # Apply the image processor transformations: resizing, rescaling, normalization model_inputs = image_processor( images=[aug_image], segmentation_maps=[aug_instance_mask], instance_id_to_semantic_id=instance_id_to_semantic_id, return_tensors="pt", ) batch["pixel_values"].append(model_inputs.pixel_values[0]) batch["mask_labels"].append(model_inputs.mask_labels[0]) batch["class_labels"].append(model_inputs.class_labels[0]) return batch def collate_fn(examples): batch = {} batch["pixel_values"] = torch.stack([example["pixel_values"] for example in examples]) batch["class_labels"] = [example["class_labels"] for example in examples] batch["mask_labels"] = [example["mask_labels"] for example in examples] if "pixel_mask" in examples[0]: batch["pixel_mask"] = torch.stack([example["pixel_mask"] for example in examples]) return batch def nested_cpu(tensors): if isinstance(tensors, (list, tuple)): return type(tensors)(nested_cpu(t) for t in tensors) elif isinstance(tensors, Mapping): return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()}) elif isinstance(tensors, torch.Tensor): return tensors.cpu().detach() else: return tensors def evaluation_loop(model, image_processor, accelerator: Accelerator, dataloader, id2label): metric = MeanAveragePrecision(iou_type="segm", class_metrics=True) for inputs in tqdm(dataloader, total=len(dataloader), disable=not accelerator.is_local_main_process): with torch.no_grad(): outputs = model(**inputs) inputs = accelerator.gather_for_metrics(inputs) inputs = nested_cpu(inputs) outputs = accelerator.gather_for_metrics(outputs) outputs = nested_cpu(outputs) # For metric computation we need to provide: # - targets in a form of list of dictionaries with keys "masks", "labels" # - predictions in a form of list of dictionaries with keys "masks", "labels", "scores" post_processed_targets = [] post_processed_predictions = [] target_sizes = [] # Collect targets for masks, labels in zip(inputs["mask_labels"], inputs["class_labels"]): post_processed_targets.append( { "masks": masks.to(dtype=torch.bool), "labels": labels, } ) target_sizes.append(masks.shape[-2:]) # Collect predictions post_processed_output = image_processor.post_process_instance_segmentation( outputs, threshold=0.0, target_sizes=target_sizes, return_binary_maps=True, ) for image_predictions, target_size in zip(post_processed_output, target_sizes): if image_predictions["segments_info"]: post_processed_image_prediction = { "masks": image_predictions["segmentation"].to(dtype=torch.bool), "labels": torch.tensor([x["label_id"] for x in image_predictions["segments_info"]]), "scores": torch.tensor([x["score"] for x in image_predictions["segments_info"]]), } else: # for void predictions, we need to provide empty tensors post_processed_image_prediction = { "masks": torch.zeros([0, *target_size], dtype=torch.bool), "labels": torch.tensor([]), "scores": torch.tensor([]), } post_processed_predictions.append(post_processed_image_prediction) # Update metric for batch targets and predictions metric.update(post_processed_predictions, post_processed_targets) # Compute metrics metrics = metric.compute() # Replace list of per class metrics with separate metric for each class classes = metrics.pop("classes") map_per_class = metrics.pop("map_per_class") mar_100_per_class = metrics.pop("mar_100_per_class") for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class): class_name = id2label[class_id.item()] if id2label is not None else class_id.item() metrics[f"map_{class_name}"] = class_map metrics[f"mar_100_{class_name}"] = class_mar metrics = {k: round(v.item(), 4) for k, v in metrics.items()} return metrics def setup_logging(accelerator: Accelerator) -> None: """Setup logging according to `training_args`.""" logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() logger.setLevel(logging.INFO) else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() def handle_repository_creation(accelerator: Accelerator, args: argparse.Namespace): """Create a repository for the model and dataset if `args.push_to_hub` is set.""" repo_id = None if accelerator.is_main_process: if args.push_to_hub: # Retrieve of infer repo_name repo_name = args.hub_model_id if repo_name is None: repo_name = Path(args.output_dir).absolute().name # Create repo and retrieve repo_id api = HfApi() repo_id = api.create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: if "step_*" not in gitignore: gitignore.write("step_*\n") if "epoch_*" not in gitignore: gitignore.write("epoch_*\n") elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) accelerator.wait_for_everyone() return repo_id def main(): args = parse_args() # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment accelerator_log_kwargs = {} if args.with_tracking: accelerator_log_kwargs["log_with"] = args.report_to accelerator_log_kwargs["project_dir"] = args.output_dir accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) setup_logging(accelerator) # If passed along, set the training seed now. # We set device_specific to True as we want different data augmentation per device. if args.seed is not None: set_seed(args.seed, device_specific=True) # Create repository if push ot hub is specified repo_id = handle_repository_creation(accelerator, args) if args.push_to_hub: api = HfApi() # ------------------------------------------------------------------------------------------------ # Load dataset, prepare splits # ------------------------------------------------------------------------------------------------ # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir, trust_remote_code=args.trust_remote_code) # We need to specify the label2id mapping for the model # it is a mapping from semantic class name to class index. # In case your dataset does not provide it, you can create it manually: # label2id = {"background": 0, "cat": 1, "dog": 2} label2id = dataset["train"][0]["semantic_class_to_id"] if args.do_reduce_labels: label2id = {name: idx for name, idx in label2id.items() if idx != 0} # remove background class label2id = {name: idx - 1 for name, idx in label2id.items()} # shift class indices by -1 id2label = {v: k for k, v in label2id.items()} # ------------------------------------------------------------------------------------------------ # Load pretrained model and image processor # ------------------------------------------------------------------------------------------------ model = AutoModelForUniversalSegmentation.from_pretrained( args.model_name_or_path, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True, token=args.hub_token, ) image_processor = AutoImageProcessor.from_pretrained( args.model_name_or_path, do_resize=True, size={"height": args.image_height, "width": args.image_width}, do_reduce_labels=args.do_reduce_labels, reduce_labels=args.do_reduce_labels, # TODO: remove when mask2former support `do_reduce_labels` token=args.hub_token, ) # ------------------------------------------------------------------------------------------------ # Define image augmentations and dataset transforms # ------------------------------------------------------------------------------------------------ train_augment_and_transform = A.Compose( [ A.HorizontalFlip(p=0.5), A.RandomBrightnessContrast(p=0.5), A.HueSaturationValue(p=0.1), ], ) validation_transform = A.Compose( [A.NoOp()], ) # Make transform functions for batch and apply for dataset splits train_transform_batch = partial( augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor ) validation_transform_batch = partial( augment_and_transform_batch, transform=validation_transform, image_processor=image_processor ) with accelerator.main_process_first(): dataset["train"] = dataset["train"].with_transform(train_transform_batch) dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch) dataloader_common_args = { "num_workers": args.dataloader_num_workers, "persistent_workers": True, "collate_fn": collate_fn, } train_dataloader = DataLoader( dataset["train"], shuffle=True, batch_size=args.per_device_train_batch_size, **dataloader_common_args ) valid_dataloader = DataLoader( dataset["validation"], shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args ) # ------------------------------------------------------------------------------------------------ # Define optimizer, scheduler and prepare everything with the accelerator # ------------------------------------------------------------------------------------------------ # Optimizer optimizer = torch.optim.AdamW( list(model.parameters()), lr=args.learning_rate, betas=[args.adam_beta1, args.adam_beta2], eps=args.adam_epsilon, ) # Figure out how many steps we should save the Accelerator states checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes, ) # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, valid_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, valid_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if args.with_tracking: experiment_config = vars(args) # TensorBoard cannot log Enums, need the raw value experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value accelerator.init_trackers("instance_segmentation_no_trainer", experiment_config) # ------------------------------------------------------------------------------------------------ # Run training with evaluation on each epoch # ------------------------------------------------------------------------------------------------ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(dataset['train'])}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 starting_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": checkpoint_path = args.resume_from_checkpoint path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs.sort(key=os.path.getctime) path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last checkpoint_path = path path = os.path.basename(checkpoint_path) accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") accelerator.load_state(checkpoint_path) # Extract `epoch_{i}` or `step_{i}` training_difference = os.path.splitext(path)[0] if "epoch" in training_difference: starting_epoch = int(training_difference.replace("epoch_", "")) + 1 resume_step = None completed_steps = starting_epoch * num_update_steps_per_epoch else: # need to multiply `gradient_accumulation_steps` to reflect real steps resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) completed_steps = resume_step // args.gradient_accumulation_steps resume_step -= starting_epoch * len(train_dataloader) # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) for epoch in range(starting_epoch, args.num_train_epochs): model.train() if args.with_tracking: total_loss = 0 if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) else: active_dataloader = train_dataloader for step, batch in enumerate(active_dataloader): with accelerator.accumulate(model): outputs = model(**batch) loss = outputs.loss # We keep track of the loss at each epoch if args.with_tracking: total_loss += loss.detach().float() accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) completed_steps += 1 if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0 and accelerator.sync_gradients: output_dir = f"step_{completed_steps}" if args.output_dir is not None: output_dir = os.path.join(args.output_dir, output_dir) accelerator.save_state(output_dir) if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, ) if accelerator.is_main_process: image_processor.save_pretrained(args.output_dir) api.upload_folder( repo_id=repo_id, commit_message=f"Training in progress epoch {epoch}", folder_path=args.output_dir, repo_type="model", token=args.hub_token, ) if completed_steps >= args.max_train_steps: break logger.info("***** Running evaluation *****") metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label) logger.info(f"epoch {epoch}: {metrics}") if args.with_tracking: accelerator.log( { "train_loss": total_loss.item() / len(train_dataloader), **metrics, "epoch": epoch, "step": completed_steps, }, step=completed_steps, ) if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save ) if accelerator.is_main_process: image_processor.save_pretrained(args.output_dir) api.upload_folder( commit_message=f"Training in progress epoch {epoch}", folder_path=args.output_dir, repo_id=repo_id, repo_type="model", token=args.hub_token, ) if args.checkpointing_steps == "epoch": output_dir = f"epoch_{epoch}" if args.output_dir is not None: output_dir = os.path.join(args.output_dir, output_dir) accelerator.save_state(output_dir) # ------------------------------------------------------------------------------------------------ # Run evaluation on test dataset and save the model # ------------------------------------------------------------------------------------------------ logger.info("***** Running evaluation on test dataset *****") metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label) metrics = {f"test_{k}": v for k, v in metrics.items()} logger.info(f"Test metrics: {metrics}") if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save ) if accelerator.is_main_process: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: json.dump(metrics, f, indent=2) image_processor.save_pretrained(args.output_dir) if args.push_to_hub: api.upload_folder( commit_message="End of training", folder_path=args.output_dir, repo_id=repo_id, repo_type="model", token=args.hub_token, ignore_patterns=["epoch_*"], ) accelerator.wait_for_everyone() accelerator.end_training() if __name__ == "__main__": main()