Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from torch.utils.data import DataLoader, Subset | |
| from tqdm import tqdm | |
| from src.data.dataset import MonogramSegmentationDataset, collate_batch | |
| from src.losses import SegmentationLoss | |
| from src.metrics import binary_metrics | |
| from src.models import make_model | |
| def _wandb_config(args: argparse.Namespace) -> dict[str, Any]: | |
| return { | |
| key: str(value) if isinstance(value, Path) else value | |
| for key, value in vars(args).items() | |
| if key != "wandb_tags" | |
| } | |
| def _start_wandb(args: argparse.Namespace): | |
| if not args.wandb: | |
| return None | |
| wandb_root = Path(args.output_dir) / "wandb" | |
| wandb_root.mkdir(parents=True, exist_ok=True) | |
| os.environ.setdefault("WANDB_DIR", str(wandb_root)) | |
| os.environ.setdefault("WANDB_CACHE_DIR", str(wandb_root / "cache")) | |
| os.environ.setdefault("WANDB_CONFIG_DIR", str(wandb_root / "config")) | |
| import wandb | |
| run = wandb.init( | |
| project=args.wandb_project, | |
| entity=args.wandb_entity or None, | |
| name=args.wandb_name or None, | |
| group=args.wandb_group or None, | |
| mode=args.wandb_mode, | |
| tags=[tag.strip() for tag in args.wandb_tags.split(",") if tag.strip()], | |
| config=_wandb_config(args), | |
| dir=str(Path(args.output_dir)), | |
| ) | |
| wandb.define_metric("epoch") | |
| wandb.define_metric("*", step_metric="epoch") | |
| return run | |
| def _log_artifact(run: Any, path: Path, name: str, artifact_type: str) -> None: | |
| if run is None or not path.exists(): | |
| return | |
| import wandb | |
| artifact = wandb.Artifact(name=name, type=artifact_type) | |
| artifact.add_file(str(path)) | |
| run.log_artifact(artifact) | |
| def _tensor_rgb(image: torch.Tensor) -> np.ndarray: | |
| return ((image.detach().cpu().clamp(-1, 1) * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype("uint8") | |
| def _tensor_mask(mask: torch.Tensor) -> np.ndarray: | |
| return (mask.detach().cpu().clamp(0, 1).squeeze(0).numpy() * 255).astype("uint8") | |
| def _overlay(image: np.ndarray, mask: np.ndarray, color: tuple[int, int, int]) -> np.ndarray: | |
| color_arr = np.zeros_like(image) | |
| color_arr[..., 0] = color[0] | |
| color_arr[..., 1] = color[1] | |
| color_arr[..., 2] = color[2] | |
| return np.where(mask[..., None] > 0, (0.55 * image + 0.45 * color_arr).astype("uint8"), image) | |
| def _save_validation_preview( | |
| model: torch.nn.Module, | |
| loader: DataLoader, | |
| device: torch.device, | |
| threshold: float, | |
| out_dir: Path, | |
| epoch: int, | |
| max_images: int, | |
| ) -> Path | None: | |
| if max_images <= 0: | |
| return None | |
| model.eval() | |
| try: | |
| batch = next(iter(loader)) | |
| except StopIteration: | |
| model.train() | |
| return None | |
| image = batch["image"].to(device) | |
| logits = model(image) | |
| pred = (torch.sigmoid(logits) >= threshold).float() | |
| count = min(max_images, image.size(0)) | |
| tile_w = int(image.shape[-1]) | |
| tile_h = int(image.shape[-2]) | |
| label_h = 26 | |
| rows: list[Image.Image] = [] | |
| for idx in range(count): | |
| rgb = _tensor_rgb(batch["image"][idx]) | |
| gt = _tensor_mask(batch["mask"][idx]) | |
| pred_mask = _tensor_mask(pred[idx]) | |
| panels = [ | |
| Image.fromarray(rgb), | |
| Image.fromarray(gt).convert("RGB"), | |
| Image.fromarray(pred_mask).convert("RGB"), | |
| Image.fromarray(_overlay(rgb, pred_mask, (255, 0, 0))), | |
| ] | |
| row = Image.new("RGB", (tile_w * len(panels), tile_h + label_h), "white") | |
| for panel_idx, panel in enumerate(panels): | |
| row.paste(panel.resize((tile_w, tile_h)), (panel_idx * tile_w, 0)) | |
| draw = ImageDraw.Draw(row) | |
| caption = f"{batch['sample_id'][idx]} | {batch['collection'][idx]} | {batch['quality_label'][idx]}" | |
| draw.text((4, tile_h + 6), caption, fill=(0, 0, 0)) | |
| rows.append(row) | |
| preview_dir = out_dir / "previews" | |
| preview_dir.mkdir(parents=True, exist_ok=True) | |
| sheet = Image.new("RGB", (tile_w * 4, (tile_h + label_h) * len(rows)), "white") | |
| for idx, row in enumerate(rows): | |
| sheet.paste(row, (0, idx * (tile_h + label_h))) | |
| path = preview_dir / f"val_epoch_{epoch:04d}.jpg" | |
| sheet.save(path, quality=95) | |
| model.train() | |
| return path | |
| def _device(name: str) -> torch.device: | |
| if name == "cuda" and not torch.cuda.is_available(): | |
| return torch.device("cpu") | |
| return torch.device(name) | |
| def _loader( | |
| csv_path: Path, | |
| batch_size: int, | |
| image_size: int, | |
| augment: bool, | |
| limit: int | None, | |
| shuffle: bool, | |
| num_workers: int, | |
| ) -> DataLoader: | |
| frame = pd.read_csv(csv_path) | |
| dataset = MonogramSegmentationDataset(frame, image_size=image_size, augment=augment) | |
| if limit is not None: | |
| dataset = Subset(dataset, list(range(min(limit, len(dataset))))) | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=torch.cuda.is_available(), | |
| collate_fn=collate_batch, | |
| ) | |
| def validate(model: torch.nn.Module, loader: DataLoader, device: torch.device, threshold: float) -> dict[str, float]: | |
| model.eval() | |
| rows: list[dict[str, float]] = [] | |
| for batch in loader: | |
| image = batch["image"].to(device) | |
| mask = batch["mask"].to(device) | |
| rows.extend(binary_metrics(model(image), mask, threshold)) | |
| model.train() | |
| if not rows: | |
| return {} | |
| return {key: float(sum(row[key] for row in rows) / len(rows)) for key in rows[0]} | |
| def train(args: argparse.Namespace) -> Path: | |
| device = _device(args.device) | |
| train_loader = _loader(Path(args.train_csv), args.batch_size, args.image_size, True, args.limit_train, True, args.num_workers) | |
| val_loader = _loader(Path(args.val_csv), args.batch_size, args.image_size, False, args.limit_val, False, args.num_workers) | |
| model_kwargs: dict[str, Any] = {} | |
| if args.model == "unet": | |
| model_kwargs["encoder_name"] = args.encoder_name | |
| model_kwargs["encoder_weights"] = None if args.no_pretrained else args.encoder_weights | |
| elif args.model == "segformer": | |
| model_kwargs["model_name"] = args.segformer_model_name | |
| model_kwargs["pretrained"] = args.segformer_pretrained | |
| model_kwargs["freeze_encoder"] = args.segformer_freeze_encoder | |
| model_kwargs["unfreeze_last_n_blocks"] = args.segformer_unfreeze_last_n_blocks | |
| model = make_model(args.model, **model_kwargs).to(device) | |
| loss_fn = SegmentationLoss(args.bce_weight, args.dice_weight, args.cldice_weight) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| run = _start_wandb(args) | |
| best_iou = -1.0 | |
| history: list[dict[str, float | int]] = [] | |
| try: | |
| for epoch in range(1, args.epochs + 1): | |
| totals: dict[str, float] = {} | |
| steps = 0 | |
| for batch in tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}"): | |
| image = batch["image"].to(device) | |
| mask = batch["mask"].to(device) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss, loss_parts = loss_fn(model(image), mask) | |
| loss.backward() | |
| optimizer.step() | |
| for key, value in loss_parts.items(): | |
| totals[key] = totals.get(key, 0.0) + float(value.item()) | |
| steps += 1 | |
| train_row = {f"train_{key}": value / max(steps, 1) for key, value in totals.items()} | |
| val_row = {f"val_{key}": value for key, value in validate(model, val_loader, device, args.threshold).items()} | |
| row = {"epoch": epoch, **train_row, **val_row} | |
| history.append(row) | |
| (out_dir / "history.json").write_text(json.dumps(history, indent=2)) | |
| if run is not None: | |
| log_row = dict(row) | |
| preview_path = None | |
| if args.wandb_log_samples and (epoch == 1 or epoch % args.wandb_sample_every == 0 or epoch == args.epochs): | |
| preview_path = _save_validation_preview( | |
| model, | |
| val_loader, | |
| device, | |
| args.threshold, | |
| out_dir, | |
| epoch, | |
| args.wandb_sample_count, | |
| ) | |
| if preview_path is not None: | |
| import wandb | |
| log_row["val_examples"] = wandb.Image(str(preview_path), caption=f"epoch {epoch}: image | gt | pred | overlay") | |
| run.log(log_row) | |
| if val_row.get("val_iou", -1.0) > best_iou: | |
| best_iou = float(val_row["val_iou"]) | |
| torch.save({"model": model.state_dict(), "args": vars(args), "epoch": epoch}, out_dir / "best.pt") | |
| if run is not None: | |
| run.summary["best_val_iou"] = best_iou | |
| run.summary["best_epoch"] = epoch | |
| torch.save({"model": model.state_dict(), "args": vars(args), "epoch": args.epochs}, out_dir / "last.pt") | |
| if args.wandb_artifacts: | |
| artifact_base = args.wandb_artifact_name or f"{args.model}-{Path(args.output_dir).name}" | |
| _log_artifact(run, out_dir / "best.pt", f"{artifact_base}-best", "model") | |
| _log_artifact(run, out_dir / "last.pt", f"{artifact_base}-last", "model") | |
| _log_artifact(run, out_dir / "history.json", f"{artifact_base}-history", "training-history") | |
| finally: | |
| if run is not None: | |
| run.finish() | |
| return out_dir / "best.pt" | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Train a monogram stroke segmentation model.") | |
| parser.add_argument("--model", choices=["unet", "segformer"], default="unet") | |
| parser.add_argument("--train-csv", default="splits/e3_full_data/train.csv") | |
| parser.add_argument("--val-csv", default="splits/e3_full_data/val.csv") | |
| parser.add_argument("--output-dir", default="outputs/train") | |
| parser.add_argument("--image-size", type=int, default=512) | |
| parser.add_argument("--batch-size", type=int, default=4) | |
| parser.add_argument("--num-workers", type=int, default=4) | |
| parser.add_argument("--epochs", type=int, default=20) | |
| parser.add_argument("--lr", type=float, default=1e-4) | |
| parser.add_argument("--weight-decay", type=float, default=1e-4) | |
| parser.add_argument("--threshold", type=float, default=0.5) | |
| parser.add_argument("--device", default="cuda") | |
| parser.add_argument("--limit-train", type=int) | |
| parser.add_argument("--limit-val", type=int) | |
| parser.add_argument("--no-pretrained", action="store_true") | |
| parser.add_argument("--encoder-name", default="resnet34") | |
| parser.add_argument("--encoder-weights", default="imagenet") | |
| parser.add_argument("--segformer-model-name", default="nvidia/mit-b2") | |
| parser.add_argument("--segformer-pretrained", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--segformer-freeze-encoder", action=argparse.BooleanOptionalAction, default=False) | |
| parser.add_argument("--segformer-unfreeze-last-n-blocks", type=int, default=2) | |
| parser.add_argument("--bce-weight", type=float, default=1.0) | |
| parser.add_argument("--dice-weight", type=float, default=1.0) | |
| parser.add_argument("--cldice-weight", type=float, default=0.5) | |
| parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging.") | |
| parser.add_argument("--wandb-project", default="segment-monograms") | |
| parser.add_argument("--wandb-entity", default="") | |
| parser.add_argument("--wandb-name", default="") | |
| parser.add_argument("--wandb-group", default="") | |
| parser.add_argument("--wandb-mode", choices=["online", "offline", "disabled"], default="online") | |
| parser.add_argument("--wandb-tags", default="") | |
| parser.add_argument("--wandb-artifacts", action="store_true", help="Upload best.pt, last.pt, and history.json as W&B artifacts.") | |
| parser.add_argument("--wandb-artifact-name", default="") | |
| parser.add_argument("--wandb-log-samples", action="store_true", help="Log validation image/mask/prediction previews to W&B.") | |
| parser.add_argument("--wandb-sample-every", type=int, default=5) | |
| parser.add_argument("--wandb-sample-count", type=int, default=4) | |
| args = parser.parse_args() | |
| print(train(args)) | |
| if __name__ == "__main__": | |
| main() | |