from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any import pandas as pd import torch from PIL import Image from torch.utils.data import Dataset from .transforms import ( annotation_path, apply_pair_augment, crop_pair, find_image_by_stem, image_transform, list_images, mask_transform, quality_key, ) DEFAULT_ADOLFO_ROOT = Path("/scratch/mahantas/datasets/MonogramSchema_Seal_pairs") DEFAULT_OTHER_ROOT = Path("/scratch/mahantas/datasets/285_monograms_segmentations") DEFAULT_METADATA_CSV = Path("/scratch/mahantas/datasets/metadata_merged.csv") @dataclass(frozen=True) class DatasetRoots: adolfo_root: Path = DEFAULT_ADOLFO_ROOT other_root: Path = DEFAULT_OTHER_ROOT metadata_csv: Path = DEFAULT_METADATA_CSV def _load_metadata(path: Path | None) -> dict[str, dict[str, Any]]: if path is None or not path.exists(): return {} frame = pd.read_csv(path) required = {"monogram_id", "quality_label"} missing = required.difference(frame.columns) if missing: raise ValueError(f"{path} is missing required columns: {sorted(missing)}") return {str(row["monogram_id"]).strip(): row for row in frame.to_dict("records")} def _clean_label(value: Any) -> str: if pd.isna(value): return "" return str(value) def _metadata_value(metadata: dict[str, dict[str, Any]], source_id: str, key: str, default: Any = "") -> Any: value = metadata.get(source_id, {}).get(key, default) return default if pd.isna(value) else value def _adolfo_manifest(root: Path, merged_metadata: dict[str, dict[str, Any]]) -> pd.DataFrame: metadata = pd.read_csv(root / "metadata.csv") rows: list[dict[str, Any]] = [] for row in metadata.to_dict("records"): sample_id = str(row["monogram_id"]).strip() quality = quality_key(_metadata_value(merged_metadata, sample_id, "quality_label", row.get("quality_label"))) if quality == "q3": continue try: seal_path = find_image_by_stem(root / "seals", sample_id) mask_path = find_image_by_stem(root / "binary_masks", sample_id) ann_path = annotation_path(root / "ann", seal_path) except FileNotFoundError: continue rows.append( { "sample_id": f"adolfo__{sample_id}", "source_id": sample_id, "collection": "adolfo", "seal_path": str(seal_path), "mask_path": str(mask_path), "ann_path": str(ann_path), "quality_label": quality, "monogram_type": _metadata_value(merged_metadata, sample_id, "monogram_type", row.get("monogram_type", "")), "broad_monogram_type": _metadata_value( merged_metadata, sample_id, "broad_monogram_type", row.get("broad_monogram_type", ""), ), } ) frame = pd.DataFrame(rows) if len(frame) != 332: raise ValueError(f"Expected 332 usable Adolfo q0-q2 masks, found {len(frame)}") return frame def _other_manifest(root: Path, merged_metadata: dict[str, dict[str, Any]]) -> pd.DataFrame: rows: list[dict[str, Any]] = [] for collection_dir in sorted(path for path in root.iterdir() if path.is_dir()): collection = collection_dir.name image_dir = collection_dir / "img" mask_dir = collection_dir / "binary_masks" ann_dir = collection_dir / "ann" for image_path in list_images(image_dir): sample_id = f"{collection}__{image_path.stem}" source_id = image_path.stem quality = quality_key(_metadata_value(merged_metadata, source_id, "quality_label", "")) if quality == "q3": continue mask_path = find_image_by_stem(mask_dir, image_path.stem) ann_path = annotation_path(ann_dir, image_path) rows.append( { "sample_id": sample_id, "source_id": source_id, "collection": collection, "seal_path": str(image_path), "mask_path": str(mask_path), "ann_path": str(ann_path), "quality_label": quality, "monogram_type": _metadata_value(merged_metadata, source_id, "monogram_type", ""), "broad_monogram_type": _metadata_value(merged_metadata, source_id, "broad_monogram_type", ""), } ) frame = pd.DataFrame(rows) if len(frame) != 273: raise ValueError(f"Expected 273 other-collection q0-q2 masks, found {len(frame)}") return frame def build_sample_manifest( adolfo_root: str | Path = DEFAULT_ADOLFO_ROOT, other_root: str | Path = DEFAULT_OTHER_ROOT, metadata_csv: str | Path | None = DEFAULT_METADATA_CSV, ) -> pd.DataFrame: metadata_path = Path(metadata_csv) if metadata_csv is not None else None merged_metadata = _load_metadata(metadata_path) frame = pd.concat( [_adolfo_manifest(Path(adolfo_root), merged_metadata), _other_manifest(Path(other_root), merged_metadata)], ignore_index=True, sort=False, ) if len(frame) != 605: raise ValueError(f"Expected 605 total q0-q2 supervised masks, found {len(frame)}") if frame["sample_id"].duplicated().any(): dupes = frame.loc[frame["sample_id"].duplicated(), "sample_id"].tolist() raise ValueError(f"Duplicate sample IDs: {dupes[:10]}") return frame class MonogramSegmentationDataset(Dataset): def __init__( self, frame: pd.DataFrame, image_size: int = 512, crop_padding: float = 0.08, augment: bool = False, ): self.frame = frame.reset_index(drop=True) self.image_size = image_size self.crop_padding = crop_padding self.augment = augment self.image_t = image_transform() self.mask_t = mask_transform() def __len__(self) -> int: return len(self.frame) def __getitem__(self, idx: int) -> dict[str, Any]: row = self.frame.iloc[idx].to_dict() image = Image.open(row["seal_path"]).convert("RGB") mask = Image.open(row["mask_path"]).convert("L") crop_jitter = 0.06 if self.augment else 0.0 crop_scale_jitter = 0.05 if self.augment else 0.0 image, mask = crop_pair( image, mask, row.get("ann_path"), self.image_size, self.crop_padding, crop_jitter=crop_jitter, crop_scale_jitter=crop_scale_jitter, ) if self.augment: image, mask = apply_pair_augment(image, mask) image_tensor = self.image_t(image) mask_tensor = self.mask_t(mask) return { "image": image_tensor, "mask": mask_tensor, "sample_id": str(row["sample_id"]), "collection": str(row["collection"]), "quality_label": _clean_label(row.get("quality_label", "")), "monogram_type": _clean_label(row.get("monogram_type", "")), "broad_monogram_type": _clean_label(row.get("broad_monogram_type", "")), "seal_path": str(row["seal_path"]), "mask_path": str(row["mask_path"]), "ann_path": str(row["ann_path"]), } def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: return { "image": torch.stack([item["image"] for item in batch]), "mask": torch.stack([item["mask"] for item in batch]), "sample_id": [item["sample_id"] for item in batch], "collection": [item["collection"] for item in batch], "quality_label": [item["quality_label"] for item in batch], "monogram_type": [item["monogram_type"] for item in batch], "broad_monogram_type": [item["broad_monogram_type"] for item in batch], "seal_path": [item["seal_path"] for item in batch], "mask_path": [item["mask_path"] for item in batch], "ann_path": [item["ann_path"] for item in batch], }