| """ |
| Oasis 500M β sai_wm third-party wrapper. |
| |
| Loaded via trust_remote_code=True: |
| |
| wm = AutoWorldModel.from_pretrained( |
| "your-org/oasis-minecraft", |
| trust_remote_code=True, |
| device="cuda:0", |
| ) |
| |
| The src/ directory (dit.py, vae.py, utils/) is included alongside |
| this file in the HF repo. Weights are downloaded from Etched/oasis-500m. |
| """ |
|
|
| import json |
| import logging |
| import os |
| import sys |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| from huggingface_hub import hf_hub_download |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| NUM_ACTION_KEYS = 25 |
|
|
|
|
| def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5): |
| """Sigmoid noise schedule β from open-oasis utils.py.""" |
| steps = timesteps + 1 |
| t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps |
| v_start = torch.tensor(start / tau).sigmoid() |
| v_end = torch.tensor(end / tau).sigmoid() |
| alphas_cumprod = ( |
| -((t * (end - start) + start) / tau).sigmoid() + v_end |
| ) / (v_end - v_start) |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| return torch.clip(betas, 0, 0.999) |
|
|
|
|
| def _ensure_src_importable(): |
| """Add the src/ directory next to this file to sys.path.""" |
| this_dir = os.path.dirname(os.path.abspath(__file__)) |
| src_dir = os.path.join(this_dir, "src") |
| if os.path.isdir(src_dir) and src_dir not in sys.path: |
| |
| |
| parent = this_dir |
| if parent not in sys.path: |
| sys.path.insert(0, parent) |
|
|
|
|
| class OasisWorldModel: |
| """ |
| Oasis 500M world model β sai_wm third-party wrapper. |
| |
| Loads DiT backbone + ViT-VAE from the bundled src/ package, |
| downloads weights from Etched/oasis-500m, wraps the diffusion |
| sampling loop (matching generate.py) into forward/predict. |
| """ |
|
|
| def __init__( |
| self, |
| world_config: dict, |
| np_random=None, |
| device: str = "cpu", |
| ddim_steps: int = 10, |
| noise_abs_max: float = 20.0, |
| ): |
| self.device = device |
| self.np_random = np_random or np.random.default_rng() |
|
|
| |
| repo_id = world_config.get("repo_id", "") |
| model_file = world_config.get("model_file", "") |
|
|
| cache_dir = os.path.expanduser("~/.cache/sai/world_models") |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| config_filename = f"{model_file}/config.json" if model_file else "config.json" |
| config_path = hf_hub_download( |
| repo_id=repo_id, filename=config_filename, |
| local_dir=cache_dir, |
| ) |
| with open(config_path) as f: |
| config = json.load(f) |
|
|
| metadata = config.get("metadata", {}) |
| self.ddim_steps = ddim_steps or metadata.get("ddim_steps", 10) |
| self.max_noise_level = metadata.get("max_noise_level", 1000) |
| self.stabilization_level = metadata.get("stabilization_level", 15) |
| self.scaling_factor = metadata.get("scaling_factor", 0.07843137255) |
| self.noise_abs_max = noise_abs_max |
| self.max_frames = metadata.get("max_frames", 32) |
|
|
| |
| self.world_spec = type("WorldModelSpec", (), { |
| "name": config.get("name", "oasis-500m"), |
| "env": config.get("env", "Minecraft"), |
| "model_type": "oasis", |
| "metadata": metadata, |
| "validate": lambda self: None, |
| })() |
|
|
| |
| _ensure_src_importable() |
| from src.dit import DiT |
| from src.vae import VAE_models |
|
|
| |
| weight_repo = metadata.get("weight_repo", "Etched/oasis-500m") |
| dit_file = metadata.get("dit_file", "oasis500m.pt") |
| vae_file = metadata.get("vae_file", "vit-l-20.pt") |
|
|
| weight_dir = os.path.join(cache_dir, "oasis_weights") |
| os.makedirs(weight_dir, exist_ok=True) |
|
|
| dit_path = hf_hub_download( |
| repo_id=weight_repo, filename=dit_file, |
| local_dir=weight_dir, |
| ) |
| vae_path = hf_hub_download( |
| repo_id=weight_repo, filename=vae_file, |
| local_dir=weight_dir, |
| ) |
|
|
| |
| logger.info("Loading Oasis DiT from %s", dit_path) |
| self.dit = DiT( |
| input_h=18, input_w=32, patch_size=2, |
| in_channels=16, hidden_size=1024, depth=16, |
| num_heads=16, mlp_ratio=4.0, |
| external_cond_dim=NUM_ACTION_KEYS, |
| max_frames=self.max_frames, |
| ) |
| ckpt = torch.load(dit_path, map_location=torch.device(device), weights_only=True) |
| self.dit.load_state_dict(ckpt, strict=False) |
| self.dit = self.dit.to(device).eval() |
|
|
| |
| logger.info("Loading Oasis ViT-VAE from %s", vae_path) |
| self.vae = VAE_models["vit-l-20-shallow-encoder"]() |
| vae_ckpt = torch.load(vae_path, map_location=torch.device(device), weights_only=True) |
| self.vae.load_state_dict(vae_ckpt) |
| self.vae = self.vae.to(device).eval() |
|
|
| |
| betas = sigmoid_beta_schedule(self.max_noise_level).float().to(device) |
| alphas = 1.0 - betas |
| self.alphas_cumprod = torch.cumprod(alphas, dim=0) |
| self.alphas_cumprod = rearrange(self.alphas_cumprod, "T -> T 1 1 1") |
|
|
| |
| self.noise_range = torch.linspace( |
| -1, self.max_noise_level - 1, self.ddim_steps + 1, |
| ) |
|
|
| |
| self._latent_buffer = None |
| self._action_buffer = None |
| self._frame_idx = 0 |
|
|
| logger.info("Oasis 500M loaded on %s", device) |
|
|
| |
|
|
| def reset(self, seed=None): |
| if seed is not None: |
| torch.manual_seed(seed) |
| self.np_random = np.random.default_rng(seed) |
| self._latent_buffer = None |
| self._action_buffer = None |
| self._frame_idx = 0 |
|
|
| def forward(self, obs: np.ndarray) -> dict: |
| """ |
| Encode initial frame(s). |
| |
| Parameters |
| ---------- |
| obs : np.ndarray |
| RGB image, CHW or HWC, [0,1] or [0,255]. |
| """ |
| img = self._to_tensor(obs) |
|
|
| with torch.no_grad(): |
| with torch.autocast(self.device, dtype=torch.float16): |
| z = self.vae.encode(img * 2 - 1).mean * self.scaling_factor |
|
|
| |
| ph = self.vae.seq_h |
| pw = self.vae.seq_w |
| z = rearrange(z, "b (h w) c -> b c h w", h=ph, w=pw) |
|
|
| |
| self._latent_buffer = z.unsqueeze(1) |
| |
| self._action_buffer = torch.zeros( |
| 1, 1, NUM_ACTION_KEYS, device=self.device, |
| ) |
| self._frame_idx = 1 |
|
|
| recon = self._decode(z) |
|
|
| return { |
| "latent_state": z.squeeze(0).cpu().numpy(), |
| "recon": recon, |
| } |
|
|
| def predict(self, action) -> dict: |
| """ |
| Generate next frame. Sampling loop matches generate.py exactly. |
| |
| Parameters |
| ---------- |
| action : int or np.ndarray |
| If int: index into the 25 action keys (sets that key to 1). |
| If np.ndarray of shape (25,): raw one-hot/continuous action vector. |
| """ |
| if self._latent_buffer is None: |
| raise RuntimeError("Call forward() first.") |
|
|
| |
| act = self._encode_action(action) |
| self._action_buffer = torch.cat( |
| [self._action_buffer, act], dim=1, |
| ) |
|
|
| |
| B = 1 |
| chunk = torch.randn( |
| (B, 1, *self._latent_buffer.shape[-3:]), device=self.device, |
| ) |
| chunk = chunk.clamp(-self.noise_abs_max, self.noise_abs_max) |
| x = torch.cat([self._latent_buffer, chunk], dim=1) |
|
|
| i = self._frame_idx |
| start_frame = max(0, i + 1 - self.max_frames) |
|
|
| |
| for noise_idx in reversed(range(1, self.ddim_steps + 1)): |
| |
| t_ctx = torch.full( |
| (B, i), self.stabilization_level - 1, |
| dtype=torch.long, device=self.device, |
| ) |
| t = torch.full( |
| (B, 1), int(self.noise_range[noise_idx].item()), |
| dtype=torch.long, device=self.device, |
| ) |
| t_next = torch.full( |
| (B, 1), int(self.noise_range[noise_idx - 1].item()), |
| dtype=torch.long, device=self.device, |
| ) |
| t_next = torch.where(t_next < 0, t, t_next) |
|
|
| t_full = torch.cat([t_ctx, t], dim=1) |
| t_next_full = torch.cat([t_ctx, t_next], dim=1) |
|
|
| |
| x_curr = x.clone()[:, start_frame:] |
| t_slice = t_full[:, start_frame:] |
| t_next_slice = t_next_full[:, start_frame:] |
| actions_slice = self._action_buffer[:, start_frame:i + 1] |
|
|
| |
| with torch.no_grad(): |
| with torch.autocast(self.device, dtype=torch.float16): |
| v = self.dit(x_curr, t_slice, external_cond=actions_slice) |
|
|
| |
| x_start = ( |
| self.alphas_cumprod[t_slice].sqrt() * x_curr |
| - (1 - self.alphas_cumprod[t_slice]).sqrt() * v |
| ) |
| x_noise = ( |
| (1 / self.alphas_cumprod[t_slice]).sqrt() * x_curr - x_start |
| ) / (1 / self.alphas_cumprod[t_slice] - 1).sqrt() |
|
|
| |
| alpha_next = self.alphas_cumprod[t_next_slice] |
| alpha_next[:, :-1] = torch.ones_like(alpha_next[:, :-1]) |
| if noise_idx == 1: |
| alpha_next[:, -1:] = torch.ones_like(alpha_next[:, -1:]) |
|
|
| x_pred = alpha_next.sqrt() * x_start + x_noise * (1 - alpha_next).sqrt() |
| x[:, -1:] = x_pred[:, -1:] |
|
|
| |
| new_latent = x[:, -1:] |
| self._latent_buffer = x |
|
|
| |
| if self._latent_buffer.shape[1] > self.max_frames: |
| trim = self._latent_buffer.shape[1] - self.max_frames |
| self._latent_buffer = self._latent_buffer[:, trim:] |
| self._action_buffer = self._action_buffer[:, trim:] |
|
|
| self._frame_idx += 1 |
|
|
| recon = self._decode(new_latent.squeeze(1)) |
|
|
| return { |
| "latent_state": new_latent.squeeze(0).squeeze(0).cpu().numpy(), |
| "recon": recon, |
| "reward": None, |
| "terminated": False, |
| } |
|
|
| |
|
|
| def _encode_action(self, action) -> torch.Tensor: |
| """Convert action to (1, 1, 25) tensor.""" |
| if isinstance(action, np.ndarray) and action.shape == (NUM_ACTION_KEYS,): |
| return torch.from_numpy(action).float().reshape(1, 1, -1).to(self.device) |
| elif isinstance(action, (int, np.integer)): |
| act = torch.zeros(1, 1, NUM_ACTION_KEYS, device=self.device) |
| act[0, 0, int(action)] = 1.0 |
| return act |
| elif isinstance(action, torch.Tensor): |
| return action.float().reshape(1, 1, -1).to(self.device) |
| else: |
| raise ValueError( |
| f"Action must be int (action key index), np.ndarray(25,), " |
| f"or torch.Tensor. Got {type(action)}." |
| ) |
|
|
| def _to_tensor(self, obs: np.ndarray) -> torch.Tensor: |
| """Convert obs to (1, C, H, W) float [0,1].""" |
| img = np.asarray(obs, dtype=np.float32) |
| if img.ndim == 3 and img.shape[-1] in (1, 3, 4): |
| img = np.transpose(img, (2, 0, 1)) |
| if img.max() > 1.0: |
| img = img / 255.0 |
| return torch.from_numpy(img).unsqueeze(0).to(self.device) |
|
|
| def _decode(self, z: torch.Tensor) -> np.ndarray: |
| """Decode latent (1, C, H, W) β RGB (C, H, W) in [0,1].""" |
| with torch.no_grad(): |
| with torch.autocast(self.device, dtype=torch.float16): |
| z_flat = rearrange(z, "b c h w -> b (h w) c") |
| decoded = self.vae.decode(z_flat / self.scaling_factor) |
| decoded = (decoded + 1) / 2 |
| return decoded.squeeze(0).clamp(0, 1).float().cpu().numpy() |