import os import torchvision from PIL import Image, ImageDraw import imageio import cv2 import torch import torch.nn.functional as F import numpy as np import zipfile _gauss_mask_cache = {} def load_gauss_mask(mask_path): if not mask_path: return None abs_path = os.path.abspath(mask_path) mask = _gauss_mask_cache.get(abs_path) if mask is None: mask = torch.load(abs_path, weights_only=False, map_location="cpu") if not torch.is_tensor(mask): mask = torch.tensor(mask) _gauss_mask_cache[abs_path] = mask return mask def apply_alpha_shift(latents, gauss_mask, shift_mean): if gauss_mask is None: return latents mask = gauss_mask if mask.ndim == 3: mask = mask.unsqueeze(0).unsqueeze(0) elif mask.ndim == 4: if mask.shape[0] != 1: mask = mask.unsqueeze(0) if mask.shape[1] != 1: mask = mask.unsqueeze(1) elif mask.ndim != 5: return latents mask = mask.to(device=latents.device, dtype=latents.dtype) target_shape = latents.shape[2:] if mask.shape[-3:] != target_shape: mask = F.interpolate(mask, size=target_shape, mode="trilinear", align_corners=False) shift_mean = torch.as_tensor(shift_mean, dtype=latents.dtype, device=latents.device) return latents + (1.0 - mask) * shift_mean def render_video(tensor_fgr, tensor_pha, nrow=8, normalize=True, value_range=(-1, 1)): def to_tensor(arr_list): tensor_list= [torch.from_numpy(arr).float().div_(127.5).sub_(1) for arr in arr_list] tensor_list = torch.stack(tensor_list, dim = 0).permute(3,0,1,2).unsqueeze(0) return tensor_list if not torch.is_tensor(tensor_fgr): tensor_fgr = to_tensor(tensor_fgr) if not torch.is_tensor(tensor_pha): tensor_pha = to_tensor(tensor_pha) tensor_fgr = tensor_fgr.clamp(min(value_range), max(value_range)) tensor_fgr = torch.stack([ torchvision.utils.make_grid( u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor_fgr.unbind(2) ], dim=1).permute(1, 2, 3, 0) tensor_fgr = (tensor_fgr * 255).type(torch.uint8).cpu() tensor_pha = tensor_pha.clamp(min(value_range), max(value_range)) tensor_pha = torch.stack([ torchvision.utils.make_grid( u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor_pha.unbind(2) ], dim=1).permute(1, 2, 3, 0) tensor_pha = (tensor_pha * 255).type(torch.uint8).cpu() frames = [] frames_fgr = [] frames_pha = [] for frame_fgr, frame_pha in zip(tensor_fgr.numpy(), tensor_pha.numpy()): if frame_pha.shape[-1] == 1: frame_pha = frame_pha[:,:,0] else: frame_pha = (0.0 + frame_pha[:,:,0:1] + frame_pha[:,:,1:2] + frame_pha[:,:,2:3]) / 3. frame = np.concatenate([frame_fgr[:,:,::-1], frame_pha.astype(np.uint8)], axis=2) frames.append(frame) frames_fgr.append(frame_fgr) frames_pha.append(frame_pha) def create_checkerboard(size=30, pattern_size=(830, 480), color1=(140, 140, 140), color2=(113, 113, 113)): img = Image.new('RGB', (pattern_size[0], pattern_size[1]), color1) draw = ImageDraw.Draw(img) for i in range(0, pattern_size[0], size): for j in range(0, pattern_size[1], size): if (i + j) // size % 2 == 0: draw.rectangle([i, j, i+size, j+size], fill=color2) return img def blender_background(frame_rgba, checkerboard): alpha_channel = frame_rgba[:, :, 3:] / 255. checkerboard = np.array(checkerboard) checkerboard = cv2.resize(checkerboard, (frame_rgba.shape[1], frame_rgba.shape[0])) frame_rgb = frame_rgba[:, :, :3] * alpha_channel + checkerboard * (1-alpha_channel) return frame_rgb.astype(np.uint8)[:,:,::-1] checkerboard = create_checkerboard() video_checkerboard = [torch.from_numpy(blender_background(f, checkerboard).copy()).float().div_(127.5).sub_(1) for f in frames] video_checkerboard = torch.stack(video_checkerboard ).permute(3, 0, 1, 2) return video_checkerboard, frames def from_BRGA_numpy_to_RGBA_torch(video): video = [torch.from_numpy(f.copy()).float().div_(127.5).sub_(1) for f in video] video = torch.stack(video).permute(3, 0, 1, 2) video[[0, 2], ...] = video[[2, 0], ...] return video def write_zip_file(zip_path, frames): # frames in BGRA format with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: for idx, img in enumerate(frames): success, buffer = cv2.imencode(".png", img) if not success: print(f"Failed to encode image {idx}, skipping...") continue filename = f"img_{idx:03d}.png" zipf.writestr(filename, buffer.tobytes())