Spaces:
Running
Running
| from contextlib import contextmanager | |
| from pathlib import Path | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from loguru import logger | |
| from torch.cuda.nvtx import range_pop, range_push | |
| from tqdm import tqdm | |
| from sorawm.cleaner.e2fgvi_hq_cleaner import * | |
| from sorawm.models.model.e2fgvi_hq import InpaintGenerator | |
| from sorawm.utils.video_utils import merge_frames_with_overlap | |
| def nvtx(msg: str): | |
| range_push(msg) | |
| try: | |
| yield | |
| finally: | |
| range_pop() | |
| class ProfileInpaintGenerator(InpaintGenerator): | |
| def forward_bidirect_flow(self, masked_local_frames): | |
| """ | |
| Estimate bidirectional optical flows between consecutive frames in a local masked sequence. | |
| Parameters: | |
| masked_local_frames (torch.Tensor): Input tensor of masked local frames with shape | |
| (batch, time, channels, height, width). | |
| Returns: | |
| tuple: A pair (pred_flows_forward, pred_flows_backward) where each is a torch.Tensor | |
| of shape (batch, time - 1, 2, height // 4, width // 4). Each tensor contains 2D | |
| optical flow vectors: `pred_flows_forward` maps each frame to the next (i -> i+1), | |
| and `pred_flows_backward` maps each frame to the previous (i+1 -> i). | |
| """ | |
| with nvtx("InpaintGenerator.forward_bidirect_flow_total"): | |
| b, l_t, c, h, w = masked_local_frames.size() | |
| with nvtx("flow_downsample_interpolate"): | |
| masked_local_frames = F.interpolate( | |
| masked_local_frames.view(-1, c, h, w), | |
| scale_factor=1 / 4, | |
| mode="bilinear", | |
| align_corners=True, | |
| recompute_scale_factor=True, | |
| ) | |
| masked_local_frames = masked_local_frames.view( | |
| b, l_t, c, h // 4, w // 4 | |
| ) | |
| with nvtx("flow_prepare_pairs"): | |
| mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape( | |
| -1, c, h // 4, w // 4 | |
| ) | |
| mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape( | |
| -1, c, h // 4, w // 4 | |
| ) | |
| with nvtx("spynet_forward"): | |
| pred_flows_forward = self.update_spynet(mlf_1, mlf_2) | |
| with nvtx("spynet_backward"): | |
| pred_flows_backward = self.update_spynet(mlf_2, mlf_1) | |
| with nvtx("flow_reshape"): | |
| pred_flows_forward = pred_flows_forward.view( | |
| b, l_t - 1, 2, h // 4, w // 4 | |
| ) | |
| pred_flows_backward = pred_flows_backward.view( | |
| b, l_t - 1, 2, h // 4, w // 4 | |
| ) | |
| return pred_flows_forward, pred_flows_backward | |
| def forward(self, masked_frames, num_local_frames): | |
| """ | |
| Run inpainting generator on a sequence of masked frames, producing reconstructed frames and bidirectional flow estimates. | |
| Parameters: | |
| masked_frames (torch.Tensor): Tensor of shape (batch, time, channels, height, width) containing masked input frames (expected normalized to model range). | |
| num_local_frames (int): Number of initial frames in each sequence treated as local (used for flow estimation and local feature propagation). | |
| Returns: | |
| output (torch.Tensor): Reconstructed frames tensor of shape (batch * time, channels_out, height_out, width_out) with values in [-1, 1]. | |
| pred_flows (tuple): A pair (pred_flows_forward, pred_flows_backward) of tensors holding predicted optical flows for forward and backward directions; each has shape (batch, time-1, 2, h_flow, w_flow). | |
| """ | |
| with nvtx("InpaintGenerator.forward_total"): | |
| l_t = num_local_frames | |
| b, t, ori_c, ori_h, ori_w = masked_frames.size() | |
| with nvtx("forward_normalize_local_frames"): | |
| masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2 | |
| with nvtx("forward_bidirect_flow_call"): | |
| pred_flows = self.forward_bidirect_flow(masked_local_frames) | |
| with nvtx("encoder_all_frames"): | |
| enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w)) | |
| with nvtx("split_local_ref_feat"): | |
| _, c, h, w = enc_feat.size() | |
| fold_output_size = (h, w) | |
| local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...] | |
| ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...] | |
| with nvtx("feat_prop_module"): | |
| local_feat = self.feat_prop_module( | |
| local_feat, pred_flows[0], pred_flows[1] | |
| ) | |
| with nvtx("concat_local_ref"): | |
| enc_feat = torch.cat((local_feat, ref_feat), dim=1) | |
| with nvtx("temporal_focal_transformers_ss"): | |
| trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size) | |
| with nvtx("temporal_transformer_blocks"): | |
| trans_feat = self.transformer([trans_feat, fold_output_size]) | |
| with nvtx("sc_fuse"): | |
| trans_feat = self.sc(trans_feat[0], t, fold_output_size) | |
| trans_feat = trans_feat.view(b, t, -1, h, w) | |
| with nvtx("residual_add"): | |
| enc_feat = enc_feat + trans_feat | |
| with nvtx("decoder"): | |
| output = self.decoder(enc_feat.view(b * t, c, h, w)) | |
| output = torch.tanh(output) | |
| return output, pred_flows | |
| class ProfileE2FGVIHDCleaner(E2FGVIHDCleaner): | |
| def __init__( | |
| self, | |
| ckpt_path: Path = E2FGVI_HQ_CHECKPOINT_PATH, | |
| config: E2FGVIHDConfig = E2FGVIHDConfig(), | |
| ): | |
| with nvtx("cleaner_init_total"): | |
| with nvtx("ensure_model_downloaded"): | |
| ensure_model_downloaded(ckpt_path, E2FGVI_HQ_CHECKPOINT_REMOTE_URL) | |
| with nvtx("init_model"): | |
| self.model = ProfileInpaintGenerator().to(device) | |
| with nvtx("load_ckpt"): | |
| state = torch.load(ckpt_path, map_location=device) | |
| self.model.load_state_dict(state) | |
| with nvtx("model_eval_mode"): | |
| self.model.eval() | |
| self.config = config | |
| def clean(self, frames: np.ndarray, masks: np.ndarray) -> List[np.ndarray]: | |
| """ | |
| Run the full cleaning pipeline on a video using chunked, overlapping processing and return reconstructed frames. | |
| Processes the input frames and masks in configurable chunks with overlap: converts inputs to tensors, runs per-chunk inpainting and fusion, merges chunk outputs handling overlaps, and returns the final list of cleaned frames in original order. | |
| Parameters: | |
| frames (np.ndarray): Sequence of input RGB frames as a numpy array of shape (T, H, W, C) with values in [0, 255] or [0,1]. | |
| masks (np.ndarray): Corresponding mask array of shape (T, H, W) where nonzero values indicate regions to inpaint. | |
| Returns: | |
| List[np.ndarray]: List of T reconstructed RGB frames as numpy arrays (H, W, C), in the same order as the input. | |
| """ | |
| with nvtx("ProfileE2FGVIHDCleaner.clean_total"): | |
| with nvtx("setup_basic_params"): | |
| video_length = len(frames) | |
| chunk_size = int(self.config.chunk_size_ratio * video_length) | |
| overlap_size = int(self.config.overlap_ratio * video_length) | |
| num_chunks = int(np.ceil(video_length / (chunk_size - overlap_size))) | |
| h, w = frames[0].shape[:2] | |
| with nvtx("numpy_to_tensor"): | |
| imgs_all, masks_all = numpy_to_tensor(frames, masks) | |
| with nvtx("prepare_binary_masks"): | |
| binary_masks = np.expand_dims(masks > 0, axis=-1).astype(np.uint8) | |
| comp_frames = [None] * video_length | |
| logger.debug( | |
| f"Processing {video_length} frames in {num_chunks} chunks " | |
| f"(chunk_size={chunk_size}, overlap={overlap_size})" | |
| ) | |
| for chunk_idx in tqdm( | |
| range(num_chunks), desc="Chunk", position=0, leave=True | |
| ): | |
| with nvtx(f"chunk_{chunk_idx:03d}_total"): | |
| with nvtx("chunk_compute_indices"): | |
| start_idx = chunk_idx * (chunk_size - overlap_size) | |
| end_idx = min(start_idx + chunk_size, video_length) | |
| actual_chunk_size = end_idx - start_idx | |
| with nvtx("chunk_extract_and_to_device"): | |
| imgs_chunk = imgs_all[:, start_idx:end_idx, :, :, :].to(device) | |
| masks_chunk = masks_all[:, start_idx:end_idx, :, :, :].to( | |
| device | |
| ) | |
| frames_np_chunk = frames[start_idx:end_idx] | |
| binary_masks_chunk = binary_masks[start_idx:end_idx] | |
| with nvtx("chunk_process_frames_chunk"): | |
| comp_frames_chunk = self.process_frames_chunk( | |
| actual_chunk_size, | |
| self.config.neighbor_stride, | |
| imgs_chunk, | |
| masks_chunk, | |
| binary_masks_chunk, | |
| frames_np_chunk, | |
| h, | |
| w, | |
| ) | |
| with nvtx("merge_frames_with_overlap"): | |
| comp_frames = merge_frames_with_overlap( | |
| result_frames=comp_frames, | |
| chunk_frames=comp_frames_chunk, | |
| start_idx=start_idx, | |
| overlap_size=overlap_size, | |
| is_first_chunk=(chunk_idx == 0), | |
| ) | |
| with nvtx("chunk_cleanup"): | |
| del imgs_chunk, masks_chunk, comp_frames_chunk | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| return comp_frames | |
| # def process_frames_chunk( | |
| # self, | |
| # chunk_length: int, | |
| # neighbor_stride: int, | |
| # imgs_chunk: torch.Tensor, | |
| # masks_chunk: torch.Tensor, | |
| # binary_masks_chunk: np.ndarray, | |
| # frames_np_chunk: np.ndarray, | |
| # h: int, | |
| # w: int, | |
| # ) -> List[np.ndarray]: | |
| # """ | |
| # Compose inpainted frames for a chunk by running the model on sliding windows, blending predictions back into original frames. | |
| # Parameters: | |
| # chunk_length (int): Number of frames in the current chunk. | |
| # neighbor_stride (int): Half-window radius (in frames) used to select neighboring frames around each reference; determines step between processed reference frames. | |
| # imgs_chunk (torch.Tensor): Tensor of shape (1, T, C, H, W) containing chunk frames normalized for model input. | |
| # masks_chunk (torch.Tensor): Tensor of shape (1, T, 1, H, W) containing corresponding masks where masked regions are 1. | |
| # binary_masks_chunk (np.ndarray): Array of per-frame binary masks (H, W) or (H, W, 1) used for compositing predictions onto original frames (values 0/1). | |
| # frames_np_chunk (np.ndarray): Original chunk frames as uint8 numpy arrays in shape (T, H, W, C). | |
| # h (int): Original frame height. | |
| # w (int): Original frame width. | |
| # Returns: | |
| # List[np.ndarray]: A list of length `chunk_length` where each entry is the reconstructed uint8 RGB frame with model predictions composited into unmasked regions; overlapping predictions are averaged. | |
| # Raises: | |
| # RuntimeError: Intentionally raises RuntimeError("Stop here") to terminate profiling at the profiling breakpoint. | |
| # """ | |
| # comp_frames_chunk = [None] * chunk_length | |
| # for f in tqdm( | |
| # range(0, chunk_length, neighbor_stride), | |
| # desc=f" Frame progress", | |
| # position=1, | |
| # leave=False, | |
| # ): | |
| # with nvtx(f"window_f_{f:05d}_total"): | |
| # with nvtx("window_neighbor_ref_ids"): | |
| # neighbor_ids = [ | |
| # i | |
| # for i in range( | |
| # max(0, f - neighbor_stride), | |
| # min(chunk_length, f + neighbor_stride + 1), | |
| # ) | |
| # ] | |
| # ref_ids = get_ref_index( | |
| # f, | |
| # neighbor_ids, | |
| # chunk_length, | |
| # self.config.ref_length, | |
| # self.config.num_ref, | |
| # ) | |
| # with nvtx("window_select_tensors"): | |
| # selected_imgs = imgs_chunk[:1, neighbor_ids + ref_ids, :, :, :] | |
| # selected_masks = masks_chunk[:1, neighbor_ids + ref_ids, :, :, :] | |
| # with torch.no_grad(): | |
| # with nvtx("window_apply_mask"): | |
| # masked_imgs = selected_imgs * (1 - selected_masks) | |
| # with nvtx("window_pad_flip_concat"): | |
| # mod_size_h = 60 | |
| # mod_size_w = 108 | |
| # h_pad = (mod_size_h - h % mod_size_h) % mod_size_h | |
| # w_pad = (mod_size_w - w % mod_size_w) % mod_size_w | |
| # masked_imgs = torch.cat( | |
| # [masked_imgs, torch.flip(masked_imgs, [3])], 3 | |
| # )[:, :, :, : h + h_pad, :] | |
| # masked_imgs = torch.cat( | |
| # [masked_imgs, torch.flip(masked_imgs, [4])], 4 | |
| # )[:, :, :, :, : w + w_pad] | |
| # with nvtx("window_model_infer"): | |
| # pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids)) | |
| # with nvtx("window_crop_postprocess"): | |
| # pred_imgs = pred_imgs[:, :, :h, :w] | |
| # pred_imgs = (pred_imgs + 1) / 2 | |
| # pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255 | |
| # with nvtx("window_composite_back_to_frames"): | |
| # for i in range(len(neighbor_ids)): | |
| # idx = neighbor_ids[i] | |
| # img = np.array(pred_imgs[i]).astype( | |
| # np.uint8 | |
| # ) * binary_masks_chunk[idx] + frames_np_chunk[idx] * ( | |
| # 1 - binary_masks_chunk[idx] | |
| # ) | |
| # if comp_frames_chunk[idx] is None: | |
| # comp_frames_chunk[idx] = img | |
| # else: | |
| # comp_frames_chunk[idx] = ( | |
| # comp_frames_chunk[idx].astype(np.float32) * 0.5 | |
| # + img.astype(np.float32) * 0.5 | |
| # ) | |
| # # 你用来中断 profiling 的断点,保留 | |
| # # raise RuntimeError("Stop here") | |
| # return comp_frames_chunk | |
| def process_frames_chunk( | |
| self, | |
| chunk_length: int, | |
| neighbor_stride: int, | |
| imgs_chunk: torch.Tensor, | |
| masks_chunk: torch.Tensor, | |
| binary_masks_chunk: np.ndarray, | |
| frames_np_chunk: np.ndarray, | |
| h: int, | |
| w: int, | |
| ) -> List[np.ndarray]: | |
| comp_frames_chunk = [None] * chunk_length | |
| # 创建用于数据传输的 stream | |
| transfer_stream = torch.cuda.Stream() | |
| # 用于存储上一轮的结果(异步传输中) | |
| prev_pred_imgs_cpu = None | |
| prev_neighbor_ids = None | |
| all_windows = list(range(0, chunk_length, neighbor_stride)) | |
| for window_idx, f in enumerate( | |
| tqdm( | |
| all_windows, | |
| desc=f" Frame progress", | |
| position=1, | |
| leave=False, | |
| ) | |
| ): | |
| with nvtx(f"window_f_{f:05d}_total"): | |
| # ============ 准备当前窗口数据 ============ | |
| with nvtx("window_neighbor_ref_ids"): | |
| neighbor_ids = [ | |
| i | |
| for i in range( | |
| max(0, f - neighbor_stride), | |
| min(chunk_length, f + neighbor_stride + 1), | |
| ) | |
| ] | |
| ref_ids = get_ref_index( | |
| f, | |
| neighbor_ids, | |
| chunk_length, | |
| self.config.ref_length, | |
| self.config.num_ref, | |
| ) | |
| with nvtx("window_select_tensors"): | |
| selected_imgs = imgs_chunk[:1, neighbor_ids + ref_ids, :, :, :] | |
| selected_masks = masks_chunk[:1, neighbor_ids + ref_ids, :, :, :] | |
| with torch.no_grad(): | |
| with nvtx("window_apply_mask"): | |
| masked_imgs = selected_imgs * (1 - selected_masks) | |
| with nvtx("window_pad_flip_concat"): | |
| mod_size_h = 60 | |
| mod_size_w = 108 | |
| h_pad = (mod_size_h - h % mod_size_h) % mod_size_h | |
| w_pad = (mod_size_w - w % mod_size_w) % mod_size_w | |
| masked_imgs = torch.cat( | |
| [masked_imgs, torch.flip(masked_imgs, [3])], 3 | |
| )[:, :, :, : h + h_pad, :] | |
| masked_imgs = torch.cat( | |
| [masked_imgs, torch.flip(masked_imgs, [4])], 4 | |
| )[:, :, :, :, : w + w_pad] | |
| # ============ 模型推理 (默认 stream) ============ | |
| with nvtx("window_model_infer"): | |
| pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids)) | |
| # ============ GPU 上的后处理 ============ | |
| with nvtx("window_crop_postprocess_gpu"): | |
| pred_imgs = pred_imgs[:, :, :h, :w] | |
| pred_imgs = (pred_imgs + 1) / 2 | |
| pred_imgs = pred_imgs.permute(0, 2, 3, 1) * 255 | |
| # 记录当前计算完成的事件 | |
| compute_done = torch.cuda.Event() | |
| compute_done.record() | |
| # ============ 处理上一轮的结果 (如果有) ============ | |
| if prev_pred_imgs_cpu is not None: | |
| with nvtx("window_composite_prev"): | |
| # 等待上一轮传输完成 | |
| transfer_stream.synchronize() | |
| # 在 CPU 上合成上一轮的帧 | |
| self._composite_frames( | |
| prev_pred_imgs_cpu, | |
| prev_neighbor_ids, | |
| binary_masks_chunk, | |
| frames_np_chunk, | |
| comp_frames_chunk, | |
| ) | |
| # ============ 异步传输当前结果到 CPU ============ | |
| with nvtx("window_async_transfer"): | |
| # 确保计算完成后再传输 | |
| transfer_stream.wait_event(compute_done) | |
| with torch.cuda.stream(transfer_stream): | |
| # 使用 non_blocking=True 异步传输 | |
| # 先转到 pinned memory 的 tensor | |
| pred_imgs_cpu = pred_imgs.cpu().numpy() | |
| # 保存给下一轮处理 | |
| prev_pred_imgs_cpu = pred_imgs_cpu | |
| prev_neighbor_ids = neighbor_ids.copy() | |
| # ============ 处理最后一轮的结果 ============ | |
| if prev_pred_imgs_cpu is not None: | |
| transfer_stream.synchronize() | |
| self._composite_frames( | |
| prev_pred_imgs_cpu, | |
| prev_neighbor_ids, | |
| binary_masks_chunk, | |
| frames_np_chunk, | |
| comp_frames_chunk, | |
| ) | |
| return comp_frames_chunk | |
| def _composite_frames( | |
| self, | |
| pred_imgs_np: np.ndarray, | |
| neighbor_ids: List[int], | |
| binary_masks_chunk: np.ndarray, | |
| frames_np_chunk: np.ndarray, | |
| comp_frames_chunk: List[np.ndarray], | |
| ): | |
| """将预测结果合成到原始帧上""" | |
| for i in range(len(neighbor_ids)): | |
| idx = neighbor_ids[i] | |
| img = np.array(pred_imgs_np[i]).astype(np.uint8) * binary_masks_chunk[ | |
| idx | |
| ] + frames_np_chunk[idx] * (1 - binary_masks_chunk[idx]) | |
| if comp_frames_chunk[idx] is None: | |
| comp_frames_chunk[idx] = img | |
| else: | |
| comp_frames_chunk[idx] = ( | |
| comp_frames_chunk[idx].astype(np.float32) * 0.5 | |
| + img.astype(np.float32) * 0.5 | |
| ) | |
| if __name__ == "__main__": | |
| CMD = Path.cwd() / "profiling" | |
| masks_npy_path = CMD / "masks.npy" | |
| frames_npy_path = CMD / "frames.npy" | |
| with nvtx("load_numpy_inputs"): | |
| masks = np.load(masks_npy_path) | |
| frames = np.load(frames_npy_path) | |
| with nvtx("init_cleaner"): | |
| cleaner = ProfileE2FGVIHDCleaner() | |
| with nvtx("run_cleaner"): | |
| cleaned_frames = cleaner.clean(frames, masks) | |
| # np.save(CMD / "cleaned_frames.npy", cleaned_frames) | |