splatt3r / demo.py
dmfenton's picture
Force weights_only=False in torch.load patch
4ea5669
#!/usr/bin/env python3
# The MASt3R Gradio demo, modified for predicting 3D Gaussian Splats
# --- Original License ---
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# MUST BE FIRST: Fix for PyTorch 2.6+ weights_only default change
import torch
# Try safe_globals approach first (PyTorch 2.6+)
try:
import omegaconf
# Add all omegaconf classes that might be in the checkpoint
safe_classes = [
omegaconf.DictConfig,
omegaconf.ListConfig,
omegaconf.base.ContainerMetadata,
]
# Try to get ValueNode and other internal classes
for name in ['ValueNode', 'AnyNode', 'StringNode', 'IntegerNode', 'FloatNode', 'BooleanNode']:
try:
cls = getattr(omegaconf.nodes, name, None)
if cls:
safe_classes.append(cls)
except:
pass
torch.serialization.add_safe_globals(safe_classes)
print(f"[INFO] Added {len(safe_classes)} omegaconf classes to safe globals")
except Exception as e:
print(f"[WARNING] Could not add safe globals: {e}")
# Force weights_only=False since Lightning explicitly passes weights_only=True
import torch.serialization as _torch_ser
_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
# Always force weights_only=False for this trusted checkpoint
kwargs['weights_only'] = False
return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
_torch_ser.load = _patched_torch_load
print("[INFO] Patched torch.load to force weights_only=False")
import functools
import os
import sys
import tempfile
# import spaces # Not needed on dedicated GPU
import gradio
# torch already imported at top with patch applied
from huggingface_hub import hf_hub_download
sys.path.append('src/mast3r_src')
sys.path.append('src/mast3r_src/dust3r')
sys.path.append('src/pixelsplat_src')
from dust3r.utils.image import load_images
from mast3r.utils.misc import hash_md5
import main
import utils.export as export
# No @spaces.GPU decorator needed on dedicated GPU
def get_reconstructed_scene(outdir, weights_path, silent, image_size, img1, img2):
import traceback
try:
print(f"[DEBUG] Starting reconstruction...", flush=True)
print(f"[DEBUG] img1 type: {type(img1)}, img2 type: {type(img2)}", flush=True)
# Extract file paths from Image components
def get_path(img):
if img is None:
return None
if isinstance(img, str):
return img
if hasattr(img, 'name'):
return img.name
if isinstance(img, dict):
return img.get('path') or img.get('name') or img.get('url')
return str(img)
path1 = get_path(img1)
path2 = get_path(img2) if img2 is not None else path1
print(f"[DEBUG] Paths: {path1}, {path2}", flush=True)
if path1 is None:
raise ValueError("Please provide at least one image")
paths = [path1, path2 if path2 else path1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[DEBUG] Device: {device}", flush=True)
model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device)
print(f"[DEBUG] Model loaded", flush=True)
imgs = load_images(paths, size=image_size, verbose=not silent)
for img in imgs:
img['img'] = img['img'].to(device)
img['original_img'] = img['original_img'].to(device)
img['true_shape'] = torch.from_numpy(img['true_shape'])
model = model.to(device)
print(f"[DEBUG] Running model inference...", flush=True)
output = model(imgs[0], imgs[1])
print(f"[DEBUG] Model inference complete", flush=True)
pred1, pred2 = output
plyfile = os.path.join(outdir, 'gaussians.ply')
print(f"[DEBUG] Saving PLY to {plyfile}", flush=True)
export.save_as_ply(pred1, pred2, plyfile)
print(f"[DEBUG] PLY saved", flush=True)
return plyfile
except Exception as e:
print(f"[ERROR] Exception in get_reconstructed_scene: {e}", flush=True)
traceback.print_exc()
raise
def diagnose():
"""Diagnostic endpoint to check torch version and model loading."""
import traceback
results = []
# 1. Torch version
results.append(f"torch version: {torch.__version__}")
results.append(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
results.append(f"CUDA device: {torch.cuda.get_device_name(0)}")
# 2. Check if patch is applied
import inspect
load_source = inspect.getsourcefile(torch.load)
results.append(f"torch.load source: {load_source}")
# 3. Try to load the model
try:
model_name = "brandonsmart/splatt3r_v1.0"
filename = "epoch=19-step=1200.ckpt"
weights_path = hf_hub_download(repo_id=model_name, filename=filename)
results.append(f"Checkpoint downloaded to: {weights_path}")
# Try loading just the checkpoint first
results.append("Attempting raw torch.load...")
checkpoint = torch.load(weights_path, map_location='cpu')
results.append(f"Raw load succeeded! Keys: {list(checkpoint.keys())[:5]}...")
# Try loading with the model
results.append("Attempting model load_from_checkpoint...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device)
results.append("Model loaded successfully!")
except Exception as e:
results.append(f"ERROR: {type(e).__name__}: {e}")
results.append(traceback.format_exc())
return "\n".join(results)
if __name__ == '__main__':
image_size = 512
silent = False
model_name = "brandonsmart/splatt3r_v1.0"
filename = "epoch=19-step=1200.ckpt"
weights_path = hf_hub_download(repo_id=model_name, filename=filename)
chkpt_tag = hash_md5(weights_path)
# Define example inputs and their corresponding precalculated outputs
examples = [
["demo_examples/scannet++_1_img_1.jpg", "demo_examples/scannet++_1_img_2.jpg", "demo_examples/scannet++_1.ply"],
["demo_examples/scannet++_2_img_1.jpg", "demo_examples/scannet++_2_img_2.jpg", "demo_examples/scannet++_2.ply"],
["demo_examples/scannet++_3_img_1.jpg", "demo_examples/scannet++_3_img_2.jpg", "demo_examples/scannet++_3.ply"],
["demo_examples/scannet++_4_img_1.jpg", "demo_examples/scannet++_4_img_2.jpg", "demo_examples/scannet++_4.ply"],
["demo_examples/scannet++_5_img_1.jpg", "demo_examples/scannet++_5_img_2.jpg", "demo_examples/scannet++_5.ply"],
["demo_examples/scannet++_6_img_1.jpg", "demo_examples/scannet++_6_img_2.jpg", "demo_examples/scannet++_6.ply"],
["demo_examples/scannet++_7_img_1.jpg", "demo_examples/scannet++_7_img_2.jpg", "demo_examples/scannet++_7.ply"],
["demo_examples/scannet++_8_img_1.jpg", "demo_examples/scannet++_8_img_2.jpg", "demo_examples/scannet++_8.ply"],
["demo_examples/in_the_wild_1_img_1.jpg", "demo_examples/in_the_wild_1_img_2.jpg", "demo_examples/in_the_wild_1.ply"],
["demo_examples/in_the_wild_2_img_1.jpg", "demo_examples/in_the_wild_2_img_2.jpg", "demo_examples/in_the_wild_2.ply"],
["demo_examples/in_the_wild_3_img_1.jpg", "demo_examples/in_the_wild_3_img_2.jpg", "demo_examples/in_the_wild_3.ply"],
["demo_examples/in_the_wild_4_img_1.jpg", "demo_examples/in_the_wild_4_img_2.jpg", "demo_examples/in_the_wild_4.ply"],
["demo_examples/in_the_wild_5_img_1.jpg", "demo_examples/in_the_wild_5_img_2.jpg", "demo_examples/in_the_wild_5.ply"],
["demo_examples/in_the_wild_6_img_1.jpg", "demo_examples/in_the_wild_6_img_2.jpg", "demo_examples/in_the_wild_6.ply"],
["demo_examples/in_the_wild_7_img_1.jpg", "demo_examples/in_the_wild_7_img_2.jpg", "demo_examples/in_the_wild_7.ply"],
["demo_examples/in_the_wild_8_img_1.jpg", "demo_examples/in_the_wild_8_img_2.jpg", "demo_examples/in_the_wild_8.ply"],
]
for i in range(len(examples)):
for j in range(len(examples[i])):
examples[i][j] = hf_hub_download(repo_id=model_name, filename=examples[i][j])
with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname:
cache_path = os.path.join(tmpdirname, chkpt_tag)
os.makedirs(cache_path, exist_ok=True)
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, weights_path, silent, image_size)
# Don't modify examples for File component - just use img1, img2, output format
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
with gradio.Blocks(css=css, title="Splatt3R Demo") as demo:
gradio.HTML('<h2 style="text-align: center;">Splatt3R Demo</h2>')
# Diagnostic section
with gradio.Accordion("Diagnostics", open=False):
diag_btn = gradio.Button("Run Diagnostics")
diag_output = gradio.Textbox(label="Diagnostic Output", lines=20)
diag_btn.click(fn=diagnose, inputs=[], outputs=[diag_output], api_name="diagnose")
with gradio.Column():
gradio.Markdown('''
Upload two images to generate a 3D Gaussian splat.
Images will be cropped to squares for reconstruction.
''')
with gradio.Row():
img1 = gradio.Image(label="Image 1", type="filepath")
img2 = gradio.Image(label="Image 2 (optional)", type="filepath")
run_btn = gradio.Button("Generate Splat", variant="primary")
gradio.Markdown('''
## Output
Below we show the generated 3D Gaussian Splat.
The generated splats are 30-40MB, so please allow up to 30 seconds for them to be downloaded from Hugging Face before rendering.
As it downloads your previous generations may be visible.
The arrow in the top right of the window below can be used to download the .ply for rendering with other viewers,
such as [here](https://projects.markkellogg.org/threejs/demo_gaussian_splats_3d.php?art=1&cu=0,-1,0&cp=0,1,0&cla=1,0,0&aa=false&2d=false&sh=0) or [here](https://playcanvas.com/supersplat/editor).
''')
outmodel = gradio.Model3D(
clear_color=[1.0, 1.0, 1.0, 0.0],
)
run_btn.click(fn=recon_fun, inputs=[img1, img2], outputs=[outmodel], api_name="predict")
gradio.Markdown('''
## Examples
A gallery of examples generated from ScanNet++ and from 'in the wild' images taken with a mobile phone.
These examples are 30-40MB, so please allow up to 30 seconds for them to be downloaded from Hugging Face before rendering.
As it downloads your previous generations may be visible.
''')
snapshot_1 = gradio.Image(None, visible=False)
snapshot_2 = gradio.Image(None, visible=False)
# Examples are pre-computed, just display images and load outputs
gradio.Examples(
examples=examples,
inputs=[snapshot_1, snapshot_2, outmodel],
examples_per_page=5
)
demo.launch()