| from models.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline |
| from models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel |
| from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline |
| import torch |
| import torch.nn as nn |
|
|
| import json |
| import einops |
| import numpy as np |
| from huggingface_hub import snapshot_download |
| from transformers import AutoTokenizer, CLIPTextModelWithProjection |
|
|
| class Action_encoder2(nn.Module): |
| def __init__(self, action_dim, action_num, hidden_size, text_cond=True): |
| super().__init__() |
| self.action_dim = action_dim |
| self.action_num = action_num |
| self.hidden_size = hidden_size |
| self.text_cond = text_cond |
|
|
| input_dim = int(action_dim) |
| self.action_encode = nn.Sequential( |
| nn.Linear(input_dim, 1024), |
| nn.SiLU(), |
| nn.Linear(1024, 1024), |
| nn.SiLU(), |
| nn.Linear(1024, 1024) |
| ) |
| |
| nn.init.kaiming_normal_(self.action_encode[0].weight, mode='fan_in', nonlinearity='relu') |
| nn.init.kaiming_normal_(self.action_encode[2].weight, mode='fan_in', nonlinearity='relu') |
|
|
| def forward(self, action, texts=None, text_tokinizer=None, text_encoder=None, frame_level_cond=True,): |
| |
| B,T,D = action.shape |
| if not frame_level_cond: |
| action = einops.rearrange(action, 'b t d -> b 1 (t d)') |
| action = self.action_encode(action) |
|
|
| if texts is not None and self.text_cond: |
| |
| with torch.no_grad(): |
| inputs = text_tokinizer(texts, padding='max_length', return_tensors="pt", truncation=True).to(text_encoder.device) |
| outputs = text_encoder(**inputs) |
| hidden_text = outputs.text_embeds |
| hidden_text = einops.repeat(hidden_text, 'b c -> b 1 (n c)', n=2) |
| |
| action = action + hidden_text |
| return action |
|
|
|
|
| class CrtlWorld(nn.Module): |
| def __init__(self, config: dict): |
| super(CrtlWorld, self).__init__() |
|
|
| self.config = config |
| |
| model_local_path = snapshot_download( |
| repo_id=config["svd_model_path"], |
| repo_type="model" |
| ) |
|
|
| |
| self.pipeline = StableVideoDiffusionPipeline.from_pretrained( |
| model_local_path, |
| torch_dtype="auto" |
| ) |
|
|
|
|
| unet = UNetSpatioTemporalConditionModel() |
| unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False) |
| self.pipeline.unet = unet |
| |
| self.unet = self.pipeline.unet |
| self.vae = self.pipeline.vae |
| self.image_encoder = self.pipeline.image_encoder |
| self.scheduler = self.pipeline.scheduler |
|
|
| |
| self.vae.requires_grad_(False) |
| self.image_encoder.requires_grad_(False) |
| self.unet.requires_grad_(True) |
| self.unet.enable_gradient_checkpointing() |
|
|
| |
| |
| model_local_path = snapshot_download( |
| repo_id=config["clip_model_path"], |
| repo_type="model" |
| ) |
|
|
| self.text_encoder = CLIPTextModelWithProjection.from_pretrained( |
| model_local_path, |
| torch_dtype="auto" |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_local_path, use_fast=False) |
| self.text_encoder.requires_grad_(False) |
|
|
| |
| self.action_encoder = Action_encoder2(action_dim=config["action_dim"], action_num=int(config["num_history"]+config["num_frames"]), hidden_size=1024, text_cond=config["text_cond"]) |
|
|
| with open(f"{config["data_stat_path"]}", 'r') as f: |
| data_stat = json.load(f) |
| self.state_p01 = np.array(data_stat['state_01'])[None,:] |
| self.state_p99 = np.array(data_stat['state_99'])[None,:] |
| |
| def normalize_bound( |
| self, |
| data: np.ndarray, |
| clip_min: float = -1, |
| clip_max: float = 1, |
| eps: float = 1e-8, |
| ) -> np.ndarray: |
| ndata = 2 * (data - self.state_p01) / (self.state_p99 - self.state_p01 + eps) - 1 |
| return np.clip(ndata, clip_min, clip_max) |
| |
| def decode(self, latents: torch.Tensor): |
|
|
| bsz, frame_num = latents.shape[:2] |
| x = latents.flatten(0, 1) |
|
|
| decoded = [] |
| chunk_size = self.config["decode_chunk_size"] |
| for i in range(0, x.shape[0], chunk_size): |
| chunk = x[i:i + chunk_size] / self.pipeline.vae.config.scaling_factor |
| decode_kwargs = {"num_frames": chunk.shape[0]} |
| out = self.pipeline.vae.decode(chunk, **decode_kwargs).sample |
| decoded.append(out) |
|
|
| videos = torch.cat(decoded, dim=0) |
| videos = videos.reshape(bsz, frame_num, *videos.shape[1:]) |
| videos = ((videos / 2.0 + 0.5).clamp(0, 1)) |
| videos = videos.detach().float().cpu() |
|
|
| def encode(self, img: torch.Tensor): |
|
|
| x = img.unsqueeze(0) |
| x = x * 2 - 1 |
|
|
| vae = self.pipeline.vae |
| with torch.no_grad(): |
| latent = vae.encode(x).latent_dist.sample() |
| latent = latent * vae.config.scaling_factor |
|
|
| return latent.detach() |
| |
| def action_text_encode(self, action: torch.Tensor, text): |
|
|
| action_tensor = action.unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| if text is not None and self.config["text_cond"]: |
| text_token = self.action_encoder(action_tensor, [text], self.tokenizer, self.text_encoder) |
| else: |
| text_token = self.action_encoder(action_tensor) |
|
|
| return text_token.detach() |
|
|
| def get_latent_views(self, frames, current_latent, text_token): |
|
|
| his_cond = torch.cat(frames, dim=0).unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| _, latents = CtrlWorldDiffusionPipeline.__call__( |
| self.pipeline, |
| image=current_latent, |
| text=text_token, |
| width=self.config["width"], |
| height=int(self.config["height"] * 3), |
| num_frames=self.config["num_frames"], |
| history=his_cond, |
| num_inference_steps=self.config["num_inference_steps"], |
| decode_chunk_size=self.config["decode_chunk_size"], |
| max_guidance_scale=self.config["guidance_scale"], |
| fps=self.config["fps"], |
| motion_bucket_id=self.config["motion_bucket_id"], |
| mask=None, |
| output_type="latent", |
| return_dict=False, |
| frame_level_cond=True, |
| ) |
|
|
| return latents |