| """Cosmos CV4x8x8 tokenizer / detokenizer for single-frame RGB images. |
| |
| Wraps the Cosmos JIT encoder and decoder with a clean encode/decode interface. |
| Input images are expected as (B, 3, H, W) float32 in [-1, 1]. |
| Latents have shape (B, 16, 1, H/8, W/8) β the temporal dim is always 1 for |
| single frames, matching the CV (Causal Video) 4Γ8Γ8 compression scheme. |
| |
| Usage: |
| tokenizer = CosmosTokenizer(ckpt_dir, device) |
| z = tokenizer(x, mode="encode") # or tokenizer.encode(x) |
| x_hat = tokenizer(z, mode="decode") # or tokenizer.decode(z) |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
|
|
| _DEFAULT_CKPT_DIR = Path( |
| "/group2/ct/weihanx/tactile_world_model/tactile_wm/pretrained_models/" |
| "Cosmos-0.1-Tokenizer-CV4x8x8" |
| ) |
|
|
|
|
| class CosmosTokenizer(nn.Module): |
| def __init__( |
| self, |
| ckpt_dir: str | Path = _DEFAULT_CKPT_DIR, |
| device: str | torch.device = "cuda", |
| dtype: str = "bfloat16", |
| ): |
| super().__init__() |
| self.device = torch.device(device) |
| self.dtype = getattr(torch, dtype) |
| ckpt_dir = Path(ckpt_dir) |
| for name in ("encoder.jit", "decoder.jit"): |
| if not (ckpt_dir / name).exists(): |
| raise FileNotFoundError(f"Cosmos checkpoint not found: {ckpt_dir / name}") |
| self.encoder = torch.jit.load(str(ckpt_dir / "encoder.jit"), map_location=self.device).eval() |
| self.decoder = torch.jit.load(str(ckpt_dir / "decoder.jit"), map_location=self.device).eval() |
|
|
| def to(self, device): |
| super().to(device) |
| self.device = torch.device(device) |
| self.encoder = self.encoder.to(device) |
| self.decoder = self.decoder.to(device) |
| return self |
|
|
| @staticmethod |
| def _extract(obj) -> torch.Tensor: |
| if isinstance(obj, torch.Tensor): |
| return obj |
| if isinstance(obj, (tuple, list)): |
| for item in obj: |
| t = CosmosTokenizer._extract(item) |
| if isinstance(t, torch.Tensor): |
| return t |
| raise TypeError(f"no tensor in model output: {type(obj)!r}") |
|
|
| @torch.no_grad() |
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| """Tokenize (B, 3, H, W) float32 [-1,1] β latent (B, 16, 1, H/8, W/8).""" |
| video = x.to(device=self.device, dtype=self.dtype).unsqueeze(2) |
| return self._extract(self.encoder(video)) |
|
|
| @torch.no_grad() |
| def decode(self, z: torch.Tensor) -> torch.Tensor: |
| """Detokenize latent (B, 16, 1, H/8, W/8) β (B, 3, H, W) float32 [-1,1].""" |
| z = z.to(device=self.device, dtype=self.dtype) |
| recon = self._extract(self.decoder(z)) |
| return recon[:, :, 0].float() |
|
|
| @torch.no_grad() |
| def forward(self, x: torch.Tensor, mode: str) -> torch.Tensor: |
| if mode == "encode": |
| return self.encode(x) |
| if mode == "decode": |
| return self.decode(x) |
| raise ValueError(f"mode must be 'encode' or 'decode', got {mode!r}") |
|
|
|
|
| if __name__ == "__main__": |
| _EPISODE_PATH = Path("/group2/ct/weihanx/tactile_world_model/mode1_v1/0323_episode_000.pt") |
| _OUT_DIR = Path("/group2/ct/weihanx/tactile_world_model/tactile_vae/test_output/cosmos") |
| _SAMPLE_INDICES = [0, 100, 500, 1000, 2000] |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"device: {device}") |
|
|
| |
| tokenizer = CosmosTokenizer(ckpt_dir=_DEFAULT_CKPT_DIR, device=device) |
| print(f"encoder loaded: {_DEFAULT_CKPT_DIR / 'encoder.jit'}") |
| print(f"decoder loaded: {_DEFAULT_CKPT_DIR / 'decoder.jit'}") |
|
|
| |
| ep = torch.load(str(_EPISODE_PATH), map_location="cpu", weights_only=False) |
| views = ep["view"] |
| frames_u8 = views[_SAMPLE_INDICES] |
| frames = frames_u8.float() / 127.5 - 1.0 |
| print(f"\nepisode frames: {tuple(frames.shape)}") |
|
|
| |
| z = tokenizer.encode(frames) |
| print(f"latent shape: {tuple(z.shape)}") |
|
|
| |
| x_hat = tokenizer.decode(z) |
| print(f"recon shape: {tuple(x_hat.shape)}") |
|
|
| |
| _OUT_DIR.mkdir(parents=True, exist_ok=True) |
| psnrs = [] |
| for i, idx in enumerate(_SAMPLE_INDICES): |
| orig_np = ((frames[i].permute(1, 2, 0).clamp(-1, 1) + 1) * 127.5).byte().numpy() |
| recon_np = ((x_hat[i].permute(1, 2, 0).clamp(-1, 1) + 1) * 127.5).byte().cpu().numpy() |
| diff_np = np.abs(orig_np.astype(int) - recon_np.astype(int)).astype(np.uint8) |
| mse = float(((orig_np.astype(float) - recon_np.astype(float)) ** 2).mean()) |
| psnr = 10 * np.log10(255.0 ** 2 / mse) if mse > 0 else float("inf") |
| psnrs.append(psnr) |
| print(f" frame {idx:5d} PSNR={psnr:.2f} dB") |
|
|
| h, w = orig_np.shape[:2] |
| panel = Image.new("RGB", (3 * w + 16, h), (20, 20, 20)) |
| panel.paste(Image.fromarray(orig_np), (0, 0)) |
| panel.paste(Image.fromarray(recon_np), (w + 8, 0)) |
| panel.paste(Image.fromarray(diff_np), (2 * w + 16, 0)) |
| panel.save(_OUT_DIR / f"cosmos_frame_{idx:05d}_panel.png") |
| Image.fromarray(orig_np).save(_OUT_DIR / f"cosmos_frame_{idx:05d}_input.png") |
| Image.fromarray(recon_np).save(_OUT_DIR / f"cosmos_frame_{idx:05d}_recon.png") |
|
|
| print(f"\nmean PSNR: {np.mean(psnrs):.2f} dB (over {len(psnrs)} frames)") |
| print(f"panels saved to {_OUT_DIR}") |
|
|