import warnings import gradio as gr import onnxruntime as ort import torch import torch.nn.functional as F import numpy as np import tempfile import os import time import logging import argparse import sys from pathlib import Path sys.path.append(os.path.join(os.getcwd(), 'src')) from sharp.utils import io from sharp.utils.gaussians import Gaussians3D, save_ply, unproject_gaussians warnings.filterwarnings("ignore") logging.basicConfig(level=logging.INFO) LOGGER = logging.getLogger(__name__) SESSION = None CURRENT_MODEL_PATH = None def load_model(model_path): global SESSION, CURRENT_MODEL_PATH if SESSION is not None and CURRENT_MODEL_PATH == model_path: return SESSION if not Path(model_path).exists(): LOGGER.error(f"Model file not found: {model_path}") return None try: LOGGER.info(f"Loading model: {model_path}...") options = ort.SessionOptions() options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL options.intra_op_num_threads = os.cpu_count() or 4 options.inter_op_num_threads = min(4, os.cpu_count() or 4) options.enable_mem_pattern = True options.enable_cpu_mem_arena = True providers = ['CPUExecutionProvider'] # Free old session memory if possible if SESSION is not None: del SESSION SESSION = ort.InferenceSession(model_path, sess_options=options, providers=providers) CURRENT_MODEL_PATH = model_path LOGGER.info(f"Model loaded successfully: {model_path}") return SESSION except Exception as e: LOGGER.error(f"Failed to load model {model_path}: {e}") return None def get_available_models(): models = list(Path('.').glob('*.onnx')) return [str(m) for m in models] def process_image(image_filepath, opacity_threshold, downsample_rate): if not image_filepath: return None if SESSION is None: gr.Warning("Model not loaded. Using dummy processing or check console.") return None start_time = time.perf_counter() img, _, f_px = io.load_rgb(Path(image_filepath), auto_rotate=True, remove_alpha=True) height, width = img.shape[:2] image_pt = torch.from_numpy(img.copy()).float().permute(2, 0, 1) / 255.0 disparity_factor = torch.tensor([f_px / width]).float() internal_shape = (1536, 1536) image_resized_pt = F.interpolate( image_pt[None], size=(internal_shape[1], internal_shape[0]), mode="bilinear", align_corners=True ) model_inputs = SESSION.get_inputs() if model_inputs[0].type == 'tensor(float16)': image_resized_pt = image_resized_pt.half() disparity_factor = disparity_factor.half() inputs = {'image': image_resized_pt.numpy(), 'disparity_factor': disparity_factor.numpy()} outputs = SESSION.run(None, inputs) gaussians_ndc = Gaussians3D( mean_vectors=torch.from_numpy(outputs[0]).float(), singular_values=torch.from_numpy(outputs[1]).float(), quaternions=torch.from_numpy(outputs[2]).float(), colors=torch.from_numpy(outputs[3]).float(), opacities=torch.from_numpy(outputs[4]).float() ) mask = gaussians_ndc.opacities[0] > opacity_threshold sampler = slice(0, None, int(downsample_rate)) def apply_mask_and_sampling(tensor): return tensor[:, mask][:, sampler] filtered_gaussians_ndc = Gaussians3D( mean_vectors=apply_mask_and_sampling(gaussians_ndc.mean_vectors), singular_values=apply_mask_and_sampling(gaussians_ndc.singular_values), quaternions=apply_mask_and_sampling(gaussians_ndc.quaternions), colors=apply_mask_and_sampling(gaussians_ndc.colors), opacities=apply_mask_and_sampling(gaussians_ndc.opacities) ) intrinsics = torch.tensor([ [f_px, 0, width / 2, 0], [0, f_px, height / 2, 0], [0, 0, 1, 0], [0, 0, 0, 1], ]).float() intrinsics_resized = intrinsics.clone() intrinsics_resized[0] *= internal_shape[0] / width intrinsics_resized[1] *= internal_shape[1] / height gaussians = unproject_gaussians(filtered_gaussians_ndc, torch.eye(4), intrinsics_resized, internal_shape) out_dir = Path(tempfile.mkdtemp()) out_path = out_dir / "output.ply" save_ply(gaussians, f_px, (height, width), out_path) return str(out_path) custom_css = """ body, .gradio-container { background: radial-gradient(circle at top left, #0d0d12 0%, #000000 100%) !important; color: #e0e0e0 !important; font-family: 'Inter', system-ui, -apple-system, sans-serif !important; margin: 0 !important; padding: 0 !important; } .panel-box { background: rgba(20, 20, 25, 0.8) !important; backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.1) !important; border-radius: 20px !important; padding: 24px; box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.8); transition: all 0.3s ease; margin-bottom: 16px; } #spark-container { width: 100%; height: 70vh; /* Responsive height */ min-height: 400px; max-height: 720px; background: #000; border-radius: 12px; border: 1px solid rgba(255, 255, 255, 0.1); position: relative; overflow: hidden; } #generate-btn { background: linear-gradient(135deg, #6366f1 0%, #a855f7 100%) !important; color: white !important; font-weight: 700 !important; border-radius: 12px !important; border: none !important; margin-top: 10px; padding: 16px 24px !important; /* Larger for touch */ text-transform: uppercase; letter-spacing: 1px; font-size: 1.1rem !important; transition: transform 0.2s, box-shadow 0.2s !important; } header h1 { background: linear-gradient(to right, #fff, #a5a5a5); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2rem !important; font-weight: 900 !important; text-align: center; margin: 20px 0 !important; } /* Mobile Optimizations */ @media (max-width: 768px) { .panel-box { padding: 16px; border-radius: 16px !important; } #spark-container { height: 50vh; /* Shorter on mobile to leave room for controls */ min-height: 300px; } header h1 { font-size: 1.5rem !important; } .gr-row { flex-direction: column !important; } /* Make inputs full width on mobile */ .gr-form { width: 100% !important; } } """ head_content = """ """ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="ml-sharp_int4.onnx") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--ssl_cert", type=str, default="cert.pem") parser.add_argument("--ssl_key", type=str, default="key.pem") args = parser.parse_args() # Pre-load initial model load_model(args.model) # Define UI inside __main__ or pass args to it with gr.Blocks(css=custom_css, theme=gr.themes.Default(), head=head_content, title="SHARP 3D Recon") as demo: gr.HTML("

SHARP 3D RECONSTRUCTION

") with gr.Row(): with gr.Column(scale=1): with gr.Group(elem_classes="panel-box"): available_models = get_available_models() model_selector = gr.Dropdown( choices=available_models, value=args.model if args.model in available_models else (available_models[0] if available_models else None), label="Select ONNX Model (Precision)", interactive=True ) input_image = gr.Image( type="filepath", label="Capture or Upload Image", height=400, sources=["upload", "webcam"] ) with gr.Accordion("Advanced Settings", open=False): opacity_val = gr.Slider(0.0, 1.0, value=0.0, label="Opacity Threshold") downsample_val = gr.Slider(1, 10, step=1, value=1, label="Downsample Rate") submit_btn = gr.Button("Generate 3D Gaussians", variant="primary", elem_id="generate-btn", interactive=False) gr.Markdown("Capture a photo from your phone or upload an image to start the real-time 3D conversion.") with gr.Column(scale=2): with gr.Group(elem_classes="panel-box"): gr.HTML("
") output_file = gr.File(label="Output Model", visible=False) demo.load(fn=None, inputs=None, outputs=None, js="() => { setTimeout(window.initSpark, 500); window.resetBtn(false); }") model_selector.change(fn=load_model, inputs=[model_selector], outputs=None) input_image.change(fn=lambda x: x is not None, inputs=[input_image], outputs=None, js="(img) => { window.resetBtn(!!img); }") submit_btn.click( fn=process_image, inputs=[input_image, opacity_val, downsample_val], outputs=[output_file], js="(img, op, down) => { window.startTimer(); return [img, op, down]; }" ) input_image.upload( fn=process_image, inputs=[input_image, opacity_val, downsample_val], outputs=[output_file], js="(img, op, down) => { window.startTimer(); return [img, op, down]; }" ) output_file.change( fn=None, inputs=[output_file], js="(f) => { window.stopTimer(); if (f && f.url) { window.loadSplat(f.url); } }" ) # Launch with HTTPS support demo.queue().launch( server_name=args.host, server_port=args.port, share=False, ssl_certfile=args.ssl_cert if os.path.exists(args.ssl_cert) else None, ssl_keyfile=args.ssl_key if os.path.exists(args.ssl_key) else None, ssl_verify=False )