segment-monograms / src /train.py
Saranga7's picture
Deploy monogram segmentation demo
bcc432f verified
Raw
History Blame Contribute Delete
12.6 kB
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from src.data.dataset import MonogramSegmentationDataset, collate_batch
from src.losses import SegmentationLoss
from src.metrics import binary_metrics
from src.models import make_model
def _wandb_config(args: argparse.Namespace) -> dict[str, Any]:
return {
key: str(value) if isinstance(value, Path) else value
for key, value in vars(args).items()
if key != "wandb_tags"
}
def _start_wandb(args: argparse.Namespace):
if not args.wandb:
return None
wandb_root = Path(args.output_dir) / "wandb"
wandb_root.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("WANDB_DIR", str(wandb_root))
os.environ.setdefault("WANDB_CACHE_DIR", str(wandb_root / "cache"))
os.environ.setdefault("WANDB_CONFIG_DIR", str(wandb_root / "config"))
import wandb
run = wandb.init(
project=args.wandb_project,
entity=args.wandb_entity or None,
name=args.wandb_name or None,
group=args.wandb_group or None,
mode=args.wandb_mode,
tags=[tag.strip() for tag in args.wandb_tags.split(",") if tag.strip()],
config=_wandb_config(args),
dir=str(Path(args.output_dir)),
)
wandb.define_metric("epoch")
wandb.define_metric("*", step_metric="epoch")
return run
def _log_artifact(run: Any, path: Path, name: str, artifact_type: str) -> None:
if run is None or not path.exists():
return
import wandb
artifact = wandb.Artifact(name=name, type=artifact_type)
artifact.add_file(str(path))
run.log_artifact(artifact)
def _tensor_rgb(image: torch.Tensor) -> np.ndarray:
return ((image.detach().cpu().clamp(-1, 1) * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype("uint8")
def _tensor_mask(mask: torch.Tensor) -> np.ndarray:
return (mask.detach().cpu().clamp(0, 1).squeeze(0).numpy() * 255).astype("uint8")
def _overlay(image: np.ndarray, mask: np.ndarray, color: tuple[int, int, int]) -> np.ndarray:
color_arr = np.zeros_like(image)
color_arr[..., 0] = color[0]
color_arr[..., 1] = color[1]
color_arr[..., 2] = color[2]
return np.where(mask[..., None] > 0, (0.55 * image + 0.45 * color_arr).astype("uint8"), image)
@torch.no_grad()
def _save_validation_preview(
model: torch.nn.Module,
loader: DataLoader,
device: torch.device,
threshold: float,
out_dir: Path,
epoch: int,
max_images: int,
) -> Path | None:
if max_images <= 0:
return None
model.eval()
try:
batch = next(iter(loader))
except StopIteration:
model.train()
return None
image = batch["image"].to(device)
logits = model(image)
pred = (torch.sigmoid(logits) >= threshold).float()
count = min(max_images, image.size(0))
tile_w = int(image.shape[-1])
tile_h = int(image.shape[-2])
label_h = 26
rows: list[Image.Image] = []
for idx in range(count):
rgb = _tensor_rgb(batch["image"][idx])
gt = _tensor_mask(batch["mask"][idx])
pred_mask = _tensor_mask(pred[idx])
panels = [
Image.fromarray(rgb),
Image.fromarray(gt).convert("RGB"),
Image.fromarray(pred_mask).convert("RGB"),
Image.fromarray(_overlay(rgb, pred_mask, (255, 0, 0))),
]
row = Image.new("RGB", (tile_w * len(panels), tile_h + label_h), "white")
for panel_idx, panel in enumerate(panels):
row.paste(panel.resize((tile_w, tile_h)), (panel_idx * tile_w, 0))
draw = ImageDraw.Draw(row)
caption = f"{batch['sample_id'][idx]} | {batch['collection'][idx]} | {batch['quality_label'][idx]}"
draw.text((4, tile_h + 6), caption, fill=(0, 0, 0))
rows.append(row)
preview_dir = out_dir / "previews"
preview_dir.mkdir(parents=True, exist_ok=True)
sheet = Image.new("RGB", (tile_w * 4, (tile_h + label_h) * len(rows)), "white")
for idx, row in enumerate(rows):
sheet.paste(row, (0, idx * (tile_h + label_h)))
path = preview_dir / f"val_epoch_{epoch:04d}.jpg"
sheet.save(path, quality=95)
model.train()
return path
def _device(name: str) -> torch.device:
if name == "cuda" and not torch.cuda.is_available():
return torch.device("cpu")
return torch.device(name)
def _loader(
csv_path: Path,
batch_size: int,
image_size: int,
augment: bool,
limit: int | None,
shuffle: bool,
num_workers: int,
) -> DataLoader:
frame = pd.read_csv(csv_path)
dataset = MonogramSegmentationDataset(frame, image_size=image_size, augment=augment)
if limit is not None:
dataset = Subset(dataset, list(range(min(limit, len(dataset)))))
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_batch,
)
@torch.no_grad()
def validate(model: torch.nn.Module, loader: DataLoader, device: torch.device, threshold: float) -> dict[str, float]:
model.eval()
rows: list[dict[str, float]] = []
for batch in loader:
image = batch["image"].to(device)
mask = batch["mask"].to(device)
rows.extend(binary_metrics(model(image), mask, threshold))
model.train()
if not rows:
return {}
return {key: float(sum(row[key] for row in rows) / len(rows)) for key in rows[0]}
def train(args: argparse.Namespace) -> Path:
device = _device(args.device)
train_loader = _loader(Path(args.train_csv), args.batch_size, args.image_size, True, args.limit_train, True, args.num_workers)
val_loader = _loader(Path(args.val_csv), args.batch_size, args.image_size, False, args.limit_val, False, args.num_workers)
model_kwargs: dict[str, Any] = {}
if args.model == "unet":
model_kwargs["encoder_name"] = args.encoder_name
model_kwargs["encoder_weights"] = None if args.no_pretrained else args.encoder_weights
elif args.model == "segformer":
model_kwargs["model_name"] = args.segformer_model_name
model_kwargs["pretrained"] = args.segformer_pretrained
model_kwargs["freeze_encoder"] = args.segformer_freeze_encoder
model_kwargs["unfreeze_last_n_blocks"] = args.segformer_unfreeze_last_n_blocks
model = make_model(args.model, **model_kwargs).to(device)
loss_fn = SegmentationLoss(args.bce_weight, args.dice_weight, args.cldice_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
run = _start_wandb(args)
best_iou = -1.0
history: list[dict[str, float | int]] = []
try:
for epoch in range(1, args.epochs + 1):
totals: dict[str, float] = {}
steps = 0
for batch in tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}"):
image = batch["image"].to(device)
mask = batch["mask"].to(device)
optimizer.zero_grad(set_to_none=True)
loss, loss_parts = loss_fn(model(image), mask)
loss.backward()
optimizer.step()
for key, value in loss_parts.items():
totals[key] = totals.get(key, 0.0) + float(value.item())
steps += 1
train_row = {f"train_{key}": value / max(steps, 1) for key, value in totals.items()}
val_row = {f"val_{key}": value for key, value in validate(model, val_loader, device, args.threshold).items()}
row = {"epoch": epoch, **train_row, **val_row}
history.append(row)
(out_dir / "history.json").write_text(json.dumps(history, indent=2))
if run is not None:
log_row = dict(row)
preview_path = None
if args.wandb_log_samples and (epoch == 1 or epoch % args.wandb_sample_every == 0 or epoch == args.epochs):
preview_path = _save_validation_preview(
model,
val_loader,
device,
args.threshold,
out_dir,
epoch,
args.wandb_sample_count,
)
if preview_path is not None:
import wandb
log_row["val_examples"] = wandb.Image(str(preview_path), caption=f"epoch {epoch}: image | gt | pred | overlay")
run.log(log_row)
if val_row.get("val_iou", -1.0) > best_iou:
best_iou = float(val_row["val_iou"])
torch.save({"model": model.state_dict(), "args": vars(args), "epoch": epoch}, out_dir / "best.pt")
if run is not None:
run.summary["best_val_iou"] = best_iou
run.summary["best_epoch"] = epoch
torch.save({"model": model.state_dict(), "args": vars(args), "epoch": args.epochs}, out_dir / "last.pt")
if args.wandb_artifacts:
artifact_base = args.wandb_artifact_name or f"{args.model}-{Path(args.output_dir).name}"
_log_artifact(run, out_dir / "best.pt", f"{artifact_base}-best", "model")
_log_artifact(run, out_dir / "last.pt", f"{artifact_base}-last", "model")
_log_artifact(run, out_dir / "history.json", f"{artifact_base}-history", "training-history")
finally:
if run is not None:
run.finish()
return out_dir / "best.pt"
def main() -> None:
parser = argparse.ArgumentParser(description="Train a monogram stroke segmentation model.")
parser.add_argument("--model", choices=["unet", "segformer"], default="unet")
parser.add_argument("--train-csv", default="splits/e3_full_data/train.csv")
parser.add_argument("--val-csv", default="splits/e3_full_data/val.csv")
parser.add_argument("--output-dir", default="outputs/train")
parser.add_argument("--image-size", type=int, default=512)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--threshold", type=float, default=0.5)
parser.add_argument("--device", default="cuda")
parser.add_argument("--limit-train", type=int)
parser.add_argument("--limit-val", type=int)
parser.add_argument("--no-pretrained", action="store_true")
parser.add_argument("--encoder-name", default="resnet34")
parser.add_argument("--encoder-weights", default="imagenet")
parser.add_argument("--segformer-model-name", default="nvidia/mit-b2")
parser.add_argument("--segformer-pretrained", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--segformer-freeze-encoder", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--segformer-unfreeze-last-n-blocks", type=int, default=2)
parser.add_argument("--bce-weight", type=float, default=1.0)
parser.add_argument("--dice-weight", type=float, default=1.0)
parser.add_argument("--cldice-weight", type=float, default=0.5)
parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging.")
parser.add_argument("--wandb-project", default="segment-monograms")
parser.add_argument("--wandb-entity", default="")
parser.add_argument("--wandb-name", default="")
parser.add_argument("--wandb-group", default="")
parser.add_argument("--wandb-mode", choices=["online", "offline", "disabled"], default="online")
parser.add_argument("--wandb-tags", default="")
parser.add_argument("--wandb-artifacts", action="store_true", help="Upload best.pt, last.pt, and history.json as W&B artifacts.")
parser.add_argument("--wandb-artifact-name", default="")
parser.add_argument("--wandb-log-samples", action="store_true", help="Log validation image/mask/prediction previews to W&B.")
parser.add_argument("--wandb-sample-every", type=int, default=5)
parser.add_argument("--wandb-sample-count", type=int, default=4)
args = parser.parse_args()
print(train(args))
if __name__ == "__main__":
main()