| |
| import argparse |
| import os |
| import sys |
| import time |
| from functools import partial |
| from PIL import Image |
| from diffusers.video_processor import VideoProcessor |
| from diffusers.utils import export_to_video |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from diffusers.pipelines.wan.pipeline_wan import prompt_clean |
| from einops import rearrange |
| from tqdm import tqdm |
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
| from configs import VA_CONFIGS |
| from distributed.fsdp import shard_model |
| from distributed.util import _configure_model, init_distributed |
| from modules.utils import ( |
| WanVAEStreamingWrapper, |
| load_text_encoder, |
| load_tokenizer, |
| load_transformer, |
| load_vae, |
| ) |
| from utils import ( |
| FlowMatchScheduler, |
| data_seq_to_patch, |
| get_mesh_id, |
| init_logger, |
| logger, |
| run_async_server_mode, |
| save_async, |
| ) |
|
|
|
|
| class VA_Server: |
|
|
| def __init__(self, job_config): |
| self.cache_name = 'pos' |
| self.job_config = job_config |
| self.save_root = job_config.save_root |
| self.dtype = job_config.param_dtype |
| self.device = torch.device(f"cuda:{job_config.local_rank}") |
|
|
| self.scheduler = FlowMatchScheduler(shift=self.job_config.snr_shift, |
| sigma_min=0.0, |
| extra_one_step=True) |
| self.action_scheduler = FlowMatchScheduler( |
| shift=self.job_config.action_snr_shift, |
| sigma_min=0.0, |
| extra_one_step=True) |
| self.scheduler.set_timesteps(1000, training=True) |
| self.action_scheduler.set_timesteps(1000, training=True) |
|
|
| self.vae = load_vae( |
| os.path.join(job_config.wan22_pretrained_model_name_or_path, |
| 'vae'), |
| torch_dtype=self.dtype, |
| torch_device=self.device, |
| ) |
| self.streaming_vae = WanVAEStreamingWrapper(self.vae) |
|
|
| self.tokenizer = load_tokenizer( |
| os.path.join(job_config.wan22_pretrained_model_name_or_path, |
| 'tokenizer'), ) |
|
|
| self.text_encoder = load_text_encoder( |
| os.path.join(job_config.wan22_pretrained_model_name_or_path, |
| 'text_encoder'), |
| torch_dtype=self.dtype, |
| torch_device=self.device, |
| ) |
|
|
| self.transformer = load_transformer( |
| os.path.join(job_config.wan22_pretrained_model_name_or_path, |
| 'transformer'), |
| torch_dtype=self.dtype, |
| torch_device=self.device, |
| ) |
| shard_fn = partial(shard_model, device_id=job_config.local_rank) |
| self.transformer = _configure_model(model=self.transformer, |
| shard_fn=shard_fn, |
| param_dtype=self.dtype, |
| device=self.device) |
|
|
| self.env_type = job_config.env_type |
| self.streaming_vae_half = None |
| if self.env_type == 'robotwin_tshape': |
| vae_half = load_vae( |
| os.path.join(job_config.wan22_pretrained_model_name_or_path, |
| 'vae'), |
| torch_dtype=self.dtype, |
| torch_device=self.device, |
| ) |
| self.streaming_vae_half = WanVAEStreamingWrapper(vae_half) |
|
|
| def _get_t5_prompt_embeds( |
| self, |
| prompt=None, |
| num_videos_per_prompt=1, |
| max_sequence_length=512, |
| device=None, |
| dtype=None, |
| ): |
| device = device or self.device |
| dtype = dtype or self.dtype |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| prompt = [prompt_clean(u) for u in prompt] |
| batch_size = len(prompt) |
|
|
| text_inputs = self.tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| add_special_tokens=True, |
| return_attention_mask=True, |
| return_tensors="pt", |
| ) |
| text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask |
| seq_lens = mask.gt(0).sum(dim=1).long() |
|
|
| prompt_embeds = self.text_encoder(text_input_ids.to(device), |
| mask.to(device)).last_hidden_state |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
| prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] |
| prompt_embeds = torch.stack([ |
| torch.cat( |
| [u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) |
| for u in prompt_embeds |
| ], |
| dim=0) |
|
|
| |
| _, seq_len, _ = prompt_embeds.shape |
| prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, |
| seq_len, -1) |
|
|
| return prompt_embeds |
|
|
| def encode_prompt( |
| self, |
| prompt, |
| negative_prompt=None, |
| do_classifier_free_guidance=True, |
| num_videos_per_prompt=1, |
| prompt_embeds=None, |
| negative_prompt_embeds=None, |
| max_sequence_length=226, |
| device=None, |
| dtype=None, |
| ): |
| r""" |
| TODO |
| """ |
| device = device or self.device |
| dtype = dtype or self.dtype |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| if prompt is not None: |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| if prompt_embeds is None: |
| prompt_embeds = self._get_t5_prompt_embeds( |
| prompt=prompt, |
| num_videos_per_prompt=num_videos_per_prompt, |
| max_sequence_length=max_sequence_length, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| if do_classifier_free_guidance and negative_prompt_embeds is None: |
| negative_prompt = negative_prompt or "" |
| negative_prompt = batch_size * [negative_prompt] if isinstance( |
| negative_prompt, str) else negative_prompt |
|
|
| if prompt is not None and type(prompt) is not type( |
| negative_prompt): |
| raise TypeError( |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
| f" {type(prompt)}.") |
| elif batch_size != len(negative_prompt): |
| raise ValueError( |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
| " the batch size of `prompt`.") |
|
|
| negative_prompt_embeds = self._get_t5_prompt_embeds( |
| prompt=negative_prompt, |
| num_videos_per_prompt=num_videos_per_prompt, |
| max_sequence_length=max_sequence_length, |
| device=device, |
| dtype=dtype, |
| ) |
| return prompt_embeds, negative_prompt_embeds |
|
|
| def normalize_latents( |
| self, |
| latents: torch.Tensor, |
| latents_mean: torch.Tensor, |
| latents_std: torch.Tensor, |
| ) -> torch.Tensor: |
| latents_mean = latents_mean.view(1, -1, 1, 1, |
| 1).to(device=latents.device) |
| latents_std = latents_std.view(1, -1, 1, 1, |
| 1).to(device=latents.device) |
| latents = ((latents.float() - latents_mean) * latents_std).to(latents) |
| return latents |
|
|
| def preprocess_action(self, action): |
| action_model_input = torch.from_numpy(action) |
| CA, FA, HA = action_model_input.shape |
| action_model_input_paded = F.pad(action_model_input, |
| [0, 0, 0, 0, 0, 1], |
| mode='constant', |
| value=0) |
|
|
| action_model_input = action_model_input_paded[ |
| self.job_config.inverse_used_action_channel_ids] |
|
|
| if self.action_norm_method == 'quantiles': |
| action_model_input = (action_model_input - self.actions_q01) / ( |
| self.actions_q99 - self.actions_q01 + 1e-6) * 2. - 1. |
| else: |
| raise NotImplementedError |
| return action_model_input.unsqueeze(0).unsqueeze(-1) |
|
|
| def postprocess_action(self, action): |
| action = action.cpu() |
|
|
| action = action[0, ..., 0] |
| if self.action_norm_method == 'quantiles': |
| action = (action + 1) / 2 * (self.actions_q99 - self.actions_q01 + |
| 1e-6) + self.actions_q01 |
| else: |
| raise NotImplementedError |
| action = action.squeeze(0).detach().cpu().numpy() |
| return action[self.job_config.used_action_channel_ids] |
| |
| def _repeat_input_for_cfg(self, input_dict): |
| if self.use_cfg: |
| input_dict['noisy_latents'] = input_dict['noisy_latents'].repeat(2, 1, 1, 1, 1) |
| input_dict['text_emb'] = torch.cat([self.prompt_embeds.to(self.dtype).clone(), self.negative_prompt_embeds.to(self.dtype).clone()], dim=0) |
| input_dict['grid_id'] = input_dict['grid_id'][None].repeat(2, 1, 1) |
| input_dict['timesteps'] = input_dict['timesteps'][None].repeat(2, 1) |
| else: |
| input_dict['grid_id'] = input_dict['grid_id'][None] |
| input_dict['timesteps'] = input_dict['timesteps'][None] |
| return input_dict |
|
|
| def _prepare_latent_input(self, |
| latent_model_input, |
| action_model_input, |
| latent_t=0, |
| action_t=0, |
| latent_cond=None, |
| action_cond=None, |
| frame_st_id=0, |
| patch_size=(1, 2, 2)): |
| logger.info(f"FRAME START ID: {frame_st_id}") |
| input_dict = dict() |
| if latent_model_input is not None: |
| input_dict['latent_res_lst'] = { |
| 'noisy_latents': |
| latent_model_input, |
| 'timesteps': |
| torch.ones([latent_model_input.shape[2]], |
| dtype=torch.float32, |
| device=self.device) * latent_t, |
| 'grid_id': |
| get_mesh_id(latent_model_input.shape[-3] // patch_size[0], |
| latent_model_input.shape[-2] // patch_size[1], |
| latent_model_input.shape[-1] // patch_size[2], 0, |
| 1, frame_st_id).to(self.device), |
| 'text_emb': |
| self.prompt_embeds.to(self.dtype).clone(), |
| } |
| if latent_cond is not None: |
| input_dict['latent_res_lst'][ |
| 'noisy_latents'][:, :, 0:1] = latent_cond[:, :, 0:1] |
| input_dict['latent_res_lst']['timesteps'][0:1] *= 0 |
|
|
| if action_model_input is not None: |
| input_dict['action_res_lst'] = { |
| 'noisy_latents': |
| action_model_input, |
| 'timesteps': |
| torch.ones([action_model_input.shape[2]], |
| dtype=torch.float32, |
| device=self.device) * action_t, |
| 'grid_id': |
| get_mesh_id(action_model_input.shape[-3], |
| action_model_input.shape[-2], |
| action_model_input.shape[-1], |
| 1, |
| 1, |
| frame_st_id, |
| action=True).to(self.device), |
| 'text_emb': |
| self.prompt_embeds.to(self.dtype).clone(), |
| } |
|
|
| if action_cond is not None: |
| input_dict['action_res_lst'][ |
| 'noisy_latents'][:, :, 0:1] = action_cond[:, :, 0:1] |
| input_dict['action_res_lst']['timesteps'][0:1] *= 0 |
| input_dict['action_res_lst']['noisy_latents'][:, ~self. |
| action_mask] *= 0 |
| return input_dict |
|
|
| def _encode_obs(self, obs): |
| images = obs['obs'] |
| if not isinstance(images, list): |
| images = [images] |
| if len(images) < 1: |
| return None |
| videos = [] |
| for k_i, k in enumerate(self.job_config.obs_cam_keys): |
| if self.env_type == 'robotwin_tshape': |
| if k_i == 0: |
| height_i, width_i = self.height, self.width |
| else: |
| height_i, width_i = self.height // 2, self.width // 2 |
| else: |
| height_i, width_i = self.height, self.width |
|
|
| history_video_k = torch.from_numpy( |
| np.stack([each[k] |
| for each in images])).float().permute(3, 0, 1, 2) |
| history_video_k = F.interpolate(history_video_k, |
| size=(height_i, width_i), |
| mode='bilinear', |
| align_corners=False).unsqueeze(0) |
| videos.append(history_video_k) |
|
|
| if self.env_type == 'robotwin_tshape': |
| videos_high = videos[0] / 255.0 * 2.0 - 1.0 |
| videos_left_and_right = torch.cat(videos[1:], |
| dim=0) / 255.0 * 2.0 - 1.0 |
| enc_out_high = self.streaming_vae.encode_chunk( |
| videos_high.to(self.device).to(self.dtype)) |
| enc_out_left_and_right = self.streaming_vae_half.encode_chunk( |
| videos_left_and_right.to(self.device).to(self.dtype)) |
| enc_out = torch.cat([ |
| torch.cat(enc_out_left_and_right.split(1, dim=0), dim=-1), |
| enc_out_high |
| ], |
| dim=-2) |
| else: |
| videos = torch.cat(videos, dim=0) / 255.0 * 2.0 - 1.0 |
| videos_chunk = videos.to(self.device).to(self.dtype) |
| enc_out = self.streaming_vae.encode_chunk(videos_chunk) |
|
|
| mu, logvar = torch.chunk(enc_out, 2, dim=1) |
| latents_mean = torch.tensor(self.vae.config.latents_mean).to(mu.device) |
| latents_std = torch.tensor(self.vae.config.latents_std).to(mu.device) |
| mu_norm = self.normalize_latents(mu, latents_mean, 1.0 / latents_std) |
| video_latent = torch.cat(mu_norm.split(1, dim=0), dim=-1) |
| return video_latent |
|
|
| def _reset(self, prompt=None): |
| logger.info('Reset.') |
| self.use_cfg = (self.job_config.guidance_scale > 1) or (self.job_config.action_guidance_scale > 1) |
| |
| self.frame_st_id = 0 |
| self.init_latent = None |
| |
| self.transformer.clear_cache(self.cache_name) |
| self.streaming_vae.clear_cache() |
|
|
| self.action_per_frame = self.job_config.action_per_frame |
| self.height, self.width = self.job_config.height, self.job_config.width |
|
|
| if self.env_type == 'robotwin_tshape': |
| self.latent_height, self.latent_width = ( |
| (self.height // 16) * 3) // 2, self.width // 16 |
| self.streaming_vae_half.clear_cache() |
| else: |
| self.latent_height, self.latent_width = self.height // 16, self.width // 16 * len( |
| self.job_config.obs_cam_keys) |
|
|
| patch_size = self.job_config.patch_size |
| latent_token_per_chunk = (self.job_config.frame_chunk_size * |
| self.latent_height * self.latent_width) // ( |
| patch_size[0] * patch_size[1] * |
| patch_size[2]) |
| action_token_per_chunk = self.job_config.frame_chunk_size * self.action_per_frame |
| self.transformer.create_empty_cache(self.cache_name, |
| self.job_config.attn_window, |
| latent_token_per_chunk, |
| action_token_per_chunk, |
| dtype=self.dtype, |
| device=self.device, |
| batch_size = 2 if self.use_cfg else 1 |
| ) |
|
|
| self.action_mask = torch.zeros([self.job_config.action_dim]).bool() |
| self.action_mask[self.job_config.used_action_channel_ids] = True |
|
|
| self.actions_q01 = torch.tensor(self.job_config.norm_stat['q01'], |
| dtype=torch.float32).reshape(-1, 1, 1) |
| self.actions_q99 = torch.tensor(self.job_config.norm_stat['q99'], |
| dtype=torch.float32).reshape(-1, 1, 1) |
| self.action_norm_method = self.job_config.action_norm_method |
|
|
| |
| if prompt is None: |
| self.prompt_embeds = self.negative_prompt_embeds = None |
| else: |
| self.prompt_embeds, self.negative_prompt_embeds = self.encode_prompt( |
| prompt=prompt, |
| negative_prompt=None, |
| do_classifier_free_guidance=self.job_config.guidance_scale > 1, |
| num_videos_per_prompt=1, |
| prompt_embeds=None, |
| negative_prompt_embeds=None, |
| max_sequence_length=512, |
| device=self.device, |
| dtype=self.dtype, |
| ) |
|
|
| self.exp_name = f"{prompt}_{time.strftime('%Y%m%d_%H%M%S')}" if prompt else "default" |
| self.exp_save_root = os.path.join(self.save_root, 'real', self.exp_name) |
| os.makedirs(self.exp_save_root, exist_ok=True) |
| torch.cuda.empty_cache() |
|
|
| def _infer(self, obs, frame_st_id=0): |
| frame_chunk_size = self.job_config.frame_chunk_size |
| if frame_st_id == 0: |
| init_latent = self._encode_obs(obs) |
| self.init_latent = init_latent |
|
|
| latents = torch.randn(1, |
| 48, |
| frame_chunk_size, |
| self.latent_height, |
| self.latent_width, |
| device=self.device, |
| dtype=self.dtype) |
| actions = torch.randn(1, |
| self.job_config.action_dim, |
| frame_chunk_size, |
| self.action_per_frame, |
| 1, |
| device=self.device, |
| dtype=self.dtype) |
|
|
| video_inference_step = self.job_config.num_inference_steps |
| action_inference_step = self.job_config.action_num_inference_steps |
| video_step = self.job_config.video_exec_step |
|
|
| self.scheduler.set_timesteps(video_inference_step) |
| self.action_scheduler.set_timesteps(action_inference_step) |
| timesteps = self.scheduler.timesteps |
| action_timesteps = self.action_scheduler.timesteps |
|
|
| timesteps = F.pad(timesteps, (0, 1), mode='constant', value=0) |
|
|
| if video_step != -1: |
| timesteps = timesteps[:video_step] |
|
|
| action_timesteps = F.pad( |
| action_timesteps, |
| (0, |
| 1), |
| mode='constant', |
| value=0) |
|
|
| with ( |
| torch.amp.autocast('cuda', dtype=self.dtype), |
| torch.no_grad(), |
| ): |
| |
| for i, t in enumerate(tqdm(timesteps)): |
| last_step = i == len(timesteps) - 1 |
| latent_cond = init_latent[:, :, 0:1].to( |
| self.dtype) if frame_st_id == 0 else None |
| input_dict = self._prepare_latent_input( |
| latents, |
| None, |
| t, |
| t, |
| latent_cond, |
| None, |
| frame_st_id=frame_st_id) |
|
|
| video_noise_pred = self.transformer( |
| self._repeat_input_for_cfg(input_dict['latent_res_lst']), |
| update_cache=1 if last_step else 0, |
| cache_name=self.cache_name, |
| action_mode=False) |
|
|
| if not last_step or video_step != -1: |
| video_noise_pred = data_seq_to_patch( |
| self.job_config.patch_size, video_noise_pred, |
| frame_chunk_size, self.latent_height, |
| self.latent_width, batch_size=2 if self.use_cfg else 1) |
| if self.job_config.guidance_scale > 1: |
| video_noise_pred = video_noise_pred[1:] + self.job_config.guidance_scale * (video_noise_pred[:1] - video_noise_pred[1:]) |
| else: |
| video_noise_pred = video_noise_pred[:1] |
| latents = self.scheduler.step(video_noise_pred, |
| t, |
| latents, |
| return_dict=False) |
|
|
| latents[:, :, 0:1] = latent_cond if frame_st_id == 0 else latents[:, :, 0:1] |
|
|
| for i, t in enumerate(tqdm(action_timesteps)): |
| last_step = i == len(action_timesteps) - 1 |
| action_cond = torch.zeros( |
| [ |
| 1, self.job_config.action_dim, 1, |
| self.action_per_frame, 1 |
| ], |
| device=self.device, |
| dtype=self.dtype) if frame_st_id == 0 else None |
|
|
| input_dict = self._prepare_latent_input( |
| None, |
| actions, |
| t, |
| t, |
| None, |
| action_cond, |
| frame_st_id=frame_st_id) |
| action_noise_pred = self.transformer( |
| self._repeat_input_for_cfg(input_dict['action_res_lst']), |
| update_cache=1 if last_step else 0, |
| cache_name=self.cache_name, |
| action_mode=True) |
|
|
| if not last_step: |
| action_noise_pred = rearrange(action_noise_pred, |
| 'b (f n) c -> b c f n 1', |
| f=frame_chunk_size) |
| if self.job_config.action_guidance_scale > 1: |
| action_noise_pred = action_noise_pred[1:] + self.job_config.action_guidance_scale * (action_noise_pred[:1] - action_noise_pred[1:]) |
| else: |
| action_noise_pred = action_noise_pred[:1] |
| actions = self.action_scheduler.step(action_noise_pred, |
| t, |
| actions, |
| return_dict=False) |
|
|
| actions[:, :, 0:1] = action_cond if frame_st_id == 0 else actions[:, :, 0:1] |
|
|
| actions[:, ~self.action_mask] *= 0 |
|
|
| save_async(latents, os.path.join(self.exp_save_root, f'latents_{frame_st_id}.pt')) |
| save_async(actions, os.path.join(self.exp_save_root, f'actions_{frame_st_id}.pt')) |
|
|
| actions = self.postprocess_action(actions) |
| torch.cuda.empty_cache() |
| return actions, latents |
|
|
| def _compute_kv_cache(self, obs): |
| |
| self.transformer.clear_pred_cache(self.cache_name) |
| save_async(obs['obs'], os.path.join(self.exp_save_root, f'obs_data_{self.frame_st_id}.pt')) |
| latent_model_input = self._encode_obs(obs) |
| if self.frame_st_id == 0: |
| latent_model_input = torch.cat( |
| [self.init_latent, latent_model_input], |
| dim=2) if latent_model_input is not None else self.init_latent |
|
|
| action_model_input = self.preprocess_action(obs['state']) |
| action_model_input = action_model_input.to(latent_model_input) |
| logger.info( |
| f"get KV cache obs: {latent_model_input.shape} {action_model_input.shape}" |
| ) |
| input_dict = self._prepare_latent_input(latent_model_input, |
| action_model_input, |
| frame_st_id=self.frame_st_id) |
|
|
| with ( |
| torch.amp.autocast('cuda', dtype=self.dtype), |
| torch.no_grad(), |
| ): |
| self.transformer(self._repeat_input_for_cfg(input_dict['latent_res_lst']), |
| update_cache=2, |
| cache_name=self.cache_name, |
| action_mode=False) |
|
|
| self.transformer(self._repeat_input_for_cfg(input_dict['action_res_lst']), |
| update_cache=2, |
| cache_name=self.cache_name, |
| action_mode=True) |
| torch.cuda.empty_cache() |
| self.frame_st_id += latent_model_input.shape[2] |
|
|
| @torch.no_grad() |
| def infer(self, obs): |
| reset = obs.get('reset', False) |
| prompt = obs.get('prompt', None) |
| compute_kv_cache = obs.get('compute_kv_cache', False) |
|
|
| if reset: |
| logger.info(f"******************* Reset server ******************") |
| self._reset(prompt=prompt) |
| return dict() |
| elif compute_kv_cache: |
| logger.info( |
| f"################# Compute KV Cache #################") |
| self._compute_kv_cache(obs) |
| return dict() |
| else: |
| logger.info(f"################# Infer One Chunk #################") |
| action, _ = self._infer(obs, frame_st_id=self.frame_st_id) |
| return dict(action=action) |
| |
| def decode_one_video(self, latents, output_type): |
| latents = latents.to(self.vae.dtype) |
| latents_mean = ( |
| torch.tensor(self.vae.config.latents_mean) |
| .view(1, self.vae.config.z_dim, 1, 1, 1) |
| .to(latents.device, latents.dtype) |
| ) |
| latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( |
| latents.device, latents.dtype |
| ) |
| latents = latents / latents_std + latents_mean |
| video = self.vae.decode(latents, return_dict=False)[0] |
| video = self.video_processor.postprocess_video(video, output_type=output_type) |
| return video |
| |
| def load_init_obs(self): |
| imf_dict = {v: np.array(Image.open(os.path.join(self.job_config.input_img_path, f"{v}.png")).convert("RGB")) for v in self.job_config.obs_cam_keys} |
| init_obs = {} |
| init_obs['obs'] = [imf_dict] |
| return init_obs |
| |
| @torch.no_grad() |
| def generate(self): |
| self.video_processor = VideoProcessor(vae_scale_factor=1) |
| self._reset(self.job_config.prompt) |
| init_obs = self.load_init_obs() |
| pred_latent_lst = [] |
| pred_action_lst = [] |
| for chunk_id in range(self.job_config.num_chunks_to_infer): |
| actions, latents = self._infer(init_obs, frame_st_id=(chunk_id * self.job_config.frame_chunk_size)) |
| actions = torch.from_numpy(actions) |
| pred_latent_lst.append(latents) |
| pred_action_lst.append(actions) |
| pred_latent = torch.cat(pred_latent_lst, dim=2) |
| pred_action = torch.cat(pred_action_lst, dim=1).flatten(1) |
| self.transformer.clear_cache(self.cache_name) |
| self.streaming_vae.clear_cache() |
| if self.streaming_vae_half: |
| self.streaming_vae_half.clear_cache() |
| del self.transformer |
| del self.streaming_vae_half |
| del self.text_encoder |
| torch.cuda.empty_cache() |
| decoded_video = self.decode_one_video(pred_latent, 'np')[0] |
| export_to_video(decoded_video, os.path.join(self.save_root, "demo.mp4"), fps=10) |
|
|
| def run(args): |
| |
| config = VA_CONFIGS[args.config_name] |
| port = config.port if args.port is None else args.port |
| if args.save_root is not None: |
| config.save_root = args.save_root |
| rank = int(os.getenv("RANK", 0)) |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| init_distributed(world_size, local_rank, rank) |
| config.rank = rank |
| config.local_rank = local_rank |
| config.world_size = world_size |
| model = VA_Server(config) |
| if getattr(args, "debug_infer_once", False): |
| from utils.Simple_Remote_Infer.deploy.msgpack_numpy import unpackb |
| from pathlib import Path |
|
|
| logger.info("******************* debug_infer_once: reset ******************") |
| path = Path("debug/place_fan/call1_reset.msgpack") |
| inp = unpackb(path.read_bytes()) |
| |
| out = model.infer(inp) |
|
|
| logger.info("******************* debug_infer_once: first infer ******************") |
| path = Path("debug/place_fan/call2.msgpack") |
| inp = unpackb(path.read_bytes()) |
| |
| out = model.infer(inp) |
|
|
| logger.info("******************* debug_infer_once: kv cache ******************") |
| path = Path("debug/place_fan/call3.msgpack") |
| inp = unpackb(path.read_bytes()) |
| |
| out = model.infer(inp) |
|
|
| if config.infer_mode == "i2va": |
| logger.info(f"******************************USE I2AV mode******************************") |
| model.generate() |
| elif config.infer_mode == "server": |
| logger.info(f"******************************USE Server mode******************************") |
| run_async_server_mode(model, local_rank, config.host, port) |
| else: |
| raise ValueError(f"Unknown infer mode: {config.infer_mode}") |
|
|
| def main(): |
| """ |
| TODO |
| """ |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config-name", |
| type=str, |
| required=False, |
| default='robotwin', |
| help="config name.", |
| ) |
| parser.add_argument( |
| "--port", |
| type=int, |
| default=None, |
| help='(start) port' |
| ) |
| parser.add_argument( |
| "--save_root", |
| type=str, |
| default=None, |
| help='save root' |
| ) |
| parser.add_argument( |
| "--debug_infer_once", |
| action="store_true", |
| help="Run one infer with dummy observation then exit (for debugging infer() without WebSocket client)", |
| ) |
| args = parser.parse_args() |
| run(args) |
| logger.info("Finish all process!!!!!!!!!!!!") |
|
|
|
|
| if __name__ == "__main__": |
| init_logger() |
| main() |
|
|