| |
| |
|
|
| |
| |
| |
|
|
| |
| import torch |
|
|
| |
| try: |
| import omegaconf |
| |
| safe_classes = [ |
| omegaconf.DictConfig, |
| omegaconf.ListConfig, |
| omegaconf.base.ContainerMetadata, |
| ] |
| |
| for name in ['ValueNode', 'AnyNode', 'StringNode', 'IntegerNode', 'FloatNode', 'BooleanNode']: |
| try: |
| cls = getattr(omegaconf.nodes, name, None) |
| if cls: |
| safe_classes.append(cls) |
| except: |
| pass |
| torch.serialization.add_safe_globals(safe_classes) |
| print(f"[INFO] Added {len(safe_classes)} omegaconf classes to safe globals") |
| except Exception as e: |
| print(f"[WARNING] Could not add safe globals: {e}") |
|
|
| |
| import torch.serialization as _torch_ser |
| _original_torch_load = torch.load |
| def _patched_torch_load(*args, **kwargs): |
| |
| kwargs['weights_only'] = False |
| return _original_torch_load(*args, **kwargs) |
| torch.load = _patched_torch_load |
| _torch_ser.load = _patched_torch_load |
| print("[INFO] Patched torch.load to force weights_only=False") |
|
|
| import functools |
| import os |
| import sys |
| import tempfile |
|
|
| |
| import gradio |
| |
| from huggingface_hub import hf_hub_download |
|
|
| sys.path.append('src/mast3r_src') |
| sys.path.append('src/mast3r_src/dust3r') |
| sys.path.append('src/pixelsplat_src') |
| from dust3r.utils.image import load_images |
| from mast3r.utils.misc import hash_md5 |
| import main |
| import utils.export as export |
|
|
| |
| def get_reconstructed_scene(outdir, weights_path, silent, image_size, img1, img2): |
| import traceback |
| try: |
| print(f"[DEBUG] Starting reconstruction...", flush=True) |
| print(f"[DEBUG] img1 type: {type(img1)}, img2 type: {type(img2)}", flush=True) |
|
|
| |
| def get_path(img): |
| if img is None: |
| return None |
| if isinstance(img, str): |
| return img |
| if hasattr(img, 'name'): |
| return img.name |
| if isinstance(img, dict): |
| return img.get('path') or img.get('name') or img.get('url') |
| return str(img) |
|
|
| path1 = get_path(img1) |
| path2 = get_path(img2) if img2 is not None else path1 |
|
|
| print(f"[DEBUG] Paths: {path1}, {path2}", flush=True) |
|
|
| if path1 is None: |
| raise ValueError("Please provide at least one image") |
|
|
| paths = [path1, path2 if path2 else path1] |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"[DEBUG] Device: {device}", flush=True) |
| model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device) |
| print(f"[DEBUG] Model loaded", flush=True) |
|
|
| imgs = load_images(paths, size=image_size, verbose=not silent) |
|
|
| for img in imgs: |
| img['img'] = img['img'].to(device) |
| img['original_img'] = img['original_img'].to(device) |
| img['true_shape'] = torch.from_numpy(img['true_shape']) |
| model = model.to(device) |
|
|
| print(f"[DEBUG] Running model inference...", flush=True) |
| output = model(imgs[0], imgs[1]) |
| print(f"[DEBUG] Model inference complete", flush=True) |
|
|
| pred1, pred2 = output |
| plyfile = os.path.join(outdir, 'gaussians.ply') |
| print(f"[DEBUG] Saving PLY to {plyfile}", flush=True) |
| export.save_as_ply(pred1, pred2, plyfile) |
| print(f"[DEBUG] PLY saved", flush=True) |
| return plyfile |
| except Exception as e: |
| print(f"[ERROR] Exception in get_reconstructed_scene: {e}", flush=True) |
| traceback.print_exc() |
| raise |
|
|
| def diagnose(): |
| """Diagnostic endpoint to check torch version and model loading.""" |
| import traceback |
| results = [] |
|
|
| |
| results.append(f"torch version: {torch.__version__}") |
| results.append(f"CUDA available: {torch.cuda.is_available()}") |
| if torch.cuda.is_available(): |
| results.append(f"CUDA device: {torch.cuda.get_device_name(0)}") |
|
|
| |
| import inspect |
| load_source = inspect.getsourcefile(torch.load) |
| results.append(f"torch.load source: {load_source}") |
|
|
| |
| try: |
| model_name = "brandonsmart/splatt3r_v1.0" |
| filename = "epoch=19-step=1200.ckpt" |
| weights_path = hf_hub_download(repo_id=model_name, filename=filename) |
| results.append(f"Checkpoint downloaded to: {weights_path}") |
|
|
| |
| results.append("Attempting raw torch.load...") |
| checkpoint = torch.load(weights_path, map_location='cpu') |
| results.append(f"Raw load succeeded! Keys: {list(checkpoint.keys())[:5]}...") |
|
|
| |
| results.append("Attempting model load_from_checkpoint...") |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device) |
| results.append("Model loaded successfully!") |
|
|
| except Exception as e: |
| results.append(f"ERROR: {type(e).__name__}: {e}") |
| results.append(traceback.format_exc()) |
|
|
| return "\n".join(results) |
|
|
| if __name__ == '__main__': |
|
|
| image_size = 512 |
| silent = False |
|
|
| model_name = "brandonsmart/splatt3r_v1.0" |
| filename = "epoch=19-step=1200.ckpt" |
| weights_path = hf_hub_download(repo_id=model_name, filename=filename) |
| chkpt_tag = hash_md5(weights_path) |
|
|
| |
| examples = [ |
| ["demo_examples/scannet++_1_img_1.jpg", "demo_examples/scannet++_1_img_2.jpg", "demo_examples/scannet++_1.ply"], |
| ["demo_examples/scannet++_2_img_1.jpg", "demo_examples/scannet++_2_img_2.jpg", "demo_examples/scannet++_2.ply"], |
| ["demo_examples/scannet++_3_img_1.jpg", "demo_examples/scannet++_3_img_2.jpg", "demo_examples/scannet++_3.ply"], |
| ["demo_examples/scannet++_4_img_1.jpg", "demo_examples/scannet++_4_img_2.jpg", "demo_examples/scannet++_4.ply"], |
| ["demo_examples/scannet++_5_img_1.jpg", "demo_examples/scannet++_5_img_2.jpg", "demo_examples/scannet++_5.ply"], |
| ["demo_examples/scannet++_6_img_1.jpg", "demo_examples/scannet++_6_img_2.jpg", "demo_examples/scannet++_6.ply"], |
| ["demo_examples/scannet++_7_img_1.jpg", "demo_examples/scannet++_7_img_2.jpg", "demo_examples/scannet++_7.ply"], |
| ["demo_examples/scannet++_8_img_1.jpg", "demo_examples/scannet++_8_img_2.jpg", "demo_examples/scannet++_8.ply"], |
| ["demo_examples/in_the_wild_1_img_1.jpg", "demo_examples/in_the_wild_1_img_2.jpg", "demo_examples/in_the_wild_1.ply"], |
| ["demo_examples/in_the_wild_2_img_1.jpg", "demo_examples/in_the_wild_2_img_2.jpg", "demo_examples/in_the_wild_2.ply"], |
| ["demo_examples/in_the_wild_3_img_1.jpg", "demo_examples/in_the_wild_3_img_2.jpg", "demo_examples/in_the_wild_3.ply"], |
| ["demo_examples/in_the_wild_4_img_1.jpg", "demo_examples/in_the_wild_4_img_2.jpg", "demo_examples/in_the_wild_4.ply"], |
| ["demo_examples/in_the_wild_5_img_1.jpg", "demo_examples/in_the_wild_5_img_2.jpg", "demo_examples/in_the_wild_5.ply"], |
| ["demo_examples/in_the_wild_6_img_1.jpg", "demo_examples/in_the_wild_6_img_2.jpg", "demo_examples/in_the_wild_6.ply"], |
| ["demo_examples/in_the_wild_7_img_1.jpg", "demo_examples/in_the_wild_7_img_2.jpg", "demo_examples/in_the_wild_7.ply"], |
| ["demo_examples/in_the_wild_8_img_1.jpg", "demo_examples/in_the_wild_8_img_2.jpg", "demo_examples/in_the_wild_8.ply"], |
| ] |
|
|
| for i in range(len(examples)): |
| for j in range(len(examples[i])): |
| examples[i][j] = hf_hub_download(repo_id=model_name, filename=examples[i][j]) |
|
|
| with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname: |
|
|
| cache_path = os.path.join(tmpdirname, chkpt_tag) |
| os.makedirs(cache_path, exist_ok=True) |
|
|
| recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, weights_path, silent, image_size) |
|
|
| |
| css = """.gradio-container {margin: 0 !important; min-width: 100%};""" |
| with gradio.Blocks(css=css, title="Splatt3R Demo") as demo: |
|
|
| gradio.HTML('<h2 style="text-align: center;">Splatt3R Demo</h2>') |
|
|
| |
| with gradio.Accordion("Diagnostics", open=False): |
| diag_btn = gradio.Button("Run Diagnostics") |
| diag_output = gradio.Textbox(label="Diagnostic Output", lines=20) |
| diag_btn.click(fn=diagnose, inputs=[], outputs=[diag_output], api_name="diagnose") |
|
|
| with gradio.Column(): |
| gradio.Markdown(''' |
| Upload two images to generate a 3D Gaussian splat. |
| Images will be cropped to squares for reconstruction. |
| ''') |
| with gradio.Row(): |
| img1 = gradio.Image(label="Image 1", type="filepath") |
| img2 = gradio.Image(label="Image 2 (optional)", type="filepath") |
| run_btn = gradio.Button("Generate Splat", variant="primary") |
| gradio.Markdown(''' |
| ## Output |
| Below we show the generated 3D Gaussian Splat. |
| The generated splats are 30-40MB, so please allow up to 30 seconds for them to be downloaded from Hugging Face before rendering. |
| As it downloads your previous generations may be visible. |
| The arrow in the top right of the window below can be used to download the .ply for rendering with other viewers, |
| such as [here](https://projects.markkellogg.org/threejs/demo_gaussian_splats_3d.php?art=1&cu=0,-1,0&cp=0,1,0&cla=1,0,0&aa=false&2d=false&sh=0) or [here](https://playcanvas.com/supersplat/editor). |
| ''') |
| outmodel = gradio.Model3D( |
| clear_color=[1.0, 1.0, 1.0, 0.0], |
| ) |
| run_btn.click(fn=recon_fun, inputs=[img1, img2], outputs=[outmodel], api_name="predict") |
|
|
| gradio.Markdown(''' |
| ## Examples |
| A gallery of examples generated from ScanNet++ and from 'in the wild' images taken with a mobile phone. |
| These examples are 30-40MB, so please allow up to 30 seconds for them to be downloaded from Hugging Face before rendering. |
| As it downloads your previous generations may be visible. |
| ''') |
| |
| snapshot_1 = gradio.Image(None, visible=False) |
| snapshot_2 = gradio.Image(None, visible=False) |
| |
| gradio.Examples( |
| examples=examples, |
| inputs=[snapshot_1, snapshot_2, outmodel], |
| examples_per_page=5 |
| ) |
|
|
| demo.launch() |
|
|