import os import torch import numpy as np from PIL import Image import torch.nn.functional as F from torchvision import transforms import rembg _SUPPORTED_IMAGE_EXTS = { '.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif', '.tiff' } def _expand_image_inputs(image_path: str) -> tuple[list[str], bool]: """Return (image_paths, is_directory). If image_path is a directory, returns all supported images under it (non-recursive), sorted by filename. Otherwise returns [image_path]. """ if image_path is None: raise ValueError('image_path is None') image_path = str(image_path) if os.path.isdir(image_path): entries = [] for name in sorted(os.listdir(image_path)): full = os.path.join(image_path, name) if not os.path.isfile(full): continue ext = os.path.splitext(name)[1].lower() if ext in _SUPPORTED_IMAGE_EXTS: entries.append(full) return entries, True return [image_path], False def load_dsine(device='cuda'): # Load DSINE model # We need to import DSINE here to avoid circular imports or path issues if possible, # but since we added sys.path, we can try importing. # Based on test_minimal.py in dsine repo from models.dsine.v02 import DSINE_v02 as DSINE # Manually define args since projects.dsine.config is missing class Args: def __init__(self): self.NNET_architecture = 'v02' self.NNET_encoder_B = 5 self.NNET_decoder_NF = 2048 self.NNET_decoder_BN = False self.NNET_decoder_down = 8 self.NNET_learned_upsampling = True self.NRN_prop_ps = 5 self.NRN_num_iter_train = 5 self.NRN_num_iter_test = 5 self.NRN_ray_relu = True self.NNET_output_dim = 3 self.NNET_output_type = 'R' self.NNET_feature_dim = 64 self.NNET_hidden_dim = 64 args = Args() model = DSINE(args).to(device) # Load checkpoint ckpt_path = 'ckpts/dsine/dsine.pt' if os.path.exists(ckpt_path): print(f"Loading DSINE checkpoint from {ckpt_path}") state_dict = torch.load(ckpt_path, map_location='cpu') if 'model' in state_dict: state_dict = state_dict['model'] model.load_state_dict(state_dict, strict=True) model.eval() return model else: print(f"DSINE checkpoint not found at {ckpt_path}. Trying torch.hub...") try: # Fallback to torch.hub if local ckpt not found # Note: This might fail if the hub model expects different args structure, # but usually it handles it internally. # However, since we are using local class definition, we should load weights into it. # If we use torch.hub.load, it returns the model object directly. model = torch.hub.load("hugoycj/DSINE-hub", "DSINE", trust_repo=True) model.to(device) model.eval() return model except Exception as e: print(f"Failed to load DSINE from hub: {e}") raise ValueError("Could not load DSINE model.") def intrins_from_fov(new_fov, H, W, device): fov = torch.tensor(new_fov).to(device) f = 0.5 * W / torch.tan(0.5 * fov * np.pi / 180.0) cx = 0.5 * W cy = 0.5 * H intrins = torch.tensor([[f, 0, cx], [0, f, cy], [0, 0, 1]]).to(device) return intrins def estimate_normal(image, model, device='cuda'): # image: PIL Image RGB w, h = image.size # Prepare input im_tensor = torch.from_numpy(np.array(image)).float() / 255.0 im_tensor = im_tensor.permute(2, 0, 1).unsqueeze(0).to(device) # Normalize normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) im_tensor = normalize(im_tensor) # Pad pad_h = (32 - h % 32) % 32 pad_w = (32 - w % 32) % 32 im_tensor = F.pad(im_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0) # Intrinsics (assume 60 deg FOV) intrins = intrins_from_fov(60.0, h, w, device).unsqueeze(0) intrins[:, 0, 2] += 0 # No left padding intrins[:, 1, 2] += 0 # No top padding with torch.no_grad(): pred_norm = model(im_tensor, intrins=intrins)[-1] # Crop padding pred_norm = pred_norm[:, :, :h, :w] # Revert the X axis pred_norm[:, 0, :, :] = -pred_norm[:, 0, :, :] # Convert to [0, 1] pred_norm = (pred_norm + 1) / 2.0 return pred_norm # (1, 3, H, W) def preprocess_image(input_image, dsine_model=None, device='cuda'): # 1. DSINE Normal Estimation on Original Image input_rgb = input_image.convert('RGB') if dsine_model is not None: normal_tensor = estimate_normal(input_rgb, dsine_model, device) # (1, 3, H, W) normal_np = normal_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() # (H, W, 3) normal_image = Image.fromarray((normal_np * 255).astype(np.uint8)) else: normal_image = Image.new('RGB', input_image.size, (128, 128, 255)) has_alpha = False if input_image.mode == 'RGBA': alpha = np.array(input_image)[:, :, 3] if not np.all(alpha == 255): has_alpha = True if has_alpha: output = input_image else: input_image = input_image.convert('RGB') max_size = max(input_image.size) scale = min(1, 1024 / max_size) if scale < 1: input_image = input_image.resize((int(input_image.width * scale), int(input_image.height * scale)), Image.Resampling.LANCZOS) # Also resize normal image if we resized input normal_image = normal_image.resize((int(normal_image.width * scale), int(normal_image.height * scale)), Image.Resampling.LANCZOS) session = rembg.new_session('birefnet-general') output = rembg.remove(input_image, session=session) output_np = np.array(output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) if len(bbox) == 0: bbox = [0, 0, output.height, output.width] bbox_crop = (0, 0, output.width, output.height) else: bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1.2) bbox_crop = (int(center[0] - size // 2), int(center[1] - size // 2), int(center[0] + size // 2), int(center[1] + size // 2)) output = output.crop(bbox_crop) output = output.resize((518, 518), Image.Resampling.LANCZOS) output = np.array(output).astype(np.float32) / 255 output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) # Process Normal normal_rgba = normal_image.convert('RGBA') # Create alpha mask image alpha_img = Image.fromarray(alpha) normal_rgba.putalpha(alpha_img) normal_crop = normal_rgba.crop(bbox_crop) normal_crop = normal_crop.resize((518, 518), Image.Resampling.LANCZOS) normal_np = np.array(normal_crop).astype(np.float32) / 255 normal_np = normal_np[:, :, :3] * normal_np[:, :, 3:4] normal_output = Image.fromarray((normal_np * 255).astype(np.uint8)) return output, normal_output def encode_image(image, image_cond_model, device): transform = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_tensor = np.array(image.convert('RGB')).astype(np.float32) / 255 image_tensor = torch.from_numpy(image_tensor).permute(2, 0, 1).float().unsqueeze(0).to(device) image_tensor = transform(image_tensor) with torch.no_grad(): features = image_cond_model(image_tensor, is_training=True)['x_prenorm'] patchtokens = F.layer_norm(features, features.shape[-1:]) return patchtokens