added dinov2 weights
Browse files- checkpoints/dinov2.bin +3 -0
- script.py +45 -39
checkpoints/dinov2.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78971dc00a0c488f2b2dff17d6dcb7ebe787af70a703d8212b38fc6a33dbcdd4
|
| 3 |
+
size 1217608166
|
script.py
CHANGED
|
@@ -4,12 +4,13 @@ import pandas as pd
|
|
| 4 |
import timm
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
import torchvision.transforms as T
|
| 9 |
from PIL import Image
|
| 10 |
from timm.models.metaformer import MlpHead
|
| 11 |
from torch.utils.data import DataLoader, Dataset
|
| 12 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
DIM = 518
|
| 15 |
DATE_SIZE = 4
|
|
@@ -99,11 +100,11 @@ SUBSTRATE = [
|
|
| 99 |
class ImageDataset(Dataset):
|
| 100 |
def __init__(self, df, local_filepath):
|
| 101 |
self.df = df
|
| 102 |
-
self.transform =
|
| 103 |
[
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
]
|
| 108 |
)
|
| 109 |
|
|
@@ -117,9 +118,10 @@ class ImageDataset(Dataset):
|
|
| 117 |
def __getitem__(self, idx):
|
| 118 |
image_path = os.path.join(self.local_filepath, self.filepaths[idx])
|
| 119 |
|
| 120 |
-
image =
|
|
|
|
| 121 |
|
| 122 |
-
return self.transform(image)
|
| 123 |
|
| 124 |
|
| 125 |
class EmbeddingMetadataDataset(Dataset):
|
|
@@ -270,11 +272,10 @@ class FungiMEEModel(nn.Module):
|
|
| 270 |
|
| 271 |
class FungiEnsembleModel(nn.Module):
|
| 272 |
|
| 273 |
-
def __init__(self, models
|
| 274 |
super().__init__()
|
| 275 |
|
| 276 |
self.models = nn.ModuleList()
|
| 277 |
-
self.softmax = softmax
|
| 278 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 279 |
|
| 280 |
for model in models:
|
|
@@ -291,12 +292,7 @@ class FungiEnsembleModel(nn.Module):
|
|
| 291 |
for model in self.models:
|
| 292 |
logits = model.forward(img_emb, metadata)
|
| 293 |
|
| 294 |
-
p = (
|
| 295 |
-
logits.softmax(dim=1).detach().cpu()
|
| 296 |
-
if self.softmax
|
| 297 |
-
else logits.detach().cpu()
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
probs.append(p)
|
| 301 |
|
| 302 |
return torch.stack(probs).mean(dim=0)
|
|
@@ -314,25 +310,32 @@ def make_submission(metadata_df):
|
|
| 314 |
OUTPUT_CSV_PATH = "./submission.csv"
|
| 315 |
BASE_CKPT_PATH = "./checkpoints"
|
| 316 |
|
| 317 |
-
|
|
|
|
| 318 |
|
| 319 |
-
models = []
|
| 320 |
|
| 321 |
-
for model_path in model_names:
|
| 322 |
-
|
| 323 |
-
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
|
| 333 |
-
|
| 334 |
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
embedding_dataset = EmbeddingMetadataDataset(metadata_df)
|
| 338 |
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
|
|
@@ -340,7 +343,7 @@ def make_submission(metadata_df):
|
|
| 340 |
preds = []
|
| 341 |
for data in tqdm(loader):
|
| 342 |
emb, metadata = data
|
| 343 |
-
pred =
|
| 344 |
preds.append(pred)
|
| 345 |
|
| 346 |
all_preds = torch.vstack(preds).numpy()
|
|
@@ -363,18 +366,21 @@ def make_submission(metadata_df):
|
|
| 363 |
|
| 364 |
if __name__ == "__main__":
|
| 365 |
|
| 366 |
-
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
-
with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
|
| 370 |
-
|
| 371 |
|
| 372 |
-
metadata_file_path = "./_test_preprocessed.csv"
|
| 373 |
-
root_dir = "/tmp/data"
|
| 374 |
|
| 375 |
# Test submission
|
| 376 |
-
|
| 377 |
-
|
| 378 |
|
| 379 |
##############
|
| 380 |
|
|
|
|
| 4 |
import timm
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
|
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
from timm.models.metaformer import MlpHead
|
| 9 |
from torch.utils.data import DataLoader, Dataset
|
| 10 |
from tqdm import tqdm
|
| 11 |
+
from albumentations import Compose, Normalize, Resize
|
| 12 |
+
from albumentations.pytorch import ToTensorV2
|
| 13 |
+
import cv2
|
| 14 |
|
| 15 |
DIM = 518
|
| 16 |
DATE_SIZE = 4
|
|
|
|
| 100 |
class ImageDataset(Dataset):
|
| 101 |
def __init__(self, df, local_filepath):
|
| 102 |
self.df = df
|
| 103 |
+
self.transform = Compose(
|
| 104 |
[
|
| 105 |
+
Resize(DIM, DIM),
|
| 106 |
+
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 107 |
+
ToTensorV2(),
|
| 108 |
]
|
| 109 |
)
|
| 110 |
|
|
|
|
| 118 |
def __getitem__(self, idx):
|
| 119 |
image_path = os.path.join(self.local_filepath, self.filepaths[idx])
|
| 120 |
|
| 121 |
+
image = cv2.imread(image_path)
|
| 122 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 123 |
|
| 124 |
+
return self.transform(image=image)['image']
|
| 125 |
|
| 126 |
|
| 127 |
class EmbeddingMetadataDataset(Dataset):
|
|
|
|
| 272 |
|
| 273 |
class FungiEnsembleModel(nn.Module):
|
| 274 |
|
| 275 |
+
def __init__(self, models) -> None:
|
| 276 |
super().__init__()
|
| 277 |
|
| 278 |
self.models = nn.ModuleList()
|
|
|
|
| 279 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 280 |
|
| 281 |
for model in models:
|
|
|
|
| 292 |
for model in self.models:
|
| 293 |
logits = model.forward(img_emb, metadata)
|
| 294 |
|
| 295 |
+
p = logits.softmax(dim=1).detach().cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
probs.append(p)
|
| 297 |
|
| 298 |
return torch.stack(probs).mean(dim=0)
|
|
|
|
| 310 |
OUTPUT_CSV_PATH = "./submission.csv"
|
| 311 |
BASE_CKPT_PATH = "./checkpoints"
|
| 312 |
|
| 313 |
+
ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
|
| 314 |
+
# model_names = os.listdir(BASE_CKPT_PATH)
|
| 315 |
|
| 316 |
+
# models = []
|
| 317 |
|
| 318 |
+
# for model_path in model_names:
|
| 319 |
+
# print("loading ", model_path)
|
| 320 |
+
# ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
|
| 321 |
|
| 322 |
+
# ckpt = torch.load(ckpt_path)
|
| 323 |
+
# model = FungiMEEModel()
|
| 324 |
+
# model.load_state_dict(
|
| 325 |
+
# {w: ckpt["model." + w] for w in model.state_dict().keys()}
|
| 326 |
+
# )
|
| 327 |
+
# model.eval()
|
| 328 |
+
# model.cuda()
|
| 329 |
|
| 330 |
+
# models.append(model)
|
| 331 |
|
| 332 |
+
# fungi_model = FungiEnsembleModel(models)
|
| 333 |
+
|
| 334 |
+
fungi_model = FungiMEEModel()
|
| 335 |
+
ckpt = torch.load(ckpt_path)
|
| 336 |
+
fungi_model.load_state_dict(
|
| 337 |
+
{w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
|
| 338 |
+
)
|
| 339 |
|
| 340 |
embedding_dataset = EmbeddingMetadataDataset(metadata_df)
|
| 341 |
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
|
|
|
|
| 343 |
preds = []
|
| 344 |
for data in tqdm(loader):
|
| 345 |
emb, metadata = data
|
| 346 |
+
pred = fungi_model.forward(emb, metadata)
|
| 347 |
preds.append(pred)
|
| 348 |
|
| 349 |
all_preds = torch.vstack(preds).numpy()
|
|
|
|
| 366 |
|
| 367 |
if __name__ == "__main__":
|
| 368 |
|
| 369 |
+
MODEL_PATH = "metaformer-s-224.pth"
|
| 370 |
+
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
|
| 371 |
+
|
| 372 |
+
# # # # Real submission
|
| 373 |
+
# import zipfile
|
| 374 |
|
| 375 |
+
# with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
|
| 376 |
+
# zip_ref.extractall("/tmp/data")
|
| 377 |
|
| 378 |
+
# metadata_file_path = "./_test_preprocessed.csv"
|
| 379 |
+
# root_dir = "/tmp/data"
|
| 380 |
|
| 381 |
# Test submission
|
| 382 |
+
metadata_file_path = "../trial_submission.csv"
|
| 383 |
+
root_dir = "../data/DF_FULL"
|
| 384 |
|
| 385 |
##############
|
| 386 |
|