AniGen / anigen /utils /image_utils.py
Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
import rembg
_SUPPORTED_IMAGE_EXTS = {
'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif', '.tiff'
}
def _expand_image_inputs(image_path: str) -> tuple[list[str], bool]:
"""Return (image_paths, is_directory).
If image_path is a directory, returns all supported images under it (non-recursive),
sorted by filename. Otherwise returns [image_path].
"""
if image_path is None:
raise ValueError('image_path is None')
image_path = str(image_path)
if os.path.isdir(image_path):
entries = []
for name in sorted(os.listdir(image_path)):
full = os.path.join(image_path, name)
if not os.path.isfile(full):
continue
ext = os.path.splitext(name)[1].lower()
if ext in _SUPPORTED_IMAGE_EXTS:
entries.append(full)
return entries, True
return [image_path], False
def load_dsine(device='cuda'):
# Load DSINE model
# We need to import DSINE here to avoid circular imports or path issues if possible,
# but since we added sys.path, we can try importing.
# Based on test_minimal.py in dsine repo
from models.dsine.v02 import DSINE_v02 as DSINE
# Manually define args since projects.dsine.config is missing
class Args:
def __init__(self):
self.NNET_architecture = 'v02'
self.NNET_encoder_B = 5
self.NNET_decoder_NF = 2048
self.NNET_decoder_BN = False
self.NNET_decoder_down = 8
self.NNET_learned_upsampling = True
self.NRN_prop_ps = 5
self.NRN_num_iter_train = 5
self.NRN_num_iter_test = 5
self.NRN_ray_relu = True
self.NNET_output_dim = 3
self.NNET_output_type = 'R'
self.NNET_feature_dim = 64
self.NNET_hidden_dim = 64
args = Args()
model = DSINE(args).to(device)
# Load checkpoint
ckpt_path = 'ckpts/dsine/dsine.pt'
if os.path.exists(ckpt_path):
print(f"Loading DSINE checkpoint from {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location='cpu')
if 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict, strict=True)
model.eval()
return model
else:
print(f"DSINE checkpoint not found at {ckpt_path}. Trying torch.hub...")
try:
# Fallback to torch.hub if local ckpt not found
# Note: This might fail if the hub model expects different args structure,
# but usually it handles it internally.
# However, since we are using local class definition, we should load weights into it.
# If we use torch.hub.load, it returns the model object directly.
model = torch.hub.load("hugoycj/DSINE-hub", "DSINE", trust_repo=True)
model.to(device)
model.eval()
return model
except Exception as e:
print(f"Failed to load DSINE from hub: {e}")
raise ValueError("Could not load DSINE model.")
def intrins_from_fov(new_fov, H, W, device):
fov = torch.tensor(new_fov).to(device)
f = 0.5 * W / torch.tan(0.5 * fov * np.pi / 180.0)
cx = 0.5 * W
cy = 0.5 * H
intrins = torch.tensor([[f, 0, cx], [0, f, cy], [0, 0, 1]]).to(device)
return intrins
def estimate_normal(image, model, device='cuda'):
# image: PIL Image RGB
w, h = image.size
# Prepare input
im_tensor = torch.from_numpy(np.array(image)).float() / 255.0
im_tensor = im_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
# Normalize
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
im_tensor = normalize(im_tensor)
# Pad
pad_h = (32 - h % 32) % 32
pad_w = (32 - w % 32) % 32
im_tensor = F.pad(im_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0)
# Intrinsics (assume 60 deg FOV)
intrins = intrins_from_fov(60.0, h, w, device).unsqueeze(0)
intrins[:, 0, 2] += 0 # No left padding
intrins[:, 1, 2] += 0 # No top padding
with torch.no_grad():
pred_norm = model(im_tensor, intrins=intrins)[-1]
# Crop padding
pred_norm = pred_norm[:, :, :h, :w]
# Revert the X axis
pred_norm[:, 0, :, :] = -pred_norm[:, 0, :, :]
# Convert to [0, 1]
pred_norm = (pred_norm + 1) / 2.0
return pred_norm # (1, 3, H, W)
def preprocess_image(input_image, dsine_model=None, device='cuda'):
# 1. DSINE Normal Estimation on Original Image
input_rgb = input_image.convert('RGB')
if dsine_model is not None:
normal_tensor = estimate_normal(input_rgb, dsine_model, device) # (1, 3, H, W)
normal_np = normal_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() # (H, W, 3)
normal_image = Image.fromarray((normal_np * 255).astype(np.uint8))
else:
normal_image = Image.new('RGB', input_image.size, (128, 128, 255))
has_alpha = False
if input_image.mode == 'RGBA':
alpha = np.array(input_image)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
if has_alpha:
output = input_image
else:
input_image = input_image.convert('RGB')
max_size = max(input_image.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input_image = input_image.resize((int(input_image.width * scale), int(input_image.height * scale)), Image.Resampling.LANCZOS)
# Also resize normal image if we resized input
normal_image = normal_image.resize((int(normal_image.width * scale), int(normal_image.height * scale)), Image.Resampling.LANCZOS)
session = rembg.new_session('birefnet-general')
output = rembg.remove(input_image, session=session)
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
if len(bbox) == 0:
bbox = [0, 0, output.height, output.width]
bbox_crop = (0, 0, output.width, output.height)
else:
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1.2)
bbox_crop = (int(center[0] - size // 2), int(center[1] - size // 2), int(center[0] + size // 2), int(center[1] + size // 2))
output = output.crop(bbox_crop)
output = output.resize((518, 518), Image.Resampling.LANCZOS)
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
# Process Normal
normal_rgba = normal_image.convert('RGBA')
# Create alpha mask image
alpha_img = Image.fromarray(alpha)
normal_rgba.putalpha(alpha_img)
normal_crop = normal_rgba.crop(bbox_crop)
normal_crop = normal_crop.resize((518, 518), Image.Resampling.LANCZOS)
normal_np = np.array(normal_crop).astype(np.float32) / 255
normal_np = normal_np[:, :, :3] * normal_np[:, :, 3:4]
normal_output = Image.fromarray((normal_np * 255).astype(np.uint8))
return output, normal_output
def encode_image(image, image_cond_model, device):
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = np.array(image.convert('RGB')).astype(np.float32) / 255
image_tensor = torch.from_numpy(image_tensor).permute(2, 0, 1).float().unsqueeze(0).to(device)
image_tensor = transform(image_tensor)
with torch.no_grad():
features = image_cond_model(image_tensor, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens