""" Oasis 500M — sai_wm third-party wrapper. Loaded via trust_remote_code=True: wm = AutoWorldModel.from_pretrained( "your-org/oasis-minecraft", trust_remote_code=True, device="cuda:0", ) The src/ directory (dit.py, vae.py, utils/) is included alongside this file in the HF repo. Weights are downloaded from Etched/oasis-500m. """ import json import logging import os import sys import numpy as np import torch import torch.nn as nn from einops import rearrange from huggingface_hub import hf_hub_download logger = logging.getLogger(__name__) # Number of action keys (matches open-oasis generate.py) NUM_ACTION_KEYS = 25 def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5): """Sigmoid noise schedule — from open-oasis utils.py.""" steps = timesteps + 1 t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps v_start = torch.tensor(start / tau).sigmoid() v_end = torch.tensor(end / tau).sigmoid() alphas_cumprod = ( -((t * (end - start) + start) / tau).sigmoid() + v_end ) / (v_end - v_start) alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) def _ensure_src_importable(): """Add the src/ directory next to this file to sys.path.""" this_dir = os.path.dirname(os.path.abspath(__file__)) src_dir = os.path.join(this_dir, "src") if os.path.isdir(src_dir) and src_dir not in sys.path: # We need the parent of src/ on sys.path so 'from src.dit import ...' works # But since dit.py uses relative imports (from .utils...), we treat src/ as a package parent = this_dir if parent not in sys.path: sys.path.insert(0, parent) class OasisWorldModel: """ Oasis 500M world model — sai_wm third-party wrapper. Loads DiT backbone + ViT-VAE from the bundled src/ package, downloads weights from Etched/oasis-500m, wraps the diffusion sampling loop (matching generate.py) into forward/predict. """ def __init__( self, world_config: dict, np_random=None, device: str = "cpu", ddim_steps: int = 10, noise_abs_max: float = 20.0, ): self.device = device self.np_random = np_random or np.random.default_rng() # ── Load config ─────────────────────────────────────── repo_id = world_config.get("repo_id", "") model_file = world_config.get("model_file", "") cache_dir = os.path.expanduser("~/.cache/sai/world_models") os.makedirs(cache_dir, exist_ok=True) config_filename = f"{model_file}/config.json" if model_file else "config.json" config_path = hf_hub_download( repo_id=repo_id, filename=config_filename, local_dir=cache_dir, ) with open(config_path) as f: config = json.load(f) metadata = config.get("metadata", {}) self.ddim_steps = ddim_steps or metadata.get("ddim_steps", 10) self.max_noise_level = metadata.get("max_noise_level", 1000) self.stabilization_level = metadata.get("stabilization_level", 15) self.scaling_factor = metadata.get("scaling_factor", 0.07843137255) self.noise_abs_max = noise_abs_max self.max_frames = metadata.get("max_frames", 32) # Build world_spec self.world_spec = type("WorldModelSpec", (), { "name": config.get("name", "oasis-500m"), "env": config.get("env", "Minecraft"), "model_type": "oasis", "metadata": metadata, "validate": lambda self: None, })() # ── Import model definitions from src/ ──────────────── _ensure_src_importable() from src.dit import DiT from src.vae import VAE_models # ── Download and load weights ───────────────────────── weight_repo = metadata.get("weight_repo", "Etched/oasis-500m") dit_file = metadata.get("dit_file", "oasis500m.pt") vae_file = metadata.get("vae_file", "vit-l-20.pt") weight_dir = os.path.join(cache_dir, "oasis_weights") os.makedirs(weight_dir, exist_ok=True) dit_path = hf_hub_download( repo_id=weight_repo, filename=dit_file, local_dir=weight_dir, ) vae_path = hf_hub_download( repo_id=weight_repo, filename=vae_file, local_dir=weight_dir, ) # ── Load DiT (matching generate.py: DiT_models["DiT-S/2"]) ─ logger.info("Loading Oasis DiT from %s", dit_path) self.dit = DiT( input_h=18, input_w=32, patch_size=2, in_channels=16, hidden_size=1024, depth=16, num_heads=16, mlp_ratio=4.0, external_cond_dim=NUM_ACTION_KEYS, max_frames=self.max_frames, ) ckpt = torch.load(dit_path, map_location=torch.device(device), weights_only=True) self.dit.load_state_dict(ckpt, strict=False) self.dit = self.dit.to(device).eval() # ── Load VAE ────────────────────────────────────────── logger.info("Loading Oasis ViT-VAE from %s", vae_path) self.vae = VAE_models["vit-l-20-shallow-encoder"]() vae_ckpt = torch.load(vae_path, map_location=torch.device(device), weights_only=True) self.vae.load_state_dict(vae_ckpt) self.vae = self.vae.to(device).eval() # ── Precompute noise schedule ───────────────────────── betas = sigmoid_beta_schedule(self.max_noise_level).float().to(device) alphas = 1.0 - betas self.alphas_cumprod = torch.cumprod(alphas, dim=0) self.alphas_cumprod = rearrange(self.alphas_cumprod, "T -> T 1 1 1") # Noise range (matching generate.py) self.noise_range = torch.linspace( -1, self.max_noise_level - 1, self.ddim_steps + 1, ) # ── State buffers ───────────────────────────────────── self._latent_buffer = None # (1, T, C, H, W) self._action_buffer = None # (1, T, num_action_keys) self._frame_idx = 0 logger.info("Oasis 500M loaded on %s", device) # ── sai_wm interface ────────────────────────────────────── def reset(self, seed=None): if seed is not None: torch.manual_seed(seed) self.np_random = np.random.default_rng(seed) self._latent_buffer = None self._action_buffer = None self._frame_idx = 0 def forward(self, obs: np.ndarray) -> dict: """ Encode initial frame(s). Parameters ---------- obs : np.ndarray RGB image, CHW or HWC, [0,1] or [0,255]. """ img = self._to_tensor(obs) # (1, C, H, W) with torch.no_grad(): with torch.autocast(self.device, dtype=torch.float16): z = self.vae.encode(img * 2 - 1).mean * self.scaling_factor # z: (1, seq_h*seq_w, latent_dim) → (1, C, H, W) ph = self.vae.seq_h pw = self.vae.seq_w z = rearrange(z, "b (h w) c -> b c h w", h=ph, w=pw) # Init buffers self._latent_buffer = z.unsqueeze(1) # (1, 1, C, H, W) # Initial "no-op" action self._action_buffer = torch.zeros( 1, 1, NUM_ACTION_KEYS, device=self.device, ) self._frame_idx = 1 recon = self._decode(z) return { "latent_state": z.squeeze(0).cpu().numpy(), "recon": recon, } def predict(self, action) -> dict: """ Generate next frame. Sampling loop matches generate.py exactly. Parameters ---------- action : int or np.ndarray If int: index into the 25 action keys (sets that key to 1). If np.ndarray of shape (25,): raw one-hot/continuous action vector. """ if self._latent_buffer is None: raise RuntimeError("Call forward() first.") # ── Prepare action ──────────────────────────────────── act = self._encode_action(action) # (1, 1, 25) self._action_buffer = torch.cat( [self._action_buffer, act], dim=1, ) # ── Append noise chunk ──────────────────────────────── B = 1 chunk = torch.randn( (B, 1, *self._latent_buffer.shape[-3:]), device=self.device, ) chunk = chunk.clamp(-self.noise_abs_max, self.noise_abs_max) x = torch.cat([self._latent_buffer, chunk], dim=1) i = self._frame_idx # current frame index (0-based) start_frame = max(0, i + 1 - self.max_frames) # ── Diffusion denoising loop (from generate.py) ─────── for noise_idx in reversed(range(1, self.ddim_steps + 1)): # Noise levels: context frames get stabilization_level, last frame gets actual noise t_ctx = torch.full( (B, i), self.stabilization_level - 1, dtype=torch.long, device=self.device, ) t = torch.full( (B, 1), int(self.noise_range[noise_idx].item()), dtype=torch.long, device=self.device, ) t_next = torch.full( (B, 1), int(self.noise_range[noise_idx - 1].item()), dtype=torch.long, device=self.device, ) t_next = torch.where(t_next < 0, t, t_next) t_full = torch.cat([t_ctx, t], dim=1) t_next_full = torch.cat([t_ctx, t_next], dim=1) # Sliding window x_curr = x.clone()[:, start_frame:] t_slice = t_full[:, start_frame:] t_next_slice = t_next_full[:, start_frame:] actions_slice = self._action_buffer[:, start_frame:i + 1] # DiT forward with torch.no_grad(): with torch.autocast(self.device, dtype=torch.float16): v = self.dit(x_curr, t_slice, external_cond=actions_slice) # v-prediction → x_start, x_noise (matching generate.py) x_start = ( self.alphas_cumprod[t_slice].sqrt() * x_curr - (1 - self.alphas_cumprod[t_slice]).sqrt() * v ) x_noise = ( (1 / self.alphas_cumprod[t_slice]).sqrt() * x_curr - x_start ) / (1 / self.alphas_cumprod[t_slice] - 1).sqrt() # Frame prediction alpha_next = self.alphas_cumprod[t_next_slice] alpha_next[:, :-1] = torch.ones_like(alpha_next[:, :-1]) if noise_idx == 1: alpha_next[:, -1:] = torch.ones_like(alpha_next[:, -1:]) x_pred = alpha_next.sqrt() * x_start + x_noise * (1 - alpha_next).sqrt() x[:, -1:] = x_pred[:, -1:] # ── Update state ────────────────────────────────────── new_latent = x[:, -1:] self._latent_buffer = x # keep full buffer (includes new frame) # Trim to max context if self._latent_buffer.shape[1] > self.max_frames: trim = self._latent_buffer.shape[1] - self.max_frames self._latent_buffer = self._latent_buffer[:, trim:] self._action_buffer = self._action_buffer[:, trim:] self._frame_idx += 1 recon = self._decode(new_latent.squeeze(1)) return { "latent_state": new_latent.squeeze(0).squeeze(0).cpu().numpy(), "recon": recon, "reward": None, "terminated": False, } # ── Helpers ──────────────────────────────────────────────── def _encode_action(self, action) -> torch.Tensor: """Convert action to (1, 1, 25) tensor.""" if isinstance(action, np.ndarray) and action.shape == (NUM_ACTION_KEYS,): return torch.from_numpy(action).float().reshape(1, 1, -1).to(self.device) elif isinstance(action, (int, np.integer)): act = torch.zeros(1, 1, NUM_ACTION_KEYS, device=self.device) act[0, 0, int(action)] = 1.0 return act elif isinstance(action, torch.Tensor): return action.float().reshape(1, 1, -1).to(self.device) else: raise ValueError( f"Action must be int (action key index), np.ndarray(25,), " f"or torch.Tensor. Got {type(action)}." ) def _to_tensor(self, obs: np.ndarray) -> torch.Tensor: """Convert obs to (1, C, H, W) float [0,1].""" img = np.asarray(obs, dtype=np.float32) if img.ndim == 3 and img.shape[-1] in (1, 3, 4): img = np.transpose(img, (2, 0, 1)) if img.max() > 1.0: img = img / 255.0 return torch.from_numpy(img).unsqueeze(0).to(self.device) def _decode(self, z: torch.Tensor) -> np.ndarray: """Decode latent (1, C, H, W) → RGB (C, H, W) in [0,1].""" with torch.no_grad(): with torch.autocast(self.device, dtype=torch.float16): z_flat = rearrange(z, "b c h w -> b (h w) c") decoded = self.vae.decode(z_flat / self.scaling_factor) decoded = (decoded + 1) / 2 return decoded.squeeze(0).clamp(0, 1).float().cpu().numpy()