AnchorCrafter / inference.py
cangcz's picture
init
34ee308 verified
import os
import logging
import math
from typing import Tuple, List, Optional
from pathlib import Path
import numpy as np
import torch
import decord
from PIL import Image
from torchvision.transforms import Resize
from torchvision.io import write_video
import torch.nn.functional as F
from torchvision.transforms.functional import resize, center_crop, to_pil_image
from anchorcrafter.utils.geglu_patch import patch_geglu_inplace
from anchorcrafter.dwpose.preprocess import get_video_pose, get_image_pose
from anchorcrafter.pipelines.pipeline import AnchorCrafterPipeline
from constants import ASPECT_RATIO
# Initialize GEGLU patch
patch_geglu_inplace()
logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
# Constants
DEFAULT_RESOLUTION = 576
DEFAULT_SAMPLE_STRIDE = 2
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def align_image(
image: np.ndarray,
scale: Tuple[float, float],
bias: Tuple[float, float]
) -> np.ndarray:
"""
Align image using affine transformation parameters
Args:
image: Input image array in HWC format
scale: Scaling factors (y_scale, x_scale)
bias: Translation biases (y_bias, x_bias) normalized to [0,1]
Returns:
Aligned image array in HWC format
"""
height, width, channels = image.shape
y_scale, x_scale = scale
y_bias, x_bias = bias
# Convert normalized bias to pixel coordinates
x_bias_px = x_bias * height
y_bias_px = y_bias * width
# Create coordinate grids
x_grid, y_grid = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
# Apply inverse transformation
x_orig = ((x_grid - x_bias_px) / x_scale).astype(int)
y_orig = ((y_grid - y_bias_px) / y_scale).astype(int)
# Create mask for valid coordinates
valid_mask = (x_orig >= 0) & (x_orig < height) & (y_orig >= 0) & (y_orig < width)
# Apply transformation
aligned_image = np.zeros_like(image)
aligned_image[valid_mask] = image[x_orig[valid_mask], y_orig[valid_mask]]
return aligned_image
def align_track_video(track_path, total_frames, sample_stride, scale, bias, resolution=512, visual=False, visual_tag='obj'):
"""Helper function to align tracking data"""
try:
vr = decord.VideoReader(track_path, ctx=decord.cpu(0))
frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy()
except Exception as e:
logger.error(f"Error loading {track_type} track: {str(e)}")
raise
if total_frames > 0:
frames = frames[:total_frames]
tensor = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
tensor = Resize([int(resolution / ASPECT_RATIO), resolution])(tensor)
aligned_track=[]
for frame in tensor:
frame = np.transpose(frame, (1, 2, 0))
new_frame = align_image(frame, scale, bias)
aligned_track.append(new_frame)
aligned_track = np.stack(aligned_track, axis=0) # (f, h, w, c)
if visual:
write_video(f'./outputs/aligned_{visual_tag}.mp4', torch.tensor(aligned_track), fps=7)
aligned_track = np.transpose(aligned_track, (0, 3, 1, 2))
zero = torch.zeros(aligned_track[0].shape)
return np.concatenate([zero.unsqueeze(0), aligned_track])
def load_reference_objects(
obj_template_path: str,
max_objects: int = 3
) -> torch.Tensor:
"""
Load reference object images with automatic path resolution
Args:
obj_template_path: Path template for object images (e.g., "obj_{}.jpg")
max_objects: Maximum number of objects to load
Returns:
Tensor of object images in NCHW format normalized to [-1, 1]
"""
obj_images = []
base_path = Path(obj_template_path)
for idx in range(max_objects):
obj_path = base_path.parent / f"{base_path.stem[:-2]}_{idx}{base_path.suffix}"
if not obj_path.exists():
if idx == 0:
raise FileNotFoundError(f"No object images found matching pattern: {obj_path}")
break
try:
image = Image.open(obj_path).convert("RGB")
tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
# Pad to square
h, w = tensor.shape[1:]
pad_dims = (
max(0, (w - h) // 2),
max(0, (h - w) // 2),
max(0, (w - h) - (w - h) // 2),
max(0, (h - w) - (h - w) // 2),
)
tensor = F.pad(tensor, pad_dims, mode="constant", value=0)
# Resize and normalize
tensor = resize(tensor, [518, 518], antialias=None)
obj_images.append(tensor)
except Exception as e:
logger.warning(f"Error loading object image {obj_path}: {str(e)}")
# Handle case with fewer than max_objects
while len(obj_images) < max_objects:
obj_images.append(obj_images[-1].clone() if obj_images else torch.zeros(3, 518, 518))
return torch.stack(obj_images)
def process_inputs(
video_path: str,
image_pixels: np.ndarray,
obj_path: str,
obj_track_path: str,
hand_path: str,
total_frames: int,
resolution: int = 576,
sample_stride: int = 2,
visual: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Process input data for video generation pipeline
Args:
video_path: Path to input video file
image_pixels: Anchor image in numpy array format (H, W, C)
obj_path: Path template for object reference images (e.g., "object_{}.jpg")
obj_track_path: Path to object tracking video
hand_path: Path to hand tracking video
total_frames: Total number of frames to process (-1 for all frames)
resolution: Target resolution for processing
sample_stride: Frame sampling interval
visual: Enable visualization outputs
Returns:
Tuple containing processed tensors:
- pose_pixels: Normalized pose sequence tensor (N, C, H, W)
- image_pixels: Normalized anchor image tensor (1, C, H, W)
- obj_pixels: Normalized object references tensor (N, C, H, W)
- obj_track_pixels: Normalized object track tensor (N, C, H, W)
- hand_pixels: Normalized hand track tensor (N, C, H, W)
"""
image_pixels=torch.from_numpy(image_pixels).permute(2, 0, 1) # (C, H, W)
h, w = image_pixels.shape[-2:]
############################ compute target h/w according to original aspect ratio ###############################
if h > w:
w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
else:
w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
h_w_ratio = float(h) / float(w)
if h_w_ratio < h_target / w_target:
h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
else:
h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
image_pixels = center_crop(image_pixels, [h_target, w_target]) # c,h,w
image_pixels = image_pixels.permute((1, 2, 0)).numpy() # h,w,c
##################################### get image&video pose value #################################################
image_pose = get_image_pose(image_pixels) # c,h,w
video_pose, scale, bias = get_video_pose(video_path, image_pixels, sample_stride=sample_stride, total_frames=total_frames)
if visual:
write_video('./outputs/pose_align.mp4', torch.tensor(video_pose).permute((0, 2, 3, 1)), fps=7)
pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) # 1 c h w
############ object reference ############
obj_pixels = load_reference_objects(obj_path)
############ object track ############
obj_track_pixels = align_track_video(obj_track_path, total_frames, sample_stride, scale, bias, resolution=resolution, visual=visual)
############ hand track ############
hand_pixels = align_track_video(hand_path, total_frames, sample_stride, scale, bias, resolution=resolution, visual=visual, visual_tag='hand')
return (torch.from_numpy(pose_pixels.copy()) / 127.5 - 1,
torch.from_numpy(image_pixels) / 127.5 - 1,
obj_pixels / 127.5 - 1,
torch.from_numpy(obj_track_pixels) / 127.5 - 1,
torch.from_numpy(hand_pixels / 127.5 - 1))
def run_pipeline(
pipeline: AnchorCrafterPipeline,
image_pixels: torch.Tensor,
pose_pixels: torch.Tensor,
obj_pixels: torch.Tensor,
obj_track_pixels: torch.Tensor,
hand_pixels: torch.Tensor,
total_frames: int,
device: torch.device,
task_config: object
) -> torch.Tensor:
"""
Execute the video generation pipeline
Args:
pipeline: Initialized AnchorCrafter pipeline
image_pixels: Normalized anchor image tensor (1, C, H, W)
pose_pixels: Normalized pose sequence tensor (N, C, H, W)
obj_pixels: Normalized object references tensor (N, C, H, W)
obj_track_pixels: Normalized object track tensor (N, C, H, W)
hand_pixels: Normalized hand track tensor (N, C, H, W)
total_frames: Number of frames to generate
device: Target computation device
task_config: Configuration object containing:
- seed: Random seed
- num_frames: Base frame count
- frames_overlap: Tile overlap size
- noise_aug_strength: Noise augmentation strength
- num_inference_steps: Diffusion steps
- guidance_scale: CFG scale
Returns:
Generated video frames tensor in uint8 format (F, C, H, W)
"""
image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]
obj_pixels = [to_pil_image(img.to(torch.uint8)) for img in (obj_pixels + 1.0) * 127.5]
generator = torch.Generator(device=device)
generator.manual_seed(task_config.seed)
total_frames = min(total_frames, pose_pixels.size(0), obj_track_pixels.size(0), hand_pixels.size(0))
frames = pipeline(
image_pixels, pose_pixels[:total_frames], obj_pixels, obj_track_pixels[:total_frames],
hand_pixels=hand_pixels[:total_frames], num_frames=total_frames,
tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap,
height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7,
noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps,
generator=generator, min_guidance_scale=task_config.guidance_scale,
max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device,
# visual_output=args.visual_output
).frames.cpu()
video_frames = (frames * 255.0).to(torch.uint8)
print(f' video_frames: {video_frames.shape}')
for vid_idx in range(video_frames.shape[0]):
# deprecated first frame because of ref image
_video_frames = video_frames[vid_idx, 1:]
return _video_frames
def set_logger(log_file=None, log_level=logging.INFO):
log_handler = logging.FileHandler(log_file, "w")
log_handler.setFormatter(
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
)
log_handler.setLevel(log_level)
logger.addHandler(log_handler)