Spaces:
Build error
Build error
| import os | |
| import logging | |
| import math | |
| from typing import Tuple, List, Optional | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import decord | |
| from PIL import Image | |
| from torchvision.transforms import Resize | |
| from torchvision.io import write_video | |
| import torch.nn.functional as F | |
| from torchvision.transforms.functional import resize, center_crop, to_pil_image | |
| from anchorcrafter.utils.geglu_patch import patch_geglu_inplace | |
| from anchorcrafter.dwpose.preprocess import get_video_pose, get_image_pose | |
| from anchorcrafter.pipelines.pipeline import AnchorCrafterPipeline | |
| from constants import ASPECT_RATIO | |
| # Initialize GEGLU patch | |
| patch_geglu_inplace() | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| DEFAULT_RESOLUTION = 576 | |
| DEFAULT_SAMPLE_STRIDE = 2 | |
| DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def align_image( | |
| image: np.ndarray, | |
| scale: Tuple[float, float], | |
| bias: Tuple[float, float] | |
| ) -> np.ndarray: | |
| """ | |
| Align image using affine transformation parameters | |
| Args: | |
| image: Input image array in HWC format | |
| scale: Scaling factors (y_scale, x_scale) | |
| bias: Translation biases (y_bias, x_bias) normalized to [0,1] | |
| Returns: | |
| Aligned image array in HWC format | |
| """ | |
| height, width, channels = image.shape | |
| y_scale, x_scale = scale | |
| y_bias, x_bias = bias | |
| # Convert normalized bias to pixel coordinates | |
| x_bias_px = x_bias * height | |
| y_bias_px = y_bias * width | |
| # Create coordinate grids | |
| x_grid, y_grid = np.meshgrid(np.arange(height), np.arange(width), indexing='ij') | |
| # Apply inverse transformation | |
| x_orig = ((x_grid - x_bias_px) / x_scale).astype(int) | |
| y_orig = ((y_grid - y_bias_px) / y_scale).astype(int) | |
| # Create mask for valid coordinates | |
| valid_mask = (x_orig >= 0) & (x_orig < height) & (y_orig >= 0) & (y_orig < width) | |
| # Apply transformation | |
| aligned_image = np.zeros_like(image) | |
| aligned_image[valid_mask] = image[x_orig[valid_mask], y_orig[valid_mask]] | |
| return aligned_image | |
| def align_track_video(track_path, total_frames, sample_stride, scale, bias, resolution=512, visual=False, visual_tag='obj'): | |
| """Helper function to align tracking data""" | |
| try: | |
| vr = decord.VideoReader(track_path, ctx=decord.cpu(0)) | |
| frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy() | |
| except Exception as e: | |
| logger.error(f"Error loading {track_type} track: {str(e)}") | |
| raise | |
| if total_frames > 0: | |
| frames = frames[:total_frames] | |
| tensor = torch.from_numpy(frames).permute(0, 3, 1, 2).float() | |
| tensor = Resize([int(resolution / ASPECT_RATIO), resolution])(tensor) | |
| aligned_track=[] | |
| for frame in tensor: | |
| frame = np.transpose(frame, (1, 2, 0)) | |
| new_frame = align_image(frame, scale, bias) | |
| aligned_track.append(new_frame) | |
| aligned_track = np.stack(aligned_track, axis=0) # (f, h, w, c) | |
| if visual: | |
| write_video(f'./outputs/aligned_{visual_tag}.mp4', torch.tensor(aligned_track), fps=7) | |
| aligned_track = np.transpose(aligned_track, (0, 3, 1, 2)) | |
| zero = torch.zeros(aligned_track[0].shape) | |
| return np.concatenate([zero.unsqueeze(0), aligned_track]) | |
| def load_reference_objects( | |
| obj_template_path: str, | |
| max_objects: int = 3 | |
| ) -> torch.Tensor: | |
| """ | |
| Load reference object images with automatic path resolution | |
| Args: | |
| obj_template_path: Path template for object images (e.g., "obj_{}.jpg") | |
| max_objects: Maximum number of objects to load | |
| Returns: | |
| Tensor of object images in NCHW format normalized to [-1, 1] | |
| """ | |
| obj_images = [] | |
| base_path = Path(obj_template_path) | |
| for idx in range(max_objects): | |
| obj_path = base_path.parent / f"{base_path.stem[:-2]}_{idx}{base_path.suffix}" | |
| if not obj_path.exists(): | |
| if idx == 0: | |
| raise FileNotFoundError(f"No object images found matching pattern: {obj_path}") | |
| break | |
| try: | |
| image = Image.open(obj_path).convert("RGB") | |
| tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() | |
| # Pad to square | |
| h, w = tensor.shape[1:] | |
| pad_dims = ( | |
| max(0, (w - h) // 2), | |
| max(0, (h - w) // 2), | |
| max(0, (w - h) - (w - h) // 2), | |
| max(0, (h - w) - (h - w) // 2), | |
| ) | |
| tensor = F.pad(tensor, pad_dims, mode="constant", value=0) | |
| # Resize and normalize | |
| tensor = resize(tensor, [518, 518], antialias=None) | |
| obj_images.append(tensor) | |
| except Exception as e: | |
| logger.warning(f"Error loading object image {obj_path}: {str(e)}") | |
| # Handle case with fewer than max_objects | |
| while len(obj_images) < max_objects: | |
| obj_images.append(obj_images[-1].clone() if obj_images else torch.zeros(3, 518, 518)) | |
| return torch.stack(obj_images) | |
| def process_inputs( | |
| video_path: str, | |
| image_pixels: np.ndarray, | |
| obj_path: str, | |
| obj_track_path: str, | |
| hand_path: str, | |
| total_frames: int, | |
| resolution: int = 576, | |
| sample_stride: int = 2, | |
| visual: bool = False | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Process input data for video generation pipeline | |
| Args: | |
| video_path: Path to input video file | |
| image_pixels: Anchor image in numpy array format (H, W, C) | |
| obj_path: Path template for object reference images (e.g., "object_{}.jpg") | |
| obj_track_path: Path to object tracking video | |
| hand_path: Path to hand tracking video | |
| total_frames: Total number of frames to process (-1 for all frames) | |
| resolution: Target resolution for processing | |
| sample_stride: Frame sampling interval | |
| visual: Enable visualization outputs | |
| Returns: | |
| Tuple containing processed tensors: | |
| - pose_pixels: Normalized pose sequence tensor (N, C, H, W) | |
| - image_pixels: Normalized anchor image tensor (1, C, H, W) | |
| - obj_pixels: Normalized object references tensor (N, C, H, W) | |
| - obj_track_pixels: Normalized object track tensor (N, C, H, W) | |
| - hand_pixels: Normalized hand track tensor (N, C, H, W) | |
| """ | |
| image_pixels=torch.from_numpy(image_pixels).permute(2, 0, 1) # (C, H, W) | |
| h, w = image_pixels.shape[-2:] | |
| ############################ compute target h/w according to original aspect ratio ############################### | |
| if h > w: | |
| w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64 | |
| else: | |
| w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution | |
| h_w_ratio = float(h) / float(w) | |
| if h_w_ratio < h_target / w_target: | |
| h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio) | |
| else: | |
| h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target | |
| image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) | |
| image_pixels = center_crop(image_pixels, [h_target, w_target]) # c,h,w | |
| image_pixels = image_pixels.permute((1, 2, 0)).numpy() # h,w,c | |
| ##################################### get image&video pose value ################################################# | |
| image_pose = get_image_pose(image_pixels) # c,h,w | |
| video_pose, scale, bias = get_video_pose(video_path, image_pixels, sample_stride=sample_stride, total_frames=total_frames) | |
| if visual: | |
| write_video('./outputs/pose_align.mp4', torch.tensor(video_pose).permute((0, 2, 3, 1)), fps=7) | |
| pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) | |
| image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) # 1 c h w | |
| ############ object reference ############ | |
| obj_pixels = load_reference_objects(obj_path) | |
| ############ object track ############ | |
| obj_track_pixels = align_track_video(obj_track_path, total_frames, sample_stride, scale, bias, resolution=resolution, visual=visual) | |
| ############ hand track ############ | |
| hand_pixels = align_track_video(hand_path, total_frames, sample_stride, scale, bias, resolution=resolution, visual=visual, visual_tag='hand') | |
| return (torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, | |
| torch.from_numpy(image_pixels) / 127.5 - 1, | |
| obj_pixels / 127.5 - 1, | |
| torch.from_numpy(obj_track_pixels) / 127.5 - 1, | |
| torch.from_numpy(hand_pixels / 127.5 - 1)) | |
| def run_pipeline( | |
| pipeline: AnchorCrafterPipeline, | |
| image_pixels: torch.Tensor, | |
| pose_pixels: torch.Tensor, | |
| obj_pixels: torch.Tensor, | |
| obj_track_pixels: torch.Tensor, | |
| hand_pixels: torch.Tensor, | |
| total_frames: int, | |
| device: torch.device, | |
| task_config: object | |
| ) -> torch.Tensor: | |
| """ | |
| Execute the video generation pipeline | |
| Args: | |
| pipeline: Initialized AnchorCrafter pipeline | |
| image_pixels: Normalized anchor image tensor (1, C, H, W) | |
| pose_pixels: Normalized pose sequence tensor (N, C, H, W) | |
| obj_pixels: Normalized object references tensor (N, C, H, W) | |
| obj_track_pixels: Normalized object track tensor (N, C, H, W) | |
| hand_pixels: Normalized hand track tensor (N, C, H, W) | |
| total_frames: Number of frames to generate | |
| device: Target computation device | |
| task_config: Configuration object containing: | |
| - seed: Random seed | |
| - num_frames: Base frame count | |
| - frames_overlap: Tile overlap size | |
| - noise_aug_strength: Noise augmentation strength | |
| - num_inference_steps: Diffusion steps | |
| - guidance_scale: CFG scale | |
| Returns: | |
| Generated video frames tensor in uint8 format (F, C, H, W) | |
| """ | |
| image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5] | |
| obj_pixels = [to_pil_image(img.to(torch.uint8)) for img in (obj_pixels + 1.0) * 127.5] | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(task_config.seed) | |
| total_frames = min(total_frames, pose_pixels.size(0), obj_track_pixels.size(0), hand_pixels.size(0)) | |
| frames = pipeline( | |
| image_pixels, pose_pixels[:total_frames], obj_pixels, obj_track_pixels[:total_frames], | |
| hand_pixels=hand_pixels[:total_frames], num_frames=total_frames, | |
| tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap, | |
| height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7, | |
| noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps, | |
| generator=generator, min_guidance_scale=task_config.guidance_scale, | |
| max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device, | |
| # visual_output=args.visual_output | |
| ).frames.cpu() | |
| video_frames = (frames * 255.0).to(torch.uint8) | |
| print(f' video_frames: {video_frames.shape}') | |
| for vid_idx in range(video_frames.shape[0]): | |
| # deprecated first frame because of ref image | |
| _video_frames = video_frames[vid_idx, 1:] | |
| return _video_frames | |
| def set_logger(log_file=None, log_level=logging.INFO): | |
| log_handler = logging.FileHandler(log_file, "w") | |
| log_handler.setFormatter( | |
| logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s") | |
| ) | |
| log_handler.setLevel(log_level) | |
| logger.addHandler(log_handler) | |