from models.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline from models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline import torch import torch.nn as nn import json import einops import numpy as np from huggingface_hub import snapshot_download from transformers import AutoTokenizer, CLIPTextModelWithProjection class Action_encoder2(nn.Module): def __init__(self, action_dim, action_num, hidden_size, text_cond=True): super().__init__() self.action_dim = action_dim self.action_num = action_num self.hidden_size = hidden_size self.text_cond = text_cond input_dim = int(action_dim) self.action_encode = nn.Sequential( nn.Linear(input_dim, 1024), nn.SiLU(), nn.Linear(1024, 1024), nn.SiLU(), nn.Linear(1024, 1024) ) # kaiming initialization nn.init.kaiming_normal_(self.action_encode[0].weight, mode='fan_in', nonlinearity='relu') nn.init.kaiming_normal_(self.action_encode[2].weight, mode='fan_in', nonlinearity='relu') def forward(self, action, texts=None, text_tokinizer=None, text_encoder=None, frame_level_cond=True,): # action: (B, action_num, action_dim) B,T,D = action.shape if not frame_level_cond: action = einops.rearrange(action, 'b t d -> b 1 (t d)') action = self.action_encode(action) if texts is not None and self.text_cond: # with 50% probability, add text condition with torch.no_grad(): inputs = text_tokinizer(texts, padding='max_length', return_tensors="pt", truncation=True).to(text_encoder.device) outputs = text_encoder(**inputs) hidden_text = outputs.text_embeds # (B, 512) hidden_text = einops.repeat(hidden_text, 'b c -> b 1 (n c)', n=2) # (B, 1, 1024) action = action + hidden_text # (B, T, hidden_size) return action # (B, 1, hidden_size) or (B, T, hidden_size) if frame_level_cond class CrtlWorld(nn.Module): def __init__(self, config: dict): super(CrtlWorld, self).__init__() self.config = config # load from pretrained stable video diffusion model_local_path = snapshot_download( repo_id=config["svd_model_path"], # e.g. "stabilityai/stable-video-diffusion-img2vid" repo_type="model" ) # Load pipeline from downloaded path self.pipeline = StableVideoDiffusionPipeline.from_pretrained( model_local_path, torch_dtype="auto" ) unet = UNetSpatioTemporalConditionModel() unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False) self.pipeline.unet = unet self.unet = self.pipeline.unet self.vae = self.pipeline.vae self.image_encoder = self.pipeline.image_encoder self.scheduler = self.pipeline.scheduler # freeze vae, image_encoder, enable unet gradient ckpt self.vae.requires_grad_(False) self.image_encoder.requires_grad_(False) self.unet.requires_grad_(True) self.unet.enable_gradient_checkpointing() # SVD is a img2video model, load a clip text encoder model_local_path = snapshot_download( repo_id=config["clip_model_path"], # e.g. "stabilityai/stable-video-diffusion-img2vid" repo_type="model" ) self.text_encoder = CLIPTextModelWithProjection.from_pretrained( model_local_path, torch_dtype="auto" ) self.tokenizer = AutoTokenizer.from_pretrained(model_local_path, use_fast=False) self.text_encoder.requires_grad_(False) # initialize an action projector self.action_encoder = Action_encoder2(action_dim=config["action_dim"], action_num=int(config["num_history"]+config["num_frames"]), hidden_size=1024, text_cond=config["text_cond"]) with open(f"{config["data_stat_path"]}", 'r') as f: data_stat = json.load(f) self.state_p01 = np.array(data_stat['state_01'])[None,:] self.state_p99 = np.array(data_stat['state_99'])[None,:] def normalize_bound( self, data: np.ndarray, clip_min: float = -1, clip_max: float = 1, eps: float = 1e-8, ) -> np.ndarray: ndata = 2 * (data - self.state_p01) / (self.state_p99 - self.state_p01 + eps) - 1 return np.clip(ndata, clip_min, clip_max) def decode(self, latents: torch.Tensor): bsz, frame_num = latents.shape[:2] x = latents.flatten(0, 1) decoded = [] chunk_size = self.config["decode_chunk_size"] for i in range(0, x.shape[0], chunk_size): chunk = x[i:i + chunk_size] / self.pipeline.vae.config.scaling_factor decode_kwargs = {"num_frames": chunk.shape[0]} out = self.pipeline.vae.decode(chunk, **decode_kwargs).sample decoded.append(out) videos = torch.cat(decoded, dim=0) videos = videos.reshape(bsz, frame_num, *videos.shape[1:]) videos = ((videos / 2.0 + 0.5).clamp(0, 1)) videos = videos.detach().float().cpu() def encode(self, img: torch.Tensor): x = img.unsqueeze(0) x = x * 2 - 1 # [0,1] → [-1,1] vae = self.pipeline.vae with torch.no_grad(): latent = vae.encode(x).latent_dist.sample() latent = latent * vae.config.scaling_factor return latent.detach() def action_text_encode(self, action: torch.Tensor, text): action_tensor = action.unsqueeze(0) # ── Encode action (+ optional text) ─────────────────── with torch.no_grad(): if text is not None and self.config["text_cond"]: text_token = self.action_encoder(action_tensor, [text], self.tokenizer, self.text_encoder) else: text_token = self.action_encoder(action_tensor) return text_token.detach() def get_latent_views(self, frames, current_latent, text_token): his_cond = torch.cat(frames, dim=0).unsqueeze(0) # (1, num_history, 4, stacked_H, W) # ── Run CtrlWorldDiffusionPipeline ──────────────────── with torch.no_grad(): _, latents = CtrlWorldDiffusionPipeline.__call__( self.pipeline, image=current_latent, text=text_token, width=self.config["width"], height=int(self.config["height"] * 3), # 3 views stacked num_frames=self.config["num_frames"], history=his_cond, num_inference_steps=self.config["num_inference_steps"], decode_chunk_size=self.config["decode_chunk_size"], max_guidance_scale=self.config["guidance_scale"], fps=self.config["fps"], motion_bucket_id=self.config["motion_bucket_id"], mask=None, output_type="latent", return_dict=False, frame_level_cond=True, ) return latents