tactile-vae / model /cosmos_tokenizer.py
WitneyWW's picture
Initial upload of tactile_vae (code, model, config, inference)
3770c94 verified
Raw
History Blame Contribute Delete
6.39 kB
"""Cosmos CV4x8x8 tokenizer / detokenizer for single-frame RGB images.
Wraps the Cosmos JIT encoder and decoder with a clean encode/decode interface.
Input images are expected as (B, 3, H, W) float32 in [-1, 1].
Latents have shape (B, 16, 1, H/8, W/8) – the temporal dim is always 1 for
single frames, matching the CV (Causal Video) 4Γ—8Γ—8 compression scheme.
Usage:
tokenizer = CosmosTokenizer(ckpt_dir, device)
z = tokenizer(x, mode="encode") # or tokenizer.encode(x)
x_hat = tokenizer(z, mode="decode") # or tokenizer.decode(z)
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
_DEFAULT_CKPT_DIR = Path(
"/group2/ct/weihanx/tactile_world_model/tactile_wm/pretrained_models/"
"Cosmos-0.1-Tokenizer-CV4x8x8"
)
class CosmosTokenizer(nn.Module):
def __init__(
self,
ckpt_dir: str | Path = _DEFAULT_CKPT_DIR,
device: str | torch.device = "cuda",
dtype: str = "bfloat16",
):
super().__init__()
self.device = torch.device(device)
self.dtype = getattr(torch, dtype)
ckpt_dir = Path(ckpt_dir)
for name in ("encoder.jit", "decoder.jit"):
if not (ckpt_dir / name).exists():
raise FileNotFoundError(f"Cosmos checkpoint not found: {ckpt_dir / name}")
self.encoder = torch.jit.load(str(ckpt_dir / "encoder.jit"), map_location=self.device).eval()
self.decoder = torch.jit.load(str(ckpt_dir / "decoder.jit"), map_location=self.device).eval()
def to(self, device):
super().to(device)
self.device = torch.device(device)
self.encoder = self.encoder.to(device)
self.decoder = self.decoder.to(device)
return self
@staticmethod
def _extract(obj) -> torch.Tensor:
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, (tuple, list)):
for item in obj:
t = CosmosTokenizer._extract(item)
if isinstance(t, torch.Tensor):
return t
raise TypeError(f"no tensor in model output: {type(obj)!r}")
@torch.no_grad()
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Tokenize (B, 3, H, W) float32 [-1,1] β†’ latent (B, 16, 1, H/8, W/8)."""
video = x.to(device=self.device, dtype=self.dtype).unsqueeze(2) # (B,3,1,H,W)
return self._extract(self.encoder(video))
@torch.no_grad()
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Detokenize latent (B, 16, 1, H/8, W/8) β†’ (B, 3, H, W) float32 [-1,1]."""
z = z.to(device=self.device, dtype=self.dtype)
recon = self._extract(self.decoder(z)) # (B, 3, 1, H, W)
return recon[:, :, 0].float()
@torch.no_grad()
def forward(self, x: torch.Tensor, mode: str) -> torch.Tensor:
if mode == "encode":
return self.encode(x)
if mode == "decode":
return self.decode(x)
raise ValueError(f"mode must be 'encode' or 'decode', got {mode!r}")
if __name__ == "__main__":
_EPISODE_PATH = Path("/group2/ct/weihanx/tactile_world_model/mode1_v1/0323_episode_000.pt")
_OUT_DIR = Path("/group2/ct/weihanx/tactile_world_model/tactile_vae/test_output/cosmos")
_SAMPLE_INDICES = [0, 100, 500, 1000, 2000]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
# ── Load tokenizer ────────────────────────────────────────────────────────
tokenizer = CosmosTokenizer(ckpt_dir=_DEFAULT_CKPT_DIR, device=device)
print(f"encoder loaded: {_DEFAULT_CKPT_DIR / 'encoder.jit'}")
print(f"decoder loaded: {_DEFAULT_CKPT_DIR / 'decoder.jit'}")
# ── Load episode frames ───────────────────────────────────────────────────
ep = torch.load(str(_EPISODE_PATH), map_location="cpu", weights_only=False)
views = ep["view"] # (T, 3, H, W) uint8
frames_u8 = views[_SAMPLE_INDICES] # (N, 3, H, W) uint8
frames = frames_u8.float() / 127.5 - 1.0 # (N, 3, H, W) float32 in [-1, 1]
print(f"\nepisode frames: {tuple(frames.shape)}")
# ── Tokenize ──────────────────────────────────────────────────────────────
z = tokenizer.encode(frames)
print(f"latent shape: {tuple(z.shape)}")
# ── Detokenize ────────────────────────────────────────────────────────────
x_hat = tokenizer.decode(z)
print(f"recon shape: {tuple(x_hat.shape)}")
# ── Save panels and print PSNR ────────────────────────────────────────────
_OUT_DIR.mkdir(parents=True, exist_ok=True)
psnrs = []
for i, idx in enumerate(_SAMPLE_INDICES):
orig_np = ((frames[i].permute(1, 2, 0).clamp(-1, 1) + 1) * 127.5).byte().numpy()
recon_np = ((x_hat[i].permute(1, 2, 0).clamp(-1, 1) + 1) * 127.5).byte().cpu().numpy()
diff_np = np.abs(orig_np.astype(int) - recon_np.astype(int)).astype(np.uint8)
mse = float(((orig_np.astype(float) - recon_np.astype(float)) ** 2).mean())
psnr = 10 * np.log10(255.0 ** 2 / mse) if mse > 0 else float("inf")
psnrs.append(psnr)
print(f" frame {idx:5d} PSNR={psnr:.2f} dB")
h, w = orig_np.shape[:2]
panel = Image.new("RGB", (3 * w + 16, h), (20, 20, 20))
panel.paste(Image.fromarray(orig_np), (0, 0))
panel.paste(Image.fromarray(recon_np), (w + 8, 0))
panel.paste(Image.fromarray(diff_np), (2 * w + 16, 0))
panel.save(_OUT_DIR / f"cosmos_frame_{idx:05d}_panel.png")
Image.fromarray(orig_np).save(_OUT_DIR / f"cosmos_frame_{idx:05d}_input.png")
Image.fromarray(recon_np).save(_OUT_DIR / f"cosmos_frame_{idx:05d}_recon.png")
print(f"\nmean PSNR: {np.mean(psnrs):.2f} dB (over {len(psnrs)} frames)")
print(f"panels saved to {_OUT_DIR}")