sai_wm / ctrl_world /src /world_model.py
SaiResearch's picture
Upload 6 files (#1)
4cd55fa
from models.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
from models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline
import torch
import torch.nn as nn
import json
import einops
import numpy as np
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, CLIPTextModelWithProjection
class Action_encoder2(nn.Module):
def __init__(self, action_dim, action_num, hidden_size, text_cond=True):
super().__init__()
self.action_dim = action_dim
self.action_num = action_num
self.hidden_size = hidden_size
self.text_cond = text_cond
input_dim = int(action_dim)
self.action_encode = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.SiLU(),
nn.Linear(1024, 1024),
nn.SiLU(),
nn.Linear(1024, 1024)
)
# kaiming initialization
nn.init.kaiming_normal_(self.action_encode[0].weight, mode='fan_in', nonlinearity='relu')
nn.init.kaiming_normal_(self.action_encode[2].weight, mode='fan_in', nonlinearity='relu')
def forward(self, action, texts=None, text_tokinizer=None, text_encoder=None, frame_level_cond=True,):
# action: (B, action_num, action_dim)
B,T,D = action.shape
if not frame_level_cond:
action = einops.rearrange(action, 'b t d -> b 1 (t d)')
action = self.action_encode(action)
if texts is not None and self.text_cond:
# with 50% probability, add text condition
with torch.no_grad():
inputs = text_tokinizer(texts, padding='max_length', return_tensors="pt", truncation=True).to(text_encoder.device)
outputs = text_encoder(**inputs)
hidden_text = outputs.text_embeds # (B, 512)
hidden_text = einops.repeat(hidden_text, 'b c -> b 1 (n c)', n=2) # (B, 1, 1024)
action = action + hidden_text # (B, T, hidden_size)
return action # (B, 1, hidden_size) or (B, T, hidden_size) if frame_level_cond
class CrtlWorld(nn.Module):
def __init__(self, config: dict):
super(CrtlWorld, self).__init__()
self.config = config
# load from pretrained stable video diffusion
model_local_path = snapshot_download(
repo_id=config["svd_model_path"], # e.g. "stabilityai/stable-video-diffusion-img2vid"
repo_type="model"
)
# Load pipeline from downloaded path
self.pipeline = StableVideoDiffusionPipeline.from_pretrained(
model_local_path,
torch_dtype="auto"
)
unet = UNetSpatioTemporalConditionModel()
unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False)
self.pipeline.unet = unet
self.unet = self.pipeline.unet
self.vae = self.pipeline.vae
self.image_encoder = self.pipeline.image_encoder
self.scheduler = self.pipeline.scheduler
# freeze vae, image_encoder, enable unet gradient ckpt
self.vae.requires_grad_(False)
self.image_encoder.requires_grad_(False)
self.unet.requires_grad_(True)
self.unet.enable_gradient_checkpointing()
# SVD is a img2video model, load a clip text encoder
model_local_path = snapshot_download(
repo_id=config["clip_model_path"], # e.g. "stabilityai/stable-video-diffusion-img2vid"
repo_type="model"
)
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
model_local_path,
torch_dtype="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_local_path, use_fast=False)
self.text_encoder.requires_grad_(False)
# initialize an action projector
self.action_encoder = Action_encoder2(action_dim=config["action_dim"], action_num=int(config["num_history"]+config["num_frames"]), hidden_size=1024, text_cond=config["text_cond"])
with open(f"{config["data_stat_path"]}", 'r') as f:
data_stat = json.load(f)
self.state_p01 = np.array(data_stat['state_01'])[None,:]
self.state_p99 = np.array(data_stat['state_99'])[None,:]
def normalize_bound(
self,
data: np.ndarray,
clip_min: float = -1,
clip_max: float = 1,
eps: float = 1e-8,
) -> np.ndarray:
ndata = 2 * (data - self.state_p01) / (self.state_p99 - self.state_p01 + eps) - 1
return np.clip(ndata, clip_min, clip_max)
def decode(self, latents: torch.Tensor):
bsz, frame_num = latents.shape[:2]
x = latents.flatten(0, 1)
decoded = []
chunk_size = self.config["decode_chunk_size"]
for i in range(0, x.shape[0], chunk_size):
chunk = x[i:i + chunk_size] / self.pipeline.vae.config.scaling_factor
decode_kwargs = {"num_frames": chunk.shape[0]}
out = self.pipeline.vae.decode(chunk, **decode_kwargs).sample
decoded.append(out)
videos = torch.cat(decoded, dim=0)
videos = videos.reshape(bsz, frame_num, *videos.shape[1:])
videos = ((videos / 2.0 + 0.5).clamp(0, 1))
videos = videos.detach().float().cpu()
def encode(self, img: torch.Tensor):
x = img.unsqueeze(0)
x = x * 2 - 1 # [0,1] β†’ [-1,1]
vae = self.pipeline.vae
with torch.no_grad():
latent = vae.encode(x).latent_dist.sample()
latent = latent * vae.config.scaling_factor
return latent.detach()
def action_text_encode(self, action: torch.Tensor, text):
action_tensor = action.unsqueeze(0)
# ── Encode action (+ optional text) ───────────────────
with torch.no_grad():
if text is not None and self.config["text_cond"]:
text_token = self.action_encoder(action_tensor, [text], self.tokenizer, self.text_encoder)
else:
text_token = self.action_encoder(action_tensor)
return text_token.detach()
def get_latent_views(self, frames, current_latent, text_token):
his_cond = torch.cat(frames, dim=0).unsqueeze(0) # (1, num_history, 4, stacked_H, W)
# ── Run CtrlWorldDiffusionPipeline ────────────────────
with torch.no_grad():
_, latents = CtrlWorldDiffusionPipeline.__call__(
self.pipeline,
image=current_latent,
text=text_token,
width=self.config["width"],
height=int(self.config["height"] * 3), # 3 views stacked
num_frames=self.config["num_frames"],
history=his_cond,
num_inference_steps=self.config["num_inference_steps"],
decode_chunk_size=self.config["decode_chunk_size"],
max_guidance_scale=self.config["guidance_scale"],
fps=self.config["fps"],
motion_bucket_id=self.config["motion_bucket_id"],
mask=None,
output_type="latent",
return_dict=False,
frame_level_cond=True,
)
return latents