| """General Lightning-based inference script for TactileVAE. |
| |
| Features: |
| - Load any Lightning `.ckpt` checkpoint. |
| - Load any config YAML. |
| - Randomly select `N` samples from any split (`train` / `val` / `test`). |
| - Run reconstruction inference and save metrics + visualization. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import pytorch_lightning as pl |
| import torch |
| import yaml |
| from PIL import Image |
| from torch.utils.data import DataLoader, Subset |
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[2] |
| if str(_REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
| from tactile_vae.dataset import TactileParquetDataset |
| from tactile_vae.model import TactileVAE |
|
|
| DEFAULT_CONFIG = Path("/group2/ct/weihanx/tactile_world_model/runs/vae_baseline_3/config.snapshot.yaml") |
| DEFAULT_CKPT = Path("/group2/ct/weihanx/tactile_world_model/runs/vae_baseline_3/checkpoints/last.ckpt") |
| DEFAULT_OUT_DIR = Path("/group2/ct/weihanx/tactile_world_model/tactile_vae/inference/vae_baseline_3") |
|
|
|
|
| def _resolve_path(p: str | Path) -> Path: |
| path = Path(p) |
| return path if path.is_absolute() else (_REPO_ROOT / path).resolve() |
|
|
|
|
| def load_config(path: Path) -> dict: |
| with path.open() as f: |
| cfg = yaml.safe_load(f) |
| if not isinstance(cfg, dict): |
| raise ValueError(f"invalid config: {path}") |
| cfg["data"]["root"] = str(_resolve_path(cfg["data"]["root"])) |
| if cfg["data"].get("splits_path"): |
| cfg["data"]["splits_path"] = str(_resolve_path(cfg["data"]["splits_path"])) |
| return cfg |
|
|
|
|
| def pick_device(spec: str) -> torch.device: |
| if spec == "auto": |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| return torch.device(spec) |
|
|
|
|
| class InferenceModule(pl.LightningModule): |
| """Minimal LightningModule used for strict Lightning checkpoint loading.""" |
|
|
| def __init__(self, config: dict): |
| super().__init__() |
| self.config = config |
| self.model = TactileVAE(**config["model"]) |
|
|
| def forward(self, x, **kw): |
| return self.model(x, **kw) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--config", type=Path, default=DEFAULT_CONFIG, help="config yaml") |
| p.add_argument("--ckpt", type=Path, default=DEFAULT_CKPT, help="Lightning checkpoint .ckpt") |
| p.add_argument("--out-dir", type=Path, default=DEFAULT_OUT_DIR, help="output directory") |
| p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) |
| p.add_argument("--num-samples", type=int, default=50, help="number of random samples from the split") |
| p.add_argument("--batch-size", type=int, default=16) |
| p.add_argument("--num-workers", type=int, default=0) |
| p.add_argument("--seed", type=int, default=0) |
| p.add_argument("--device", type=str, default="auto", help="auto / cuda / cpu / cuda:0 ...") |
| p.add_argument("--max-grid", type=int, default=16, help="max samples shown in saved reconstruction grid") |
| return p.parse_args() |
|
|
|
|
| def build_dataset(cfg: dict, split: str) -> TactileParquetDataset: |
| dcfg = cfg["data"] |
| return TactileParquetDataset( |
| root=dcfg["root"], |
| split=split, |
| splits_path=dcfg.get("splits_path"), |
| image_size=dcfg["image_size"], |
| cache_files=dcfg.get("cache_files", 1), |
| color_jitter=None, |
| ) |
|
|
|
|
| def select_subset(ds: TactileParquetDataset, n: int, seed: int) -> tuple[Subset, list[int]]: |
| n = min(max(1, int(n)), len(ds)) |
| rng = np.random.default_rng(seed) |
| idx = rng.choice(len(ds), size=n, replace=False).tolist() |
| return Subset(ds, idx), idx |
|
|
|
|
| @torch.no_grad() |
| def run_inference( |
| module: InferenceModule, |
| ds: TactileParquetDataset, |
| subset_idx: list[int], |
| loader: DataLoader, |
| device: torch.device, |
| ) -> tuple[list[dict[str, Any]], float, float, list[tuple[torch.Tensor, torch.Tensor]]]: |
| module.eval().to(device) |
| per_sample: list[dict[str, Any]] = [] |
| vis_pairs: list[tuple[torch.Tensor, torch.Tensor]] = [] |
| mae_total = 0.0 |
| mse_total = 0.0 |
| n_total = 0 |
|
|
| cursor = 0 |
| for x in loader: |
| x = x.to(device, non_blocking=True) |
| out = module.model(x, sample=False) |
| x_hat = out["x_hat"] |
|
|
| abs_err = (x - x_hat).abs().mean(dim=(1, 2, 3)) |
| sq_err = ((x - x_hat) ** 2).mean(dim=(1, 2, 3)) |
| bs = x.shape[0] |
|
|
| for i in range(bs): |
| gidx = subset_idx[cursor + i] |
| sample_id = ds.sample_id(gidx) |
| mae_i = float(abs_err[i].item()) |
| mse_i = float(sq_err[i].item()) |
| per_sample.append( |
| { |
| "subset_rank": cursor + i, |
| "dataset_index": int(gidx), |
| "sample_id": sample_id, |
| "mae": mae_i, |
| "mse": mse_i, |
| } |
| ) |
| vis_pairs.append((x[i].detach().cpu(), x_hat[i].detach().cpu())) |
| mae_total += mae_i |
| mse_total += mse_i |
| n_total += 1 |
| cursor += bs |
|
|
| mae_mean = mae_total / max(1, n_total) |
| mse_mean = mse_total / max(1, n_total) |
| return per_sample, mae_mean, mse_mean, vis_pairs |
|
|
|
|
| def save_grid(pairs: list[tuple[torch.Tensor, torch.Tensor]], out_path: Path, n_show: int, image_size: int) -> None: |
| n = min(n_show, len(pairs)) |
| if n <= 0: |
| return |
| h = w = int(image_size) |
| canvas = np.zeros((2 * h, n * w, 3), dtype=np.uint8) |
| for i in range(n): |
| src, rec = pairs[i] |
| src_np = (src.clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
| rec_np = (rec.clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
| canvas[:h, i * w : (i + 1) * w] = src_np |
| canvas[h:, i * w : (i + 1) * w] = rec_np |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| Image.fromarray(canvas).save(out_path) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| cfg = load_config(args.config) |
| device = pick_device(args.device) |
|
|
| args.out_dir.mkdir(parents=True, exist_ok=True) |
| print(f"config: {args.config}") |
| print(f"ckpt: {args.ckpt}") |
| print(f"split: {args.split}") |
| print(f"num_samples: {args.num_samples}") |
| print(f"device: {device}") |
| print(f"out_dir: {args.out_dir}") |
|
|
| ds = build_dataset(cfg, split=args.split) |
| subset, subset_idx = select_subset(ds, args.num_samples, args.seed) |
| print(f"split_size={len(ds)} selected={len(subset_idx)}") |
| print(f"preview_sample_ids={[ds.sample_id(i) for i in subset_idx[:5]]}") |
|
|
| loader = DataLoader( |
| subset, |
| batch_size=min(max(1, args.batch_size), len(subset)), |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=device.type == "cuda", |
| drop_last=False, |
| persistent_workers=args.num_workers > 0, |
| ) |
|
|
| module = InferenceModule.load_from_checkpoint( |
| str(args.ckpt), |
| config=cfg, |
| strict=True, |
| map_location="cpu", |
| ) |
|
|
| per_sample, mae_mean, mse_mean, vis_pairs = run_inference( |
| module=module, ds=ds, subset_idx=subset_idx, loader=loader, device=device |
| ) |
|
|
| grid_path = args.out_dir / "reconstruction_grid.png" |
| save_grid(vis_pairs, out_path=grid_path, n_show=args.max_grid, image_size=cfg["data"]["image_size"]) |
|
|
| summary = { |
| "config": str(args.config), |
| "checkpoint": str(args.ckpt), |
| "split": args.split, |
| "seed": args.seed, |
| "selected_num_samples": len(subset_idx), |
| "mean_mae": mae_mean, |
| "mean_mse": mse_mean, |
| "grid_path": str(grid_path), |
| } |
| with (args.out_dir / "summary.json").open("w") as f: |
| json.dump(summary, f, indent=2) |
| with (args.out_dir / "per_sample_metrics.json").open("w") as f: |
| json.dump(per_sample, f, indent=2) |
|
|
| print(f"mean_mae={mae_mean:.6f} mean_mse={mse_mean:.6f}") |
| print(f"saved: {args.out_dir / 'summary.json'}") |
| print(f"saved: {args.out_dir / 'per_sample_metrics.json'}") |
| print(f"saved: {grid_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|