"""Cosmos CV4x8x8 tokenizer / detokenizer for single-frame RGB images. Wraps the Cosmos JIT encoder and decoder with a clean encode/decode interface. Input images are expected as (B, 3, H, W) float32 in [-1, 1]. Latents have shape (B, 16, 1, H/8, W/8) – the temporal dim is always 1 for single frames, matching the CV (Causal Video) 4×8×8 compression scheme. Usage: tokenizer = CosmosTokenizer(ckpt_dir, device) z = tokenizer(x, mode="encode") # or tokenizer.encode(x) x_hat = tokenizer(z, mode="decode") # or tokenizer.decode(z) """ from __future__ import annotations from pathlib import Path import numpy as np import torch import torch.nn as nn from PIL import Image _DEFAULT_CKPT_DIR = Path( "/group2/ct/weihanx/tactile_world_model/tactile_wm/pretrained_models/" "Cosmos-0.1-Tokenizer-CV4x8x8" ) class CosmosTokenizer(nn.Module): def __init__( self, ckpt_dir: str | Path = _DEFAULT_CKPT_DIR, device: str | torch.device = "cuda", dtype: str = "bfloat16", ): super().__init__() self.device = torch.device(device) self.dtype = getattr(torch, dtype) ckpt_dir = Path(ckpt_dir) for name in ("encoder.jit", "decoder.jit"): if not (ckpt_dir / name).exists(): raise FileNotFoundError(f"Cosmos checkpoint not found: {ckpt_dir / name}") self.encoder = torch.jit.load(str(ckpt_dir / "encoder.jit"), map_location=self.device).eval() self.decoder = torch.jit.load(str(ckpt_dir / "decoder.jit"), map_location=self.device).eval() def to(self, device): super().to(device) self.device = torch.device(device) self.encoder = self.encoder.to(device) self.decoder = self.decoder.to(device) return self @staticmethod def _extract(obj) -> torch.Tensor: if isinstance(obj, torch.Tensor): return obj if isinstance(obj, (tuple, list)): for item in obj: t = CosmosTokenizer._extract(item) if isinstance(t, torch.Tensor): return t raise TypeError(f"no tensor in model output: {type(obj)!r}") @torch.no_grad() def encode(self, x: torch.Tensor) -> torch.Tensor: """Tokenize (B, 3, H, W) float32 [-1,1] → latent (B, 16, 1, H/8, W/8).""" video = x.to(device=self.device, dtype=self.dtype).unsqueeze(2) # (B,3,1,H,W) return self._extract(self.encoder(video)) @torch.no_grad() def decode(self, z: torch.Tensor) -> torch.Tensor: """Detokenize latent (B, 16, 1, H/8, W/8) → (B, 3, H, W) float32 [-1,1].""" z = z.to(device=self.device, dtype=self.dtype) recon = self._extract(self.decoder(z)) # (B, 3, 1, H, W) return recon[:, :, 0].float() @torch.no_grad() def forward(self, x: torch.Tensor, mode: str) -> torch.Tensor: if mode == "encode": return self.encode(x) if mode == "decode": return self.decode(x) raise ValueError(f"mode must be 'encode' or 'decode', got {mode!r}") if __name__ == "__main__": _EPISODE_PATH = Path("/group2/ct/weihanx/tactile_world_model/mode1_v1/0323_episode_000.pt") _OUT_DIR = Path("/group2/ct/weihanx/tactile_world_model/tactile_vae/test_output/cosmos") _SAMPLE_INDICES = [0, 100, 500, 1000, 2000] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"device: {device}") # ── Load tokenizer ──────────────────────────────────────────────────────── tokenizer = CosmosTokenizer(ckpt_dir=_DEFAULT_CKPT_DIR, device=device) print(f"encoder loaded: {_DEFAULT_CKPT_DIR / 'encoder.jit'}") print(f"decoder loaded: {_DEFAULT_CKPT_DIR / 'decoder.jit'}") # ── Load episode frames ─────────────────────────────────────────────────── ep = torch.load(str(_EPISODE_PATH), map_location="cpu", weights_only=False) views = ep["view"] # (T, 3, H, W) uint8 frames_u8 = views[_SAMPLE_INDICES] # (N, 3, H, W) uint8 frames = frames_u8.float() / 127.5 - 1.0 # (N, 3, H, W) float32 in [-1, 1] print(f"\nepisode frames: {tuple(frames.shape)}") # ── Tokenize ────────────────────────────────────────────────────────────── z = tokenizer.encode(frames) print(f"latent shape: {tuple(z.shape)}") # ── Detokenize ──────────────────────────────────────────────────────────── x_hat = tokenizer.decode(z) print(f"recon shape: {tuple(x_hat.shape)}") # ── Save panels and print PSNR ──────────────────────────────────────────── _OUT_DIR.mkdir(parents=True, exist_ok=True) psnrs = [] for i, idx in enumerate(_SAMPLE_INDICES): orig_np = ((frames[i].permute(1, 2, 0).clamp(-1, 1) + 1) * 127.5).byte().numpy() recon_np = ((x_hat[i].permute(1, 2, 0).clamp(-1, 1) + 1) * 127.5).byte().cpu().numpy() diff_np = np.abs(orig_np.astype(int) - recon_np.astype(int)).astype(np.uint8) mse = float(((orig_np.astype(float) - recon_np.astype(float)) ** 2).mean()) psnr = 10 * np.log10(255.0 ** 2 / mse) if mse > 0 else float("inf") psnrs.append(psnr) print(f" frame {idx:5d} PSNR={psnr:.2f} dB") h, w = orig_np.shape[:2] panel = Image.new("RGB", (3 * w + 16, h), (20, 20, 20)) panel.paste(Image.fromarray(orig_np), (0, 0)) panel.paste(Image.fromarray(recon_np), (w + 8, 0)) panel.paste(Image.fromarray(diff_np), (2 * w + 16, 0)) panel.save(_OUT_DIR / f"cosmos_frame_{idx:05d}_panel.png") Image.fromarray(orig_np).save(_OUT_DIR / f"cosmos_frame_{idx:05d}_input.png") Image.fromarray(recon_np).save(_OUT_DIR / f"cosmos_frame_{idx:05d}_recon.png") print(f"\nmean PSNR: {np.mean(psnrs):.2f} dB (over {len(psnrs)} frames)") print(f"panels saved to {_OUT_DIR}")