sharp-onnx-int8 / webui.py
Olsc's picture
更新代码
04b19b4 verified
import warnings
import gradio as gr
import onnxruntime as ort
import torch
import torch.nn.functional as F
import numpy as np
import tempfile
import os
import time
import logging
import argparse
import sys
from pathlib import Path
sys.path.append(os.path.join(os.getcwd(), 'src'))
from sharp.utils import io
from sharp.utils.gaussians import Gaussians3D, save_ply, unproject_gaussians
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)
SESSION = None
CURRENT_MODEL_PATH = None
def load_model(model_path):
global SESSION, CURRENT_MODEL_PATH
if SESSION is not None and CURRENT_MODEL_PATH == model_path:
return SESSION
if not Path(model_path).exists():
LOGGER.error(f"Model file not found: {model_path}")
return None
try:
LOGGER.info(f"Loading model: {model_path}...")
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
options.intra_op_num_threads = os.cpu_count() or 4
options.inter_op_num_threads = min(4, os.cpu_count() or 4)
options.enable_mem_pattern = True
options.enable_cpu_mem_arena = True
providers = ['CPUExecutionProvider']
# Free old session memory if possible
if SESSION is not None:
del SESSION
SESSION = ort.InferenceSession(model_path, sess_options=options, providers=providers)
CURRENT_MODEL_PATH = model_path
LOGGER.info(f"Model loaded successfully: {model_path}")
return SESSION
except Exception as e:
LOGGER.error(f"Failed to load model {model_path}: {e}")
return None
def get_available_models():
models = list(Path('.').glob('*.onnx'))
return [str(m) for m in models]
def process_image(image_filepath, opacity_threshold, downsample_rate):
if not image_filepath:
return None
if SESSION is None:
gr.Warning("Model not loaded. Using dummy processing or check console.")
return None
start_time = time.perf_counter()
img, _, f_px = io.load_rgb(Path(image_filepath), auto_rotate=True, remove_alpha=True)
height, width = img.shape[:2]
image_pt = torch.from_numpy(img.copy()).float().permute(2, 0, 1) / 255.0
disparity_factor = torch.tensor([f_px / width]).float()
internal_shape = (1536, 1536)
image_resized_pt = F.interpolate(
image_pt[None], size=(internal_shape[1], internal_shape[0]), mode="bilinear", align_corners=True
)
model_inputs = SESSION.get_inputs()
if model_inputs[0].type == 'tensor(float16)':
image_resized_pt = image_resized_pt.half()
disparity_factor = disparity_factor.half()
inputs = {'image': image_resized_pt.numpy(), 'disparity_factor': disparity_factor.numpy()}
outputs = SESSION.run(None, inputs)
gaussians_ndc = Gaussians3D(
mean_vectors=torch.from_numpy(outputs[0]).float(),
singular_values=torch.from_numpy(outputs[1]).float(),
quaternions=torch.from_numpy(outputs[2]).float(),
colors=torch.from_numpy(outputs[3]).float(),
opacities=torch.from_numpy(outputs[4]).float()
)
mask = gaussians_ndc.opacities[0] > opacity_threshold
sampler = slice(0, None, int(downsample_rate))
def apply_mask_and_sampling(tensor):
return tensor[:, mask][:, sampler]
filtered_gaussians_ndc = Gaussians3D(
mean_vectors=apply_mask_and_sampling(gaussians_ndc.mean_vectors),
singular_values=apply_mask_and_sampling(gaussians_ndc.singular_values),
quaternions=apply_mask_and_sampling(gaussians_ndc.quaternions),
colors=apply_mask_and_sampling(gaussians_ndc.colors),
opacities=apply_mask_and_sampling(gaussians_ndc.opacities)
)
intrinsics = torch.tensor([
[f_px, 0, width / 2, 0],
[0, f_px, height / 2, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]).float()
intrinsics_resized = intrinsics.clone()
intrinsics_resized[0] *= internal_shape[0] / width
intrinsics_resized[1] *= internal_shape[1] / height
gaussians = unproject_gaussians(filtered_gaussians_ndc, torch.eye(4), intrinsics_resized, internal_shape)
out_dir = Path(tempfile.mkdtemp())
out_path = out_dir / "output.ply"
save_ply(gaussians, f_px, (height, width), out_path)
return str(out_path)
custom_css = """
body, .gradio-container {
background: radial-gradient(circle at top left, #0d0d12 0%, #000000 100%) !important;
color: #e0e0e0 !important;
font-family: 'Inter', system-ui, -apple-system, sans-serif !important;
margin: 0 !important;
padding: 0 !important;
}
.panel-box {
background: rgba(20, 20, 25, 0.8) !important;
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.1) !important;
border-radius: 20px !important;
padding: 24px;
box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.8);
transition: all 0.3s ease;
margin-bottom: 16px;
}
#spark-container {
width: 100%;
height: 70vh; /* Responsive height */
min-height: 400px;
max-height: 720px;
background: #000;
border-radius: 12px;
border: 1px solid rgba(255, 255, 255, 0.1);
position: relative;
overflow: hidden;
}
#generate-btn {
background: linear-gradient(135deg, #6366f1 0%, #a855f7 100%) !important;
color: white !important;
font-weight: 700 !important;
border-radius: 12px !important;
border: none !important;
margin-top: 10px;
padding: 16px 24px !important; /* Larger for touch */
text-transform: uppercase;
letter-spacing: 1px;
font-size: 1.1rem !important;
transition: transform 0.2s, box-shadow 0.2s !important;
}
header h1 {
background: linear-gradient(to right, #fff, #a5a5a5);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 2rem !important;
font-weight: 900 !important;
text-align: center;
margin: 20px 0 !important;
}
/* Mobile Optimizations */
@media (max-width: 768px) {
.panel-box {
padding: 16px;
border-radius: 16px !important;
}
#spark-container {
height: 50vh; /* Shorter on mobile to leave room for controls */
min-height: 300px;
}
header h1 {
font-size: 1.5rem !important;
}
.gr-row {
flex-direction: column !important;
}
/* Make inputs full width on mobile */
.gr-form {
width: 100% !important;
}
}
"""
head_content = """
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;700;900&display=swap" rel="stylesheet">
<script type="importmap">
{
"imports": {
"three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.178.0/three.module.js",
"@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.10/spark.module.js"
}
}
</script>
<script type="module">
import * as THREE from "three";
import { OrbitControls } from "https://unpkg.com/three@0.178.0/examples/jsm/controls/OrbitControls.js";
import { SplatMesh } from "@sparkjsdev/spark";
let renderer, scene, camera, controls, splat, container;
let startTime, timerInterval;
window.initSpark = function() {
container = document.getElementById('spark-container');
if (!container || window.sparkInitialized) return;
scene = new THREE.Scene();
camera = new THREE.PerspectiveCamera(60, container.clientWidth / container.clientHeight, 0.1, 1000);
camera.position.set(0, 1, 4);
renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true, logarithmicDepthBuffer: true });
renderer.setSize(container.clientWidth, container.clientHeight);
renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
container.appendChild(renderer.domElement);
controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
function animate() {
requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
}
animate();
window.addEventListener('resize', () => {
if (!container) return;
camera.aspect = container.clientWidth / container.clientHeight;
camera.updateProjectionMatrix();
renderer.setSize(container.clientWidth, container.clientHeight);
});
window.sparkInitialized = true;
};
window.loadSplat = async function(url) {
if (!window.sparkInitialized) window.initSpark();
if (splat) { scene.remove(splat); splat.dispose(); }
try {
splat = new SplatMesh({ url: url });
splat.rotation.x = Math.PI;
scene.add(splat);
setTimeout(window.focusModel, 500);
} catch (e) { console.error(e); }
};
window.focusModel = function() {
if (!splat || !controls || !camera) return;
const box = new THREE.Box3();
let pointsFound = 0;
splat.traverse((obj) => {
if (obj.geometry && obj.geometry.attributes.position) {
const pos = obj.geometry.attributes.position;
const count = pos.count;
const step = Math.max(1, Math.floor(count / 5000));
for (let i = 0; i < count; i += step) {
const p = new THREE.Vector3(pos.getX(i), pos.getY(i), pos.getZ(i));
p.applyMatrix4(obj.matrixWorld);
box.expandByPoint(p);
}
pointsFound += count;
}
});
let center = new THREE.Vector3();
let size = new THREE.Vector3();
if (pointsFound === 0 || box.isEmpty()) { center.set(0, 1.5, -3); size.set(2, 2, 2); }
else { box.getCenter(center); box.getSize(size); }
const maxDim = Math.max(size.x, size.y, size.z);
const fovRad = camera.fov * (Math.PI / 180);
let distance = (maxDim / 2) / Math.tan(fovRad / 2) * 1.5;
controls.target.copy(center);
camera.position.set(center.x, center.y, center.z + distance);
controls.update();
};
function getBtn() {
return document.getElementById('generate-btn') || document.querySelector('#generate-btn button');
}
window.startTimer = function() {
const btn = getBtn();
if (!btn) return;
btn.disabled = true;
btn.style.opacity = "0.6";
btn.style.cursor = "wait";
startTime = Date.now();
timerInterval = setInterval(() => {
const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
btn.innerText = `Generating... ${elapsed}s`;
}, 100);
};
window.stopTimer = function() {
if (timerInterval) {
clearInterval(timerInterval);
const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
const btn = getBtn();
if (btn) btn.innerText = `Done in ${elapsed}s`;
}
};
window.resetBtn = function(hasImage) {
const btn = getBtn();
if (btn) {
btn.disabled = !hasImage;
btn.style.opacity = hasImage ? "1.0" : "0.5";
btn.style.cursor = hasImage ? "pointer" : "default";
btn.innerText = "Generate 3D Gaussians";
}
if (timerInterval) clearInterval(timerInterval);
};
</script>
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="ml-sharp_int4.onnx")
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--ssl_cert", type=str, default="cert.pem")
parser.add_argument("--ssl_key", type=str, default="key.pem")
args = parser.parse_args()
# Pre-load initial model
load_model(args.model)
# Define UI inside __main__ or pass args to it
with gr.Blocks(css=custom_css, theme=gr.themes.Default(), head=head_content, title="SHARP 3D Recon") as demo:
gr.HTML("<header><h1>SHARP 3D RECONSTRUCTION</h1></header>")
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="panel-box"):
available_models = get_available_models()
model_selector = gr.Dropdown(
choices=available_models,
value=args.model if args.model in available_models else (available_models[0] if available_models else None),
label="Select ONNX Model (Precision)",
interactive=True
)
input_image = gr.Image(
type="filepath",
label="Capture or Upload Image",
height=400,
sources=["upload", "webcam"]
)
with gr.Accordion("Advanced Settings", open=False):
opacity_val = gr.Slider(0.0, 1.0, value=0.0, label="Opacity Threshold")
downsample_val = gr.Slider(1, 10, step=1, value=1, label="Downsample Rate")
submit_btn = gr.Button("Generate 3D Gaussians", variant="primary", elem_id="generate-btn", interactive=False)
gr.Markdown("Capture a photo from your phone or upload an image to start the real-time 3D conversion.")
with gr.Column(scale=2):
with gr.Group(elem_classes="panel-box"):
gr.HTML("<div id='spark-container'></div>")
output_file = gr.File(label="Output Model", visible=False)
demo.load(fn=None, inputs=None, outputs=None, js="() => { setTimeout(window.initSpark, 500); window.resetBtn(false); }")
model_selector.change(fn=load_model, inputs=[model_selector], outputs=None)
input_image.change(fn=lambda x: x is not None, inputs=[input_image], outputs=None, js="(img) => { window.resetBtn(!!img); }")
submit_btn.click(
fn=process_image,
inputs=[input_image, opacity_val, downsample_val],
outputs=[output_file],
js="(img, op, down) => { window.startTimer(); return [img, op, down]; }"
)
input_image.upload(
fn=process_image,
inputs=[input_image, opacity_val, downsample_val],
outputs=[output_file],
js="(img, op, down) => { window.startTimer(); return [img, op, down]; }"
)
output_file.change(
fn=None,
inputs=[output_file],
js="(f) => { window.stopTimer(); if (f && f.url) { window.loadSplat(f.url); } }"
)
# Launch with HTTPS support
demo.queue().launch(
server_name=args.host,
server_port=args.port,
share=False,
ssl_certfile=args.ssl_cert if os.path.exists(args.ssl_cert) else None,
ssl_keyfile=args.ssl_key if os.path.exists(args.ssl_key) else None,
ssl_verify=False
)