Upload 5 files
Browse files- args.json +26 -0
- events.out.tfevents.1776157992.gpu-pro6000-5.cluster02.eee.ntu.edu.sg.3637795.0 +3 -0
- explore_m2f_finetune_v2.py +416 -0
- log.txt +77 -0
- model_epoch=050.ckpt +3 -0
args.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"data_path": "/home/satoshi.tsutsui/projects/wbcas/dataset_txt/pbc_attr_v1_ccrop_all.csv",
|
| 3 |
+
"data_root": "/home/satoshi.tsutsui/satoshissd/PBC/pbcseg_final_v1/",
|
| 4 |
+
"model_name": "facebook/mask2former-swin-tiny-ade-semantic",
|
| 5 |
+
"resolution": 1024,
|
| 6 |
+
"out_resolution": 360,
|
| 7 |
+
"num_classes": 6,
|
| 8 |
+
"ignore_index": 0,
|
| 9 |
+
"no_flip": false,
|
| 10 |
+
"use_crop": false,
|
| 11 |
+
"use_color": true,
|
| 12 |
+
"freeze_encoder": false,
|
| 13 |
+
"freeze_decoder": false,
|
| 14 |
+
"lr": 3e-05,
|
| 15 |
+
"weight_decay": 0.01,
|
| 16 |
+
"epochs": 50,
|
| 17 |
+
"batch_size": 16,
|
| 18 |
+
"num_workers": 1,
|
| 19 |
+
"pflip": 0.5,
|
| 20 |
+
"grad_clip": 1.0,
|
| 21 |
+
"seed": 42,
|
| 22 |
+
"label_smoothing": 0.1,
|
| 23 |
+
"device": "cuda",
|
| 24 |
+
"save_dir": "./experiments",
|
| 25 |
+
"exp_name": "m2f_tiny_1024"
|
| 26 |
+
}
|
events.out.tfevents.1776157992.gpu-pro6000-5.cluster02.eee.ntu.edu.sg.3637795.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:890f04ef48de4fb90b670995166acb779950b3f5339e23c08585cbbc7cca0640
|
| 3 |
+
size 1083794
|
explore_m2f_finetune_v2.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import argparse
|
| 6 |
+
import shutil
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import numpy as np
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 14 |
+
from torch.utils.data import Dataset, DataLoader
|
| 15 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 16 |
+
from torchvision.datasets.folder import default_loader
|
| 17 |
+
from torchvision.transforms import v2, RandomHorizontalFlip, RandomVerticalFlip, InterpolationMode
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
# Required for the real Mask2Former backbone
|
| 21 |
+
from transformers import Mask2FormerForUniversalSegmentation
|
| 22 |
+
|
| 23 |
+
# --- REPRODUCIBILITY ---
|
| 24 |
+
|
| 25 |
+
def set_seed(seed):
|
| 26 |
+
random.seed(seed)
|
| 27 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 28 |
+
np.random.seed(seed)
|
| 29 |
+
torch.manual_seed(seed)
|
| 30 |
+
torch.cuda.manual_seed(seed)
|
| 31 |
+
torch.cuda.manual_seed_all(seed)
|
| 32 |
+
torch.backends.cudnn.benchmark = False
|
| 33 |
+
torch.backends.cudnn.deterministic = True
|
| 34 |
+
|
| 35 |
+
# --- LOGGER UTILITY ---
|
| 36 |
+
|
| 37 |
+
class Logger(object):
|
| 38 |
+
def __init__(self, filename="log.txt"):
|
| 39 |
+
self.terminal = sys.stdout
|
| 40 |
+
self.log = open(filename, "a")
|
| 41 |
+
|
| 42 |
+
def write(self, message):
|
| 43 |
+
self.terminal.write(message)
|
| 44 |
+
self.log.write(message)
|
| 45 |
+
self.log.flush()
|
| 46 |
+
|
| 47 |
+
def flush(self):
|
| 48 |
+
self.terminal.flush()
|
| 49 |
+
self.log.flush()
|
| 50 |
+
|
| 51 |
+
# Add these two methods to fix compatibility error
|
| 52 |
+
def isatty(self):
|
| 53 |
+
return self.terminal.isatty()
|
| 54 |
+
|
| 55 |
+
def fileno(self):
|
| 56 |
+
return self.terminal.fileno()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_args():
|
| 60 |
+
parser = argparse.ArgumentParser(description="Mask2Former Fine-tuning")
|
| 61 |
+
|
| 62 |
+
# Paths
|
| 63 |
+
parser.add_argument("--data_path", type=str, default="/home/satoshi.tsutsui/projects/wbcas/dataset_txt/pbc_attr_v1_ccrop_all.csv")
|
| 64 |
+
parser.add_argument("--data_root", type=str, default="/home/satoshi.tsutsui/satoshissd/PBC/pbcseg_final_v1/")
|
| 65 |
+
|
| 66 |
+
# Model & Resolution
|
| 67 |
+
parser.add_argument("--model_name", type=str, default="facebook/mask2former-swin-tiny-ade-semantic",
|
| 68 |
+
help="Hugging Face model checkpoint")
|
| 69 |
+
parser.add_argument("--resolution", type=int, default=1024, help="Input image resolution")
|
| 70 |
+
parser.add_argument("--out_resolution", type=int, default=360, help="Output/Dataset resolution")
|
| 71 |
+
parser.add_argument("--num_classes", type=int, default=6, help="Number of target classes")
|
| 72 |
+
parser.add_argument("--ignore_index", type=int, default=0, help="Class index to ignore")
|
| 73 |
+
|
| 74 |
+
# Augmentation options (new)
|
| 75 |
+
parser.add_argument("--no_flip", action="store_true", help="Disable flips")
|
| 76 |
+
parser.add_argument("--use_crop", action="store_true", help="Enable random crop + resize")
|
| 77 |
+
parser.add_argument("--use_color", action="store_true", help="Enable color jitter")
|
| 78 |
+
|
| 79 |
+
# Freezing Options
|
| 80 |
+
parser.add_argument("--freeze_encoder", action="store_true", default=False)
|
| 81 |
+
parser.add_argument("--freeze_decoder", action="store_true", default=False)
|
| 82 |
+
|
| 83 |
+
# Training Hyperparameters
|
| 84 |
+
parser.add_argument("--lr", type=float, default=0.0001)
|
| 85 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
| 86 |
+
parser.add_argument("--epochs", type=int, default=50)
|
| 87 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
| 88 |
+
parser.add_argument("--num_workers", type=int, default=1)
|
| 89 |
+
parser.add_argument("--pflip", type=float, default=0.5)
|
| 90 |
+
parser.add_argument("--grad_clip", type=float, default=1.0)
|
| 91 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 92 |
+
parser.add_argument("--label_smoothing", type=float, default=0.1)
|
| 93 |
+
|
| 94 |
+
# Paths & Device
|
| 95 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 96 |
+
parser.add_argument("--save_dir", type=str, default="./experiments")
|
| 97 |
+
parser.add_argument("--exp_name", type=str, default="m2f_finetune")
|
| 98 |
+
|
| 99 |
+
return parser.parse_args()
|
| 100 |
+
|
| 101 |
+
# --- MASK2FORMER WRAPPER ---
|
| 102 |
+
|
| 103 |
+
class Mask2FormerWrapper(nn.Module):
|
| 104 |
+
def __init__(self, model_name, num_classes, out_resolution):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
|
| 107 |
+
model_name,
|
| 108 |
+
num_labels=num_classes,
|
| 109 |
+
ignore_mismatched_sizes=True
|
| 110 |
+
)
|
| 111 |
+
self.out_resolution = out_resolution
|
| 112 |
+
self.num_classes = num_classes
|
| 113 |
+
|
| 114 |
+
def forward(self, images):
|
| 115 |
+
outputs = self.model(pixel_values=images)
|
| 116 |
+
cls_logits = outputs.class_queries_logits
|
| 117 |
+
mask_logits = outputs.masks_queries_logits
|
| 118 |
+
|
| 119 |
+
cls_probs = F.softmax(cls_logits, dim=-1)
|
| 120 |
+
mask_probs = torch.sigmoid(mask_logits)
|
| 121 |
+
|
| 122 |
+
b, q, h_small, w_small = mask_probs.shape
|
| 123 |
+
mask_probs_flat = mask_probs.view(b, q, h_small * w_small)
|
| 124 |
+
|
| 125 |
+
# Reconstruct semantic map from queries
|
| 126 |
+
semantic_map = torch.bmm(cls_probs[:, :, :self.num_classes].transpose(1, 2), mask_probs_flat)
|
| 127 |
+
semantic_map = semantic_map.view(b, self.num_classes, h_small, w_small)
|
| 128 |
+
|
| 129 |
+
# Resize to 360x360 for loss and metrics
|
| 130 |
+
return F.interpolate(semantic_map, size=(self.out_resolution, self.out_resolution),
|
| 131 |
+
mode="bilinear", align_corners=False)
|
| 132 |
+
|
| 133 |
+
# --- DATASET ---
|
| 134 |
+
|
| 135 |
+
class SegDataset(Dataset):
|
| 136 |
+
def __init__(self, df, img_col="img_path", mask_col="mask_path",
|
| 137 |
+
backbone_res=512, # ← NEW: input resolution
|
| 138 |
+
transform=None, pflip=0.0, flip=True, crop=False, color=False):
|
| 139 |
+
self.df = df
|
| 140 |
+
self.img_col = img_col
|
| 141 |
+
self.mask_col = mask_col
|
| 142 |
+
|
| 143 |
+
# Augmentation flags
|
| 144 |
+
self.flip = flip and pflip > 0
|
| 145 |
+
self.crop = crop
|
| 146 |
+
self.color = color
|
| 147 |
+
self.backbone_res = backbone_res # store for resizing
|
| 148 |
+
|
| 149 |
+
# Flip transforms (synced)
|
| 150 |
+
if flip and pflip > 0:
|
| 151 |
+
self.flip_transforms = v2.Compose([
|
| 152 |
+
RandomHorizontalFlip(p=0.5),
|
| 153 |
+
RandomVerticalFlip(p=0.5)
|
| 154 |
+
])
|
| 155 |
+
else:
|
| 156 |
+
self.flip_transforms = lambda x: x
|
| 157 |
+
|
| 158 |
+
def __len__(self):
|
| 159 |
+
return len(self.df)
|
| 160 |
+
|
| 161 |
+
# 🔥 Custom random resized crop → always outputs to out_resolution (360)
|
| 162 |
+
def random_resized_crop(self, img, mask,
|
| 163 |
+
scale=(0.4, 1.0), ratio=(0.75, 1.33),
|
| 164 |
+
out_size=360):
|
| 165 |
+
_, h, w = img.shape
|
| 166 |
+
area = h * w
|
| 167 |
+
|
| 168 |
+
for _ in range(10):
|
| 169 |
+
target_area = random.uniform(*scale) * area
|
| 170 |
+
aspect_ratio = random.uniform(*ratio)
|
| 171 |
+
|
| 172 |
+
new_w = int(round((target_area * aspect_ratio) ** 0.5))
|
| 173 |
+
new_h = int(round((target_area / aspect_ratio) ** 0.5))
|
| 174 |
+
|
| 175 |
+
if new_w <= w and new_h <= h:
|
| 176 |
+
top = random.randint(0, h - new_h)
|
| 177 |
+
left = random.randint(0, w - new_w)
|
| 178 |
+
|
| 179 |
+
img_crop = v2.functional.crop(img, top, left, new_h, new_w)
|
| 180 |
+
mask_crop = v2.functional.crop(mask, top, left, new_h, new_w)
|
| 181 |
+
|
| 182 |
+
img_resized = v2.functional.resize(
|
| 183 |
+
img_crop,
|
| 184 |
+
(out_size, out_size),
|
| 185 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 186 |
+
antialias=True
|
| 187 |
+
)
|
| 188 |
+
mask_resized = v2.functional.resize(
|
| 189 |
+
mask_crop,
|
| 190 |
+
(out_size, out_size),
|
| 191 |
+
interpolation=InterpolationMode.NEAREST_EXACT
|
| 192 |
+
)
|
| 193 |
+
return img_resized, mask_resized
|
| 194 |
+
|
| 195 |
+
# fallback center crop
|
| 196 |
+
min_side = min(h, w)
|
| 197 |
+
top = (h - min_side) // 2
|
| 198 |
+
left = (w - min_side) // 2
|
| 199 |
+
|
| 200 |
+
img_crop = v2.functional.crop(img, top, left, min_side, min_side)
|
| 201 |
+
mask_crop = v2.functional.crop(mask, top, left, min_side, min_side)
|
| 202 |
+
|
| 203 |
+
img_resized = v2.functional.resize(
|
| 204 |
+
img_crop,
|
| 205 |
+
(out_size, out_size),
|
| 206 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 207 |
+
antialias=True
|
| 208 |
+
)
|
| 209 |
+
mask_resized = v2.functional.resize(
|
| 210 |
+
mask_crop,
|
| 211 |
+
(out_size, out_size),
|
| 212 |
+
interpolation=InterpolationMode.NEAREST_EXACT
|
| 213 |
+
)
|
| 214 |
+
return img_resized, mask_resized
|
| 215 |
+
|
| 216 |
+
def __getitem__(self, idx):
|
| 217 |
+
img = v2.functional.to_image(default_loader(self.df.iloc[idx][self.img_col]))
|
| 218 |
+
mask = v2.functional.to_image(default_loader(self.df.iloc[idx][self.mask_col]))
|
| 219 |
+
|
| 220 |
+
# Sync flip
|
| 221 |
+
state = torch.get_rng_state()
|
| 222 |
+
img = self.flip_transforms(img)
|
| 223 |
+
torch.set_rng_state(state)
|
| 224 |
+
mask = self.flip_transforms(mask)
|
| 225 |
+
|
| 226 |
+
# Custom crop (output = 360)
|
| 227 |
+
if self.crop:
|
| 228 |
+
img, mask = self.random_resized_crop(img, mask)
|
| 229 |
+
|
| 230 |
+
# Color jitter
|
| 231 |
+
if self.color:
|
| 232 |
+
img = v2.ColorJitter(brightness=0.2, contrast=0.2)(img)
|
| 233 |
+
|
| 234 |
+
# Normalize & dtype
|
| 235 |
+
img = v2.functional.to_dtype(img, torch.float32, scale=True)
|
| 236 |
+
img = v2.functional.normalize(
|
| 237 |
+
img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# ✅ FIXED: use configurable backbone_res instead of hardcoded 512
|
| 241 |
+
img = v2.functional.resize(img, (self.backbone_res, self.backbone_res), antialias=True)
|
| 242 |
+
|
| 243 |
+
mask = mask.long()[0]
|
| 244 |
+
return {"input": img, "mask": mask}
|
| 245 |
+
|
| 246 |
+
# --- METRICS HELPERS ---
|
| 247 |
+
|
| 248 |
+
def compute_conf_matrix(pred, target, num_classes):
|
| 249 |
+
mask = (target >= 0) & (target < num_classes)
|
| 250 |
+
return torch.bincount(
|
| 251 |
+
num_classes * target[mask].view(-1) + pred[mask].view(-1),
|
| 252 |
+
minlength=num_classes**2
|
| 253 |
+
).reshape(num_classes, num_classes)
|
| 254 |
+
|
| 255 |
+
def calculate_metrics(conf_matrix, ignore_index=None):
|
| 256 |
+
ious = []
|
| 257 |
+
conf_matrix = conf_matrix.float()
|
| 258 |
+
num_classes = conf_matrix.shape[0]
|
| 259 |
+
for i in range(num_classes):
|
| 260 |
+
tp = conf_matrix[i, i]
|
| 261 |
+
fp = conf_matrix[:, i].sum() - tp
|
| 262 |
+
fn = conf_matrix[i, :].sum() - tp
|
| 263 |
+
denom = tp + fp + fn
|
| 264 |
+
iou = tp / denom if denom > 0 else torch.tensor(float('nan'))
|
| 265 |
+
ious.append(iou.item())
|
| 266 |
+
relevant_ious = [iou for i, iou in enumerate(ious) if i != ignore_index and not np.isnan(iou)]
|
| 267 |
+
miou = np.mean(relevant_ious) if relevant_ious else 0
|
| 268 |
+
return miou, ious
|
| 269 |
+
|
| 270 |
+
def validate(model, loader, criterion, device, num_classes, ignore_index, stage="val"):
|
| 271 |
+
model.eval()
|
| 272 |
+
total_loss, conf_matrix = 0, torch.zeros(num_classes, num_classes, device=device)
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
for item in tqdm(loader, desc=f"evaluating_{stage}", leave=False):
|
| 275 |
+
images, masks = item['input'].to(device), item['mask'].to(device).long()
|
| 276 |
+
with torch.autocast(device, dtype=torch.bfloat16):
|
| 277 |
+
outputs = model(images)
|
| 278 |
+
loss = criterion(outputs, masks)
|
| 279 |
+
total_loss += loss.item()
|
| 280 |
+
conf_matrix += compute_conf_matrix(torch.argmax(outputs, dim=1), masks, num_classes)
|
| 281 |
+
avg_loss = total_loss / len(loader)
|
| 282 |
+
miou, class_ious = calculate_metrics(conf_matrix, ignore_index=ignore_index)
|
| 283 |
+
return avg_loss, miou, class_ious
|
| 284 |
+
|
| 285 |
+
# --- MAIN EXECUTION ---
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
args = get_args()
|
| 289 |
+
set_seed(args.seed)
|
| 290 |
+
|
| 291 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 292 |
+
|
| 293 |
+
# --- Folder naming based on augmentations ---
|
| 294 |
+
aug_suffix = ""
|
| 295 |
+
if args.use_crop: aug_suffix += "_crop"
|
| 296 |
+
if args.use_color: aug_suffix += "_color"
|
| 297 |
+
if args.no_flip: aug_suffix += "_noflip"
|
| 298 |
+
|
| 299 |
+
exp_name = f"{args.exp_name}{aug_suffix}"
|
| 300 |
+
|
| 301 |
+
run_dir = os.path.join(args.save_dir, f"{exp_name}_{timestamp}")
|
| 302 |
+
os.makedirs(run_dir, exist_ok=True)
|
| 303 |
+
shutil.copy(__file__, os.path.join(run_dir, os.path.basename(__file__)))
|
| 304 |
+
sys.stdout = Logger(os.path.join(run_dir, "log.txt"))
|
| 305 |
+
|
| 306 |
+
print(f"--- Experiment: {exp_name} ---")
|
| 307 |
+
print(f"Arguments: {json.dumps(vars(args), indent=4)}")
|
| 308 |
+
|
| 309 |
+
with open(os.path.join(run_dir, "args.json"), "w") as f:
|
| 310 |
+
json.dump(vars(args), f, indent=4)
|
| 311 |
+
writer = SummaryWriter(log_dir=run_dir)
|
| 312 |
+
|
| 313 |
+
df = pd.read_csv(args.data_path)
|
| 314 |
+
df['img_path'] = args.data_root + df['img_name']
|
| 315 |
+
df['mask_path'] = df['img_path'].apply(lambda x: x.replace(".jpg", "_mask.png"))
|
| 316 |
+
|
| 317 |
+
# Build transforms for input images (note: resized to 512 as before)
|
| 318 |
+
transform = v2.Compose([
|
| 319 |
+
v2.ToImage(),
|
| 320 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 321 |
+
])
|
| 322 |
+
|
| 323 |
+
seg_model = Mask2FormerWrapper(args.model_name, args.num_classes, args.out_resolution).to(args.device)
|
| 324 |
+
|
| 325 |
+
if args.freeze_encoder:
|
| 326 |
+
for p in seg_model.model.model.backbone.parameters(): p.requires_grad = False
|
| 327 |
+
if args.freeze_decoder:
|
| 328 |
+
for p in seg_model.model.model.pixel_decoder.parameters(): p.requires_grad = False
|
| 329 |
+
for p in seg_model.model.model.transformer_module.parameters(): p.requires_grad = False
|
| 330 |
+
|
| 331 |
+
trainable_params = [p for p in seg_model.parameters() if p.requires_grad]
|
| 332 |
+
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
|
| 333 |
+
|
| 334 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) # simplified per request
|
| 335 |
+
|
| 336 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=min(1e-6, args.lr / 100))
|
| 337 |
+
|
| 338 |
+
# Pass augmentation flags to dataset
|
| 339 |
+
train_loader = DataLoader(
|
| 340 |
+
SegDataset(df[df['split']=="train"],
|
| 341 |
+
backbone_res=args.resolution, # ← use CLI arg
|
| 342 |
+
pflip=args.pflip if not args.no_flip else 0.0,
|
| 343 |
+
flip=not args.no_flip,
|
| 344 |
+
crop=args.use_crop, color=args.use_color),
|
| 345 |
+
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
|
| 346 |
+
pin_memory=True
|
| 347 |
+
)
|
| 348 |
+
val_loader = DataLoader(
|
| 349 |
+
SegDataset(df[df['split']=="val"],
|
| 350 |
+
backbone_res=args.resolution,
|
| 351 |
+
pflip=0, flip=False, crop=False, color=False),
|
| 352 |
+
batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
|
| 353 |
+
)
|
| 354 |
+
test_loader = DataLoader(
|
| 355 |
+
SegDataset(df[df['split']=="test"],
|
| 356 |
+
backbone_res=args.resolution,
|
| 357 |
+
pflip=0, flip=False, crop=False, color=False),
|
| 358 |
+
batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
global_step = 0
|
| 362 |
+
for epoch in range(args.epochs):
|
| 363 |
+
seg_model.train()
|
| 364 |
+
pbar = tqdm(train_loader, desc=f"epoch {epoch+1}/{args.epochs}")
|
| 365 |
+
epoch_loss = 0
|
| 366 |
+
|
| 367 |
+
for item in pbar:
|
| 368 |
+
images, masks = item['input'].to(args.device), item['mask'].to(args.device).long()
|
| 369 |
+
optimizer.zero_grad()
|
| 370 |
+
|
| 371 |
+
with torch.autocast(args.device, dtype=torch.bfloat16):
|
| 372 |
+
loss = criterion(seg_model(images), masks)
|
| 373 |
+
|
| 374 |
+
loss.backward()
|
| 375 |
+
if args.grad_clip > 0:
|
| 376 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clip)
|
| 377 |
+
optimizer.step()
|
| 378 |
+
epoch_loss += loss.item()
|
| 379 |
+
|
| 380 |
+
writer.add_scalar("loss_train_step", loss.item(), global_step)
|
| 381 |
+
global_step += 1
|
| 382 |
+
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 383 |
+
|
| 384 |
+
scheduler.step()
|
| 385 |
+
avg_train_loss = epoch_loss / len(train_loader)
|
| 386 |
+
|
| 387 |
+
val_loss, val_miou, val_ious = validate(seg_model, val_loader, criterion,
|
| 388 |
+
args.device, args.num_classes, args.ignore_index, "val")
|
| 389 |
+
test_loss, test_miou, test_ious = validate(seg_model, test_loader, criterion,
|
| 390 |
+
args.device, args.num_classes, args.ignore_index, "test")
|
| 391 |
+
|
| 392 |
+
writer.add_scalar("loss_train_epoch", avg_train_loss, epoch)
|
| 393 |
+
writer.add_scalar("loss_val", val_loss, epoch)
|
| 394 |
+
writer.add_scalar("loss_test", test_loss, epoch)
|
| 395 |
+
writer.add_scalar("miou_val", val_miou, epoch)
|
| 396 |
+
writer.add_scalar("miou_test", test_miou, epoch)
|
| 397 |
+
|
| 398 |
+
for i, iou in enumerate(val_ious):
|
| 399 |
+
if i != args.ignore_index:
|
| 400 |
+
writer.add_scalar(f"iou_val_class_{i}", iou, epoch)
|
| 401 |
+
for i, iou in enumerate(test_ious):
|
| 402 |
+
if i != args.ignore_index:
|
| 403 |
+
writer.add_scalar(f"iou_test_class_{i}", iou, epoch)
|
| 404 |
+
|
| 405 |
+
log_msg = (f"Epoch {epoch+1:03d} | Train Loss: {avg_train_loss:.4f} | "
|
| 406 |
+
f"Val Loss: {val_loss:.4f} | Val mIoU: {val_miou:.4f} | "
|
| 407 |
+
f"Test mIoU: {test_miou:.4f}")
|
| 408 |
+
print(log_msg)
|
| 409 |
+
|
| 410 |
+
torch.save({
|
| 411 |
+
'epoch': epoch + 1,
|
| 412 |
+
'model_state_dict': seg_model.state_dict(),
|
| 413 |
+
'val_miou': val_miou
|
| 414 |
+
}, os.path.join(run_dir, f"model_epoch={epoch+1:03d}.ckpt"))
|
| 415 |
+
|
| 416 |
+
writer.close()
|
log.txt
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--- Experiment: m2f_tiny_1024_color ---
|
| 2 |
+
Arguments: {
|
| 3 |
+
"data_path": "/home/satoshi.tsutsui/projects/wbcas/dataset_txt/pbc_attr_v1_ccrop_all.csv",
|
| 4 |
+
"data_root": "/home/satoshi.tsutsui/satoshissd/PBC/pbcseg_final_v1/",
|
| 5 |
+
"model_name": "facebook/mask2former-swin-tiny-ade-semantic",
|
| 6 |
+
"resolution": 1024,
|
| 7 |
+
"out_resolution": 360,
|
| 8 |
+
"num_classes": 6,
|
| 9 |
+
"ignore_index": 0,
|
| 10 |
+
"no_flip": false,
|
| 11 |
+
"use_crop": false,
|
| 12 |
+
"use_color": true,
|
| 13 |
+
"freeze_encoder": false,
|
| 14 |
+
"freeze_decoder": false,
|
| 15 |
+
"lr": 3e-05,
|
| 16 |
+
"weight_decay": 0.01,
|
| 17 |
+
"epochs": 50,
|
| 18 |
+
"batch_size": 16,
|
| 19 |
+
"num_workers": 1,
|
| 20 |
+
"pflip": 0.5,
|
| 21 |
+
"grad_clip": 1.0,
|
| 22 |
+
"seed": 42,
|
| 23 |
+
"label_smoothing": 0.1,
|
| 24 |
+
"device": "cuda",
|
| 25 |
+
"save_dir": "./experiments",
|
| 26 |
+
"exp_name": "m2f_tiny_1024"
|
| 27 |
+
}
|
| 28 |
+
Epoch 001 | Train Loss: 0.4798 | Val Loss: 0.4378 | Val mIoU: 0.8971 | Test mIoU: 0.9000
|
| 29 |
+
Epoch 002 | Train Loss: 0.4374 | Val Loss: 0.4356 | Val mIoU: 0.9088 | Test mIoU: 0.9121
|
| 30 |
+
Epoch 003 | Train Loss: 0.4356 | Val Loss: 0.4344 | Val mIoU: 0.9175 | Test mIoU: 0.9201
|
| 31 |
+
Epoch 004 | Train Loss: 0.4348 | Val Loss: 0.4342 | Val mIoU: 0.9199 | Test mIoU: 0.9216
|
| 32 |
+
Epoch 005 | Train Loss: 0.4342 | Val Loss: 0.4338 | Val mIoU: 0.9184 | Test mIoU: 0.9247
|
| 33 |
+
Epoch 006 | Train Loss: 0.4339 | Val Loss: 0.4336 | Val mIoU: 0.9204 | Test mIoU: 0.9260
|
| 34 |
+
Epoch 007 | Train Loss: 0.4337 | Val Loss: 0.4338 | Val mIoU: 0.9193 | Test mIoU: 0.9224
|
| 35 |
+
Epoch 008 | Train Loss: 0.4335 | Val Loss: 0.4336 | Val mIoU: 0.9273 | Test mIoU: 0.9301
|
| 36 |
+
Epoch 009 | Train Loss: 0.4332 | Val Loss: 0.4335 | Val mIoU: 0.9222 | Test mIoU: 0.9295
|
| 37 |
+
Epoch 010 | Train Loss: 0.4332 | Val Loss: 0.4329 | Val mIoU: 0.9312 | Test mIoU: 0.9312
|
| 38 |
+
Epoch 011 | Train Loss: 0.4328 | Val Loss: 0.4328 | Val mIoU: 0.9318 | Test mIoU: 0.9320
|
| 39 |
+
Epoch 012 | Train Loss: 0.4329 | Val Loss: 0.4327 | Val mIoU: 0.9282 | Test mIoU: 0.9309
|
| 40 |
+
Epoch 013 | Train Loss: 0.4325 | Val Loss: 0.4328 | Val mIoU: 0.9276 | Test mIoU: 0.9299
|
| 41 |
+
Epoch 014 | Train Loss: 0.4324 | Val Loss: 0.4327 | Val mIoU: 0.9300 | Test mIoU: 0.9334
|
| 42 |
+
Epoch 015 | Train Loss: 0.4326 | Val Loss: 0.4329 | Val mIoU: 0.9306 | Test mIoU: 0.9332
|
| 43 |
+
Epoch 016 | Train Loss: 0.4323 | Val Loss: 0.4326 | Val mIoU: 0.9289 | Test mIoU: 0.9303
|
| 44 |
+
Epoch 017 | Train Loss: 0.4321 | Val Loss: 0.4324 | Val mIoU: 0.9300 | Test mIoU: 0.9297
|
| 45 |
+
Epoch 018 | Train Loss: 0.4321 | Val Loss: 0.4326 | Val mIoU: 0.9331 | Test mIoU: 0.9330
|
| 46 |
+
Epoch 019 | Train Loss: 0.4321 | Val Loss: 0.4325 | Val mIoU: 0.9343 | Test mIoU: 0.9358
|
| 47 |
+
Epoch 020 | Train Loss: 0.4318 | Val Loss: 0.4324 | Val mIoU: 0.9350 | Test mIoU: 0.9352
|
| 48 |
+
Epoch 021 | Train Loss: 0.4318 | Val Loss: 0.4322 | Val mIoU: 0.9347 | Test mIoU: 0.9345
|
| 49 |
+
Epoch 022 | Train Loss: 0.4317 | Val Loss: 0.4322 | Val mIoU: 0.9357 | Test mIoU: 0.9357
|
| 50 |
+
Epoch 023 | Train Loss: 0.4317 | Val Loss: 0.4323 | Val mIoU: 0.9331 | Test mIoU: 0.9339
|
| 51 |
+
Epoch 024 | Train Loss: 0.4316 | Val Loss: 0.4322 | Val mIoU: 0.9355 | Test mIoU: 0.9354
|
| 52 |
+
Epoch 025 | Train Loss: 0.4315 | Val Loss: 0.4323 | Val mIoU: 0.9344 | Test mIoU: 0.9364
|
| 53 |
+
Epoch 026 | Train Loss: 0.4314 | Val Loss: 0.4322 | Val mIoU: 0.9352 | Test mIoU: 0.9353
|
| 54 |
+
Epoch 027 | Train Loss: 0.4313 | Val Loss: 0.4321 | Val mIoU: 0.9375 | Test mIoU: 0.9359
|
| 55 |
+
Epoch 028 | Train Loss: 0.4313 | Val Loss: 0.4321 | Val mIoU: 0.9373 | Test mIoU: 0.9370
|
| 56 |
+
Epoch 029 | Train Loss: 0.4312 | Val Loss: 0.4321 | Val mIoU: 0.9366 | Test mIoU: 0.9359
|
| 57 |
+
Epoch 030 | Train Loss: 0.4312 | Val Loss: 0.4321 | Val mIoU: 0.9379 | Test mIoU: 0.9368
|
| 58 |
+
Epoch 031 | Train Loss: 0.4311 | Val Loss: 0.4321 | Val mIoU: 0.9376 | Test mIoU: 0.9378
|
| 59 |
+
Epoch 032 | Train Loss: 0.4311 | Val Loss: 0.4319 | Val mIoU: 0.9392 | Test mIoU: 0.9373
|
| 60 |
+
Epoch 033 | Train Loss: 0.4310 | Val Loss: 0.4321 | Val mIoU: 0.9378 | Test mIoU: 0.9362
|
| 61 |
+
Epoch 034 | Train Loss: 0.4310 | Val Loss: 0.4321 | Val mIoU: 0.9384 | Test mIoU: 0.9353
|
| 62 |
+
Epoch 035 | Train Loss: 0.4309 | Val Loss: 0.4320 | Val mIoU: 0.9386 | Test mIoU: 0.9368
|
| 63 |
+
Epoch 036 | Train Loss: 0.4309 | Val Loss: 0.4320 | Val mIoU: 0.9377 | Test mIoU: 0.9369
|
| 64 |
+
Epoch 037 | Train Loss: 0.4308 | Val Loss: 0.4320 | Val mIoU: 0.9372 | Test mIoU: 0.9372
|
| 65 |
+
Epoch 038 | Train Loss: 0.4308 | Val Loss: 0.4320 | Val mIoU: 0.9383 | Test mIoU: 0.9365
|
| 66 |
+
Epoch 039 | Train Loss: 0.4308 | Val Loss: 0.4319 | Val mIoU: 0.9386 | Test mIoU: 0.9369
|
| 67 |
+
Epoch 040 | Train Loss: 0.4307 | Val Loss: 0.4319 | Val mIoU: 0.9379 | Test mIoU: 0.9368
|
| 68 |
+
Epoch 041 | Train Loss: 0.4307 | Val Loss: 0.4319 | Val mIoU: 0.9377 | Test mIoU: 0.9365
|
| 69 |
+
Epoch 042 | Train Loss: 0.4307 | Val Loss: 0.4319 | Val mIoU: 0.9378 | Test mIoU: 0.9367
|
| 70 |
+
Epoch 043 | Train Loss: 0.4307 | Val Loss: 0.4319 | Val mIoU: 0.9380 | Test mIoU: 0.9368
|
| 71 |
+
Epoch 044 | Train Loss: 0.4307 | Val Loss: 0.4319 | Val mIoU: 0.9385 | Test mIoU: 0.9369
|
| 72 |
+
Epoch 045 | Train Loss: 0.4306 | Val Loss: 0.4319 | Val mIoU: 0.9383 | Test mIoU: 0.9369
|
| 73 |
+
Epoch 046 | Train Loss: 0.4306 | Val Loss: 0.4319 | Val mIoU: 0.9385 | Test mIoU: 0.9365
|
| 74 |
+
Epoch 047 | Train Loss: 0.4306 | Val Loss: 0.4319 | Val mIoU: 0.9385 | Test mIoU: 0.9367
|
| 75 |
+
Epoch 048 | Train Loss: 0.4306 | Val Loss: 0.4319 | Val mIoU: 0.9386 | Test mIoU: 0.9369
|
| 76 |
+
Epoch 049 | Train Loss: 0.4306 | Val Loss: 0.4319 | Val mIoU: 0.9383 | Test mIoU: 0.9368
|
| 77 |
+
Epoch 050 | Train Loss: 0.4306 | Val Loss: 0.4319 | Val mIoU: 0.9385 | Test mIoU: 0.9367
|
model_epoch=050.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3cc444616f5e41a215b17ae5a1bab76de7dc124d8407a08124f9eba4c02297aa
|
| 3 |
+
size 190100915
|