PyTorch
ssl-aasist
custom_code
ssl-aasist / fairseq /examples /data2vec /models /data2vec_image_classification.py
ash56's picture
Add files using upload-large-folder tool
6789f6f verified
raw
history blame
4.22 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from typing import Any
from omegaconf import II, MISSING
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
logger = logging.getLogger(__name__)
@dataclass
class Data2VecImageClassificationConfig(FairseqDataclass):
model_path: str = MISSING
no_pretrained_weights: bool = False
num_classes: int = 1000
mixup: float = 0.8
cutmix: float = 1.0
label_smoothing: float = 0.1
pretrained_model_args: Any = None
data: str = II("task.data")
@register_model(
"data2vec_image_classification", dataclass=Data2VecImageClassificationConfig
)
class Data2VecImageClassificationModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecImageClassificationConfig):
super().__init__()
self.cfg = cfg
if cfg.pretrained_model_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
pretrained_args = state.get("cfg", None)
pretrained_args.criterion = None
pretrained_args.lr_scheduler = None
cfg.pretrained_model_args = pretrained_args
logger.info(pretrained_args)
else:
state = None
pretrained_args = cfg.pretrained_model_args
pretrained_args.task.data = cfg.data
task = tasks.setup_task(pretrained_args.task)
model = task.build_model(pretrained_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
self.model = model
if state is not None and not cfg.no_pretrained_weights:
self.load_model_weights(state, model, cfg)
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim)
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
self.head.weight.data.mul_(1e-3)
self.head.bias.data.mul_(1e-3)
self.mixup_fn = None
if cfg.mixup > 0 or cfg.cutmix > 0:
from timm.data import Mixup
self.mixup_fn = Mixup(
mixup_alpha=cfg.mixup,
cutmix_alpha=cfg.cutmix,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode="batch",
label_smoothing=cfg.label_smoothing,
num_classes=cfg.num_classes,
)
def load_model_weights(self, state, model, cfg):
if "_ema" in state["model"]:
del state["model"]["_ema"]
model.load_state_dict(state["model"], strict=True)
@classmethod
def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def forward(
self,
img,
label=None,
):
if self.training and self.mixup_fn is not None and label is not None:
img, label = self.mixup_fn(img, label)
x = self.model(img, mask=False)
x = x[:, 1:]
x = self.fc_norm(x.mean(1))
x = self.head(x)
if label is None:
return x
if self.training and self.mixup_fn is not None:
loss = -label * F.log_softmax(x.float(), dim=-1)
else:
loss = F.cross_entropy(
x.float(),
label,
label_smoothing=self.cfg.label_smoothing if self.training else 0,
reduction="none",
)
result = {
"losses": {"regression": loss},
"sample_size": img.size(0),
}
if not self.training:
with torch.no_grad():
pred = x.argmax(-1)
correct = (pred == label).sum()
result["correct"] = correct
return result