File size: 8,050 Bytes
6b92ff7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
import rembg

_SUPPORTED_IMAGE_EXTS = {
    '.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif', '.tiff'
}

def _expand_image_inputs(image_path: str) -> tuple[list[str], bool]:
    """Return (image_paths, is_directory).

    If image_path is a directory, returns all supported images under it (non-recursive),
    sorted by filename. Otherwise returns [image_path].
    """
    if image_path is None:
        raise ValueError('image_path is None')

    image_path = str(image_path)
    if os.path.isdir(image_path):
        entries = []
        for name in sorted(os.listdir(image_path)):
            full = os.path.join(image_path, name)
            if not os.path.isfile(full):
                continue
            ext = os.path.splitext(name)[1].lower()
            if ext in _SUPPORTED_IMAGE_EXTS:
                entries.append(full)
        return entries, True

    return [image_path], False

def load_dsine(device='cuda'):
    # Load DSINE model
    # We need to import DSINE here to avoid circular imports or path issues if possible,
    # but since we added sys.path, we can try importing.
    # Based on test_minimal.py in dsine repo
    from models.dsine.v02 import DSINE_v02 as DSINE
    
    # Manually define args since projects.dsine.config is missing
    class Args:
        def __init__(self):
            self.NNET_architecture = 'v02'
            self.NNET_encoder_B = 5
            self.NNET_decoder_NF = 2048
            self.NNET_decoder_BN = False
            self.NNET_decoder_down = 8
            self.NNET_learned_upsampling = True
            self.NRN_prop_ps = 5
            self.NRN_num_iter_train = 5
            self.NRN_num_iter_test = 5
            self.NRN_ray_relu = True
            self.NNET_output_dim = 3
            self.NNET_output_type = 'R'
            self.NNET_feature_dim = 64
            self.NNET_hidden_dim = 64
            
    args = Args()
            
    model = DSINE(args).to(device)
    
    # Load checkpoint
    ckpt_path = 'ckpts/dsine/dsine.pt'
    if os.path.exists(ckpt_path):
        print(f"Loading DSINE checkpoint from {ckpt_path}")
        state_dict = torch.load(ckpt_path, map_location='cpu')
        if 'model' in state_dict:
            state_dict = state_dict['model']
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        return model
    else:
        print(f"DSINE checkpoint not found at {ckpt_path}. Trying torch.hub...")
        try:
            # Fallback to torch.hub if local ckpt not found
            # Note: This might fail if the hub model expects different args structure, 
            # but usually it handles it internally.
            # However, since we are using local class definition, we should load weights into it.
            # If we use torch.hub.load, it returns the model object directly.
            model = torch.hub.load("hugoycj/DSINE-hub", "DSINE", trust_repo=True)
            model.to(device)
            model.eval()
            return model
        except Exception as e:
            print(f"Failed to load DSINE from hub: {e}")
            raise ValueError("Could not load DSINE model.")

def intrins_from_fov(new_fov, H, W, device):
    fov = torch.tensor(new_fov).to(device)
    f = 0.5 * W / torch.tan(0.5 * fov * np.pi / 180.0)
    cx = 0.5 * W
    cy = 0.5 * H
    intrins = torch.tensor([[f, 0, cx], [0, f, cy], [0, 0, 1]]).to(device)
    return intrins

def estimate_normal(image, model, device='cuda'):
    # image: PIL Image RGB
    w, h = image.size
    
    # Prepare input
    im_tensor = torch.from_numpy(np.array(image)).float() / 255.0
    im_tensor = im_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
    
    # Normalize
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    im_tensor = normalize(im_tensor)
    
    # Pad
    pad_h = (32 - h % 32) % 32
    pad_w = (32 - w % 32) % 32
    im_tensor = F.pad(im_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0)
    
    # Intrinsics (assume 60 deg FOV)
    intrins = intrins_from_fov(60.0, h, w, device).unsqueeze(0)
    intrins[:, 0, 2] += 0 # No left padding
    intrins[:, 1, 2] += 0 # No top padding
    
    with torch.no_grad():
        pred_norm = model(im_tensor, intrins=intrins)[-1]
        
    # Crop padding
    pred_norm = pred_norm[:, :, :h, :w]

    # Revert the X axis
    pred_norm[:, 0, :, :] = -pred_norm[:, 0, :, :]
    
    # Convert to [0, 1]
    pred_norm = (pred_norm + 1) / 2.0
    
    return pred_norm # (1, 3, H, W)

def preprocess_image(input_image, dsine_model=None, device='cuda'):
    # 1. DSINE Normal Estimation on Original Image
    input_rgb = input_image.convert('RGB')
    if dsine_model is not None:
        normal_tensor = estimate_normal(input_rgb, dsine_model, device) # (1, 3, H, W)
        normal_np = normal_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() # (H, W, 3)
        normal_image = Image.fromarray((normal_np * 255).astype(np.uint8))
    else:
        normal_image = Image.new('RGB', input_image.size, (128, 128, 255))

    has_alpha = False
    if input_image.mode == 'RGBA':
        alpha = np.array(input_image)[:, :, 3]
        if not np.all(alpha == 255):
            has_alpha = True
    if has_alpha:
        output = input_image
    else:
        input_image = input_image.convert('RGB')
        max_size = max(input_image.size)
        scale = min(1, 1024 / max_size)
        if scale < 1:
            input_image = input_image.resize((int(input_image.width * scale), int(input_image.height * scale)), Image.Resampling.LANCZOS)
            # Also resize normal image if we resized input
            normal_image = normal_image.resize((int(normal_image.width * scale), int(normal_image.height * scale)), Image.Resampling.LANCZOS)
        
        session = rembg.new_session('birefnet-general')
        output = rembg.remove(input_image, session=session)
        
    output_np = np.array(output)
    alpha = output_np[:, :, 3]
    bbox = np.argwhere(alpha > 0.8 * 255)
    if len(bbox) == 0:
        bbox = [0, 0, output.height, output.width]
        bbox_crop = (0, 0, output.width, output.height)
    else:
        bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
        center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
        size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
        size = int(size * 1.2)
        bbox_crop = (int(center[0] - size // 2), int(center[1] - size // 2), int(center[0] + size // 2), int(center[1] + size // 2))
    
    output = output.crop(bbox_crop)
    output = output.resize((518, 518), Image.Resampling.LANCZOS)
    output = np.array(output).astype(np.float32) / 255
    output = output[:, :, :3] * output[:, :, 3:4]
    output = Image.fromarray((output * 255).astype(np.uint8))

    # Process Normal
    normal_rgba = normal_image.convert('RGBA')
    
    # Create alpha mask image
    alpha_img = Image.fromarray(alpha)
    normal_rgba.putalpha(alpha_img)
    
    normal_crop = normal_rgba.crop(bbox_crop)
    normal_crop = normal_crop.resize((518, 518), Image.Resampling.LANCZOS)
    
    normal_np = np.array(normal_crop).astype(np.float32) / 255
    normal_np = normal_np[:, :, :3] * normal_np[:, :, 3:4]
    normal_output = Image.fromarray((normal_np * 255).astype(np.uint8))

    return output, normal_output

def encode_image(image, image_cond_model, device):
    transform = transforms.Compose([
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image_tensor = np.array(image.convert('RGB')).astype(np.float32) / 255
    image_tensor = torch.from_numpy(image_tensor).permute(2, 0, 1).float().unsqueeze(0).to(device)
    image_tensor = transform(image_tensor)
    
    with torch.no_grad():
        features = image_cond_model(image_tensor, is_training=True)['x_prenorm']
        patchtokens = F.layer_norm(features, features.shape[-1:])
    return patchtokens