from __future__ import annotations import argparse import json import os from pathlib import Path from typing import Any import numpy as np import pandas as pd import torch from PIL import Image, ImageDraw from torch.utils.data import DataLoader, Subset from tqdm import tqdm from src.data.dataset import MonogramSegmentationDataset, collate_batch from src.losses import SegmentationLoss from src.metrics import binary_metrics from src.models import make_model def _wandb_config(args: argparse.Namespace) -> dict[str, Any]: return { key: str(value) if isinstance(value, Path) else value for key, value in vars(args).items() if key != "wandb_tags" } def _start_wandb(args: argparse.Namespace): if not args.wandb: return None wandb_root = Path(args.output_dir) / "wandb" wandb_root.mkdir(parents=True, exist_ok=True) os.environ.setdefault("WANDB_DIR", str(wandb_root)) os.environ.setdefault("WANDB_CACHE_DIR", str(wandb_root / "cache")) os.environ.setdefault("WANDB_CONFIG_DIR", str(wandb_root / "config")) import wandb run = wandb.init( project=args.wandb_project, entity=args.wandb_entity or None, name=args.wandb_name or None, group=args.wandb_group or None, mode=args.wandb_mode, tags=[tag.strip() for tag in args.wandb_tags.split(",") if tag.strip()], config=_wandb_config(args), dir=str(Path(args.output_dir)), ) wandb.define_metric("epoch") wandb.define_metric("*", step_metric="epoch") return run def _log_artifact(run: Any, path: Path, name: str, artifact_type: str) -> None: if run is None or not path.exists(): return import wandb artifact = wandb.Artifact(name=name, type=artifact_type) artifact.add_file(str(path)) run.log_artifact(artifact) def _tensor_rgb(image: torch.Tensor) -> np.ndarray: return ((image.detach().cpu().clamp(-1, 1) * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype("uint8") def _tensor_mask(mask: torch.Tensor) -> np.ndarray: return (mask.detach().cpu().clamp(0, 1).squeeze(0).numpy() * 255).astype("uint8") def _overlay(image: np.ndarray, mask: np.ndarray, color: tuple[int, int, int]) -> np.ndarray: color_arr = np.zeros_like(image) color_arr[..., 0] = color[0] color_arr[..., 1] = color[1] color_arr[..., 2] = color[2] return np.where(mask[..., None] > 0, (0.55 * image + 0.45 * color_arr).astype("uint8"), image) @torch.no_grad() def _save_validation_preview( model: torch.nn.Module, loader: DataLoader, device: torch.device, threshold: float, out_dir: Path, epoch: int, max_images: int, ) -> Path | None: if max_images <= 0: return None model.eval() try: batch = next(iter(loader)) except StopIteration: model.train() return None image = batch["image"].to(device) logits = model(image) pred = (torch.sigmoid(logits) >= threshold).float() count = min(max_images, image.size(0)) tile_w = int(image.shape[-1]) tile_h = int(image.shape[-2]) label_h = 26 rows: list[Image.Image] = [] for idx in range(count): rgb = _tensor_rgb(batch["image"][idx]) gt = _tensor_mask(batch["mask"][idx]) pred_mask = _tensor_mask(pred[idx]) panels = [ Image.fromarray(rgb), Image.fromarray(gt).convert("RGB"), Image.fromarray(pred_mask).convert("RGB"), Image.fromarray(_overlay(rgb, pred_mask, (255, 0, 0))), ] row = Image.new("RGB", (tile_w * len(panels), tile_h + label_h), "white") for panel_idx, panel in enumerate(panels): row.paste(panel.resize((tile_w, tile_h)), (panel_idx * tile_w, 0)) draw = ImageDraw.Draw(row) caption = f"{batch['sample_id'][idx]} | {batch['collection'][idx]} | {batch['quality_label'][idx]}" draw.text((4, tile_h + 6), caption, fill=(0, 0, 0)) rows.append(row) preview_dir = out_dir / "previews" preview_dir.mkdir(parents=True, exist_ok=True) sheet = Image.new("RGB", (tile_w * 4, (tile_h + label_h) * len(rows)), "white") for idx, row in enumerate(rows): sheet.paste(row, (0, idx * (tile_h + label_h))) path = preview_dir / f"val_epoch_{epoch:04d}.jpg" sheet.save(path, quality=95) model.train() return path def _device(name: str) -> torch.device: if name == "cuda" and not torch.cuda.is_available(): return torch.device("cpu") return torch.device(name) def _loader( csv_path: Path, batch_size: int, image_size: int, augment: bool, limit: int | None, shuffle: bool, num_workers: int, ) -> DataLoader: frame = pd.read_csv(csv_path) dataset = MonogramSegmentationDataset(frame, image_size=image_size, augment=augment) if limit is not None: dataset = Subset(dataset, list(range(min(limit, len(dataset))))) return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=torch.cuda.is_available(), collate_fn=collate_batch, ) @torch.no_grad() def validate(model: torch.nn.Module, loader: DataLoader, device: torch.device, threshold: float) -> dict[str, float]: model.eval() rows: list[dict[str, float]] = [] for batch in loader: image = batch["image"].to(device) mask = batch["mask"].to(device) rows.extend(binary_metrics(model(image), mask, threshold)) model.train() if not rows: return {} return {key: float(sum(row[key] for row in rows) / len(rows)) for key in rows[0]} def train(args: argparse.Namespace) -> Path: device = _device(args.device) train_loader = _loader(Path(args.train_csv), args.batch_size, args.image_size, True, args.limit_train, True, args.num_workers) val_loader = _loader(Path(args.val_csv), args.batch_size, args.image_size, False, args.limit_val, False, args.num_workers) model_kwargs: dict[str, Any] = {} if args.model == "unet": model_kwargs["encoder_name"] = args.encoder_name model_kwargs["encoder_weights"] = None if args.no_pretrained else args.encoder_weights elif args.model == "segformer": model_kwargs["model_name"] = args.segformer_model_name model_kwargs["pretrained"] = args.segformer_pretrained model_kwargs["freeze_encoder"] = args.segformer_freeze_encoder model_kwargs["unfreeze_last_n_blocks"] = args.segformer_unfreeze_last_n_blocks model = make_model(args.model, **model_kwargs).to(device) loss_fn = SegmentationLoss(args.bce_weight, args.dice_weight, args.cldice_weight) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) run = _start_wandb(args) best_iou = -1.0 history: list[dict[str, float | int]] = [] try: for epoch in range(1, args.epochs + 1): totals: dict[str, float] = {} steps = 0 for batch in tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}"): image = batch["image"].to(device) mask = batch["mask"].to(device) optimizer.zero_grad(set_to_none=True) loss, loss_parts = loss_fn(model(image), mask) loss.backward() optimizer.step() for key, value in loss_parts.items(): totals[key] = totals.get(key, 0.0) + float(value.item()) steps += 1 train_row = {f"train_{key}": value / max(steps, 1) for key, value in totals.items()} val_row = {f"val_{key}": value for key, value in validate(model, val_loader, device, args.threshold).items()} row = {"epoch": epoch, **train_row, **val_row} history.append(row) (out_dir / "history.json").write_text(json.dumps(history, indent=2)) if run is not None: log_row = dict(row) preview_path = None if args.wandb_log_samples and (epoch == 1 or epoch % args.wandb_sample_every == 0 or epoch == args.epochs): preview_path = _save_validation_preview( model, val_loader, device, args.threshold, out_dir, epoch, args.wandb_sample_count, ) if preview_path is not None: import wandb log_row["val_examples"] = wandb.Image(str(preview_path), caption=f"epoch {epoch}: image | gt | pred | overlay") run.log(log_row) if val_row.get("val_iou", -1.0) > best_iou: best_iou = float(val_row["val_iou"]) torch.save({"model": model.state_dict(), "args": vars(args), "epoch": epoch}, out_dir / "best.pt") if run is not None: run.summary["best_val_iou"] = best_iou run.summary["best_epoch"] = epoch torch.save({"model": model.state_dict(), "args": vars(args), "epoch": args.epochs}, out_dir / "last.pt") if args.wandb_artifacts: artifact_base = args.wandb_artifact_name or f"{args.model}-{Path(args.output_dir).name}" _log_artifact(run, out_dir / "best.pt", f"{artifact_base}-best", "model") _log_artifact(run, out_dir / "last.pt", f"{artifact_base}-last", "model") _log_artifact(run, out_dir / "history.json", f"{artifact_base}-history", "training-history") finally: if run is not None: run.finish() return out_dir / "best.pt" def main() -> None: parser = argparse.ArgumentParser(description="Train a monogram stroke segmentation model.") parser.add_argument("--model", choices=["unet", "segformer"], default="unet") parser.add_argument("--train-csv", default="splits/e3_full_data/train.csv") parser.add_argument("--val-csv", default="splits/e3_full_data/val.csv") parser.add_argument("--output-dir", default="outputs/train") parser.add_argument("--image-size", type=int, default=512) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--weight-decay", type=float, default=1e-4) parser.add_argument("--threshold", type=float, default=0.5) parser.add_argument("--device", default="cuda") parser.add_argument("--limit-train", type=int) parser.add_argument("--limit-val", type=int) parser.add_argument("--no-pretrained", action="store_true") parser.add_argument("--encoder-name", default="resnet34") parser.add_argument("--encoder-weights", default="imagenet") parser.add_argument("--segformer-model-name", default="nvidia/mit-b2") parser.add_argument("--segformer-pretrained", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--segformer-freeze-encoder", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--segformer-unfreeze-last-n-blocks", type=int, default=2) parser.add_argument("--bce-weight", type=float, default=1.0) parser.add_argument("--dice-weight", type=float, default=1.0) parser.add_argument("--cldice-weight", type=float, default=0.5) parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging.") parser.add_argument("--wandb-project", default="segment-monograms") parser.add_argument("--wandb-entity", default="") parser.add_argument("--wandb-name", default="") parser.add_argument("--wandb-group", default="") parser.add_argument("--wandb-mode", choices=["online", "offline", "disabled"], default="online") parser.add_argument("--wandb-tags", default="") parser.add_argument("--wandb-artifacts", action="store_true", help="Upload best.pt, last.pt, and history.json as W&B artifacts.") parser.add_argument("--wandb-artifact-name", default="") parser.add_argument("--wandb-log-samples", action="store_true", help="Log validation image/mask/prediction previews to W&B.") parser.add_argument("--wandb-sample-every", type=int, default=5) parser.add_argument("--wandb-sample-count", type=int, default=4) args = parser.parse_args() print(train(args)) if __name__ == "__main__": main()