Spaces:
Runtime error
Runtime error
| import cv2 | |
| import torch | |
| import numpy as np | |
| import einops | |
| import skimage | |
| import time | |
| from genie.st_mask_git import STMaskGIT | |
| from genie.st_mar import STMAR | |
| from datasets.utils import get_image_encoder | |
| from data import DATA_FREQ_TABLE | |
| from train_diffusion import SVD_SCALE | |
| from typing import Optional, Tuple, Callable, Dict | |
| class Simulator: | |
| def set_initial_state(self, state): | |
| """ | |
| the initial state of the simulated scene | |
| e.g. | |
| 1. in robomimic, it's the scene state vector | |
| 2. in genie, it's the initial frames to prompt the model | |
| """ | |
| raise NotImplementedError | |
| def step(self, action): | |
| raise NotImplementedError | |
| def reset(self): | |
| raise NotImplementedError | |
| def close(self): | |
| raise NotImplementedError | |
| def dt(self): | |
| raise NotImplementedError | |
| class PhysicsSimulator(Simulator): | |
| def __init__(self): | |
| super().__init__() | |
| # physics engine should be able to update dt | |
| def set_dt(self, dt): | |
| raise NotImplementedError | |
| # physics engine should be able to get scene state | |
| # e.g., robot joint positions, object positions, etc. | |
| def get_raw_state(self, port: Optional[str] = None): | |
| raise NotImplementedError | |
| def action_dimension(self): | |
| raise NotImplementedError | |
| class LearnedSimulator(Simulator): | |
| def __init__(self): | |
| super().__init__() | |
| # data replayed respect physics, so we inherit from PhysicsSimulator | |
| # it can be considered as a special case of PhysicsSimulator | |
| class ReplaySimulator(PhysicsSimulator): | |
| def __init__(self, | |
| frames, | |
| prompt_horizon: int = 0, | |
| dt: Optional[float] = None | |
| ): | |
| super().__init__() | |
| self.frames = frames | |
| self.frame_idx = prompt_horizon | |
| assert self.frame_idx < len(self.frames) | |
| self._dt = dt | |
| self.prompt_horizon = prompt_horizon | |
| def __len__(self): | |
| return len(self.frames) - self.prompt_horizon | |
| def step(self, action): | |
| frame = self.frames[self.frame_idx] | |
| assert self.frame_idx < len(self.frames) | |
| self.frame_idx = self.frame_idx + 1 | |
| return { | |
| 'pred_next_frame': frame | |
| } | |
| def reset(self): # return current frame = last frame of prompt | |
| self.frame_idx = self.prompt_horizon | |
| return self.prompt()[-1] | |
| def prompt(self): | |
| return self.frames[:self.prompt_horizon] | |
| def dt(self): | |
| return self._dt | |
| class GenieSimulator(LearnedSimulator): | |
| average_delta_psnr_over = 5 | |
| def __init__(self, | |
| # image preprocessing | |
| max_image_resolution: int = 1024, | |
| resize_image: bool = True, | |
| resize_image_resolution: int = 256, | |
| # tokenizer setting | |
| image_encoder_type: str = "temporalvae", | |
| image_encoder_ckpt: str = "stabilityai/stable-video-diffusion-img2vid", | |
| quantize: bool = False, | |
| quantization_slice_size: int = 16, | |
| # dynamics backbone setting | |
| backbone_type: str = "stmar", | |
| backbone_ckpt: str = "data/mar_ckpt/robomimic", | |
| prompt_horizon: int = 11, | |
| inference_iterations: Optional[int] = None, | |
| sampling_temperature: float = 0.0, | |
| action_stride: Optional[int] = None, | |
| domain: str = "robomimic", | |
| genie_frequency: int = 2, | |
| # misc | |
| measure_step_time: bool = False, | |
| compute_psnr: bool = False, | |
| compute_delta_psnr: bool = False, # act as a signal for controlability | |
| gaussian_action_perturbation_scale: Optional[float] = None, | |
| device: str = 'cuda', | |
| physics_simulator: Optional[PhysicsSimulator] = None, | |
| physics_simulator_teacher_force: Optional[int] = None, | |
| post_processor: Optional[Callable] = None, # on the predicted image, e.g., add action | |
| allow_external_prompt: bool = False | |
| ): | |
| super().__init__() | |
| assert quantize == (image_encoder_type == "magvit"), \ | |
| "Currently quantization if and only if magvit is the image encoder." | |
| assert image_encoder_type in ["magvit", "temporalvae"], \ | |
| "Image encoder type must be either 'magvit' or 'temporalvae'." | |
| assert not quantize or image_encoder_type == "magvit", \ | |
| "If quantize is enabled, image encoder type must be 'magvit'." | |
| assert backbone_type in ["stmaskgit", "stmar"], \ | |
| "Backbone type must be either 'stmaskgit' or 'stmar'." | |
| if physics_simulator is None: | |
| assert physics_simulator_teacher_force is None, \ | |
| "Physics simulator teacher force is only available when physics simulator is provided." | |
| assert compute_psnr is False, \ | |
| "PSNR computation is only available when physics simulator is provided." | |
| assert compute_delta_psnr is False, \ | |
| "Delta PSNR computation is only available when physics simulator is provided." | |
| if action_stride is None: | |
| action_stride = DATA_FREQ_TABLE[domain] // genie_frequency | |
| if compute_delta_psnr: | |
| compute_psnr = True # to compute delta psnr, psnr must be computed | |
| if inference_iterations is None: | |
| if backbone_type == "stmaskgit": | |
| inference_iterations = 2 | |
| elif backbone_type == "stmar": | |
| inference_iterations = 2 | |
| # misc | |
| self.device = torch.device(device) | |
| self.measure_step_time = measure_step_time | |
| self.compute_psnr = compute_psnr | |
| self.compute_delta_psnr = compute_delta_psnr | |
| self.allow_external_prompt = allow_external_prompt | |
| # image preprocessing | |
| self.max_image_resolution = max_image_resolution | |
| self.resize_image = resize_image | |
| self.resize_image_resolution = resize_image_resolution | |
| # load image encoder | |
| self.image_encoding_dtype = torch.bfloat16 | |
| self.quantize = quantize | |
| self.quant_slice_size = quantization_slice_size | |
| self.image_encoder_type = image_encoder_type | |
| self.image_encoder = get_image_encoder( | |
| image_encoder_type, | |
| image_encoder_ckpt | |
| ).to(device=self.device, dtype=self.image_encoding_dtype).eval() | |
| # load STMaskGIT model (STMAR is inherited from STMaskGIT) | |
| self.prompt_horizon = prompt_horizon | |
| self.domain = domain | |
| self.genie_frequency = genie_frequency | |
| self.inference_iterations = inference_iterations | |
| self.sampling_temperature = sampling_temperature | |
| self.action_stride = action_stride | |
| self.gauss_act_perturb_scale = gaussian_action_perturbation_scale | |
| self.backbone_type = backbone_type | |
| if backbone_type == "stmaskgit": | |
| self.backbone = STMaskGIT.from_pretrained(backbone_ckpt) | |
| else: | |
| self.backbone = STMAR.from_pretrained(backbone_ckpt) | |
| self.backbone = self.backbone.to(device=self.device).eval() | |
| self.post_processor = post_processor | |
| # load physics simulator if available | |
| # the phys sim to get ground truth image, | |
| # assume the phys sim has aligned prompt frames | |
| self.gt_phys_sim = physics_simulator | |
| self.gt_teacher_force = physics_simulator_teacher_force | |
| # history buffer, i.e., the input to the model | |
| self.cached_actions = None # (prompt_horizon, action_stride, A) | |
| self.cached_latent_frames = None # (prompt_horizon, ...) | |
| self.init_prompt = None # (prompt_frames, prompt_actions) | |
| self.step_count = 0 | |
| # report model size | |
| print( | |
| "================ Model Size Report ================\n" | |
| f" encoder size: {sum(p.numel() for p in self.image_encoder.parameters()) / 1e6:.3f}M \n" | |
| f" backbone size: {sum(p.numel() for p in self.backbone.parameters()) / 1e6:.3f}M\n" | |
| "===================================================" | |
| ) | |
| def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]): | |
| if not self.allow_external_prompt and self.gt_phys_sim is not None: | |
| raise NotImplementedError("Initial state is set by the physics simulator.") | |
| self.init_prompt = state | |
| def step(self, action: np.ndarray) -> Dict: | |
| # action: (action_stride, A) OR (A,) | |
| # return: (H, W, 3) | |
| assert self.cached_latent_frames is not None and self.cached_actions is not None, \ | |
| "Model is not prompted yet. Please call `set_initial_state` first." | |
| if action.ndim == 1: | |
| action = np.tile(action, (self.action_stride, 1)) | |
| # perturb action | |
| if self.gauss_act_perturb_scale is not None: | |
| action = np.random.normal(action, self.gauss_act_perturb_scale) | |
| # encoding | |
| input_latent_states = torch.cat([ | |
| self.cached_latent_frames, | |
| torch.zeros_like(self.cached_latent_frames[[0]]), | |
| ]).unsqueeze(0).to(torch.float32) | |
| input_latent_states = input_latent_states[:, :self.prompt_horizon + 1] | |
| # dtype conversion and mask token | |
| if self.backbone_type == "stmaskgit": | |
| input_latent_states = input_latent_states.long() | |
| input_latent_states[:, -1] = self.backbone.mask_token_id | |
| elif self.backbone_type == "stmar": | |
| input_latent_states[:, -1] = self.backbone.mask_token | |
| # dynamics rollout | |
| action = torch.from_numpy(action).to(device=self.device) | |
| input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A) | |
| self.cached_actions, | |
| action.unsqueeze(0), | |
| action.unsqueeze(0) # the last action is not used, but we need a_{t-1}, s_{t-1} to predict s_t | |
| ]).view(1, -1, action.shape[-1]).to(torch.float32) # + 1 | |
| input_actions = input_actions[:, :self.prompt_horizon + 1] | |
| if self.measure_step_time: | |
| start_time = time.time() | |
| pred_next_latent_state = self.backbone.maskgit_generate( | |
| input_latent_states, | |
| out_t=input_latent_states.shape[1] - 1, | |
| maskgit_steps=self.inference_iterations, | |
| temperature=self.sampling_temperature, | |
| action_ids=input_actions, | |
| domain=[self.domain] | |
| )[0].squeeze(0) | |
| # decoding | |
| pred_next_frame = self._decode_image(pred_next_latent_state) | |
| # timing | |
| if self.measure_step_time: | |
| end_time = time.time() | |
| step_result = {'pred_next_frame': pred_next_frame,} | |
| if self.measure_step_time: | |
| step_result['step_time'] = end_time - start_time | |
| # physics simulation | |
| if self.gt_phys_sim is not None: | |
| for a in action.cpu().numpy(): | |
| gt_result = self.gt_phys_sim.step(a) | |
| gt_next_frame = cv2.resize(gt_result['pred_next_frame'], pred_next_frame.shape[:2]) | |
| step_result['gt_next_frame'] = gt_next_frame | |
| gt_result.pop('pred_next_frame') | |
| step_result.update(gt_result) | |
| # gt state observation | |
| try: | |
| raw_state = self.gt_phys_sim.get_raw_state() | |
| step_result.update(raw_state) | |
| except NotImplementedError: | |
| pass | |
| # compute PSNR against ground truth | |
| if self.compute_psnr: | |
| psnr = skimage.metrics.peak_signal_noise_ratio( | |
| image_true=gt_next_frame / 255., | |
| image_test=pred_next_frame / 255., | |
| data_range=1.0 | |
| ) | |
| step_result['psnr'] = psnr | |
| # controlability metric | |
| if self.compute_delta_psnr: | |
| delta_psnr = 0.0 | |
| for _ in range(self.average_delta_psnr_over): | |
| # re-mask the input latent states for masked prediction | |
| if self.backbone_type == "stmaskgit": | |
| input_latent_states = input_latent_states.long() | |
| input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token_id | |
| elif self.backbone_type == "stmar": | |
| input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token | |
| # sample random action from N(0, 1) | |
| random_input_actions = torch.randn_like(input_actions) | |
| random_pred_next_latent_state = self.backbone.maskgit_generate( | |
| input_latent_states, | |
| out_t=self.prompt_horizon, | |
| maskgit_steps=self.inference_iterations, | |
| temperature=self.sampling_temperature, | |
| action_ids=random_input_actions, | |
| domain=[self.domain], | |
| skip_normalization=True | |
| )[0].squeeze(0) | |
| random_pred_next_frame = self._decode_image(random_pred_next_latent_state) | |
| this_delta_psnr = step_result['psnr'] - skimage.metrics.peak_signal_noise_ratio( | |
| image_true=gt_next_frame / 255., | |
| image_test=random_pred_next_frame / 255., | |
| data_range=1.0 | |
| ) | |
| delta_psnr += this_delta_psnr / self.average_delta_psnr_over | |
| step_result['delta_psnr'] = delta_psnr | |
| if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0: | |
| pred_next_latent_state = self._encode_image(gt_next_frame) | |
| # update history buffer | |
| self.cached_latent_frames = torch.cat([ | |
| self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0) | |
| ]) | |
| self.cached_actions = torch.cat([ | |
| self.cached_actions[1:], action.unsqueeze(0) | |
| ]) | |
| # post processing | |
| if self.post_processor is not None: | |
| pred_next_frame = self.post_processor(pred_next_frame, action) | |
| self.step_count += 1 | |
| return step_result | |
| def _encode_image(self, image: np.ndarray) -> torch.Tensor: | |
| # (H, W, 3) | |
| image = torch.from_numpy( | |
| self._normalize_image(image).transpose(2, 0, 1) | |
| ).to(device=self.device, dtype=self.image_encoding_dtype | |
| ).unsqueeze(0) | |
| H, W = image.shape[-2:] | |
| if self.quantize: | |
| H //= self.quant_slice_size | |
| W //= self.quant_slice_size | |
| _, _, indices, _ = self.image_encoder.encode(image, flip=True) | |
| indices = einops.rearrange(indices, "(h w) -> h w", h=H, w=W) | |
| indices = indices.to(torch.int32) | |
| return indices | |
| else: | |
| if self.image_encoder_type == "magvit": | |
| latent = self.image_encoder.encode_without_quantize(image) | |
| elif self.image_encoder_type == "temporalvae": | |
| latent_dist = self.image_encoder.encode(image).latent_dist | |
| latent = latent_dist.mean | |
| latent *= SVD_SCALE | |
| latent = einops.rearrange(latent, "b c h w -> b h w c") | |
| else: | |
| pass | |
| latent = latent.squeeze(0).to(torch.float32) | |
| return latent | |
| def _decode_image(self, latent: torch.Tensor) -> np.ndarray: | |
| # latent can be either quantized indices or raw latent | |
| # return (H, W, 3) | |
| latent = latent.to(device=self.device).unsqueeze(0) | |
| if self.quantize: | |
| latent = self.image_encoder.quantize.get_codebook_entry( | |
| einops.rearrange(latent, "b h w -> b (h w)"), | |
| bhwc=(*latent.shape, self.image_encoder.quantize.codebook_dim) | |
| ).flip(1) | |
| latent = latent.to(device=self.device, dtype=self.image_encoding_dtype) | |
| if self.image_encoder_type == "magvit": | |
| decoded_image = self.image_encoder.decode(latent) | |
| elif self.image_encoder_type == "temporalvae": | |
| latent = einops.rearrange(latent, "b h w c -> b c h w") | |
| latent /= SVD_SCALE | |
| # HACK: clip for less visual artifacts | |
| latent = torch.clamp(latent, -25, 25) | |
| decoded_image = self.image_encoder.decode(latent, num_frames=1).sample | |
| decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy() | |
| decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0) | |
| return decoded_image | |
| def _normalize_image(self, image: np.ndarray) -> np.ndarray: | |
| # (H, W, 3) normalized to [-1, 1] | |
| # if `resize`, resize the shorter side to `resized_res` | |
| # and then do a center crop | |
| image = np.asarray(image, dtype=np.float32) | |
| image /= 255. | |
| H, W = image.shape[:2] | |
| # resize if asked | |
| if self.resize_image: | |
| resized_res = self.resize_image_resolution | |
| if H < W: | |
| Hnew, Wnew = resized_res, int(resized_res * W / H) | |
| else: | |
| Hnew, Wnew = int(resized_res * H / W), resized_res | |
| image = cv2.resize(image, (Wnew, Hnew)) | |
| # center crop | |
| H, W = image.shape[:2] | |
| Hstart = (H - resized_res) // 2 | |
| Wstart = (W - resized_res) // 2 | |
| image = image[Hstart:Hstart + resized_res, Wstart:Wstart + resized_res] | |
| # resize if resolution is too large | |
| elif H > self.max_image_resolution or W > self.max_image_resolution: | |
| if H < W: | |
| Hnew, Wnew = int(self.max_image_resolution * H / W), self.max_image_resolution | |
| else: | |
| Hnew, Wnew = self.max_image_resolution, int(self.max_image_resolution * W / H) | |
| image = cv2.resize(image, (Wnew, Hnew)) | |
| image = (image * 2 - 1.) | |
| return image | |
| def _unnormalize_image(self, image: np.ndarray) -> np.ndarray: | |
| # (H, W, 3) from [-1, 1] to [0, 255] | |
| # NOTE: clip happens here | |
| image = (image + 1.) * 127.5 | |
| image = np.clip(image, 0, 255).astype(np.uint8) | |
| return image | |
| def reset(self) -> np.ndarray: | |
| # if ground truth physics simulator is provided, | |
| # return the the side-by-side concatenated image | |
| # get the initial prompt from the physics simulator if not yet set | |
| if not self.allow_external_prompt and self.gt_phys_sim is not None: | |
| image_prompt = np.tile( | |
| self.gt_phys_sim.reset(), (self.prompt_horizon, 1, 1, 1) | |
| ).astype(np.uint8) | |
| action_prompt = np.zeros( | |
| (self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension) | |
| ).astype(np.float32) | |
| else: | |
| assert self.init_prompt is not None, "Initial state is not set." | |
| image_prompt, action_prompt = self.init_prompt | |
| # standardize the image | |
| image_prompt = [self._unnormalize_image(self._normalize_image(frame)) for frame in image_prompt] | |
| current_image = image_prompt[-1] | |
| action_prompt = torch.from_numpy(action_prompt).to(device=self.device) | |
| self.cached_actions = action_prompt | |
| # convert to latent | |
| self.cached_latent_frames = torch.stack([ | |
| self._encode_image(frame) for frame in image_prompt | |
| ], axis=0) | |
| if self.resize_image: | |
| current_image = cv2.resize(current_image, | |
| (self.resize_image_resolution, self.resize_image_resolution)) | |
| if self.gt_phys_sim is not None: | |
| current_image = np.concatenate([current_image, current_image], axis=1) | |
| self.step_count = 0 | |
| return current_image | |
| def close(self): | |
| if self.gt_phys_sim is not None: | |
| try: | |
| self.gt_phys_sim.close() | |
| except NotImplementedError: | |
| pass | |
| def dt(self): | |
| return 1.0 / self.genie_frequency | |