File size: 8,050 Bytes
6b92ff7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
import rembg
_SUPPORTED_IMAGE_EXTS = {
'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif', '.tiff'
}
def _expand_image_inputs(image_path: str) -> tuple[list[str], bool]:
"""Return (image_paths, is_directory).
If image_path is a directory, returns all supported images under it (non-recursive),
sorted by filename. Otherwise returns [image_path].
"""
if image_path is None:
raise ValueError('image_path is None')
image_path = str(image_path)
if os.path.isdir(image_path):
entries = []
for name in sorted(os.listdir(image_path)):
full = os.path.join(image_path, name)
if not os.path.isfile(full):
continue
ext = os.path.splitext(name)[1].lower()
if ext in _SUPPORTED_IMAGE_EXTS:
entries.append(full)
return entries, True
return [image_path], False
def load_dsine(device='cuda'):
# Load DSINE model
# We need to import DSINE here to avoid circular imports or path issues if possible,
# but since we added sys.path, we can try importing.
# Based on test_minimal.py in dsine repo
from models.dsine.v02 import DSINE_v02 as DSINE
# Manually define args since projects.dsine.config is missing
class Args:
def __init__(self):
self.NNET_architecture = 'v02'
self.NNET_encoder_B = 5
self.NNET_decoder_NF = 2048
self.NNET_decoder_BN = False
self.NNET_decoder_down = 8
self.NNET_learned_upsampling = True
self.NRN_prop_ps = 5
self.NRN_num_iter_train = 5
self.NRN_num_iter_test = 5
self.NRN_ray_relu = True
self.NNET_output_dim = 3
self.NNET_output_type = 'R'
self.NNET_feature_dim = 64
self.NNET_hidden_dim = 64
args = Args()
model = DSINE(args).to(device)
# Load checkpoint
ckpt_path = 'ckpts/dsine/dsine.pt'
if os.path.exists(ckpt_path):
print(f"Loading DSINE checkpoint from {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location='cpu')
if 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict, strict=True)
model.eval()
return model
else:
print(f"DSINE checkpoint not found at {ckpt_path}. Trying torch.hub...")
try:
# Fallback to torch.hub if local ckpt not found
# Note: This might fail if the hub model expects different args structure,
# but usually it handles it internally.
# However, since we are using local class definition, we should load weights into it.
# If we use torch.hub.load, it returns the model object directly.
model = torch.hub.load("hugoycj/DSINE-hub", "DSINE", trust_repo=True)
model.to(device)
model.eval()
return model
except Exception as e:
print(f"Failed to load DSINE from hub: {e}")
raise ValueError("Could not load DSINE model.")
def intrins_from_fov(new_fov, H, W, device):
fov = torch.tensor(new_fov).to(device)
f = 0.5 * W / torch.tan(0.5 * fov * np.pi / 180.0)
cx = 0.5 * W
cy = 0.5 * H
intrins = torch.tensor([[f, 0, cx], [0, f, cy], [0, 0, 1]]).to(device)
return intrins
def estimate_normal(image, model, device='cuda'):
# image: PIL Image RGB
w, h = image.size
# Prepare input
im_tensor = torch.from_numpy(np.array(image)).float() / 255.0
im_tensor = im_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
# Normalize
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
im_tensor = normalize(im_tensor)
# Pad
pad_h = (32 - h % 32) % 32
pad_w = (32 - w % 32) % 32
im_tensor = F.pad(im_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0)
# Intrinsics (assume 60 deg FOV)
intrins = intrins_from_fov(60.0, h, w, device).unsqueeze(0)
intrins[:, 0, 2] += 0 # No left padding
intrins[:, 1, 2] += 0 # No top padding
with torch.no_grad():
pred_norm = model(im_tensor, intrins=intrins)[-1]
# Crop padding
pred_norm = pred_norm[:, :, :h, :w]
# Revert the X axis
pred_norm[:, 0, :, :] = -pred_norm[:, 0, :, :]
# Convert to [0, 1]
pred_norm = (pred_norm + 1) / 2.0
return pred_norm # (1, 3, H, W)
def preprocess_image(input_image, dsine_model=None, device='cuda'):
# 1. DSINE Normal Estimation on Original Image
input_rgb = input_image.convert('RGB')
if dsine_model is not None:
normal_tensor = estimate_normal(input_rgb, dsine_model, device) # (1, 3, H, W)
normal_np = normal_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() # (H, W, 3)
normal_image = Image.fromarray((normal_np * 255).astype(np.uint8))
else:
normal_image = Image.new('RGB', input_image.size, (128, 128, 255))
has_alpha = False
if input_image.mode == 'RGBA':
alpha = np.array(input_image)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
if has_alpha:
output = input_image
else:
input_image = input_image.convert('RGB')
max_size = max(input_image.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input_image = input_image.resize((int(input_image.width * scale), int(input_image.height * scale)), Image.Resampling.LANCZOS)
# Also resize normal image if we resized input
normal_image = normal_image.resize((int(normal_image.width * scale), int(normal_image.height * scale)), Image.Resampling.LANCZOS)
session = rembg.new_session('birefnet-general')
output = rembg.remove(input_image, session=session)
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
if len(bbox) == 0:
bbox = [0, 0, output.height, output.width]
bbox_crop = (0, 0, output.width, output.height)
else:
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1.2)
bbox_crop = (int(center[0] - size // 2), int(center[1] - size // 2), int(center[0] + size // 2), int(center[1] + size // 2))
output = output.crop(bbox_crop)
output = output.resize((518, 518), Image.Resampling.LANCZOS)
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
# Process Normal
normal_rgba = normal_image.convert('RGBA')
# Create alpha mask image
alpha_img = Image.fromarray(alpha)
normal_rgba.putalpha(alpha_img)
normal_crop = normal_rgba.crop(bbox_crop)
normal_crop = normal_crop.resize((518, 518), Image.Resampling.LANCZOS)
normal_np = np.array(normal_crop).astype(np.float32) / 255
normal_np = normal_np[:, :, :3] * normal_np[:, :, 3:4]
normal_output = Image.fromarray((normal_np * 255).astype(np.uint8))
return output, normal_output
def encode_image(image, image_cond_model, device):
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = np.array(image.convert('RGB')).astype(np.float32) / 255
image_tensor = torch.from_numpy(image_tensor).permute(2, 0, 1).float().unsqueeze(0).to(device)
image_tensor = transform(image_tensor)
with torch.no_grad():
features = image_cond_model(image_tensor, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens
|