| |
| """Build precomputed HyperView embedding assets for the jaguar Space.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from dataclasses import dataclass |
| from datetime import datetime, timezone |
| from pathlib import Path |
| import sys |
| from typing import Any |
| from urllib.parse import urlparse |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from PIL import Image |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.append(str(PROJECT_ROOT)) |
|
|
| from experiment_scripts.evaluate_inpainted_bgfg import ( |
| _load_arcface_benchmark, |
| _load_lorentz, |
| _load_triplet_benchmark, |
| ) |
| from experiment_scripts.train_lorentz_reid import build_transforms |
|
|
| DEFAULT_MANIFEST_PATH = PROJECT_ROOT / "HyperViewDemoHuggingFaceSpace/config/model_manifest.json" |
| DEFAULT_DATASET_ROOT = PROJECT_ROOT / "kaggle_jaguar_dataset_v2" |
| DEFAULT_CORESET_CSV = PROJECT_ROOT / "data/validation_coreset.csv" |
| DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "HyperViewDemoHuggingFaceSpace/assets" |
|
|
|
|
| @dataclass |
| class LoadedModel: |
| model: Any |
| val_transform: Any |
| image_size: int |
|
|
|
|
| class JaguarEmbeddingDataset(Dataset): |
| def __init__( |
| self, |
| rows: list[dict[str, str]], |
| images_dir: Path, |
| transform: Any, |
| image_variant: str, |
| ): |
| self.rows = rows |
| self.images_dir = images_dir |
| self.transform = transform |
| self.image_variant = image_variant |
|
|
| def __len__(self) -> int: |
| return len(self.rows) |
|
|
| @staticmethod |
| def _is_albumentations_transform(transform: Any) -> bool: |
| return transform.__class__.__module__.startswith("albumentations") |
|
|
| def _load_image(self, filename: str) -> Image.Image: |
| image_path = self.images_dir / filename |
| if self.image_variant == "foreground_only": |
| rgba = Image.open(image_path).convert("RGBA") |
| rgba_np = np.array(rgba, dtype=np.uint8) |
| rgb = rgba_np[:, :, :3] |
| alpha = rgba_np[:, :, 3] |
| mask = (alpha > 0).astype(np.uint8) |
| cutout_rgb = (rgb * mask[:, :, np.newaxis]).astype(np.uint8) |
| return Image.fromarray(cutout_rgb, mode="RGB") |
| return Image.open(image_path).convert("RGB") |
|
|
| def __getitem__(self, idx: int): |
| row = self.rows[idx] |
| image = self._load_image(row["filename"]) |
|
|
| if self.transform is None: |
| raise ValueError("Validation transform is required for embedding extraction.") |
|
|
| if self._is_albumentations_transform(self.transform): |
| image_tensor = self.transform(image=np.array(image, dtype=np.uint8))["image"] |
| else: |
| image_tensor = self.transform(image) |
|
|
| return ( |
| image_tensor, |
| row["sample_id"], |
| row["label"], |
| row["filename"], |
| row["split_tag"], |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Build precomputed embedding artifacts for HyperView Space runtime." |
| ) |
| parser.add_argument( |
| "--model_manifest", |
| type=Path, |
| default=DEFAULT_MANIFEST_PATH, |
| help="Model manifest JSON defining the three demo models.", |
| ) |
| parser.add_argument( |
| "--dataset_root", |
| type=Path, |
| default=DEFAULT_DATASET_ROOT, |
| help="Dataset root containing train.csv and train/ images.", |
| ) |
| parser.add_argument( |
| "--coreset_csv", |
| type=Path, |
| default=DEFAULT_CORESET_CSV, |
| help="Validation coreset CSV used to tag split_tag=train/validation.", |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=Path, |
| default=DEFAULT_OUTPUT_DIR, |
| help="Output directory for per-model embeddings and manifest JSON.", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda", |
| choices=["cuda"], |
| help="Runtime device. CUDA-only by contract.", |
| ) |
| parser.add_argument("--batch_size", type=int, default=64) |
| parser.add_argument("--num_workers", type=int, default=4) |
| parser.add_argument( |
| "--image_variant", |
| type=str, |
| default="foreground_only", |
| choices=["foreground_only", "full_rgb"], |
| ) |
| parser.add_argument( |
| "--max_samples", |
| type=int, |
| default=None, |
| help="Optional smoke-mode sample cap for quick checks.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def utc_now() -> str: |
| return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") |
|
|
|
|
| def resolve_device(device_name: str) -> torch.device: |
| if device_name != "cuda": |
| raise SystemExit("GPU unavailable: CUDA requested but not available.") |
| if not torch.cuda.is_available(): |
| raise SystemExit("GPU unavailable: CUDA requested but not available.") |
| return torch.device("cuda") |
|
|
|
|
| def load_model_manifest(manifest_path: Path) -> dict[str, Any]: |
| payload = json.loads(manifest_path.read_text(encoding="utf-8")) |
| if "models" not in payload or not isinstance(payload["models"], list): |
| raise ValueError(f"Invalid model manifest: {manifest_path}") |
| return payload |
|
|
|
|
| def parse_run_url(run_url: str) -> tuple[str, str, str]: |
| parsed = urlparse(run_url) |
| parts = [p for p in parsed.path.split("/") if p] |
| if len(parts) >= 4 and parts[2] == "runs": |
| return parts[0], parts[1], parts[3] |
| raise ValueError(f"Unsupported W&B run URL format: {run_url}") |
|
|
|
|
| def pick_checkpoint_file(root: Path, checkpoint_name: str | None) -> Path: |
| if checkpoint_name: |
| exact = sorted(root.rglob(checkpoint_name)) |
| if exact: |
| return exact[0] |
|
|
| candidates = sorted(root.rglob("*.pth")) |
| if not candidates: |
| raise FileNotFoundError(f"No .pth checkpoints found under downloaded artifact: {root}") |
| return candidates[0] |
|
|
|
|
| def download_checkpoint_from_wandb( |
| run_url: str, |
| model_key: str, |
| checkpoint_name: str | None, |
| output_dir: Path, |
| ) -> tuple[Path, str]: |
| try: |
| import wandb |
| except ImportError as exc: |
| raise ImportError( |
| "wandb is required to download missing checkpoints. Install with `uv pip install wandb`." |
| ) from exc |
|
|
| entity, project, run_id = parse_run_url(run_url) |
| api = wandb.Api() |
| run = api.run(f"{entity}/{project}/{run_id}") |
|
|
| artifacts = [artifact for artifact in run.logged_artifacts() if artifact.type == "model"] |
| if not artifacts: |
| raise FileNotFoundError( |
| f"No model artifacts found for run {entity}/{project}/{run_id}." |
| ) |
|
|
| artifact = artifacts[-1] |
| safe_name = artifact.name.replace("/", "_").replace(":", "_") |
| download_root = output_dir / "downloaded_checkpoints" / model_key / safe_name |
| download_root.mkdir(parents=True, exist_ok=True) |
| downloaded_dir = Path(artifact.download(root=str(download_root))) |
|
|
| checkpoint_path = pick_checkpoint_file(downloaded_dir, checkpoint_name) |
| return checkpoint_path, f"wandb_artifact:{artifact.name}" |
|
|
|
|
| def resolve_checkpoint_path(model_cfg: dict[str, Any], output_dir: Path) -> tuple[Path, str]: |
| checkpoint_path = Path(model_cfg.get("checkpoint_path", "")) |
| if not checkpoint_path.is_absolute(): |
| checkpoint_path = (PROJECT_ROOT / checkpoint_path).resolve() |
|
|
| if checkpoint_path.exists(): |
| return checkpoint_path, "local_path" |
|
|
| run_url = model_cfg.get("run_url") |
| if not run_url: |
| raise FileNotFoundError( |
| f"Checkpoint not found at {checkpoint_path} and no run_url provided for fallback download." |
| ) |
|
|
| return download_checkpoint_from_wandb( |
| run_url=run_url, |
| model_key=str(model_cfg["model_key"]), |
| checkpoint_name=model_cfg.get("checkpoint_name"), |
| output_dir=output_dir, |
| ) |
|
|
|
|
| def read_augmentation_profile(checkpoint_path: Path) -> str: |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| return str(checkpoint.get("augmentation_profile", "lorentz_default")) |
|
|
|
|
| def load_model(model_cfg: dict[str, Any], checkpoint_path: Path, device: str) -> LoadedModel: |
| loader = str(model_cfg["loader"]) |
| if loader == "arcface_benchmark": |
| model, image_size, _metric = _load_arcface_benchmark(str(checkpoint_path), device) |
| augmentation_profile = read_augmentation_profile(checkpoint_path) |
| _train_tf, val_tf, _resolved = build_transforms(image_size, augmentation_profile=augmentation_profile) |
| return LoadedModel(model=model, val_transform=val_tf, image_size=int(image_size)) |
|
|
| if loader == "triplet_benchmark": |
| model, image_size, _metric = _load_triplet_benchmark(str(checkpoint_path), device) |
| augmentation_profile = read_augmentation_profile(checkpoint_path) |
| _train_tf, val_tf, _resolved = build_transforms(image_size, augmentation_profile=augmentation_profile) |
| return LoadedModel(model=model, val_transform=val_tf, image_size=int(image_size)) |
|
|
| if loader == "lorentz": |
| model, image_size, _metric, val_tf = _load_lorentz(str(checkpoint_path), device) |
| return LoadedModel(model=model, val_transform=val_tf, image_size=int(image_size)) |
|
|
| raise ValueError(f"Unsupported loader='{loader}' in model manifest.") |
|
|
|
|
| def build_sample_rows( |
| dataset_root: Path, |
| coreset_csv: Path, |
| max_samples: int | None, |
| ) -> list[dict[str, str]]: |
| train_csv = dataset_root / "train.csv" |
| images_dir = dataset_root / "train" |
| if not train_csv.exists(): |
| raise FileNotFoundError(f"Missing train.csv at {train_csv}") |
| if not images_dir.exists(): |
| raise FileNotFoundError(f"Missing train images directory at {images_dir}") |
|
|
| train_df = pd.read_csv(train_csv) |
| coreset_df = pd.read_csv(coreset_csv) |
| coreset_filenames = set(coreset_df["filename"].astype(str).tolist()) |
|
|
| train_df = train_df.copy() |
| train_df["filename"] = train_df["filename"].astype(str) |
| train_df["ground_truth"] = train_df["ground_truth"].astype(str) |
| train_df["sample_id"] = train_df["filename"] |
| train_df["split_tag"] = np.where(train_df["filename"].isin(coreset_filenames), "validation", "train") |
|
|
| if max_samples is not None: |
| train_df = train_df.iloc[: int(max_samples)].copy() |
|
|
| rows: list[dict[str, str]] = [] |
| for _, row in train_df.iterrows(): |
| rows.append( |
| { |
| "sample_id": str(row["sample_id"]), |
| "filename": str(row["filename"]), |
| "label": str(row["ground_truth"]), |
| "split_tag": str(row["split_tag"]), |
| } |
| ) |
|
|
| return rows |
|
|
|
|
| def extract_embeddings( |
| loaded_model: LoadedModel, |
| rows: list[dict[str, str]], |
| images_dir: Path, |
| image_variant: str, |
| device: torch.device, |
| batch_size: int, |
| num_workers: int, |
| progress_label: str, |
| ) -> tuple[list[str], np.ndarray, list[str], list[str], list[str]]: |
| dataset = JaguarEmbeddingDataset( |
| rows=rows, |
| images_dir=images_dir, |
| transform=loaded_model.val_transform, |
| image_variant=image_variant, |
| ) |
| loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=True, |
| ) |
|
|
| all_vectors: list[np.ndarray] = [] |
| all_ids: list[str] = [] |
| all_labels: list[str] = [] |
| all_filenames: list[str] = [] |
| all_split_tags: list[str] = [] |
|
|
| loaded_model.model.eval() |
| with torch.no_grad(): |
| for images, sample_ids, labels, filenames, split_tags in tqdm(loader, desc=progress_label): |
| images = images.to(device, non_blocking=True) |
| vectors = loaded_model.model(images) |
| if isinstance(vectors, (tuple, list)): |
| vectors = vectors[0] |
| vectors_np = vectors.detach().cpu().numpy().astype(np.float32) |
|
|
| all_vectors.append(vectors_np) |
| all_ids.extend([str(x) for x in sample_ids]) |
| all_labels.extend([str(x) for x in labels]) |
| all_filenames.extend([str(x) for x in filenames]) |
| all_split_tags.extend([str(x) for x in split_tags]) |
|
|
| if not all_vectors: |
| raise RuntimeError("No embeddings were generated.") |
|
|
| stacked = np.vstack(all_vectors).astype(np.float32) |
| return all_ids, stacked, all_labels, all_filenames, all_split_tags |
|
|
|
|
| def save_model_artifacts( |
| output_dir: Path, |
| model_cfg: dict[str, Any], |
| checkpoint_path: Path, |
| checkpoint_source: str, |
| sample_ids: list[str], |
| vectors: np.ndarray, |
| labels: list[str], |
| filenames: list[str], |
| split_tags: list[str], |
| image_variant: str, |
| image_size: int, |
| batch_size: int, |
| num_workers: int, |
| ) -> dict[str, Any]: |
| model_key = str(model_cfg["model_key"]) |
| model_dir = output_dir / "models" / model_key |
| model_dir.mkdir(parents=True, exist_ok=True) |
|
|
| embeddings_path = model_dir / "embeddings.npz" |
| metadata_path = model_dir / "metadata.json" |
|
|
| np.savez_compressed( |
| embeddings_path, |
| ids=np.asarray(sample_ids), |
| vectors=vectors, |
| labels=np.asarray(labels), |
| filenames=np.asarray(filenames), |
| split_tags=np.asarray(split_tags), |
| ) |
|
|
| metadata = { |
| "generated_at_utc": utc_now(), |
| "model_key": model_key, |
| "comparison_key": model_cfg.get("comparison_key"), |
| "family": model_cfg.get("family"), |
| "loader": model_cfg.get("loader"), |
| "space_key": model_cfg.get("space_key"), |
| "geometry": model_cfg.get("geometry"), |
| "layout": model_cfg.get("layout"), |
| "num_samples": int(vectors.shape[0]), |
| "embedding_dim": int(vectors.shape[1]), |
| "checkpoint_path": str(checkpoint_path), |
| "checkpoint_source": checkpoint_source, |
| "run_url": model_cfg.get("run_url"), |
| "image_variant": image_variant, |
| "image_size": int(image_size), |
| "batch_size": int(batch_size), |
| "num_workers": int(num_workers), |
| } |
| metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") |
|
|
| return { |
| "model_key": model_key, |
| "comparison_key": model_cfg.get("comparison_key"), |
| "family": model_cfg.get("family"), |
| "loader": model_cfg.get("loader"), |
| "space_key": model_cfg.get("space_key"), |
| "geometry": model_cfg.get("geometry"), |
| "layout": model_cfg.get("layout"), |
| "checkpoint_path": str(checkpoint_path), |
| "checkpoint_source": checkpoint_source, |
| "run_url": model_cfg.get("run_url"), |
| "embeddings_path": str(embeddings_path.relative_to(output_dir)), |
| "metadata_path": str(metadata_path.relative_to(output_dir)), |
| "num_samples": int(vectors.shape[0]), |
| "embedding_dim": int(vectors.shape[1]), |
| } |
|
|
|
|
| def write_sample_index(output_dir: Path, rows: list[dict[str, str]]) -> Path: |
| sample_index_path = output_dir / "sample_index.csv" |
| sample_df = pd.DataFrame(rows) |
| sample_df.to_csv(sample_index_path, index=False) |
| return sample_index_path |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| device = resolve_device(args.device) |
|
|
| model_manifest = load_model_manifest(args.model_manifest) |
| output_dir = args.output_dir.resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| dataset_root = args.dataset_root.resolve() |
| images_dir = dataset_root / "train" |
| rows = build_sample_rows( |
| dataset_root=dataset_root, |
| coreset_csv=args.coreset_csv, |
| max_samples=args.max_samples, |
| ) |
| if not rows: |
| raise RuntimeError("No rows found in train.csv after applying filters.") |
|
|
| expected_ids = [row["sample_id"] for row in rows] |
| sample_index_path = write_sample_index(output_dir, rows) |
|
|
| emitted_models: list[dict[str, Any]] = [] |
| for model_cfg in model_manifest["models"]: |
| model_key = str(model_cfg["model_key"]) |
| print(f"\n=== Building embeddings for {model_key} ===") |
|
|
| checkpoint_path, checkpoint_source = resolve_checkpoint_path(model_cfg=model_cfg, output_dir=output_dir) |
| print(f"Checkpoint: {checkpoint_path} ({checkpoint_source})") |
|
|
| loaded_model = load_model(model_cfg=model_cfg, checkpoint_path=checkpoint_path, device=args.device) |
| ids, vectors, labels, filenames, split_tags = extract_embeddings( |
| loaded_model=loaded_model, |
| rows=rows, |
| images_dir=images_dir, |
| image_variant=args.image_variant, |
| device=device, |
| batch_size=int(args.batch_size), |
| num_workers=int(args.num_workers), |
| progress_label=f"extract:{model_key}", |
| ) |
|
|
| if ids != expected_ids: |
| raise RuntimeError( |
| f"Sample ID alignment failed for {model_key}: extracted order does not match expected sample index." |
| ) |
|
|
| emitted = save_model_artifacts( |
| output_dir=output_dir, |
| model_cfg=model_cfg, |
| checkpoint_path=checkpoint_path, |
| checkpoint_source=checkpoint_source, |
| sample_ids=ids, |
| vectors=vectors, |
| labels=labels, |
| filenames=filenames, |
| split_tags=split_tags, |
| image_variant=args.image_variant, |
| image_size=loaded_model.image_size, |
| batch_size=int(args.batch_size), |
| num_workers=int(args.num_workers), |
| ) |
| emitted_models.append(emitted) |
|
|
| manifest_out = { |
| "generated_at_utc": utc_now(), |
| "source_model_manifest": str(args.model_manifest.resolve()), |
| "dataset": { |
| "dataset_root": str(dataset_root), |
| "images_dir": str(images_dir), |
| "coreset_csv": str(args.coreset_csv.resolve()), |
| "num_samples": len(rows), |
| "image_variant": args.image_variant, |
| "sample_index_csv": str(sample_index_path.relative_to(output_dir)), |
| }, |
| "models": emitted_models, |
| } |
|
|
| manifest_path = output_dir / "manifest.json" |
| manifest_path.write_text(json.dumps(manifest_out, indent=2), encoding="utf-8") |
|
|
| print("\n=== HyperView asset build complete ===") |
| print(f"Sample count: {len(rows)}") |
| print(f"Manifest: {manifest_path}") |
| for emitted in emitted_models: |
| print( |
| f"- {emitted['model_key']}: {emitted['num_samples']} x {emitted['embedding_dim']} " |
| f"({emitted['embeddings_path']})" |
| ) |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|