wbcsegmentor_m2f_tiny / explore_m2f_finetune_v2.py
apple2373's picture
Upload 5 files
3e1d5f8 verified
import os
import sys
import json
import random
import argparse
import shutil
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from datetime import datetime
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets.folder import default_loader
from torchvision.transforms import v2, RandomHorizontalFlip, RandomVerticalFlip, InterpolationMode
from tqdm.auto import tqdm
# Required for the real Mask2Former backbone
from transformers import Mask2FormerForUniversalSegmentation
# --- REPRODUCIBILITY ---
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# --- LOGGER UTILITY ---
class Logger(object):
def __init__(self, filename="log.txt"):
self.terminal = sys.stdout
self.log = open(filename, "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
self.terminal.flush()
self.log.flush()
# Add these two methods to fix compatibility error
def isatty(self):
return self.terminal.isatty()
def fileno(self):
return self.terminal.fileno()
def get_args():
parser = argparse.ArgumentParser(description="Mask2Former Fine-tuning")
# Paths
parser.add_argument("--data_path", type=str, default="/home/satoshi.tsutsui/projects/wbcas/dataset_txt/pbc_attr_v1_ccrop_all.csv")
parser.add_argument("--data_root", type=str, default="/home/satoshi.tsutsui/satoshissd/PBC/pbcseg_final_v1/")
# Model & Resolution
parser.add_argument("--model_name", type=str, default="facebook/mask2former-swin-tiny-ade-semantic",
help="Hugging Face model checkpoint")
parser.add_argument("--resolution", type=int, default=1024, help="Input image resolution")
parser.add_argument("--out_resolution", type=int, default=360, help="Output/Dataset resolution")
parser.add_argument("--num_classes", type=int, default=6, help="Number of target classes")
parser.add_argument("--ignore_index", type=int, default=0, help="Class index to ignore")
# Augmentation options (new)
parser.add_argument("--no_flip", action="store_true", help="Disable flips")
parser.add_argument("--use_crop", action="store_true", help="Enable random crop + resize")
parser.add_argument("--use_color", action="store_true", help="Enable color jitter")
# Freezing Options
parser.add_argument("--freeze_encoder", action="store_true", default=False)
parser.add_argument("--freeze_decoder", action="store_true", default=False)
# Training Hyperparameters
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--pflip", type=float, default=0.5)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--label_smoothing", type=float, default=0.1)
# Paths & Device
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--save_dir", type=str, default="./experiments")
parser.add_argument("--exp_name", type=str, default="m2f_finetune")
return parser.parse_args()
# --- MASK2FORMER WRAPPER ---
class Mask2FormerWrapper(nn.Module):
def __init__(self, model_name, num_classes, out_resolution):
super().__init__()
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
model_name,
num_labels=num_classes,
ignore_mismatched_sizes=True
)
self.out_resolution = out_resolution
self.num_classes = num_classes
def forward(self, images):
outputs = self.model(pixel_values=images)
cls_logits = outputs.class_queries_logits
mask_logits = outputs.masks_queries_logits
cls_probs = F.softmax(cls_logits, dim=-1)
mask_probs = torch.sigmoid(mask_logits)
b, q, h_small, w_small = mask_probs.shape
mask_probs_flat = mask_probs.view(b, q, h_small * w_small)
# Reconstruct semantic map from queries
semantic_map = torch.bmm(cls_probs[:, :, :self.num_classes].transpose(1, 2), mask_probs_flat)
semantic_map = semantic_map.view(b, self.num_classes, h_small, w_small)
# Resize to 360x360 for loss and metrics
return F.interpolate(semantic_map, size=(self.out_resolution, self.out_resolution),
mode="bilinear", align_corners=False)
# --- DATASET ---
class SegDataset(Dataset):
def __init__(self, df, img_col="img_path", mask_col="mask_path",
backbone_res=512, # ← NEW: input resolution
transform=None, pflip=0.0, flip=True, crop=False, color=False):
self.df = df
self.img_col = img_col
self.mask_col = mask_col
# Augmentation flags
self.flip = flip and pflip > 0
self.crop = crop
self.color = color
self.backbone_res = backbone_res # store for resizing
# Flip transforms (synced)
if flip and pflip > 0:
self.flip_transforms = v2.Compose([
RandomHorizontalFlip(p=0.5),
RandomVerticalFlip(p=0.5)
])
else:
self.flip_transforms = lambda x: x
def __len__(self):
return len(self.df)
# 🔥 Custom random resized crop → always outputs to out_resolution (360)
def random_resized_crop(self, img, mask,
scale=(0.4, 1.0), ratio=(0.75, 1.33),
out_size=360):
_, h, w = img.shape
area = h * w
for _ in range(10):
target_area = random.uniform(*scale) * area
aspect_ratio = random.uniform(*ratio)
new_w = int(round((target_area * aspect_ratio) ** 0.5))
new_h = int(round((target_area / aspect_ratio) ** 0.5))
if new_w <= w and new_h <= h:
top = random.randint(0, h - new_h)
left = random.randint(0, w - new_w)
img_crop = v2.functional.crop(img, top, left, new_h, new_w)
mask_crop = v2.functional.crop(mask, top, left, new_h, new_w)
img_resized = v2.functional.resize(
img_crop,
(out_size, out_size),
interpolation=InterpolationMode.BILINEAR,
antialias=True
)
mask_resized = v2.functional.resize(
mask_crop,
(out_size, out_size),
interpolation=InterpolationMode.NEAREST_EXACT
)
return img_resized, mask_resized
# fallback center crop
min_side = min(h, w)
top = (h - min_side) // 2
left = (w - min_side) // 2
img_crop = v2.functional.crop(img, top, left, min_side, min_side)
mask_crop = v2.functional.crop(mask, top, left, min_side, min_side)
img_resized = v2.functional.resize(
img_crop,
(out_size, out_size),
interpolation=InterpolationMode.BILINEAR,
antialias=True
)
mask_resized = v2.functional.resize(
mask_crop,
(out_size, out_size),
interpolation=InterpolationMode.NEAREST_EXACT
)
return img_resized, mask_resized
def __getitem__(self, idx):
img = v2.functional.to_image(default_loader(self.df.iloc[idx][self.img_col]))
mask = v2.functional.to_image(default_loader(self.df.iloc[idx][self.mask_col]))
# Sync flip
state = torch.get_rng_state()
img = self.flip_transforms(img)
torch.set_rng_state(state)
mask = self.flip_transforms(mask)
# Custom crop (output = 360)
if self.crop:
img, mask = self.random_resized_crop(img, mask)
# Color jitter
if self.color:
img = v2.ColorJitter(brightness=0.2, contrast=0.2)(img)
# Normalize & dtype
img = v2.functional.to_dtype(img, torch.float32, scale=True)
img = v2.functional.normalize(
img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
# ✅ FIXED: use configurable backbone_res instead of hardcoded 512
img = v2.functional.resize(img, (self.backbone_res, self.backbone_res), antialias=True)
mask = mask.long()[0]
return {"input": img, "mask": mask}
# --- METRICS HELPERS ---
def compute_conf_matrix(pred, target, num_classes):
mask = (target >= 0) & (target < num_classes)
return torch.bincount(
num_classes * target[mask].view(-1) + pred[mask].view(-1),
minlength=num_classes**2
).reshape(num_classes, num_classes)
def calculate_metrics(conf_matrix, ignore_index=None):
ious = []
conf_matrix = conf_matrix.float()
num_classes = conf_matrix.shape[0]
for i in range(num_classes):
tp = conf_matrix[i, i]
fp = conf_matrix[:, i].sum() - tp
fn = conf_matrix[i, :].sum() - tp
denom = tp + fp + fn
iou = tp / denom if denom > 0 else torch.tensor(float('nan'))
ious.append(iou.item())
relevant_ious = [iou for i, iou in enumerate(ious) if i != ignore_index and not np.isnan(iou)]
miou = np.mean(relevant_ious) if relevant_ious else 0
return miou, ious
def validate(model, loader, criterion, device, num_classes, ignore_index, stage="val"):
model.eval()
total_loss, conf_matrix = 0, torch.zeros(num_classes, num_classes, device=device)
with torch.no_grad():
for item in tqdm(loader, desc=f"evaluating_{stage}", leave=False):
images, masks = item['input'].to(device), item['mask'].to(device).long()
with torch.autocast(device, dtype=torch.bfloat16):
outputs = model(images)
loss = criterion(outputs, masks)
total_loss += loss.item()
conf_matrix += compute_conf_matrix(torch.argmax(outputs, dim=1), masks, num_classes)
avg_loss = total_loss / len(loader)
miou, class_ious = calculate_metrics(conf_matrix, ignore_index=ignore_index)
return avg_loss, miou, class_ious
# --- MAIN EXECUTION ---
if __name__ == "__main__":
args = get_args()
set_seed(args.seed)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# --- Folder naming based on augmentations ---
aug_suffix = ""
if args.use_crop: aug_suffix += "_crop"
if args.use_color: aug_suffix += "_color"
if args.no_flip: aug_suffix += "_noflip"
exp_name = f"{args.exp_name}{aug_suffix}"
run_dir = os.path.join(args.save_dir, f"{exp_name}_{timestamp}")
os.makedirs(run_dir, exist_ok=True)
shutil.copy(__file__, os.path.join(run_dir, os.path.basename(__file__)))
sys.stdout = Logger(os.path.join(run_dir, "log.txt"))
print(f"--- Experiment: {exp_name} ---")
print(f"Arguments: {json.dumps(vars(args), indent=4)}")
with open(os.path.join(run_dir, "args.json"), "w") as f:
json.dump(vars(args), f, indent=4)
writer = SummaryWriter(log_dir=run_dir)
df = pd.read_csv(args.data_path)
df['img_path'] = args.data_root + df['img_name']
df['mask_path'] = df['img_path'].apply(lambda x: x.replace(".jpg", "_mask.png"))
# Build transforms for input images (note: resized to 512 as before)
transform = v2.Compose([
v2.ToImage(),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
seg_model = Mask2FormerWrapper(args.model_name, args.num_classes, args.out_resolution).to(args.device)
if args.freeze_encoder:
for p in seg_model.model.model.backbone.parameters(): p.requires_grad = False
if args.freeze_decoder:
for p in seg_model.model.model.pixel_decoder.parameters(): p.requires_grad = False
for p in seg_model.model.model.transformer_module.parameters(): p.requires_grad = False
trainable_params = [p for p in seg_model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) # simplified per request
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=min(1e-6, args.lr / 100))
# Pass augmentation flags to dataset
train_loader = DataLoader(
SegDataset(df[df['split']=="train"],
backbone_res=args.resolution, # ← use CLI arg
pflip=args.pflip if not args.no_flip else 0.0,
flip=not args.no_flip,
crop=args.use_crop, color=args.use_color),
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
pin_memory=True
)
val_loader = DataLoader(
SegDataset(df[df['split']=="val"],
backbone_res=args.resolution,
pflip=0, flip=False, crop=False, color=False),
batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
)
test_loader = DataLoader(
SegDataset(df[df['split']=="test"],
backbone_res=args.resolution,
pflip=0, flip=False, crop=False, color=False),
batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
)
global_step = 0
for epoch in range(args.epochs):
seg_model.train()
pbar = tqdm(train_loader, desc=f"epoch {epoch+1}/{args.epochs}")
epoch_loss = 0
for item in pbar:
images, masks = item['input'].to(args.device), item['mask'].to(args.device).long()
optimizer.zero_grad()
with torch.autocast(args.device, dtype=torch.bfloat16):
loss = criterion(seg_model(images), masks)
loss.backward()
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clip)
optimizer.step()
epoch_loss += loss.item()
writer.add_scalar("loss_train_step", loss.item(), global_step)
global_step += 1
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
scheduler.step()
avg_train_loss = epoch_loss / len(train_loader)
val_loss, val_miou, val_ious = validate(seg_model, val_loader, criterion,
args.device, args.num_classes, args.ignore_index, "val")
test_loss, test_miou, test_ious = validate(seg_model, test_loader, criterion,
args.device, args.num_classes, args.ignore_index, "test")
writer.add_scalar("loss_train_epoch", avg_train_loss, epoch)
writer.add_scalar("loss_val", val_loss, epoch)
writer.add_scalar("loss_test", test_loss, epoch)
writer.add_scalar("miou_val", val_miou, epoch)
writer.add_scalar("miou_test", test_miou, epoch)
for i, iou in enumerate(val_ious):
if i != args.ignore_index:
writer.add_scalar(f"iou_val_class_{i}", iou, epoch)
for i, iou in enumerate(test_ious):
if i != args.ignore_index:
writer.add_scalar(f"iou_test_class_{i}", iou, epoch)
log_msg = (f"Epoch {epoch+1:03d} | Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | Val mIoU: {val_miou:.4f} | "
f"Test mIoU: {test_miou:.4f}")
print(log_msg)
torch.save({
'epoch': epoch + 1,
'model_state_dict': seg_model.state_dict(),
'val_miou': val_miou
}, os.path.join(run_dir, f"model_epoch={epoch+1:03d}.ckpt"))
writer.close()