| import subprocess |
| import os |
| import torch |
| import sys |
|
|
| def install_cuda_toolkit(): |
| |
| CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run" |
| CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) |
| subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) |
| subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) |
| subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) |
|
|
| os.environ["CUDA_HOME"] = "/usr/local/cuda" |
| os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) |
| os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( |
| os.environ["CUDA_HOME"], |
| "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], |
| ) |
| |
| os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| os.system('pip install iopath') |
| |
|
|
| pyt_version_str=torch.__version__.split("+")[0].replace(".", "") |
| version_str="".join([ |
| f"py3{sys.version_info.minor}_cu", |
| torch.version.cuda.replace(".",""), |
| f"_pyt{pyt_version_str}" |
| ]) |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import spaces |
| import mast3r.utils.path_to_dust3r |
| import dust3r.utils.path_to_croco |
| import mast3r.utils.path_to_dust3r |
| import sys |
| import os.path as path |
| import torch |
| import tempfile |
| import gradio |
| import shutil |
| import math |
| from mast3r.model import AsymmetricMASt3R |
| import matplotlib.pyplot as pl |
| from dust3r.utils.image import load_images |
| import torch.nn.functional as F |
| from dust3r.utils.geometry import xy_grid |
| import numpy as np |
| import cv2 |
| from dust3r.utils.device import to_numpy |
| import trimesh |
| from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes |
| from scipy.spatial.transform import Rotation |
|
|
|
|
| pl.ion() |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| batch_size = 1 |
| inf = float('inf') |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| model = AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True, desc_conf_mode=('exp', 0, inf)) |
| model = AsymmetricMASt3R.from_pretrained("zhang3z/FLARE").to(device) |
| |
| model = model.to(device).eval() |
|
|
| image_size = 512 |
| silent = True |
| gradio_delete_cache = 7200 |
| backbone = torch.hub.load( |
| "facebookresearch/dinov2", "dinov2_vitb14_reg" |
| ) |
| backbone = backbone.eval().cuda() |
|
|
| class FileState: |
| def __init__(self, outfile_name=None): |
| self.outfile_name = outfile_name |
|
|
| def __del__(self): |
| if self.outfile_name is not None and os.path.isfile(self.outfile_name): |
| os.remove(self.outfile_name) |
| self.outfile_name = None |
|
|
| def pad_to_square(reshaped_image): |
| B, C, H, W = reshaped_image.shape |
| max_dim = max(H, W) |
| pad_height = max_dim - H |
| pad_width = max_dim - W |
| padding = (pad_width // 2, pad_width - pad_width // 2, |
| pad_height // 2, pad_height - pad_height // 2) |
| padded_image = F.pad(reshaped_image, padding, mode='constant', value=0) |
| return padded_image |
|
|
| def generate_rank_by_dino( |
| reshaped_image, backbone, query_frame_num, image_size=336 |
| ): |
| |
| |
| rgbs = pad_to_square(reshaped_image) |
| rgbs = F.interpolate( |
| reshaped_image, |
| (image_size, image_size), |
| mode="bilinear", |
| align_corners=True, |
| ) |
| rgbs = _resnet_normalize_image(rgbs.cuda()) |
|
|
| |
| frame_feat = backbone(rgbs, is_training=True) |
| frame_feat = frame_feat["x_norm_patchtokens"] |
| frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) |
|
|
| |
| frame_feat_norm = frame_feat_norm.permute(1, 0, 2) |
| similarity_matrix = torch.bmm( |
| frame_feat_norm, frame_feat_norm.transpose(-1, -2) |
| ) |
| similarity_matrix = similarity_matrix.mean(dim=0) |
| distance_matrix = 100 - similarity_matrix.clone() |
|
|
| |
| similarity_matrix.fill_diagonal_(-100) |
|
|
| similarity_sum = similarity_matrix.sum(dim=1) |
|
|
| |
| most_common_frame_index = torch.argmax(similarity_sum).item() |
| return most_common_frame_index |
|
|
| _RESNET_MEAN = [0.485, 0.456, 0.406] |
| _RESNET_STD = [0.229, 0.224, 0.225] |
| _resnet_mean = torch.tensor(_RESNET_MEAN).view(1, 3, 1, 1).cuda() |
| _resnet_std = torch.tensor(_RESNET_STD).view(1, 3, 1, 1).cuda() |
| def _resnet_normalize_image(img: torch.Tensor) -> torch.Tensor: |
| return (img - _resnet_mean) / _resnet_std |
|
|
| def calculate_index_mappings(query_index, S, device=None): |
| """ |
| Construct an order that we can switch [query_index] and [0] |
| so that the content of query_index would be placed at [0] |
| """ |
| new_order = torch.arange(S) |
| new_order[0] = query_index |
| new_order[query_index] = 0 |
| if device is not None: |
| new_order = new_order.to(device) |
| return new_order |
|
|
| def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05, |
| cam_color=None, as_pointcloud=False, |
| transparent_cams=False, silent=False): |
| assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) |
| pts3d = to_numpy(pts3d) |
| imgs = to_numpy(imgs) |
| focals = to_numpy(focals) |
| mask = to_numpy(mask) |
| cams2world = to_numpy(cams2world) |
|
|
| scene = trimesh.Scene() |
| |
| if as_pointcloud: |
| pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]).reshape(-1, 3) |
| col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3) |
| valid_msk = np.isfinite(pts.sum(axis=1)) |
| pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk]) |
| scene.add_geometry(pct) |
| else: |
| meshes = [] |
| for i in range(len(imgs)): |
| pts3d_i = pts3d[i].reshape(imgs[i].shape) |
| msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1)) |
| meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i)) |
| mesh = trimesh.Trimesh(**cat_meshes(meshes)) |
| scene.add_geometry(mesh) |
|
|
| |
| for i, pose_c2w in enumerate(cams2world): |
| if isinstance(cam_color, list): |
| camera_edge_color = cam_color[i] |
| else: |
| camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] |
| add_scene_cam(scene, pose_c2w, camera_edge_color, |
| None if transparent_cams else imgs[i], focals[i], |
| imsize=imgs[i].shape[1::-1], screen_width=cam_size) |
|
|
| rot = np.eye(4) |
| rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() |
| scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) |
| if not silent: |
| print('(exporting 3D scene to', outfile, ')') |
|
|
| scene.export(file_obj=outfile) |
| return outfile |
|
|
|
|
| class FileState: |
| def __init__(self, outfile_name=None): |
| self.outfile_name = outfile_name |
|
|
| def __del__(self): |
| if self.outfile_name is not None and os.path.isfile(self.outfile_name): |
| os.remove(self.outfile_name) |
| self.outfile_name = None |
|
|
|
|
| @spaces.GPU(duration=180) |
| def local_get_reconstructed_scene(inputfiles, min_conf_thr, cam_size): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from pytorch3d.ops import knn_points |
| outdir = tempfile.mkdtemp(suffix='_FLARE_gradio_demo') |
| batch = load_images(inputfiles, size=image_size, verbose=not silent) |
| images = [gt['img'] for gt in batch] |
| images = torch.cat(images, dim=0) |
| images = images / 2 + 0.5 |
| index = generate_rank_by_dino(images, backbone, query_frame_num=1) |
| sorted_order = calculate_index_mappings(index, len(images), device=device) |
| sorted_batch = [] |
| for i in range(len(batch)): |
| sorted_batch.append(batch[sorted_order[i]]) |
| batch = sorted_batch |
| ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'rng', 'vid']) |
| ignore_dtype_keys = set(['true_shape', 'camera_pose', 'pts3d', 'fxfycxcy', 'img_org', 'camera_intrinsics', 'depthmap', 'depth_anything', 'fxfycxcy_unorm']) |
| dtype = torch.bfloat16 |
| for view in batch: |
| for name in view.keys(): |
| if name in ignore_keys: |
| continue |
| if isinstance(view[name], torch.Tensor): |
| view[name] = view[name].to(device, non_blocking=True) |
| else: |
| view[name] = torch.tensor(view[name]).to(device, non_blocking=True) |
| if view[name].dtype == torch.float32 and name not in ignore_dtype_keys: |
| view[name] = view[name].to(dtype) |
| view1 = batch[:1] |
| view2 = batch[1:] |
| with torch.cuda.amp.autocast(enabled=True, dtype=dtype): |
| pred1, pred2, pred_cameras = model(view1, view2, True, dtype) |
| pts3d = pred2['pts3d'] |
| conf = pred2['conf'] |
| pts3d = pts3d.detach().cpu() |
| B, N, H, W, _ = pts3d.shape |
| thres = torch.quantile(conf.flatten(2,3), min_conf_thr, dim=-1)[0] |
| masks_conf = conf > thres[None, :, None, None] |
| masks_conf = masks_conf.cpu() |
| |
| images = [view['img'] for view in view1+view2] |
| shape = torch.stack([view['true_shape'] for view in view1+view2], dim=1).detach().cpu().numpy() |
| images = torch.stack(images,1).float().permute(0,1,3,4,2).detach().cpu().numpy() |
| images = images / 2 + 0.5 |
| images = images.reshape(B, N, H, W, 3) |
| |
| images = images[0] |
| pts3d = pts3d[0] |
| masks_conf = masks_conf[0] |
| xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) |
| pp = torch.tensor((W/2, H/2)).to(xy_over_z) |
| pixels = xy_grid(W, H, device=xy_over_z.device).view(1, -1, 2) - pp.view(-1, 1, 2) |
| u, v = pixels[:1].unbind(dim=-1) |
| x, y, z = pts3d[:1].reshape(-1,3).unbind(dim=-1) |
| fx_votes = (u * z) / x |
| fy_votes = (v * z) / y |
| |
| f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) |
| focal = torch.nanmedian(f_votes, dim=-1).values |
| focal = focal.item() |
| pts3d = pts3d.numpy() |
| |
| pred_poses = [] |
| for i in range(pts3d.shape[0]): |
| shape_input_each = shape[:, i] |
| mesh_grid = xy_grid(shape_input_each[0,1], shape_input_each[0,0]) |
| cur_inlier = conf[0,i] > torch.quantile(conf[0,i], 0.6) |
| cur_inlier = cur_inlier.detach().cpu().numpy() |
| ransac_thres = 0.5 |
| confidence = 0.9999 |
| iterationsCount = 10_000 |
| cur_pts3d = pts3d[i] |
| K = np.float32([(focal, 0, W/2), (0, focal, H/2), (0, 0, 1)]) |
| success, r_pose, t_pose, _ = cv2.solvePnPRansac(cur_pts3d[cur_inlier].astype(np.float64), mesh_grid[cur_inlier].astype(np.float64), K, None, |
| flags=cv2.SOLVEPNP_SQPNP, |
| iterationsCount=iterationsCount, |
| reprojectionError=1, |
| confidence=confidence) |
| r_pose = cv2.Rodrigues(r_pose)[0] |
| RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]] |
| cam2world = np.linalg.inv(RT) |
| pred_poses.append(cam2world) |
| pred_poses = np.stack(pred_poses, axis=0) |
| pred_poses = torch.tensor(pred_poses) |
| |
| K = 10 |
| print('Cleaning point cloud with knn...') |
| points = torch.tensor(pts3d.reshape(1,-1,3)).cuda() |
| |
| |
| |
| |
| |
| |
| masks_conf = masks_conf > 0 |
| os.makedirs(outdir, exist_ok=True) |
| focals = [focal] * len(images) |
| outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir) |
|
|
| _convert_scene_output_to_glb(outfile_name, images, pts3d, masks_conf, focals, pred_poses, as_pointcloud=True, |
| transparent_cams=False, cam_size=cam_size, silent=silent) |
| return outfile_name |
|
|
| css = """.gradio-container {margin: 0 !important; min-width: 100%};""" |
| title = "FLARE Demo" |
| |
| |
| |
| with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo: |
| |
| gradio.HTML('<h2 style="text-align: center;">3D Reconstruction with FLARE</h2>') |
| with gradio.Column(): |
| inputfiles = gradio.File(file_count="multiple") |
| snapshot = gradio.Image(None, visible=False) |
| with gradio.Row(): |
| |
| min_conf_thr = gradio.Slider(label="min_conf_thr", value=0.1, minimum=0.0, maximum=1, step=0.05) |
| |
| cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001) |
| run_btn = gradio.Button("Run") |
| outmodel = gradio.Model3D() |
|
|
|
|
| run_btn.click(fn=local_get_reconstructed_scene, |
| inputs=[inputfiles, min_conf_thr, cam_size], |
| outputs=[outmodel]) |
|
|
| demo.launch(show_error=True, share=None, server_name=None, server_port=None) |
| shutil.rmtree(tmpdirname) |