oasis-500m / oasis_wrapper.py
ShaswatRobotics's picture
Update oasis_wrapper.py
c4274b4 verified
"""
Oasis 500M β€” sai_wm third-party wrapper.
Loaded via trust_remote_code=True:
wm = AutoWorldModel.from_pretrained(
"your-org/oasis-minecraft",
trust_remote_code=True,
device="cuda:0",
)
The src/ directory (dit.py, vae.py, utils/) is included alongside
this file in the HF repo. Weights are downloaded from Etched/oasis-500m.
"""
import json
import logging
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
# Number of action keys (matches open-oasis generate.py)
NUM_ACTION_KEYS = 25
def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
"""Sigmoid noise schedule β€” from open-oasis utils.py."""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
v_start = torch.tensor(start / tau).sigmoid()
v_end = torch.tensor(end / tau).sigmoid()
alphas_cumprod = (
-((t * (end - start) + start) / tau).sigmoid() + v_end
) / (v_end - v_start)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def _ensure_src_importable():
"""Add the src/ directory next to this file to sys.path."""
this_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(this_dir, "src")
if os.path.isdir(src_dir) and src_dir not in sys.path:
# We need the parent of src/ on sys.path so 'from src.dit import ...' works
# But since dit.py uses relative imports (from .utils...), we treat src/ as a package
parent = this_dir
if parent not in sys.path:
sys.path.insert(0, parent)
class OasisWorldModel:
"""
Oasis 500M world model β€” sai_wm third-party wrapper.
Loads DiT backbone + ViT-VAE from the bundled src/ package,
downloads weights from Etched/oasis-500m, wraps the diffusion
sampling loop (matching generate.py) into forward/predict.
"""
def __init__(
self,
world_config: dict,
np_random=None,
device: str = "cpu",
ddim_steps: int = 10,
noise_abs_max: float = 20.0,
):
self.device = device
self.np_random = np_random or np.random.default_rng()
# ── Load config ───────────────────────────────────────
repo_id = world_config.get("repo_id", "")
model_file = world_config.get("model_file", "")
cache_dir = os.path.expanduser("~/.cache/sai/world_models")
os.makedirs(cache_dir, exist_ok=True)
config_filename = f"{model_file}/config.json" if model_file else "config.json"
config_path = hf_hub_download(
repo_id=repo_id, filename=config_filename,
local_dir=cache_dir,
)
with open(config_path) as f:
config = json.load(f)
metadata = config.get("metadata", {})
self.ddim_steps = ddim_steps or metadata.get("ddim_steps", 10)
self.max_noise_level = metadata.get("max_noise_level", 1000)
self.stabilization_level = metadata.get("stabilization_level", 15)
self.scaling_factor = metadata.get("scaling_factor", 0.07843137255)
self.noise_abs_max = noise_abs_max
self.max_frames = metadata.get("max_frames", 32)
# Build world_spec
self.world_spec = type("WorldModelSpec", (), {
"name": config.get("name", "oasis-500m"),
"env": config.get("env", "Minecraft"),
"model_type": "oasis",
"metadata": metadata,
"validate": lambda self: None,
})()
# ── Import model definitions from src/ ────────────────
_ensure_src_importable()
from src.dit import DiT
from src.vae import VAE_models
# ── Download and load weights ─────────────────────────
weight_repo = metadata.get("weight_repo", "Etched/oasis-500m")
dit_file = metadata.get("dit_file", "oasis500m.pt")
vae_file = metadata.get("vae_file", "vit-l-20.pt")
weight_dir = os.path.join(cache_dir, "oasis_weights")
os.makedirs(weight_dir, exist_ok=True)
dit_path = hf_hub_download(
repo_id=weight_repo, filename=dit_file,
local_dir=weight_dir,
)
vae_path = hf_hub_download(
repo_id=weight_repo, filename=vae_file,
local_dir=weight_dir,
)
# ── Load DiT (matching generate.py: DiT_models["DiT-S/2"]) ─
logger.info("Loading Oasis DiT from %s", dit_path)
self.dit = DiT(
input_h=18, input_w=32, patch_size=2,
in_channels=16, hidden_size=1024, depth=16,
num_heads=16, mlp_ratio=4.0,
external_cond_dim=NUM_ACTION_KEYS,
max_frames=self.max_frames,
)
ckpt = torch.load(dit_path, map_location=torch.device(device), weights_only=True)
self.dit.load_state_dict(ckpt, strict=False)
self.dit = self.dit.to(device).eval()
# ── Load VAE ──────────────────────────────────────────
logger.info("Loading Oasis ViT-VAE from %s", vae_path)
self.vae = VAE_models["vit-l-20-shallow-encoder"]()
vae_ckpt = torch.load(vae_path, map_location=torch.device(device), weights_only=True)
self.vae.load_state_dict(vae_ckpt)
self.vae = self.vae.to(device).eval()
# ── Precompute noise schedule ─────────────────────────
betas = sigmoid_beta_schedule(self.max_noise_level).float().to(device)
alphas = 1.0 - betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod = rearrange(self.alphas_cumprod, "T -> T 1 1 1")
# Noise range (matching generate.py)
self.noise_range = torch.linspace(
-1, self.max_noise_level - 1, self.ddim_steps + 1,
)
# ── State buffers ─────────────────────────────────────
self._latent_buffer = None # (1, T, C, H, W)
self._action_buffer = None # (1, T, num_action_keys)
self._frame_idx = 0
logger.info("Oasis 500M loaded on %s", device)
# ── sai_wm interface ──────────────────────────────────────
def reset(self, seed=None):
if seed is not None:
torch.manual_seed(seed)
self.np_random = np.random.default_rng(seed)
self._latent_buffer = None
self._action_buffer = None
self._frame_idx = 0
def forward(self, obs: np.ndarray) -> dict:
"""
Encode initial frame(s).
Parameters
----------
obs : np.ndarray
RGB image, CHW or HWC, [0,1] or [0,255].
"""
img = self._to_tensor(obs) # (1, C, H, W)
with torch.no_grad():
with torch.autocast(self.device, dtype=torch.float16):
z = self.vae.encode(img * 2 - 1).mean * self.scaling_factor
# z: (1, seq_h*seq_w, latent_dim) β†’ (1, C, H, W)
ph = self.vae.seq_h
pw = self.vae.seq_w
z = rearrange(z, "b (h w) c -> b c h w", h=ph, w=pw)
# Init buffers
self._latent_buffer = z.unsqueeze(1) # (1, 1, C, H, W)
# Initial "no-op" action
self._action_buffer = torch.zeros(
1, 1, NUM_ACTION_KEYS, device=self.device,
)
self._frame_idx = 1
recon = self._decode(z)
return {
"latent_state": z.squeeze(0).cpu().numpy(),
"recon": recon,
}
def predict(self, action) -> dict:
"""
Generate next frame. Sampling loop matches generate.py exactly.
Parameters
----------
action : int or np.ndarray
If int: index into the 25 action keys (sets that key to 1).
If np.ndarray of shape (25,): raw one-hot/continuous action vector.
"""
if self._latent_buffer is None:
raise RuntimeError("Call forward() first.")
# ── Prepare action ────────────────────────────────────
act = self._encode_action(action) # (1, 1, 25)
self._action_buffer = torch.cat(
[self._action_buffer, act], dim=1,
)
# ── Append noise chunk ────────────────────────────────
B = 1
chunk = torch.randn(
(B, 1, *self._latent_buffer.shape[-3:]), device=self.device,
)
chunk = chunk.clamp(-self.noise_abs_max, self.noise_abs_max)
x = torch.cat([self._latent_buffer, chunk], dim=1)
i = self._frame_idx # current frame index (0-based)
start_frame = max(0, i + 1 - self.max_frames)
# ── Diffusion denoising loop (from generate.py) ───────
for noise_idx in reversed(range(1, self.ddim_steps + 1)):
# Noise levels: context frames get stabilization_level, last frame gets actual noise
t_ctx = torch.full(
(B, i), self.stabilization_level - 1,
dtype=torch.long, device=self.device,
)
t = torch.full(
(B, 1), int(self.noise_range[noise_idx].item()),
dtype=torch.long, device=self.device,
)
t_next = torch.full(
(B, 1), int(self.noise_range[noise_idx - 1].item()),
dtype=torch.long, device=self.device,
)
t_next = torch.where(t_next < 0, t, t_next)
t_full = torch.cat([t_ctx, t], dim=1)
t_next_full = torch.cat([t_ctx, t_next], dim=1)
# Sliding window
x_curr = x.clone()[:, start_frame:]
t_slice = t_full[:, start_frame:]
t_next_slice = t_next_full[:, start_frame:]
actions_slice = self._action_buffer[:, start_frame:i + 1]
# DiT forward
with torch.no_grad():
with torch.autocast(self.device, dtype=torch.float16):
v = self.dit(x_curr, t_slice, external_cond=actions_slice)
# v-prediction β†’ x_start, x_noise (matching generate.py)
x_start = (
self.alphas_cumprod[t_slice].sqrt() * x_curr
- (1 - self.alphas_cumprod[t_slice]).sqrt() * v
)
x_noise = (
(1 / self.alphas_cumprod[t_slice]).sqrt() * x_curr - x_start
) / (1 / self.alphas_cumprod[t_slice] - 1).sqrt()
# Frame prediction
alpha_next = self.alphas_cumprod[t_next_slice]
alpha_next[:, :-1] = torch.ones_like(alpha_next[:, :-1])
if noise_idx == 1:
alpha_next[:, -1:] = torch.ones_like(alpha_next[:, -1:])
x_pred = alpha_next.sqrt() * x_start + x_noise * (1 - alpha_next).sqrt()
x[:, -1:] = x_pred[:, -1:]
# ── Update state ──────────────────────────────────────
new_latent = x[:, -1:]
self._latent_buffer = x # keep full buffer (includes new frame)
# Trim to max context
if self._latent_buffer.shape[1] > self.max_frames:
trim = self._latent_buffer.shape[1] - self.max_frames
self._latent_buffer = self._latent_buffer[:, trim:]
self._action_buffer = self._action_buffer[:, trim:]
self._frame_idx += 1
recon = self._decode(new_latent.squeeze(1))
return {
"latent_state": new_latent.squeeze(0).squeeze(0).cpu().numpy(),
"recon": recon,
"reward": None,
"terminated": False,
}
# ── Helpers ────────────────────────────────────────────────
def _encode_action(self, action) -> torch.Tensor:
"""Convert action to (1, 1, 25) tensor."""
if isinstance(action, np.ndarray) and action.shape == (NUM_ACTION_KEYS,):
return torch.from_numpy(action).float().reshape(1, 1, -1).to(self.device)
elif isinstance(action, (int, np.integer)):
act = torch.zeros(1, 1, NUM_ACTION_KEYS, device=self.device)
act[0, 0, int(action)] = 1.0
return act
elif isinstance(action, torch.Tensor):
return action.float().reshape(1, 1, -1).to(self.device)
else:
raise ValueError(
f"Action must be int (action key index), np.ndarray(25,), "
f"or torch.Tensor. Got {type(action)}."
)
def _to_tensor(self, obs: np.ndarray) -> torch.Tensor:
"""Convert obs to (1, C, H, W) float [0,1]."""
img = np.asarray(obs, dtype=np.float32)
if img.ndim == 3 and img.shape[-1] in (1, 3, 4):
img = np.transpose(img, (2, 0, 1))
if img.max() > 1.0:
img = img / 255.0
return torch.from_numpy(img).unsqueeze(0).to(self.device)
def _decode(self, z: torch.Tensor) -> np.ndarray:
"""Decode latent (1, C, H, W) β†’ RGB (C, H, W) in [0,1]."""
with torch.no_grad():
with torch.autocast(self.device, dtype=torch.float16):
z_flat = rearrange(z, "b c h w -> b (h w) c")
decoded = self.vae.decode(z_flat / self.scaling_factor)
decoded = (decoded + 1) / 2
return decoded.squeeze(0).clamp(0, 1).float().cpu().numpy()