Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import re | |
| from pathlib import Path | |
| from shutil import copyfile | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from torch.utils.data import DataLoader | |
| from src.data.dataset import MonogramSegmentationDataset, collate_batch | |
| from src.metrics import binary_metrics | |
| from src.models import make_model | |
| def _tensor_rgb(image: torch.Tensor) -> np.ndarray: | |
| return ((image.detach().cpu().clamp(-1, 1) * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype("uint8") | |
| def _tensor_mask(mask: torch.Tensor) -> np.ndarray: | |
| return (mask.detach().cpu().clamp(0, 1).squeeze(0).numpy() * 255).astype("uint8") | |
| def _overlay(image: np.ndarray, mask: np.ndarray, color: tuple[int, int, int]) -> np.ndarray: | |
| color_arr = np.zeros_like(image) | |
| color_arr[..., 0] = color[0] | |
| color_arr[..., 1] = color[1] | |
| color_arr[..., 2] = color[2] | |
| return np.where(mask[..., None] > 0, (0.55 * image + 0.45 * color_arr).astype("uint8"), image) | |
| def _safe_name(value: str) -> str: | |
| return re.sub(r"[^A-Za-z0-9_.-]+", "_", value).strip("_") or "sample" | |
| def _font(size: int, bold: bool = False) -> ImageFont.ImageFont: | |
| names = ["DejaVuSans-Bold.ttf", "DejaVuSans.ttf"] if bold else ["DejaVuSans.ttf"] | |
| for name in names: | |
| try: | |
| return ImageFont.truetype(name, size=size) | |
| except OSError: | |
| continue | |
| return ImageFont.load_default() | |
| def _draw_centered(draw: ImageDraw.ImageDraw, box: tuple[int, int, int, int], text: str, font: ImageFont.ImageFont, fill: str) -> None: | |
| left, top, right, bottom = box | |
| text_box = draw.textbbox((0, 0), text, font=font) | |
| text_w = text_box[2] - text_box[0] | |
| text_h = text_box[3] - text_box[1] | |
| draw.text((left + (right - left - text_w) / 2, top + (bottom - top - text_h) / 2), text, font=font, fill=fill) | |
| def _primary_visual_group(setting: str, row: pd.Series) -> str: | |
| quality = str(row.get("quality_label", "") or "unlabeled") | |
| collection = str(row.get("collection", "") or "unlabeled") | |
| if setting == "e1_cross_collection": | |
| return collection | |
| if setting == "e2_quality_stratified": | |
| return quality | |
| if setting == "e3_full_data": | |
| return f"{collection}__{quality}" | |
| return collection | |
| def _selected_visual_sample_ids(frame: pd.DataFrame, setting: str, count: int) -> set[str]: | |
| if count <= 0 or frame.empty: | |
| return set() | |
| group_to_ids: dict[str, list[str]] = {} | |
| for _, row in frame.iterrows(): | |
| group_to_ids.setdefault(_primary_visual_group(setting, row), []).append(str(row["sample_id"])) | |
| selected: list[str] = [] | |
| max_group_len = max(len(ids) for ids in group_to_ids.values()) | |
| for offset in range(max_group_len): | |
| for group in sorted(group_to_ids): | |
| ids = group_to_ids[group] | |
| if offset < len(ids): | |
| selected.append(ids[offset]) | |
| if len(selected) >= count: | |
| return set(selected) | |
| return set(selected) | |
| def _sample_visual( | |
| image: torch.Tensor, | |
| mask: torch.Tensor, | |
| pred: torch.Tensor, | |
| sample_id: str, | |
| collection: str, | |
| quality_label: str, | |
| ) -> Image.Image: | |
| rgb = _tensor_rgb(image) | |
| gt = _tensor_mask(mask) | |
| pred_mask = _tensor_mask(pred) | |
| panels = [ | |
| ("Input", Image.fromarray(rgb)), | |
| ("Ground Truth", Image.fromarray(gt).convert("RGB")), | |
| ("Prediction", Image.fromarray(pred_mask).convert("RGB")), | |
| ("Overlay", Image.fromarray(_overlay(rgb, pred_mask, (255, 0, 0)))), | |
| ] | |
| tile_w, tile_h = panels[0][1].size | |
| header_h = max(34, tile_h // 15) | |
| meta_h = max(46, tile_h // 12) | |
| title_font = _font(max(18, tile_w // 26), bold=True) | |
| meta_font = _font(max(16, tile_w // 32)) | |
| row = Image.new("RGB", (tile_w * len(panels), tile_h + header_h + meta_h), "white") | |
| draw = ImageDraw.Draw(row) | |
| draw.rectangle((0, 0, row.width, meta_h), fill=(245, 247, 250)) | |
| draw.text((12, 8), f"{sample_id}", font=title_font, fill=(18, 24, 38)) | |
| draw.text((12, 8 + max(20, tile_w // 24)), f"collection: {collection} quality: {quality_label}", font=meta_font, fill=(66, 75, 92)) | |
| for panel_idx, (label, panel) in enumerate(panels): | |
| left = panel_idx * tile_w | |
| draw.rectangle((left, meta_h, left + tile_w, meta_h + header_h), fill=(32, 37, 50)) | |
| _draw_centered(draw, (left, meta_h, left + tile_w, meta_h + header_h), label, title_font, "white") | |
| row.paste(panel, (left, meta_h + header_h)) | |
| return row | |
| def _contact_sheet(visuals: list[Image.Image], path: Path) -> None: | |
| if not visuals: | |
| return | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| width = max(image.width for image in visuals) | |
| height = sum(image.height for image in visuals) | |
| sheet = Image.new("RGB", (width, height), "white") | |
| top = 0 | |
| for image in visuals: | |
| sheet.paste(image, (0, top)) | |
| top += image.height | |
| sheet.save(path, quality=95) | |
| def _experiment_groups(setting: str, collection: str, quality_label: str) -> dict[str, str]: | |
| quality = quality_label if quality_label and quality_label.lower() != "nan" else "unlabeled" | |
| if setting == "e1_cross_collection": | |
| return {"by_collection": collection} | |
| if setting == "e2_quality_stratified": | |
| return {"by_quality": quality} | |
| if setting == "e3_full_data": | |
| return { | |
| "by_collection": collection, | |
| "by_quality": quality, | |
| "by_collection_quality": f"{collection}__{quality}", | |
| } | |
| return {"by_collection": collection, "by_quality": quality} | |
| def _save_grouped_visual( | |
| visual: Image.Image, | |
| visual_path: Path, | |
| visual_dir: Path, | |
| visual_name: str, | |
| groups: dict[str, str], | |
| grouped_visuals: dict[tuple[str, str], list[Image.Image]], | |
| ) -> None: | |
| for group_name, group_value in groups.items(): | |
| group_dir = visual_dir / group_name / _safe_name(group_value) | |
| group_dir.mkdir(parents=True, exist_ok=True) | |
| copyfile(visual_path, group_dir / visual_name) | |
| grouped_visuals.setdefault((group_name, group_value), []).append(visual) | |
| def _save_eval_visuals(visuals: list[Image.Image], grouped_visuals: dict[tuple[str, str], list[Image.Image]], out_dir: Path) -> None: | |
| visual_dir = out_dir / "visualizations" | |
| _contact_sheet(visuals, visual_dir / "contact_sheet.jpg") | |
| for (group_name, group_value), group_images in grouped_visuals.items(): | |
| _contact_sheet(group_images, visual_dir / group_name / _safe_name(group_value) / "contact_sheet.jpg") | |
| def evaluate(args: argparse.Namespace) -> None: | |
| device = torch.device(args.device if args.device != "cuda" or torch.cuda.is_available() else "cpu") | |
| frame = pd.read_csv(args.csv) | |
| if args.limit is not None: | |
| frame = frame.head(args.limit).copy() | |
| visual_sample_ids = _selected_visual_sample_ids(frame, args.setting, args.visualize_count) | |
| dataset = MonogramSegmentationDataset(frame, image_size=args.image_size, augment=False) | |
| loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=collate_batch) | |
| checkpoint = torch.load(args.checkpoint, map_location="cpu") | |
| model_args = checkpoint.get("args", {}) | |
| model_name = args.model or model_args.get("model", "unet") | |
| model_kwargs = {} | |
| if model_name == "unet": | |
| model_kwargs["encoder_name"] = model_args.get("encoder_name", "resnet34") | |
| model_kwargs["encoder_weights"] = None | |
| elif model_name == "segformer": | |
| model_kwargs["model_name"] = model_args.get("segformer_model_name", "nvidia/mit-b2") | |
| model_kwargs["pretrained"] = False | |
| model_kwargs["freeze_encoder"] = bool(model_args.get("segformer_freeze_encoder", False)) | |
| model_kwargs["unfreeze_last_n_blocks"] = int(model_args.get("segformer_unfreeze_last_n_blocks", 2)) | |
| model = make_model(model_name, **model_kwargs).to(device) | |
| model.load_state_dict(checkpoint["model"], strict=False) | |
| model.eval() | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| visual_dir = out_dir / "visualizations" | |
| visual_dir.mkdir(parents=True, exist_ok=True) | |
| visuals: list[Image.Image] = [] | |
| grouped_visuals: dict[tuple[str, str], list[Image.Image]] = {} | |
| rows: list[dict] = [] | |
| with torch.no_grad(): | |
| for batch in loader: | |
| image = batch["image"].to(device) | |
| mask = batch["mask"].to(device) | |
| logits = model(image) | |
| pred = (torch.sigmoid(logits) >= args.threshold).float() | |
| metric_rows = binary_metrics(logits, mask, args.threshold) | |
| for idx, metrics in enumerate(metric_rows): | |
| sample_id = str(batch["sample_id"][idx]) | |
| collection = str(batch["collection"][idx]) | |
| quality_label = str(batch["quality_label"][idx]) | |
| rows.append( | |
| { | |
| "sample_id": sample_id, | |
| "collection": collection, | |
| "quality_label": quality_label, | |
| **metrics, | |
| } | |
| ) | |
| if sample_id in visual_sample_ids: | |
| visual = _sample_visual(batch["image"][idx], batch["mask"][idx], pred[idx], sample_id, collection, quality_label) | |
| visuals.append(visual) | |
| visual_name = f"{len(visuals):03d}_{_safe_name(sample_id)}.jpg" | |
| visual_path = visual_dir / visual_name | |
| visual.save(visual_path, quality=95) | |
| _save_grouped_visual( | |
| visual, | |
| visual_path, | |
| visual_dir, | |
| visual_name, | |
| _experiment_groups(args.setting, collection, quality_label), | |
| grouped_visuals, | |
| ) | |
| _save_eval_visuals(visuals, grouped_visuals, out_dir) | |
| metrics = pd.DataFrame(rows) | |
| metrics.to_csv(out_dir / "metrics.csv", index=False) | |
| metrics.groupby("collection", dropna=False)[["iou", "dice", "cldice", "precision", "recall"]].mean().to_csv(out_dir / "metrics_by_collection.csv") | |
| metrics.assign(quality_label=metrics["quality_label"].fillna("").replace("", "unlabeled")).groupby("quality_label", dropna=False)[ | |
| ["iou", "dice", "cldice", "precision", "recall"] | |
| ].mean().to_csv(out_dir / "metrics_by_quality.csv") | |
| metrics[["iou", "dice", "cldice", "precision", "recall"]].mean().to_frame("mean").to_csv(out_dir / "metrics_summary.csv") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Evaluate a monogram segmentation checkpoint.") | |
| parser.add_argument("--checkpoint", required=True) | |
| parser.add_argument("--csv", default="splits/e3_full_data/test.csv") | |
| parser.add_argument("--output-dir", default="outputs/eval") | |
| parser.add_argument("--model", choices=["unet", "segformer"]) | |
| parser.add_argument("--image-size", type=int, default=512) | |
| parser.add_argument("--batch-size", type=int, default=4) | |
| parser.add_argument("--threshold", type=float, default=0.5) | |
| parser.add_argument("--device", default="cuda") | |
| parser.add_argument("--no-pretrained", action="store_true") | |
| parser.add_argument("--limit", type=int) | |
| parser.add_argument("--visualize-count", type=int, default=32) | |
| parser.add_argument("--setting", default="e3_full_data") | |
| evaluate(parser.parse_args()) | |
| if __name__ == "__main__": | |
| main() | |