| import warnings |
| import gradio as gr |
| import onnxruntime as ort |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import tempfile |
| import os |
| import time |
| import logging |
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.append(os.path.join(os.getcwd(), 'src')) |
|
|
| from sharp.utils import io |
| from sharp.utils.gaussians import Gaussians3D, save_ply, unproject_gaussians |
|
|
| warnings.filterwarnings("ignore") |
| logging.basicConfig(level=logging.INFO) |
| LOGGER = logging.getLogger(__name__) |
|
|
| SESSION = None |
| CURRENT_MODEL_PATH = None |
|
|
| def load_model(model_path): |
| global SESSION, CURRENT_MODEL_PATH |
| if SESSION is not None and CURRENT_MODEL_PATH == model_path: |
| return SESSION |
| |
| if not Path(model_path).exists(): |
| LOGGER.error(f"Model file not found: {model_path}") |
| return None |
| |
| try: |
| LOGGER.info(f"Loading model: {model_path}...") |
| options = ort.SessionOptions() |
| options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| options.intra_op_num_threads = os.cpu_count() or 4 |
| options.inter_op_num_threads = min(4, os.cpu_count() or 4) |
| options.enable_mem_pattern = True |
| options.enable_cpu_mem_arena = True |
| providers = ['CPUExecutionProvider'] |
| |
| |
| if SESSION is not None: |
| del SESSION |
| |
| SESSION = ort.InferenceSession(model_path, sess_options=options, providers=providers) |
| CURRENT_MODEL_PATH = model_path |
| LOGGER.info(f"Model loaded successfully: {model_path}") |
| return SESSION |
| except Exception as e: |
| LOGGER.error(f"Failed to load model {model_path}: {e}") |
| return None |
|
|
| def get_available_models(): |
| models = list(Path('.').glob('*.onnx')) |
| return [str(m) for m in models] |
|
|
| def process_image(image_filepath, opacity_threshold, downsample_rate): |
| if not image_filepath: |
| return None |
| if SESSION is None: |
| gr.Warning("Model not loaded. Using dummy processing or check console.") |
| return None |
| |
| start_time = time.perf_counter() |
| img, _, f_px = io.load_rgb(Path(image_filepath), auto_rotate=True, remove_alpha=True) |
| height, width = img.shape[:2] |
| image_pt = torch.from_numpy(img.copy()).float().permute(2, 0, 1) / 255.0 |
| disparity_factor = torch.tensor([f_px / width]).float() |
| internal_shape = (1536, 1536) |
| image_resized_pt = F.interpolate( |
| image_pt[None], size=(internal_shape[1], internal_shape[0]), mode="bilinear", align_corners=True |
| ) |
| model_inputs = SESSION.get_inputs() |
| if model_inputs[0].type == 'tensor(float16)': |
| image_resized_pt = image_resized_pt.half() |
| disparity_factor = disparity_factor.half() |
| inputs = {'image': image_resized_pt.numpy(), 'disparity_factor': disparity_factor.numpy()} |
| outputs = SESSION.run(None, inputs) |
| gaussians_ndc = Gaussians3D( |
| mean_vectors=torch.from_numpy(outputs[0]).float(), |
| singular_values=torch.from_numpy(outputs[1]).float(), |
| quaternions=torch.from_numpy(outputs[2]).float(), |
| colors=torch.from_numpy(outputs[3]).float(), |
| opacities=torch.from_numpy(outputs[4]).float() |
| ) |
| mask = gaussians_ndc.opacities[0] > opacity_threshold |
| sampler = slice(0, None, int(downsample_rate)) |
| def apply_mask_and_sampling(tensor): |
| return tensor[:, mask][:, sampler] |
| filtered_gaussians_ndc = Gaussians3D( |
| mean_vectors=apply_mask_and_sampling(gaussians_ndc.mean_vectors), |
| singular_values=apply_mask_and_sampling(gaussians_ndc.singular_values), |
| quaternions=apply_mask_and_sampling(gaussians_ndc.quaternions), |
| colors=apply_mask_and_sampling(gaussians_ndc.colors), |
| opacities=apply_mask_and_sampling(gaussians_ndc.opacities) |
| ) |
| intrinsics = torch.tensor([ |
| [f_px, 0, width / 2, 0], |
| [0, f_px, height / 2, 0], |
| [0, 0, 1, 0], |
| [0, 0, 0, 1], |
| ]).float() |
| intrinsics_resized = intrinsics.clone() |
| intrinsics_resized[0] *= internal_shape[0] / width |
| intrinsics_resized[1] *= internal_shape[1] / height |
| gaussians = unproject_gaussians(filtered_gaussians_ndc, torch.eye(4), intrinsics_resized, internal_shape) |
| out_dir = Path(tempfile.mkdtemp()) |
| out_path = out_dir / "output.ply" |
| save_ply(gaussians, f_px, (height, width), out_path) |
| return str(out_path) |
|
|
| custom_css = """ |
| body, .gradio-container { |
| background: radial-gradient(circle at top left, #0d0d12 0%, #000000 100%) !important; |
| color: #e0e0e0 !important; |
| font-family: 'Inter', system-ui, -apple-system, sans-serif !important; |
| margin: 0 !important; |
| padding: 0 !important; |
| } |
| .panel-box { |
| background: rgba(20, 20, 25, 0.8) !important; |
| backdrop-filter: blur(10px); |
| border: 1px solid rgba(255, 255, 255, 0.1) !important; |
| border-radius: 20px !important; |
| padding: 24px; |
| box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.8); |
| transition: all 0.3s ease; |
| margin-bottom: 16px; |
| } |
| #spark-container { |
| width: 100%; |
| height: 70vh; /* Responsive height */ |
| min-height: 400px; |
| max-height: 720px; |
| background: #000; |
| border-radius: 12px; |
| border: 1px solid rgba(255, 255, 255, 0.1); |
| position: relative; |
| overflow: hidden; |
| } |
| #generate-btn { |
| background: linear-gradient(135deg, #6366f1 0%, #a855f7 100%) !important; |
| color: white !important; |
| font-weight: 700 !important; |
| border-radius: 12px !important; |
| border: none !important; |
| margin-top: 10px; |
| padding: 16px 24px !important; /* Larger for touch */ |
| text-transform: uppercase; |
| letter-spacing: 1px; |
| font-size: 1.1rem !important; |
| transition: transform 0.2s, box-shadow 0.2s !important; |
| } |
| header h1 { |
| background: linear-gradient(to right, #fff, #a5a5a5); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| font-size: 2rem !important; |
| font-weight: 900 !important; |
| text-align: center; |
| margin: 20px 0 !important; |
| } |
| |
| /* Mobile Optimizations */ |
| @media (max-width: 768px) { |
| .panel-box { |
| padding: 16px; |
| border-radius: 16px !important; |
| } |
| #spark-container { |
| height: 50vh; /* Shorter on mobile to leave room for controls */ |
| min-height: 300px; |
| } |
| header h1 { |
| font-size: 1.5rem !important; |
| } |
| .gr-row { |
| flex-direction: column !important; |
| } |
| /* Make inputs full width on mobile */ |
| .gr-form { |
| width: 100% !important; |
| } |
| } |
| """ |
|
|
| head_content = """ |
| <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no"> |
| <link rel="preconnect" href="https://fonts.googleapis.com"> |
| <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> |
| <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;700;900&display=swap" rel="stylesheet"> |
| <script type="importmap"> |
| { |
| "imports": { |
| "three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.178.0/three.module.js", |
| "@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.10/spark.module.js" |
| } |
| } |
| </script> |
| <script type="module"> |
| import * as THREE from "three"; |
| import { OrbitControls } from "https://unpkg.com/three@0.178.0/examples/jsm/controls/OrbitControls.js"; |
| import { SplatMesh } from "@sparkjsdev/spark"; |
| |
| let renderer, scene, camera, controls, splat, container; |
| let startTime, timerInterval; |
| |
| window.initSpark = function() { |
| container = document.getElementById('spark-container'); |
| if (!container || window.sparkInitialized) return; |
| scene = new THREE.Scene(); |
| camera = new THREE.PerspectiveCamera(60, container.clientWidth / container.clientHeight, 0.1, 1000); |
| camera.position.set(0, 1, 4); |
| renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true, logarithmicDepthBuffer: true }); |
| renderer.setSize(container.clientWidth, container.clientHeight); |
| renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2)); |
| container.appendChild(renderer.domElement); |
| controls = new OrbitControls(camera, renderer.domElement); |
| controls.enableDamping = true; |
| function animate() { |
| requestAnimationFrame(animate); |
| controls.update(); |
| renderer.render(scene, camera); |
| } |
| animate(); |
| window.addEventListener('resize', () => { |
| if (!container) return; |
| camera.aspect = container.clientWidth / container.clientHeight; |
| camera.updateProjectionMatrix(); |
| renderer.setSize(container.clientWidth, container.clientHeight); |
| }); |
| window.sparkInitialized = true; |
| }; |
| |
| window.loadSplat = async function(url) { |
| if (!window.sparkInitialized) window.initSpark(); |
| if (splat) { scene.remove(splat); splat.dispose(); } |
| try { |
| splat = new SplatMesh({ url: url }); |
| splat.rotation.x = Math.PI; |
| scene.add(splat); |
| setTimeout(window.focusModel, 500); |
| } catch (e) { console.error(e); } |
| }; |
| |
| window.focusModel = function() { |
| if (!splat || !controls || !camera) return; |
| const box = new THREE.Box3(); |
| let pointsFound = 0; |
| splat.traverse((obj) => { |
| if (obj.geometry && obj.geometry.attributes.position) { |
| const pos = obj.geometry.attributes.position; |
| const count = pos.count; |
| const step = Math.max(1, Math.floor(count / 5000)); |
| for (let i = 0; i < count; i += step) { |
| const p = new THREE.Vector3(pos.getX(i), pos.getY(i), pos.getZ(i)); |
| p.applyMatrix4(obj.matrixWorld); |
| box.expandByPoint(p); |
| } |
| pointsFound += count; |
| } |
| }); |
| let center = new THREE.Vector3(); |
| let size = new THREE.Vector3(); |
| if (pointsFound === 0 || box.isEmpty()) { center.set(0, 1.5, -3); size.set(2, 2, 2); } |
| else { box.getCenter(center); box.getSize(size); } |
| const maxDim = Math.max(size.x, size.y, size.z); |
| const fovRad = camera.fov * (Math.PI / 180); |
| let distance = (maxDim / 2) / Math.tan(fovRad / 2) * 1.5; |
| controls.target.copy(center); |
| camera.position.set(center.x, center.y, center.z + distance); |
| controls.update(); |
| }; |
| |
| function getBtn() { |
| return document.getElementById('generate-btn') || document.querySelector('#generate-btn button'); |
| } |
| |
| window.startTimer = function() { |
| const btn = getBtn(); |
| if (!btn) return; |
| btn.disabled = true; |
| btn.style.opacity = "0.6"; |
| btn.style.cursor = "wait"; |
| startTime = Date.now(); |
| timerInterval = setInterval(() => { |
| const elapsed = ((Date.now() - startTime) / 1000).toFixed(1); |
| btn.innerText = `Generating... ${elapsed}s`; |
| }, 100); |
| }; |
| |
| window.stopTimer = function() { |
| if (timerInterval) { |
| clearInterval(timerInterval); |
| const elapsed = ((Date.now() - startTime) / 1000).toFixed(1); |
| const btn = getBtn(); |
| if (btn) btn.innerText = `Done in ${elapsed}s`; |
| } |
| }; |
| |
| window.resetBtn = function(hasImage) { |
| const btn = getBtn(); |
| if (btn) { |
| btn.disabled = !hasImage; |
| btn.style.opacity = hasImage ? "1.0" : "0.5"; |
| btn.style.cursor = hasImage ? "pointer" : "default"; |
| btn.innerText = "Generate 3D Gaussians"; |
| } |
| if (timerInterval) clearInterval(timerInterval); |
| }; |
| </script> |
| """ |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", type=str, default="ml-sharp_int4.onnx") |
| parser.add_argument("--host", type=str, default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=7860) |
| parser.add_argument("--ssl_cert", type=str, default="cert.pem") |
| parser.add_argument("--ssl_key", type=str, default="key.pem") |
| args = parser.parse_args() |
| |
| |
| load_model(args.model) |
| |
| |
| with gr.Blocks(css=custom_css, theme=gr.themes.Default(), head=head_content, title="SHARP 3D Recon") as demo: |
| gr.HTML("<header><h1>SHARP 3D RECONSTRUCTION</h1></header>") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| with gr.Group(elem_classes="panel-box"): |
| available_models = get_available_models() |
| model_selector = gr.Dropdown( |
| choices=available_models, |
| value=args.model if args.model in available_models else (available_models[0] if available_models else None), |
| label="Select ONNX Model (Precision)", |
| interactive=True |
| ) |
| |
| input_image = gr.Image( |
| type="filepath", |
| label="Capture or Upload Image", |
| height=400, |
| sources=["upload", "webcam"] |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| opacity_val = gr.Slider(0.0, 1.0, value=0.0, label="Opacity Threshold") |
| downsample_val = gr.Slider(1, 10, step=1, value=1, label="Downsample Rate") |
| |
| submit_btn = gr.Button("Generate 3D Gaussians", variant="primary", elem_id="generate-btn", interactive=False) |
| |
| gr.Markdown("Capture a photo from your phone or upload an image to start the real-time 3D conversion.") |
| |
| with gr.Column(scale=2): |
| with gr.Group(elem_classes="panel-box"): |
| gr.HTML("<div id='spark-container'></div>") |
| output_file = gr.File(label="Output Model", visible=False) |
| |
| demo.load(fn=None, inputs=None, outputs=None, js="() => { setTimeout(window.initSpark, 500); window.resetBtn(false); }") |
| |
| model_selector.change(fn=load_model, inputs=[model_selector], outputs=None) |
| |
| input_image.change(fn=lambda x: x is not None, inputs=[input_image], outputs=None, js="(img) => { window.resetBtn(!!img); }") |
| |
| submit_btn.click( |
| fn=process_image, |
| inputs=[input_image, opacity_val, downsample_val], |
| outputs=[output_file], |
| js="(img, op, down) => { window.startTimer(); return [img, op, down]; }" |
| ) |
| |
| input_image.upload( |
| fn=process_image, |
| inputs=[input_image, opacity_val, downsample_val], |
| outputs=[output_file], |
| js="(img, op, down) => { window.startTimer(); return [img, op, down]; }" |
| ) |
| |
| output_file.change( |
| fn=None, |
| inputs=[output_file], |
| js="(f) => { window.stopTimer(); if (f && f.url) { window.loadSplat(f.url); } }" |
| ) |
|
|
| |
| demo.queue().launch( |
| server_name=args.host, |
| server_port=args.port, |
| share=False, |
| ssl_certfile=args.ssl_cert if os.path.exists(args.ssl_cert) else None, |
| ssl_keyfile=args.ssl_key if os.path.exists(args.ssl_key) else None, |
| ssl_verify=False |
| ) |
|
|