File size: 4,163 Bytes
010952f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os.path as osp
import logging
from dataclasses import dataclass
import torch
from torchvision import transforms
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import register_task
from fairseq.logging import metrics
try:
from ..data import ImageDataset
except:
import sys
sys.path.append("..")
from data import ImageDataset
from .image_pretraining import (
ImagePretrainingConfig,
ImagePretrainingTask,
IMG_EXTENSIONS,
)
logger = logging.getLogger(__name__)
@dataclass
class ImageClassificationConfig(ImagePretrainingConfig):
pass
@register_task("image_classification", dataclass=ImageClassificationConfig)
class ImageClassificationTask(ImagePretrainingTask):
cfg: ImageClassificationConfig
@classmethod
def setup_task(cls, cfg: ImageClassificationConfig, **kwargs):
return cls(cfg)
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
cfg = task_cfg or self.cfg
path_with_split = osp.join(data_path, split)
if osp.exists(path_with_split):
data_path = path_with_split
from timm.data import create_transform
if split == "train":
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=cfg.input_size,
is_training=True,
auto_augment="rand-m9-mstd0.5-inc1",
interpolation="bicubic",
re_prob=0.25,
re_mode="pixel",
re_count=1,
mean=cfg.normalization_mean,
std=cfg.normalization_std,
)
if not cfg.input_size > 32:
transform.transforms[0] = transforms.RandomCrop(
cfg.input_size, padding=4
)
else:
t = []
if cfg.input_size > 32:
crop_pct = 1
if cfg.input_size < 384:
crop_pct = 224 / 256
size = int(cfg.input_size / crop_pct)
t.append(
transforms.Resize(
size, interpolation=3
), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(cfg.input_size))
t.append(transforms.ToTensor())
t.append(
transforms.Normalize(cfg.normalization_mean, cfg.normalization_std)
)
transform = transforms.Compose(t)
logger.info(transform)
self.datasets[split] = ImageDataset(
root=data_path,
extensions=IMG_EXTENSIONS,
load_classes=True,
transform=transform,
)
for k in self.datasets.keys():
if k != split:
assert self.datasets[k].classes == self.datasets[split].classes
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(model_cfg, from_checkpoint)
actualized_cfg = getattr(model, "cfg", None)
if actualized_cfg is not None:
if hasattr(actualized_cfg, "pretrained_model_args"):
model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
return model
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if "correct" in logging_outputs[0]:
zero = torch.scalar_tensor(0.0)
correct = sum(log.get("correct", zero) for log in logging_outputs)
metrics.log_scalar_sum("_correct", correct)
metrics.log_derived(
"accuracy",
lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
)
|