Upload model
Browse files- config.json +5 -1
- config.py +28 -0
- configuration_cetacean_classifier.py +35 -0
- metric_learning.py +59 -0
- modeling_cetacean_classifier.py +66 -0
- train.py +318 -0
- utils.py +41 -0
config.json
CHANGED
|
@@ -2,7 +2,11 @@
|
|
| 2 |
"architectures": [
|
| 3 |
"CetaceanClassifierModelForImageClassification"
|
| 4 |
],
|
| 5 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"torch_dtype": "float32",
|
| 7 |
"transformers_version": "4.46.0"
|
| 8 |
}
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"CetaceanClassifierModelForImageClassification"
|
| 4 |
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_cetacean_classifier.CetaceanClassifierConfig",
|
| 7 |
+
"AutoModelForImageClassification": "modeling_cetacean_classifier.CetaceanClassifierModelForImageClassification"
|
| 8 |
+
},
|
| 9 |
+
"model_type": "cetaceanet",
|
| 10 |
"torch_dtype": "float32",
|
| 11 |
"transformers_version": "4.46.0"
|
| 12 |
}
|
config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import yaml
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Config(dict):
|
| 7 |
+
def __getattr__(self, key):
|
| 8 |
+
try:
|
| 9 |
+
val = self[key]
|
| 10 |
+
except KeyError:
|
| 11 |
+
return super().__getattr__(key)
|
| 12 |
+
if isinstance(val, dict):
|
| 13 |
+
return Config(val)
|
| 14 |
+
return val
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_config(path: str, default_path: Optional[str]) -> Config:
|
| 18 |
+
with open(path) as f:
|
| 19 |
+
cfg = Config(yaml.full_load(f))
|
| 20 |
+
if default_path is not None:
|
| 21 |
+
# set keys not included in `path` by default
|
| 22 |
+
with open(default_path) as f:
|
| 23 |
+
default_cfg = Config(yaml.full_load(f))
|
| 24 |
+
for key, val in default_cfg.items():
|
| 25 |
+
if key not in cfg:
|
| 26 |
+
print(f"used default config {key}: {val}")
|
| 27 |
+
cfg[key] = val
|
| 28 |
+
return cfg
|
configuration_cetacean_classifier.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CetaceanClassifierConfig(PretrainedConfig):
|
| 6 |
+
model_type = "cetaceanet"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
# block_type="bottleneck",
|
| 11 |
+
# layers: List[int] = [3, 4, 6, 3],
|
| 12 |
+
# num_classes: int = 1000,
|
| 13 |
+
# input_channels: int = 3,
|
| 14 |
+
# cardinality: int = 1,
|
| 15 |
+
# base_width: int = 64,
|
| 16 |
+
# stem_width: int = 64,
|
| 17 |
+
# stem_type: str = "",
|
| 18 |
+
# avg_down: bool = False,
|
| 19 |
+
**kwargs,
|
| 20 |
+
):
|
| 21 |
+
# if block_type not in ["basic", "bottleneck"]:
|
| 22 |
+
# raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
|
| 23 |
+
# if stem_type not in ["", "deep", "deep-tiered"]:
|
| 24 |
+
# raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
|
| 25 |
+
|
| 26 |
+
# self.block_type = block_type
|
| 27 |
+
# self.layers = layers
|
| 28 |
+
# self.num_classes = num_classes
|
| 29 |
+
# self.input_channels = input_channels
|
| 30 |
+
# self.cardinality = cardinality
|
| 31 |
+
# self.base_width = base_width
|
| 32 |
+
# self.stem_width = stem_width
|
| 33 |
+
# self.stem_type = stem_type
|
| 34 |
+
# self.avg_down = avg_down
|
| 35 |
+
super().__init__(**kwargs)
|
metric_learning.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GeM(nn.Module):
|
| 10 |
+
def __init__(self, p=3, eps=1e-6, requires_grad=False):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad)
|
| 13 |
+
self.eps = eps
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor):
|
| 16 |
+
return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Copied and modified from
|
| 20 |
+
# https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/034a7d8665bb4696981698348c9370f2d4e61e35/models/ch_mdl_dolg_efficientnet.py
|
| 21 |
+
class ArcMarginProductSubcenter(nn.Module):
|
| 22 |
+
def __init__(self, in_features: int, out_features: int, k: int = 3):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
|
| 25 |
+
self.reset_parameters()
|
| 26 |
+
self.k = k
|
| 27 |
+
self.out_features = out_features
|
| 28 |
+
|
| 29 |
+
def reset_parameters(self):
|
| 30 |
+
stdv = 1.0 / math.sqrt(self.weight.size(1))
|
| 31 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 32 |
+
|
| 33 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
|
| 35 |
+
cosine_all = cosine_all.view(-1, self.out_features, self.k)
|
| 36 |
+
cosine, _ = torch.max(cosine_all, dim=2)
|
| 37 |
+
return cosine
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ArcFaceLossAdaptiveMargin(nn.modules.Module):
|
| 41 |
+
def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.s = s
|
| 44 |
+
self.margins = margins
|
| 45 |
+
self.out_dim = n_classes
|
| 46 |
+
|
| 47 |
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
ms = self.margins[labels.cpu().numpy()]
|
| 49 |
+
cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
|
| 50 |
+
sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
|
| 51 |
+
th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
|
| 52 |
+
mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
|
| 53 |
+
labels = F.one_hot(labels, self.out_dim).float()
|
| 54 |
+
logits = logits.float()
|
| 55 |
+
cosine = logits
|
| 56 |
+
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
| 57 |
+
phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
|
| 58 |
+
phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
|
| 59 |
+
return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s
|
modeling_cetacean_classifier.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .configuration_cetacean_classifier import CetaceanClassifierConfig
|
| 7 |
+
from .train import SphereClassifier
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
WHALE_CLASSES = np.array(
|
| 11 |
+
[
|
| 12 |
+
"beluga",
|
| 13 |
+
"blue_whale",
|
| 14 |
+
"bottlenose_dolphin",
|
| 15 |
+
"brydes_whale",
|
| 16 |
+
"commersons_dolphin",
|
| 17 |
+
"common_dolphin",
|
| 18 |
+
"cuviers_beaked_whale",
|
| 19 |
+
"dusky_dolphin",
|
| 20 |
+
"false_killer_whale",
|
| 21 |
+
"fin_whale",
|
| 22 |
+
"frasiers_dolphin",
|
| 23 |
+
"gray_whale",
|
| 24 |
+
"humpback_whale",
|
| 25 |
+
"killer_whale",
|
| 26 |
+
"long_finned_pilot_whale",
|
| 27 |
+
"melon_headed_whale",
|
| 28 |
+
"minke_whale",
|
| 29 |
+
"pantropic_spotted_dolphin",
|
| 30 |
+
"pygmy_killer_whale",
|
| 31 |
+
"rough_toothed_dolphin",
|
| 32 |
+
"sei_whale",
|
| 33 |
+
"short_finned_pilot_whale",
|
| 34 |
+
"southern_right_whale",
|
| 35 |
+
"spinner_dolphin",
|
| 36 |
+
"spotted_dolphin",
|
| 37 |
+
"white_sided_dolphin",
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CetaceanClassifierModelForImageClassification(PreTrainedModel):
|
| 43 |
+
config_class = CetaceanClassifierConfig
|
| 44 |
+
|
| 45 |
+
def __init__(self, config):
|
| 46 |
+
super().__init__(config)
|
| 47 |
+
self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
|
| 48 |
+
self.model.eval()
|
| 49 |
+
|
| 50 |
+
def preprocess_image(self, img: Image) -> torch.Tensor:
|
| 51 |
+
image_resized = img.resize((480, 480))
|
| 52 |
+
image_resized = np.array(image_resized)[None]
|
| 53 |
+
image_resized = np.transpose(image_resized, [0, 3, 2, 1])
|
| 54 |
+
image_tensor = torch.Tensor(image_resized)
|
| 55 |
+
return image_tensor
|
| 56 |
+
|
| 57 |
+
def forward(self, img: Image, labels=None):
|
| 58 |
+
tensor = self.preprocess_image(img)
|
| 59 |
+
head_id_logits, head_species_logits = self.model(tensor)
|
| 60 |
+
head_species_logits = head_species_logits.detach().numpy()
|
| 61 |
+
sorted_idx = head_species_logits.argsort()[0]
|
| 62 |
+
sorted_idx = np.array(list(reversed(sorted_idx)))
|
| 63 |
+
top_three_logits = sorted_idx[:3]
|
| 64 |
+
top_three_whale_preds = WHALE_CLASSES[top_three_logits]
|
| 65 |
+
|
| 66 |
+
return {"predictions": top_three_whale_preds}
|
train.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import argparse
|
| 2 |
+
# import os
|
| 3 |
+
# import warnings
|
| 4 |
+
from typing import Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
# import optuna
|
| 8 |
+
# import pandas as pd
|
| 9 |
+
import timm
|
| 10 |
+
import torch
|
| 11 |
+
# import wandb
|
| 12 |
+
# from optuna.integration import PyTorchLightningPruningCallback
|
| 13 |
+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
| 14 |
+
# from pytorch_lightning import loggers as pl_loggers
|
| 15 |
+
# from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 16 |
+
# from sklearn.model_selection import StratifiedKFold
|
| 17 |
+
# from torch.utils.data import ConcatDataset, DataLoader
|
| 18 |
+
|
| 19 |
+
from .config import Config, load_config
|
| 20 |
+
# from .dataset import WhaleDataset, load_df
|
| 21 |
+
from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
|
| 22 |
+
from .utils import WarmupCosineLambda, map_dict, topk_average_precision
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# def parse():
|
| 26 |
+
# parser = argparse.ArgumentParser(description="Training for HappyWhale")
|
| 27 |
+
# parser.add_argument("--out_base_dir", default="result")
|
| 28 |
+
# parser.add_argument("--in_base_dir", default="input")
|
| 29 |
+
# parser.add_argument("--exp_name", default="tmp")
|
| 30 |
+
# parser.add_argument("--load_snapshot", action="store_true")
|
| 31 |
+
# parser.add_argument("--save_checkpoint", action="store_true")
|
| 32 |
+
# parser.add_argument("--wandb_logger", action="store_true")
|
| 33 |
+
# parser.add_argument("--config_path", default="config/debug.yaml")
|
| 34 |
+
# return parser.parse_args()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# class WhaleDataModule(LightningDataModule):
|
| 38 |
+
# def __init__(
|
| 39 |
+
# self,
|
| 40 |
+
# df: pd.DataFrame,
|
| 41 |
+
# cfg: Config,
|
| 42 |
+
# image_dir: str,
|
| 43 |
+
# val_bbox_name: str,
|
| 44 |
+
# fold: int,
|
| 45 |
+
# additional_dataset: WhaleDataset = None,
|
| 46 |
+
# ):
|
| 47 |
+
# super().__init__()
|
| 48 |
+
# self.cfg = cfg
|
| 49 |
+
# self.image_dir = image_dir
|
| 50 |
+
# self.val_bbox_name = val_bbox_name
|
| 51 |
+
# self.additional_dataset = additional_dataset
|
| 52 |
+
# if cfg.n_data != -1:
|
| 53 |
+
# df = df.iloc[: cfg.n_data]
|
| 54 |
+
# self.all_df = df
|
| 55 |
+
# if fold == -1:
|
| 56 |
+
# self.train_df = df
|
| 57 |
+
# else:
|
| 58 |
+
# skf = StratifiedKFold(n_splits=cfg.n_splits, shuffle=True, random_state=0)
|
| 59 |
+
# train_idx, val_idx = list(skf.split(df, df.individual_id))[fold]
|
| 60 |
+
# self.train_df = df.iloc[train_idx].copy()
|
| 61 |
+
# self.val_df = df.iloc[val_idx].copy()
|
| 62 |
+
# # relabel ids not included in training data as "new individual"
|
| 63 |
+
# new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id)
|
| 64 |
+
# self.val_df.individual_id.mask(new_mask, cfg.num_classes, inplace=True)
|
| 65 |
+
# print(f"new: {(self.val_df.individual_id == cfg.num_classes).sum()} / {len(self.val_df)}")
|
| 66 |
+
|
| 67 |
+
# def get_dataset(self, df, data_aug):
|
| 68 |
+
# return WhaleDataset(df, self.cfg, self.image_dir, self.val_bbox_name, data_aug)
|
| 69 |
+
|
| 70 |
+
# def train_dataloader(self):
|
| 71 |
+
# dataset = self.get_dataset(self.train_df, True)
|
| 72 |
+
# if self.additional_dataset is not None:
|
| 73 |
+
# dataset = ConcatDataset([dataset, self.additional_dataset])
|
| 74 |
+
# return DataLoader(
|
| 75 |
+
# dataset,
|
| 76 |
+
# batch_size=self.cfg.batch_size,
|
| 77 |
+
# shuffle=True,
|
| 78 |
+
# num_workers=2,
|
| 79 |
+
# pin_memory=True,
|
| 80 |
+
# drop_last=True,
|
| 81 |
+
# )
|
| 82 |
+
|
| 83 |
+
# def val_dataloader(self):
|
| 84 |
+
# if self.cfg.n_splits == -1:
|
| 85 |
+
# return None
|
| 86 |
+
# return DataLoader(
|
| 87 |
+
# self.get_dataset(self.val_df, False),
|
| 88 |
+
# batch_size=self.cfg.batch_size,
|
| 89 |
+
# shuffle=False,
|
| 90 |
+
# num_workers=2,
|
| 91 |
+
# pin_memory=True,
|
| 92 |
+
# )
|
| 93 |
+
|
| 94 |
+
# def all_dataloader(self):
|
| 95 |
+
# return DataLoader(
|
| 96 |
+
# self.get_dataset(self.all_df, False),
|
| 97 |
+
# batch_size=self.cfg.batch_size,
|
| 98 |
+
# shuffle=False,
|
| 99 |
+
# num_workers=2,
|
| 100 |
+
# pin_memory=True,
|
| 101 |
+
# )
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class SphereClassifier(LightningModule):
|
| 105 |
+
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
|
| 106 |
+
super().__init__()
|
| 107 |
+
if not isinstance(cfg, Config):
|
| 108 |
+
cfg = Config(cfg)
|
| 109 |
+
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
|
| 110 |
+
self.test_results_fp = None
|
| 111 |
+
|
| 112 |
+
print(cfg.model_name)
|
| 113 |
+
|
| 114 |
+
# NN architecture
|
| 115 |
+
self.backbone = timm.create_model(
|
| 116 |
+
cfg.model_name,
|
| 117 |
+
in_chans=3,
|
| 118 |
+
pretrained=cfg.pretrained,
|
| 119 |
+
num_classes=0,
|
| 120 |
+
features_only=True,
|
| 121 |
+
out_indices=cfg.out_indices,
|
| 122 |
+
)
|
| 123 |
+
feature_dims = self.backbone.feature_info.channels()
|
| 124 |
+
print(f"feature dims: {feature_dims}")
|
| 125 |
+
self.global_pools = torch.nn.ModuleList(
|
| 126 |
+
[GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
|
| 127 |
+
)
|
| 128 |
+
self.mid_features = np.sum(feature_dims)
|
| 129 |
+
if cfg.normalization == "batchnorm":
|
| 130 |
+
self.neck = torch.nn.BatchNorm1d(self.mid_features)
|
| 131 |
+
elif cfg.normalization == "layernorm":
|
| 132 |
+
self.neck = torch.nn.LayerNorm(self.mid_features)
|
| 133 |
+
self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
|
| 134 |
+
self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
|
| 135 |
+
if id_class_nums is not None and species_class_nums is not None:
|
| 136 |
+
margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
|
| 137 |
+
margins_species = (
|
| 138 |
+
np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
|
| 139 |
+
+ cfg.margin_cons_species
|
| 140 |
+
)
|
| 141 |
+
print("margins_id", margins_id)
|
| 142 |
+
print("margins_species", margins_species)
|
| 143 |
+
self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
|
| 144 |
+
self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
|
| 145 |
+
self.loss_fn_id = torch.nn.CrossEntropyLoss()
|
| 146 |
+
self.loss_fn_species = torch.nn.CrossEntropyLoss()
|
| 147 |
+
|
| 148 |
+
def get_feat(self, x: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
ms = self.backbone(x)
|
| 150 |
+
h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
|
| 151 |
+
return self.neck(h)
|
| 152 |
+
|
| 153 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 154 |
+
feat = self.get_feat(x)
|
| 155 |
+
return self.head_id(feat), self.head_species(feat)
|
| 156 |
+
|
| 157 |
+
def training_step(self, batch, batch_idx):
|
| 158 |
+
x, ids, species = batch["image"], batch["label"], batch["label_species"]
|
| 159 |
+
logits_ids, logits_species = self(x)
|
| 160 |
+
margin_logits_ids = self.margin_fn_id(logits_ids, ids)
|
| 161 |
+
loss_ids = self.loss_fn_id(margin_logits_ids, ids)
|
| 162 |
+
loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
|
| 163 |
+
self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
|
| 164 |
+
self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
|
| 167 |
+
self.log_dict(
|
| 168 |
+
{"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
|
| 169 |
+
on_step=False,
|
| 170 |
+
on_epoch=True,
|
| 171 |
+
)
|
| 172 |
+
return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)
|
| 173 |
+
|
| 174 |
+
def validation_step(self, batch, batch_idx):
|
| 175 |
+
x, ids, species = batch["image"], batch["label"], batch["label_species"]
|
| 176 |
+
out1, out_species1 = self(x)
|
| 177 |
+
out2, out_species2 = self(x.flip(3))
|
| 178 |
+
output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
|
| 179 |
+
self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
|
| 180 |
+
self.log_dict(
|
| 181 |
+
{"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
|
| 182 |
+
on_step=False,
|
| 183 |
+
on_epoch=True,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def configure_optimizers(self):
|
| 187 |
+
backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
|
| 188 |
+
head_params = (
|
| 189 |
+
list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
|
| 190 |
+
)
|
| 191 |
+
params = [
|
| 192 |
+
{"params": backbone_params, "lr": self.hparams.lr_backbone},
|
| 193 |
+
{"params": head_params, "lr": self.hparams.lr_head},
|
| 194 |
+
]
|
| 195 |
+
if self.hparams.optimizer == "Adam":
|
| 196 |
+
optimizer = torch.optim.Adam(params)
|
| 197 |
+
elif self.hparams.optimizer == "AdamW":
|
| 198 |
+
optimizer = torch.optim.AdamW(params)
|
| 199 |
+
elif self.hparams.optimizer == "RAdam":
|
| 200 |
+
optimizer = torch.optim.RAdam(params)
|
| 201 |
+
|
| 202 |
+
warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
|
| 203 |
+
cycle_steps = self.hparams.max_epochs - warmup_steps
|
| 204 |
+
lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
|
| 205 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 206 |
+
return [optimizer], [scheduler]
|
| 207 |
+
|
| 208 |
+
def test_step(self, batch, batch_idx):
|
| 209 |
+
x = batch["image"]
|
| 210 |
+
feat1 = self.get_feat(x)
|
| 211 |
+
out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
|
| 212 |
+
feat2 = self.get_feat(x.flip(3))
|
| 213 |
+
out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
|
| 214 |
+
pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
|
| 215 |
+
return {
|
| 216 |
+
"original_index": batch["original_index"],
|
| 217 |
+
"label": batch["label"],
|
| 218 |
+
"label_species": batch["label_species"],
|
| 219 |
+
"pred_logit": pred_logit[:, :1000],
|
| 220 |
+
"pred_idx": pred_idx[:, :1000],
|
| 221 |
+
"pred_species": ((out_species1 + out_species2) / 2).cpu(),
|
| 222 |
+
"embed_features1": feat1.cpu(),
|
| 223 |
+
"embed_features2": feat2.cpu(),
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
|
| 227 |
+
outputs = self.all_gather(outputs)
|
| 228 |
+
if self.trainer.global_rank == 0:
|
| 229 |
+
epoch_results: Dict[str, np.ndarray] = {}
|
| 230 |
+
for key in outputs[0].keys():
|
| 231 |
+
if torch.cuda.device_count() > 1:
|
| 232 |
+
result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
|
| 233 |
+
else:
|
| 234 |
+
result = torch.cat([x[key] for x in outputs], dim=0)
|
| 235 |
+
epoch_results[key] = result.detach().cpu().numpy()
|
| 236 |
+
np.savez_compressed(self.test_results_fp, **epoch_results)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# def train(
|
| 240 |
+
# df: pd.DataFrame,
|
| 241 |
+
# args: argparse.Namespace,
|
| 242 |
+
# cfg: Config,
|
| 243 |
+
# fold: int,
|
| 244 |
+
# do_inference: bool = False,
|
| 245 |
+
# additional_dataset: WhaleDataset = None,
|
| 246 |
+
# optuna_trial: Optional[optuna.Trial] = None,
|
| 247 |
+
# ) -> Optional[float]:
|
| 248 |
+
# out_dir = f"{args.out_base_dir}/{args.exp_name}/{fold}"
|
| 249 |
+
# id_class_nums = df.individual_id.value_counts().sort_index().values
|
| 250 |
+
# species_class_nums = df.species.value_counts().sort_index().values
|
| 251 |
+
# model = SphereClassifier(cfg, id_class_nums=id_class_nums, species_class_nums=species_class_nums)
|
| 252 |
+
# data_module = WhaleDataModule(
|
| 253 |
+
# df, cfg, f"{args.in_base_dir}/train_images", cfg.val_bbox, fold, additional_dataset=additional_dataset
|
| 254 |
+
# )
|
| 255 |
+
# loggers = [pl_loggers.CSVLogger(out_dir)]
|
| 256 |
+
# if args.wandb_logger:
|
| 257 |
+
# loggers.append(
|
| 258 |
+
# pl_loggers.WandbLogger(
|
| 259 |
+
# project="kaggle-happywhale", group=args.exp_name, name=f"{args.exp_name}/{fold}", save_dir=out_dir
|
| 260 |
+
# )
|
| 261 |
+
# )
|
| 262 |
+
# callbacks = [LearningRateMonitor("epoch")]
|
| 263 |
+
# if optuna_trial is not None:
|
| 264 |
+
# callbacks.append(PyTorchLightningPruningCallback(optuna_trial, "val/mapNone"))
|
| 265 |
+
# if args.save_checkpoint:
|
| 266 |
+
# callbacks.append(ModelCheckpoint(out_dir, save_last=True, save_top_k=0))
|
| 267 |
+
# trainer = Trainer(
|
| 268 |
+
# gpus=torch.cuda.device_count(),
|
| 269 |
+
# max_epochs=cfg["max_epochs"],
|
| 270 |
+
# logger=loggers,
|
| 271 |
+
# callbacks=callbacks,
|
| 272 |
+
# checkpoint_callback=args.save_checkpoint,
|
| 273 |
+
# precision=16,
|
| 274 |
+
# sync_batchnorm=True,
|
| 275 |
+
# )
|
| 276 |
+
# ckpt_path = f"{out_dir}/last.ckpt"
|
| 277 |
+
# if not os.path.exists(ckpt_path) or not args.load_snapshot:
|
| 278 |
+
# ckpt_path = None
|
| 279 |
+
# trainer.fit(model, ckpt_path=ckpt_path, datamodule=data_module)
|
| 280 |
+
# if do_inference:
|
| 281 |
+
# for test_bbox in cfg.test_bboxes:
|
| 282 |
+
# # all train data
|
| 283 |
+
# model.test_results_fp = f"{out_dir}/train_{test_bbox}_results.npz"
|
| 284 |
+
# trainer.test(model, data_module.all_dataloader())
|
| 285 |
+
# # test data
|
| 286 |
+
# model.test_results_fp = f"{out_dir}/test_{test_bbox}_results.npz"
|
| 287 |
+
# df_test = load_df(args.in_base_dir, cfg, "sample_submission.csv", False)
|
| 288 |
+
# test_data_module = WhaleDataModule(df_test, cfg, f"{args.in_base_dir}/test_images", test_bbox, -1)
|
| 289 |
+
# trainer.test(model, test_data_module.all_dataloader())
|
| 290 |
+
|
| 291 |
+
# if args.wandb_logger:
|
| 292 |
+
# wandb.finish()
|
| 293 |
+
# if optuna_trial is not None:
|
| 294 |
+
# return trainer.callback_metrics["val/mapNone"].item()
|
| 295 |
+
# else:
|
| 296 |
+
# return None
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# def main():
|
| 300 |
+
# args = parse()
|
| 301 |
+
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
| 302 |
+
# cfg = load_config(args.config_path, "config/default.yaml")
|
| 303 |
+
# print(cfg)
|
| 304 |
+
# df = load_df(args.in_base_dir, cfg, "train.csv", True)
|
| 305 |
+
# pseudo_dataset = None
|
| 306 |
+
# if cfg.pseudo_label is not None:
|
| 307 |
+
# pseudo_df = load_df(args.in_base_dir, cfg, cfg.pseudo_label, False)
|
| 308 |
+
# pseudo_dataset = WhaleDataset(
|
| 309 |
+
# pseudo_df[pseudo_df.conf > cfg.pseudo_conf_threshold], cfg, f"{args.in_base_dir}/test_images", "", True
|
| 310 |
+
# )
|
| 311 |
+
# if cfg["n_splits"] == -1:
|
| 312 |
+
# train(df, args, cfg, -1, do_inference=True, additional_dataset=pseudo_dataset)
|
| 313 |
+
# else:
|
| 314 |
+
# train(df, args, cfg, 0, do_inference=True, additional_dataset=pseudo_dataset)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# if __name__ == "__main__":
|
| 318 |
+
# main()
|
utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class WarmupCosineLambda:
|
| 8 |
+
def __init__(self, warmup_steps: int, cycle_steps: int, decay_scale: float, exponential_warmup: bool = False):
|
| 9 |
+
self.warmup_steps = warmup_steps
|
| 10 |
+
self.cycle_steps = cycle_steps
|
| 11 |
+
self.decay_scale = decay_scale
|
| 12 |
+
self.exponential_warmup = exponential_warmup
|
| 13 |
+
|
| 14 |
+
def __call__(self, epoch: int):
|
| 15 |
+
if epoch < self.warmup_steps:
|
| 16 |
+
if self.exponential_warmup:
|
| 17 |
+
return self.decay_scale * pow(self.decay_scale, -epoch / self.warmup_steps)
|
| 18 |
+
ratio = epoch / self.warmup_steps
|
| 19 |
+
else:
|
| 20 |
+
ratio = (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.cycle_steps)) / 2
|
| 21 |
+
return self.decay_scale + (1 - self.decay_scale) * ratio
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def topk_average_precision(output: torch.Tensor, y: torch.Tensor, k: int):
|
| 25 |
+
score_array = torch.tensor([1.0 / i for i in range(1, k + 1)], device=output.device)
|
| 26 |
+
topk = output.topk(k)[1]
|
| 27 |
+
match_mat = topk == y[:, None].expand(topk.shape)
|
| 28 |
+
return (match_mat * score_array).sum(dim=1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def calc_map5(output: torch.Tensor, y: torch.Tensor, threshold: Optional[float]):
|
| 32 |
+
if threshold is not None:
|
| 33 |
+
output = torch.cat([output, torch.full((output.shape[0], 1), threshold, device=output.device)], dim=1)
|
| 34 |
+
return topk_average_precision(output, y, 5).mean().detach()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def map_dict(output: torch.Tensor, y: torch.Tensor, prefix: str):
|
| 38 |
+
d = {f"{prefix}/acc": topk_average_precision(output, y, 1).mean().detach()}
|
| 39 |
+
for threshold in [None, 0.3, 0.4, 0.5, 0.6, 0.7]:
|
| 40 |
+
d[f"{prefix}/map{threshold}"] = calc_map5(output, y, threshold)
|
| 41 |
+
return d
|