| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Finetuning 🤗 Transformers model for instance segmentation leveraging the Trainer API.""" |
| |
|
| | import logging |
| | import os |
| | import sys |
| | from collections.abc import Mapping |
| | from dataclasses import dataclass, field |
| | from functools import partial |
| | from typing import Any |
| |
|
| | import albumentations as A |
| | import numpy as np |
| | import torch |
| | from datasets import load_dataset |
| | from torchmetrics.detection.mean_ap import MeanAveragePrecision |
| |
|
| | import transformers |
| | from transformers import ( |
| | AutoImageProcessor, |
| | AutoModelForUniversalSegmentation, |
| | HfArgumentParser, |
| | Trainer, |
| | TrainingArguments, |
| | ) |
| | from transformers.image_processing_utils import BatchFeature |
| | from transformers.trainer import EvalPrediction |
| | from transformers.utils import check_min_version |
| | from transformers.utils.versions import require_version |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | check_min_version("4.57.0.dev0") |
| |
|
| | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") |
| |
|
| |
|
| | @dataclass |
| | class Arguments: |
| | """ |
| | Arguments pertaining to what data we are going to input our model for training and eval. |
| | Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify |
| | them on the command line. |
| | """ |
| |
|
| | model_name_or_path: str = field( |
| | default="facebook/mask2former-swin-tiny-coco-instance", |
| | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}, |
| | ) |
| | dataset_name: str = field( |
| | default="qubvel-hf/ade20k-mini", |
| | metadata={ |
| | "help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)." |
| | }, |
| | ) |
| | trust_remote_code: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": ( |
| | "Whether to trust the execution of code from datasets/models defined on the Hub." |
| | " This option should only be set to `True` for repositories you trust and in which you have read the" |
| | " code, as it will execute code present on the Hub on your local machine." |
| | ) |
| | }, |
| | ) |
| | image_height: int | None = field(default=512, metadata={"help": "Image height after resizing."}) |
| | image_width: int | None = field(default=512, metadata={"help": "Image width after resizing."}) |
| | token: str = field( |
| | default=None, |
| | metadata={ |
| | "help": ( |
| | "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " |
| | "generated when running `hf auth login` (stored in `~/.huggingface`)." |
| | ) |
| | }, |
| | ) |
| | do_reduce_labels: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": ( |
| | "If background class is labeled as 0 and you want to remove it from the labels, set this flag to True." |
| | ) |
| | }, |
| | ) |
| |
|
| |
|
| | def augment_and_transform_batch( |
| | examples: Mapping[str, Any], transform: A.Compose, image_processor: AutoImageProcessor |
| | ) -> BatchFeature: |
| | batch = { |
| | "pixel_values": [], |
| | "mask_labels": [], |
| | "class_labels": [], |
| | } |
| |
|
| | for pil_image, pil_annotation in zip(examples["image"], examples["annotation"]): |
| | image = np.array(pil_image) |
| | semantic_and_instance_masks = np.array(pil_annotation)[..., :2] |
| |
|
| | |
| | output = transform(image=image, mask=semantic_and_instance_masks) |
| |
|
| | aug_image = output["image"] |
| | aug_semantic_and_instance_masks = output["mask"] |
| | aug_instance_mask = aug_semantic_and_instance_masks[..., 1] |
| |
|
| | |
| | unique_semantic_id_instance_id_pairs = np.unique(aug_semantic_and_instance_masks.reshape(-1, 2), axis=0) |
| | instance_id_to_semantic_id = { |
| | instance_id: semantic_id for semantic_id, instance_id in unique_semantic_id_instance_id_pairs |
| | } |
| |
|
| | |
| | model_inputs = image_processor( |
| | images=[aug_image], |
| | segmentation_maps=[aug_instance_mask], |
| | instance_id_to_semantic_id=instance_id_to_semantic_id, |
| | return_tensors="pt", |
| | ) |
| |
|
| | batch["pixel_values"].append(model_inputs.pixel_values[0]) |
| | batch["mask_labels"].append(model_inputs.mask_labels[0]) |
| | batch["class_labels"].append(model_inputs.class_labels[0]) |
| |
|
| | return batch |
| |
|
| |
|
| | def collate_fn(examples): |
| | batch = {} |
| | batch["pixel_values"] = torch.stack([example["pixel_values"] for example in examples]) |
| | batch["class_labels"] = [example["class_labels"] for example in examples] |
| | batch["mask_labels"] = [example["mask_labels"] for example in examples] |
| | if "pixel_mask" in examples[0]: |
| | batch["pixel_mask"] = torch.stack([example["pixel_mask"] for example in examples]) |
| | return batch |
| |
|
| |
|
| | @dataclass |
| | class ModelOutput: |
| | class_queries_logits: torch.Tensor |
| | masks_queries_logits: torch.Tensor |
| |
|
| |
|
| | def nested_cpu(tensors): |
| | if isinstance(tensors, (list, tuple)): |
| | return type(tensors)(nested_cpu(t) for t in tensors) |
| | elif isinstance(tensors, Mapping): |
| | return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()}) |
| | elif isinstance(tensors, torch.Tensor): |
| | return tensors.cpu().detach() |
| | else: |
| | return tensors |
| |
|
| |
|
| | class Evaluator: |
| | """ |
| | Compute metrics for the instance segmentation task. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | image_processor: AutoImageProcessor, |
| | id2label: Mapping[int, str], |
| | threshold: float = 0.0, |
| | ): |
| | """ |
| | Initialize evaluator with image processor, id2label mapping and threshold for filtering predictions. |
| | |
| | Args: |
| | image_processor (AutoImageProcessor): Image processor for |
| | `post_process_instance_segmentation` method. |
| | id2label (Mapping[int, str]): Mapping from class id to class name. |
| | threshold (float): Threshold to filter predicted boxes by confidence. Defaults to 0.0. |
| | """ |
| | self.image_processor = image_processor |
| | self.id2label = id2label |
| | self.threshold = threshold |
| | self.metric = self.get_metric() |
| |
|
| | def get_metric(self): |
| | metric = MeanAveragePrecision(iou_type="segm", class_metrics=True) |
| | return metric |
| |
|
| | def reset_metric(self): |
| | self.metric.reset() |
| |
|
| | def postprocess_target_batch(self, target_batch) -> list[dict[str, torch.Tensor]]: |
| | """Collect targets in a form of list of dictionaries with keys "masks", "labels".""" |
| | batch_masks = target_batch[0] |
| | batch_labels = target_batch[1] |
| | post_processed_targets = [] |
| | for masks, labels in zip(batch_masks, batch_labels): |
| | post_processed_targets.append( |
| | { |
| | "masks": masks.to(dtype=torch.bool), |
| | "labels": labels, |
| | } |
| | ) |
| | return post_processed_targets |
| |
|
| | def get_target_sizes(self, post_processed_targets) -> list[list[int]]: |
| | target_sizes = [] |
| | for target in post_processed_targets: |
| | target_sizes.append(target["masks"].shape[-2:]) |
| | return target_sizes |
| |
|
| | def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> list[dict[str, torch.Tensor]]: |
| | """Collect predictions in a form of list of dictionaries with keys "masks", "labels", "scores".""" |
| |
|
| | model_output = ModelOutput(class_queries_logits=prediction_batch[0], masks_queries_logits=prediction_batch[1]) |
| | post_processed_output = self.image_processor.post_process_instance_segmentation( |
| | model_output, |
| | threshold=self.threshold, |
| | target_sizes=target_sizes, |
| | return_binary_maps=True, |
| | ) |
| |
|
| | post_processed_predictions = [] |
| | for image_predictions, target_size in zip(post_processed_output, target_sizes): |
| | if image_predictions["segments_info"]: |
| | post_processed_image_prediction = { |
| | "masks": image_predictions["segmentation"].to(dtype=torch.bool), |
| | "labels": torch.tensor([x["label_id"] for x in image_predictions["segments_info"]]), |
| | "scores": torch.tensor([x["score"] for x in image_predictions["segments_info"]]), |
| | } |
| | else: |
| | |
| | post_processed_image_prediction = { |
| | "masks": torch.zeros([0, *target_size], dtype=torch.bool), |
| | "labels": torch.tensor([]), |
| | "scores": torch.tensor([]), |
| | } |
| | post_processed_predictions.append(post_processed_image_prediction) |
| |
|
| | return post_processed_predictions |
| |
|
| | @torch.no_grad() |
| | def __call__(self, evaluation_results: EvalPrediction, compute_result: bool = False) -> Mapping[str, float]: |
| | """ |
| | Update metrics with current evaluation results and return metrics if `compute_result` is True. |
| | |
| | Args: |
| | evaluation_results (EvalPrediction): Predictions and targets from evaluation. |
| | compute_result (bool): Whether to compute and return metrics. |
| | |
| | Returns: |
| | Mapping[str, float]: Metrics in a form of dictionary {<metric_name>: <metric_value>} |
| | """ |
| | prediction_batch = nested_cpu(evaluation_results.predictions) |
| | target_batch = nested_cpu(evaluation_results.label_ids) |
| |
|
| | |
| | |
| | |
| | post_processed_targets = self.postprocess_target_batch(target_batch) |
| | target_sizes = self.get_target_sizes(post_processed_targets) |
| | post_processed_predictions = self.postprocess_prediction_batch(prediction_batch, target_sizes) |
| |
|
| | |
| | self.metric.update(post_processed_predictions, post_processed_targets) |
| |
|
| | if not compute_result: |
| | return |
| |
|
| | metrics = self.metric.compute() |
| |
|
| | |
| | classes = metrics.pop("classes") |
| | map_per_class = metrics.pop("map_per_class") |
| | mar_100_per_class = metrics.pop("mar_100_per_class") |
| | for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class): |
| | class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item() |
| | metrics[f"map_{class_name}"] = class_map |
| | metrics[f"mar_100_{class_name}"] = class_mar |
| |
|
| | metrics = {k: round(v.item(), 4) for k, v in metrics.items()} |
| |
|
| | |
| | self.reset_metric() |
| |
|
| | return metrics |
| |
|
| |
|
| | def setup_logging(training_args: TrainingArguments) -> None: |
| | """Setup logging according to `training_args`.""" |
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%m/%d/%Y %H:%M:%S", |
| | handlers=[logging.StreamHandler(sys.stdout)], |
| | ) |
| |
|
| | if training_args.should_log: |
| | |
| | transformers.utils.logging.set_verbosity_info() |
| |
|
| | log_level = training_args.get_process_log_level() |
| | logger.setLevel(log_level) |
| | transformers.utils.logging.set_verbosity(log_level) |
| | transformers.utils.logging.enable_default_handler() |
| | transformers.utils.logging.enable_explicit_format() |
| |
|
| |
|
| | def find_last_checkpoint(training_args: TrainingArguments) -> str | None: |
| | """Find the last checkpoint in the output directory according to parameters specified in `training_args`.""" |
| |
|
| | checkpoint = None |
| | if training_args.resume_from_checkpoint is not None: |
| | checkpoint = training_args.resume_from_checkpoint |
| |
|
| | return checkpoint |
| |
|
| |
|
| | def main(): |
| | |
| | |
| |
|
| | parser = HfArgumentParser([Arguments, TrainingArguments]) |
| | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| | |
| | |
| | args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| | else: |
| | args, training_args = parser.parse_args_into_dataclasses() |
| |
|
| | |
| | training_args.eval_do_concat_batches = False |
| | training_args.batch_eval_metrics = True |
| | training_args.remove_unused_columns = False |
| |
|
| | |
| | setup_logging(training_args) |
| | logger.warning( |
| | f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " |
| | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" |
| | ) |
| | logger.info(f"Training/evaluation parameters {training_args}") |
| |
|
| | |
| | checkpoint = find_last_checkpoint(training_args) |
| |
|
| | |
| | |
| | |
| |
|
| | dataset = load_dataset(args.dataset_name, trust_remote_code=args.trust_remote_code) |
| |
|
| | |
| | |
| | |
| | |
| | label2id = dataset["train"][0]["semantic_class_to_id"] |
| |
|
| | if args.do_reduce_labels: |
| | label2id = {name: idx for name, idx in label2id.items() if idx != 0} |
| | label2id = {name: idx - 1 for name, idx in label2id.items()} |
| |
|
| | id2label = {v: k for k, v in label2id.items()} |
| |
|
| | |
| | |
| | |
| | model = AutoModelForUniversalSegmentation.from_pretrained( |
| | args.model_name_or_path, |
| | label2id=label2id, |
| | id2label=id2label, |
| | ignore_mismatched_sizes=True, |
| | token=args.token, |
| | ) |
| |
|
| | image_processor = AutoImageProcessor.from_pretrained( |
| | args.model_name_or_path, |
| | do_resize=True, |
| | size={"height": args.image_height, "width": args.image_width}, |
| | do_reduce_labels=args.do_reduce_labels, |
| | reduce_labels=args.do_reduce_labels, |
| | token=args.token, |
| | ) |
| |
|
| | |
| | |
| | |
| | train_augment_and_transform = A.Compose( |
| | [ |
| | A.HorizontalFlip(p=0.5), |
| | A.RandomBrightnessContrast(p=0.5), |
| | A.HueSaturationValue(p=0.1), |
| | ], |
| | ) |
| | validation_transform = A.Compose( |
| | [A.NoOp()], |
| | ) |
| |
|
| | |
| | train_transform_batch = partial( |
| | augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor |
| | ) |
| | validation_transform_batch = partial( |
| | augment_and_transform_batch, transform=validation_transform, image_processor=image_processor |
| | ) |
| |
|
| | dataset["train"] = dataset["train"].with_transform(train_transform_batch) |
| | dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch) |
| |
|
| | |
| | |
| | |
| |
|
| | compute_metrics = Evaluator(image_processor=image_processor, id2label=id2label, threshold=0.0) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=dataset["train"] if training_args.do_train else None, |
| | eval_dataset=dataset["validation"] if training_args.do_eval else None, |
| | processing_class=image_processor, |
| | data_collator=collate_fn, |
| | compute_metrics=compute_metrics, |
| | ) |
| |
|
| | |
| | if training_args.do_train: |
| | train_result = trainer.train(resume_from_checkpoint=checkpoint) |
| | trainer.save_model() |
| | trainer.log_metrics("train", train_result.metrics) |
| | trainer.save_metrics("train", train_result.metrics) |
| | trainer.save_state() |
| |
|
| | |
| | if training_args.do_eval: |
| | metrics = trainer.evaluate(eval_dataset=dataset["validation"], metric_key_prefix="test") |
| | trainer.log_metrics("test", metrics) |
| | trainer.save_metrics("test", metrics) |
| |
|
| | |
| | kwargs = { |
| | "finetuned_from": args.model_name_or_path, |
| | "dataset": args.dataset_name, |
| | "tags": ["image-segmentation", "instance-segmentation", "vision"], |
| | } |
| | if training_args.push_to_hub: |
| | trainer.push_to_hub(**kwargs) |
| | else: |
| | trainer.create_model_card(**kwargs) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|