File size: 18,366 Bytes
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import cv2
import torch
import numpy as np
import sys
import shutil
from datetime import datetime
import glob
import gc
import time
from pathlib import Path
from argparse import ArgumentParser
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

sys.path.append("vggt/")

from visual_util import predictions_to_glb
from vggt.models.vggt import VGGT
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map

from rec_utils.datasets import ARKitDataset
from PIL import Image
from torchvision import transforms as TF

val_split = ['47334096', '47895367', '41125696', '41125756', '45662926', '47429925', '42898581', '48018972', '48018387', '44358455', '45261150', '42898538', '47430490', '47334109', '45663114', '42897508', '47430475', '47332901', '42899461', '45662942', '47331964', '47204552', '45261144', '41069021', '42899736', '42899737', '47430026', '48018566', '48458489', '42444955', '42446536', '47895341', '47430034', '45663154', '47430489', '42444950', '42898862', '44358451', '47331069', '41254405', '42445028', '44358448', '48458481', '47895771', '47204566', '42898508', '47331990', '47332911', '48018358', '44358498', '41159519', '45260905', '42898854', '42446533', '47115548', '45261581', '45260899', '48018346', '47333940', '47332908', '48018386', '42897559', '42445022', '42897696', '42897541', '42446529', '47333927', '47331061', '45261190', '47331063', '41159558', '47429995', '47334110', '47333934', '47332905', '48018356', '42444953', '47334241', '47332895', '47895740', '47331333', '42446038', '42446156', '48458663', '48458657', '48458660', '47333924', '45260928', '47895536', '41125760', '42899691', '41254246', '42445991', '42445441', '45662987', '47334234', '47334367', '47430424', '44358442', '47430045', '45663105', '42897550', '47430005', '41254412', '44358532', '47331311', '42898816', '47895736', '47895738', '48458667', '47332893', '42899612', '47204605', '41142278', '42446517', '42446079', '41159553', '42899726', '42898574', '47115469', '47331963', '42899700', '47334237', '47430048', '48018957', '47334117', '42446540', '44358536', '42444954', '41125722', '41159504', '47430047', '41159566', '42897651', '47333456', '47331068', '42446519', '47333923', '47895739', '47430483', '45261142', '47430470', '45662970', '47334105', '47429922', '48018962', '41142281', '47895745', '42446546', '42897678', '47204554', '47331334', '42897667', '42897629', '42899720', '41159557', '47895556', '42897521', '42898486', '45663113', '47334093', '42899714', '45662944', '48458465', '42446137', '48458473', '45260898', '42445429', '47430036', '48458430', '47204559', '42898544', '47895353', '42899685', '44358505', '47430051', '45260900', '42899698', '47331316', '45260914', '48018572', '47333918', '47334238', '42899723', '44358513', '42899620', '47115460', '45261619', '47429912', '41159571', '47334362', '48458654', '42446163', '41254269', '45662975', '47331644', '41159530', '44358499', '47204609', '47333431', '41159555', '47429987', '42899688', '45662921', '47332890', '47895374', '47430001', '45261587', '45260856', '47430038', '42897599', '47332885', '42899679', '44358435', '42445966', '47895348', '48018353', '47895357', '47204573', '47333452', '45663115', '48458424', '42444976', '42444968', '42897564', '47331336', '42445448', '45260854', '42898527', '47334379', '45260925', '47430023', '47331662', '45662983', '42898826', '42899694', '42899617', '45662924', '42446049', '42899717', '48458650', '41069046', '42899699', '41254435', '47331972', '47895750', '47331339', '42446165', '41159525', '47895547', '47332899', '47895541', '42445031', '47895365', '42446535', '42899739', '45261631', '47333925', '47895554', '47430485', '47115463', '42897695', '47430468', '47333916', '47895776', '42899471', '44358446', '47334360', '47334381', '42897552', '42898868', '47333436', '48018562', '42898519', '42899680', '41254402', '47334256', '42897692', '42899725', '47331653', '41254400', '42445026', '45261588', '42899734', '45662943', '47334120', '47331314', '48018737', '48458472', '47331971', '45261193', '42446016', '45260920', '48018571', '42446056', '47333443', '41069025', '42897549', '44358515', '47115526', '42897688', '48458417', '47115474', '47430024', '47332916', '42898554', '48018732', '48018375', '47331989', '47115452', '45261615', '47334103', '41159572', '41159508', '42446541', '47115529', '44358440', '47115550', '45663165', '47895779', '47334240', '47331646', '48018970', '47430002', '42446527', '47334102', '47332000', '47895783', '47895542', '48458747', '42898570', '47331337', '42899613', '48018345', '48458665', '42446083', '41254382', '41125731', '48458732', '44358518', '42899696', '42897504', '41069051', '48018368', '48018741', '47429971', '47331266', '42897528', '42445981', '45663107', '42897501', '47895534', '42445029', '47430471', '47333440', '42445988', '45260903', '41159540', '42897566', '48458456', '47331651', '47332910', '47333904', '42445021', '45261575', '47895355', '45261140', '47331654', '47333920', '47895743', '45261143', '42898822', '47430479', '42446167', '47334361', '47334380', '45662981', '48018966', '44358436', '47334252', '41254432', '48458647', '48018560', '47334107', '47895549', '45261632', '45261128', '47895350', '44358538', '41159534', '42899611', '42898521', '47331988', '42899729', '48458656', '47115525', '42897538', '42897545', '47331970', '42897647', '42897554', '47430003', '47332904', '41159541', '48018379', '42897526', '41069043', '47331319', '47895371', '42446104', '41159538', '42898818', '48018956', '42899619', '48018381', '41069042', '48458735', '45261182', '42446151', '42898869', '47334368', '47333899', '47430033', '41125718', '47331645', '44358584', '48018739', '45261179', '47333931', '47333898', '42898817', '47332918', '45261121', '42446522', '45261637', '48018559', '45663164', '47332005', '41254386', '47331265', '45663175', '42898497', '48018367', '47429904', '41254262', '47115543', '41254425', '48458652', '42445984', '41069050', '48018960', '42898811', '41069048', '47895364', '48018382', '42446103', '48458427', '45260857', '42899731', '47895782', '47430419', '42446093', '47429913', '47332915', '44358452', '47333457', '47334091', '45261133', '42446532', '47895735', '47204607', '47204556', '47334115', '41254441', '42897561', '48458484', '47429998', '42446116', '47331071', '45261594', '47333937', '47204575', '47333932', '47331661', '47895732', '47332004', '42445998', '47429914', '44358582', '48018361', '47204563', '41125700', '42899690', '41159529', '41125763', '47115473', '48458415', '47204578', '47331668', '45261185', '47430043', '42446114', '47430422', '47331324', '42444949', '47334372', '45663150', '42444966', '42444946', '41125709', '48018360', '47429975', '42898867', '45261129', '47333435', '42899712', '48018730', '47429992', '42897542', '48018372', '41254398', '47429906', '41159503', '47332886', '42897672', '47331064', '47334239', '47333441', '45261181', '48018347', '45662979', '47895777', '45663149', '47895552', '47331974', '47331322', '47334254', '48458428', '42898849', '41142280', '44358583', '45261620', '47429977', '47430007', '42899459', '42446100', '45663099', '47331262', '47331331']

def load_and_preprocess_images(image_list, mode="crop"):
    """
    A quick start function to load and preprocess images for model input.
    This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.

    Args:
        image_path_list (list): List of paths to image files
        mode (str, optional): Preprocessing mode, either "crop" or "pad".
                             - "crop" (default): Sets width to 518px and center crops height if needed.
                             - "pad": Preserves all pixels by making the largest dimension 518px
                               and padding the smaller dimension to reach a square shape.

    Returns:
        torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)

    Raises:
        ValueError: If the input list is empty or if mode is invalid

    Notes:
        - Images with different dimensions will be padded with white (value=1.0)
        - A warning is printed when images have different shapes
        - When mode="crop": The function ensures width=518px while maintaining aspect ratio
          and height is center-cropped if larger than 518px
        - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
          and the smaller dimension is padded to reach a square shape (518x518)
        - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
    """
    # Check for empty list
    if len(image_list) == 0:
        raise ValueError("At least 1 image is required")

    # Validate mode
    if mode not in ["crop", "pad"]:
        raise ValueError("Mode must be either 'crop' or 'pad'")

    images = []
    shapes = set()
    to_tensor = TF.ToTensor()
    target_size = 518

    # First process all images and collect their shapes
    for image in image_list:
        # Open image
        img = Image.fromarray(image)

        # If there's an alpha channel, blend onto white background:
        if img.mode == "RGBA":
            # Create white background
            background = Image.new("RGBA", img.size, (255, 255, 255, 255))
            # Alpha composite onto the white background
            img = Image.alpha_composite(background, img)

        # Now convert to "RGB" (this step assigns white for transparent areas)
        img = img.convert("RGB")

        width, height = img.size

        if mode == "pad":
            # Make the largest dimension 518px while maintaining aspect ratio
            if width >= height:
                new_width = target_size
                new_height = round(height * (new_width / width) / 14) * 14  # Make divisible by 14
            else:
                new_height = target_size
                new_width = round(width * (new_height / height) / 14) * 14  # Make divisible by 14
        else:  # mode == "crop"
            # Original behavior: set width to 518px
            new_width = target_size
            # Calculate height maintaining aspect ratio, divisible by 14
            new_height = round(height * (new_width / width) / 14) * 14

        # Resize with new dimensions (width, height)
        img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
        img = to_tensor(img)  # Convert to tensor (0, 1)

        # Center crop height if it's larger than 518 (only in crop mode)
        # if mode == "crop" and new_height > target_size:
        #     start_y = (new_height - target_size) // 2
        #     img = img[:, start_y : start_y + target_size, :]

        # For pad mode, pad to make a square of target_size x target_size
        if mode == "pad":
            h_padding = target_size - img.shape[1]
            w_padding = target_size - img.shape[2]

            if h_padding > 0 or w_padding > 0:
                pad_top = h_padding // 2
                pad_bottom = h_padding - pad_top
                pad_left = w_padding // 2
                pad_right = w_padding - pad_left

                # Pad with white (value=1.0)
                img = torch.nn.functional.pad(
                    img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
                )

        shapes.add((img.shape[1], img.shape[2]))
        images.append(img)

    # Check if we have different shapes
    # In theory our model can also work well with different shapes
    if len(shapes) > 1:
        print(f"Warning: Found images with different shapes: {shapes}")
        # Find maximum dimensions
        max_height = max(shape[0] for shape in shapes)
        max_width = max(shape[1] for shape in shapes)

        # Pad images if necessary
        padded_images = []
        for img in images:
            h_padding = max_height - img.shape[1]
            w_padding = max_width - img.shape[2]

            if h_padding > 0 or w_padding > 0:
                pad_top = h_padding // 2
                pad_bottom = h_padding - pad_top
                pad_left = w_padding // 2
                pad_right = w_padding - pad_left

                img = torch.nn.functional.pad(
                    img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
                )
            padded_images.append(img)
        images = padded_images

    images = torch.stack(images)  # concatenate images

    # Ensure correct shape when single image
    if len(image_list) == 1:
        # Verify shape is (1, C, H, W)
        if images.dim() == 3:
            images = images.unsqueeze(0)

    return images



# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
def run_model(model, scene, device, max_images) -> dict:
    """
    Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
    """
    

    if not torch.cuda.is_available():
        raise ValueError("CUDA is not available. Check your environment.")

    scene.filter_valid_poses()


    print(f"Found {len(scene.images)} images")
    frames = scene.frames
    if len(scene.images) == 0:
        raise ValueError(f"No images found at {scene.id}. Check your upload.")
    if len(scene) > max_images:
        print(f"Downsampling {len(scene)} images to {max_images} images")
        frames = [scene.frames[i] for i in np.linspace(0, len(scene) - 1, max_images).round().astype(int)]
    


    images = load_and_preprocess_images([frame.image for frame in frames]).to(device)
    print(f"Preprocessed images shape: {images.shape}")

    # Run inference
    print("Running inference...")
    dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=dtype):
            predictions = model(images)

    # Convert pose encoding to extrinsic and intrinsic matrices
    print("Converting pose encoding to extrinsic and intrinsic matrices...")
    extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
    predictions["poses"] = extrinsic
    predictions["Ks"] = intrinsic

    # Convert tensors to numpy
    for key in predictions.keys():
        if isinstance(predictions[key], torch.Tensor):
            predictions[key] = predictions[key].cpu().numpy().squeeze(0)  # remove batch dimension

    # Generate world points from depth map
    # print("Computing world points from depth map...")
    # depth_map = predictions["depth"]  # (S, H, W, 1)
    # world_points = unproject_depth_map_to_point_map(depth_map, predictions["poses"], predictions["Ks"])
    # predictions["world_points_from_depth"] = world_points

    # Clean up
    torch.cuda.empty_cache()
    predictions["image_names"] = [frame.image_path for frame in frames]
    return predictions

def process_scene(
    model,
    scene_name,
    scene,
    output_dir,
    device,
    max_images=10000,
    force=False
):
    """
    Perform reconstruction using the already-created target_dir/images.
    """

    if not force and (output_dir / "predictions.npz").exists():
        print(f"Skipping scene {scene_name} because it already exists")
        return

    start_time = time.time()
    gc.collect()
    torch.cuda.empty_cache()


    print("Running run_model...")
    with torch.no_grad():
        predictions = run_model(model, scene, device, max_images)

    # Save predictions

    del predictions["images"]
    
    np.savez(output_dir / "predictions.npz", **predictions)

    del predictions
    gc.collect()
    torch.cuda.empty_cache()

    end_time = time.time()

import pickle

val_path = Path("../") / "Indoor/OKNO/data/arkitscenes/arkitscenes_offline_infos_train.pkl"
out_dir = Path("data/arkit_gt/processed")
with open(val_path, "rb") as f:
    data = pickle.load(f)

data_list = data["data_list"]
val_split = [scene["lidar_points"]["lidar_path"] for scene in data_list][:2500]
val_split = [a.split("_")[0] for a in val_split]
print(val_split)
if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--scene_names", nargs="+", default=val_split)
    parser.add_argument("--input_dir", type=str, default='/workspace-SR006.nfs2/datasets/arkitscenes/offline_prepared_data/posed_images/')
    parser.add_argument("--output_dir", type=str, default='output/arkit_new')
    parser.add_argument("--max_images", type=int, default=100)
    parser.add_argument("--conf_thres", type=float, default=3.0)
    parser.add_argument("--job_num", "-n", type=int, default=1)
    parser.add_argument("--job_id", "-i", type=int, default=0)
    parser.add_argument("--device", type=str, default="2")
    parser.add_argument("--force", action="store_true")
    args = parser.parse_args()

    model = VGGT()
    _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
    model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
    model.eval()

    scene_names = args.scene_names[args.job_id::args.job_num]
    # scene_names = ['47334096']
    device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"

    model = model.to(device)
    from datetime import datetime
    errors_path = Path(f"logs/errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")

    dataset = ARKitDataset(args.input_dir)
    for scene_name in tqdm(scene_names):
        print(f"Processing scene {scene_name}")
        try:

            scene = dataset[scene_name]
            output_dir = Path(args.output_dir) / scene_name
            output_dir.mkdir(parents=True, exist_ok=True)
            process_scene(model, scene_name, scene, output_dir, 
            device=device, max_images=args.max_images, force=args.force)
        except Exception as e:
            print(f"Error processing scene {scene_name}: {e}")
            errors_path.parent.mkdir(parents=True, exist_ok=True)
            with open(errors_path, "a") as f:
                f.write(f"{scene_name}\n")