"""General Lightning-based inference script for TactileVAE. Features: - Load any Lightning `.ckpt` checkpoint. - Load any config YAML. - Randomly select `N` samples from any split (`train` / `val` / `test`). - Run reconstruction inference and save metrics + visualization. """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any import numpy as np import pytorch_lightning as pl import torch import yaml from PIL import Image from torch.utils.data import DataLoader, Subset _REPO_ROOT = Path(__file__).resolve().parents[2] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) from tactile_vae.dataset import TactileParquetDataset from tactile_vae.model import TactileVAE DEFAULT_CONFIG = Path("/group2/ct/weihanx/tactile_world_model/runs/vae_baseline_3/config.snapshot.yaml") DEFAULT_CKPT = Path("/group2/ct/weihanx/tactile_world_model/runs/vae_baseline_3/checkpoints/last.ckpt") DEFAULT_OUT_DIR = Path("/group2/ct/weihanx/tactile_world_model/tactile_vae/inference/vae_baseline_3") def _resolve_path(p: str | Path) -> Path: path = Path(p) return path if path.is_absolute() else (_REPO_ROOT / path).resolve() def load_config(path: Path) -> dict: with path.open() as f: cfg = yaml.safe_load(f) if not isinstance(cfg, dict): raise ValueError(f"invalid config: {path}") cfg["data"]["root"] = str(_resolve_path(cfg["data"]["root"])) if cfg["data"].get("splits_path"): cfg["data"]["splits_path"] = str(_resolve_path(cfg["data"]["splits_path"])) return cfg def pick_device(spec: str) -> torch.device: if spec == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(spec) class InferenceModule(pl.LightningModule): """Minimal LightningModule used for strict Lightning checkpoint loading.""" def __init__(self, config: dict): super().__init__() self.config = config self.model = TactileVAE(**config["model"]) def forward(self, x, **kw): return self.model(x, **kw) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--config", type=Path, default=DEFAULT_CONFIG, help="config yaml") p.add_argument("--ckpt", type=Path, default=DEFAULT_CKPT, help="Lightning checkpoint .ckpt") p.add_argument("--out-dir", type=Path, default=DEFAULT_OUT_DIR, help="output directory") p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) p.add_argument("--num-samples", type=int, default=50, help="number of random samples from the split") p.add_argument("--batch-size", type=int, default=16) p.add_argument("--num-workers", type=int, default=0) p.add_argument("--seed", type=int, default=0) p.add_argument("--device", type=str, default="auto", help="auto / cuda / cpu / cuda:0 ...") p.add_argument("--max-grid", type=int, default=16, help="max samples shown in saved reconstruction grid") return p.parse_args() def build_dataset(cfg: dict, split: str) -> TactileParquetDataset: dcfg = cfg["data"] return TactileParquetDataset( root=dcfg["root"], split=split, splits_path=dcfg.get("splits_path"), image_size=dcfg["image_size"], cache_files=dcfg.get("cache_files", 1), color_jitter=None, ) def select_subset(ds: TactileParquetDataset, n: int, seed: int) -> tuple[Subset, list[int]]: n = min(max(1, int(n)), len(ds)) rng = np.random.default_rng(seed) idx = rng.choice(len(ds), size=n, replace=False).tolist() return Subset(ds, idx), idx @torch.no_grad() def run_inference( module: InferenceModule, ds: TactileParquetDataset, subset_idx: list[int], loader: DataLoader, device: torch.device, ) -> tuple[list[dict[str, Any]], float, float, list[tuple[torch.Tensor, torch.Tensor]]]: module.eval().to(device) per_sample: list[dict[str, Any]] = [] vis_pairs: list[tuple[torch.Tensor, torch.Tensor]] = [] mae_total = 0.0 mse_total = 0.0 n_total = 0 cursor = 0 for x in loader: x = x.to(device, non_blocking=True) out = module.model(x, sample=False) x_hat = out["x_hat"] abs_err = (x - x_hat).abs().mean(dim=(1, 2, 3)) sq_err = ((x - x_hat) ** 2).mean(dim=(1, 2, 3)) bs = x.shape[0] for i in range(bs): gidx = subset_idx[cursor + i] sample_id = ds.sample_id(gidx) mae_i = float(abs_err[i].item()) mse_i = float(sq_err[i].item()) per_sample.append( { "subset_rank": cursor + i, "dataset_index": int(gidx), "sample_id": sample_id, "mae": mae_i, "mse": mse_i, } ) vis_pairs.append((x[i].detach().cpu(), x_hat[i].detach().cpu())) mae_total += mae_i mse_total += mse_i n_total += 1 cursor += bs mae_mean = mae_total / max(1, n_total) mse_mean = mse_total / max(1, n_total) return per_sample, mae_mean, mse_mean, vis_pairs def save_grid(pairs: list[tuple[torch.Tensor, torch.Tensor]], out_path: Path, n_show: int, image_size: int) -> None: n = min(n_show, len(pairs)) if n <= 0: return h = w = int(image_size) canvas = np.zeros((2 * h, n * w, 3), dtype=np.uint8) for i in range(n): src, rec = pairs[i] src_np = (src.clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8) rec_np = (rec.clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8) canvas[:h, i * w : (i + 1) * w] = src_np canvas[h:, i * w : (i + 1) * w] = rec_np out_path.parent.mkdir(parents=True, exist_ok=True) Image.fromarray(canvas).save(out_path) def main() -> None: args = parse_args() cfg = load_config(args.config) device = pick_device(args.device) args.out_dir.mkdir(parents=True, exist_ok=True) print(f"config: {args.config}") print(f"ckpt: {args.ckpt}") print(f"split: {args.split}") print(f"num_samples: {args.num_samples}") print(f"device: {device}") print(f"out_dir: {args.out_dir}") ds = build_dataset(cfg, split=args.split) subset, subset_idx = select_subset(ds, args.num_samples, args.seed) print(f"split_size={len(ds)} selected={len(subset_idx)}") print(f"preview_sample_ids={[ds.sample_id(i) for i in subset_idx[:5]]}") loader = DataLoader( subset, batch_size=min(max(1, args.batch_size), len(subset)), shuffle=False, num_workers=args.num_workers, pin_memory=device.type == "cuda", drop_last=False, persistent_workers=args.num_workers > 0, ) module = InferenceModule.load_from_checkpoint( str(args.ckpt), config=cfg, strict=True, map_location="cpu", ) per_sample, mae_mean, mse_mean, vis_pairs = run_inference( module=module, ds=ds, subset_idx=subset_idx, loader=loader, device=device ) grid_path = args.out_dir / "reconstruction_grid.png" save_grid(vis_pairs, out_path=grid_path, n_show=args.max_grid, image_size=cfg["data"]["image_size"]) summary = { "config": str(args.config), "checkpoint": str(args.ckpt), "split": args.split, "seed": args.seed, "selected_num_samples": len(subset_idx), "mean_mae": mae_mean, "mean_mse": mse_mean, "grid_path": str(grid_path), } with (args.out_dir / "summary.json").open("w") as f: json.dump(summary, f, indent=2) with (args.out_dir / "per_sample_metrics.json").open("w") as f: json.dump(per_sample, f, indent=2) print(f"mean_mae={mae_mean:.6f} mean_mse={mse_mean:.6f}") print(f"saved: {args.out_dir / 'summary.json'}") print(f"saved: {args.out_dir / 'per_sample_metrics.json'}") print(f"saved: {grid_path}") if __name__ == "__main__": main()