# FlashImgs Hugging Face Space # Copyright (C) 2025 OpsiClear # # This program (the Gradio wrapper) is free software: you can redistribute it # and/or modify it under the terms of the GNU Affero General Public License v3.0 # as published by the Free Software Foundation. See the LICENSE file for the full # text. It is distributed WITHOUT ANY WARRANTY. # # The bundled FlashImgs engine wheels (wheels/) are a separate work under the # OpsiClear Restrictive License and are NOT covered by the AGPL; see NOTICE. from __future__ import annotations import gc import math import os import shutil import struct import time import uuid from pathlib import Path import gradio as gr import numpy as np import torch from PIL import Image, ImageOps try: import spaces # Hugging Face ZeroGPU runtime except Exception: # effect-free fallback when not running on ZeroGPU class _SpacesStub: @staticmethod def GPU(*args, **kwargs): def _decorator(func): return func if args and callable(args[0]): return args[0] return _decorator spaces = _SpacesStub() os.environ.setdefault("FLASHIMGS_GAUSSIFIER_AOTI", "0") OUTPUT_ROOT = Path(os.environ.get("FLASHIMGS_OUTPUT_DIR", "/tmp/flashimgs_hf_outputs")) OUTPUT_ROOT.mkdir(parents=True, exist_ok=True) APP_CSS = """ #workspace { max-width: 1440px; margin: 0 auto; gap: 16px; align-items: flex-start; } .settings-panel { flex: 0 0 330px !important; max-width: 350px !important; min-width: 300px !important; } .results-panel { flex: 1 1 0 !important; min-width: min(760px, 100%) !important; } .preview-tabs, .preview-tabs > div { min-width: 0; } .download-row { gap: 12px; } @media (max-width: 920px) { #workspace { display: block; } .settings-panel, .results-panel { max-width: none !important; min-width: 0 !important; flex-basis: auto !important; } } """ def _load_flashimgs(): import flashimgs return flashimgs def _device_name() -> str: if not torch.cuda.is_available(): return "CPU" return torch.cuda.get_device_name(torch.cuda.current_device()) def _prepare_image(image: Image.Image, max_side: int) -> tuple[torch.Tensor, Image.Image, torch.Tensor | None, str]: if image is None: raise gr.Error("Upload an image first.") image = ImageOps.exif_transpose(image) original_w, original_h = image.size has_alpha = image.mode in {"RGBA", "LA"} or (image.mode == "P" and "transparency" in image.info) image_rgba = image.convert("RGBA") max_side = int(max_side) if max_side > 0 and max(original_w, original_h) > max_side: scale = max_side / float(max(original_w, original_h)) new_size = (max(1, round(original_w * scale)), max(1, round(original_h * scale))) image_rgba = image_rgba.resize(new_size, Image.Resampling.LANCZOS) rgba = np.asarray(image_rgba, dtype=np.uint8) rgb = rgba[..., :3].astype(np.float32) / 255.0 image_rgb = Image.fromarray(rgba[..., :3], mode="RGB") image_tensor = torch.from_numpy(rgb).permute(2, 0, 1).contiguous() mask_tensor = None if has_alpha: alpha = rgba[..., 3] if alpha.min() < 255: mask_array = (alpha > 127).astype(np.float32) if np.any(mask_array): mask_tensor = torch.from_numpy(mask_array).unsqueeze(0).contiguous() size_note = f"{original_w}x{original_h}" if image_rgba.size != (original_w, original_h): size_note += f" -> {image_rgba.size[0]}x{image_rgba.size[1]}" return image_tensor, image_rgb, mask_tensor, size_note def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image: array = ( tensor.detach() .clamp(0.0, 1.0) .mul(255.0) .byte() .permute(1, 2, 0) .cpu() .numpy() ) return Image.fromarray(array, mode="RGB") def _prune_old_outputs(keep: int = 20) -> None: runs = [p for p in OUTPUT_ROOT.iterdir() if p.is_dir()] runs.sort(key=lambda p: p.stat().st_mtime, reverse=True) for path in runs[keep:]: shutil.rmtree(path, ignore_errors=True) def _load_splat2d(path: Path): with path.open("rb") as f: magic = f.read(4) if magic != b"GS2D": raise gr.Error("The cached Splat2D file is not a FlashImgs GS2D file.") header = f.read(8) if len(header) != 8: raise gr.Error("The cached Splat2D file is truncated.") n_gaussians, height, width = struct.unpack(" np.ndarray: import cv2 mask_u8 = (mask_array > 0).astype(np.uint8) padded = np.pad(mask_u8, 1, mode="constant", constant_values=0) dist = cv2.distanceTransform(padded, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) return dist[1:-1, 1:-1].astype(np.float32, copy=False) def _write_ply_from_splat_arrays( path: Path, xy: np.ndarray, scale_px: np.ndarray, rot: np.ndarray, feat: np.ndarray, height: int, width: int, *, mask_array: np.ndarray | None = None, mask_edge_clamp: float = 0.35, opacity_scale: float, min_opacity: float, max_opacity: float, scale_multiplier: float, thickness: float, y_axis: str, ) -> None: n_gaussians = int(xy.shape[0]) height_f = max(float(height), 1.0) width_f = max(float(width), 1.0) aspect = width_f / height_f world_height = 1.0 thickness = max(float(thickness), 1e-8) opacity_mode = "factor_rgb" opacity_scale = max(float(opacity_scale), 1e-6) min_opacity = min(1.0 - 1e-6, max(1e-6, float(min_opacity))) max_opacity = min(1.0 - 1e-6, max(min_opacity, float(max_opacity))) scale_multiplier = max(float(scale_multiplier), 1e-6) y_axis = str(y_axis).lower().strip() if y_axis not in {"down", "up"}: y_axis = "down" if mask_array is not None and float(mask_edge_clamp) > 0: dist = _mask_edge_distance(mask_array) x_px = np.clip(np.floor(xy[:, 0] * float(width)).astype(np.int64), 0, int(width) - 1) y_px = np.clip(np.floor(xy[:, 1] * float(height)).astype(np.int64), 0, int(height) - 1) max_scale = np.maximum(0.25, dist[y_px, x_px] * float(mask_edge_clamp)).astype(np.float32) scale_px = np.minimum(scale_px, max_scale[:, None]).astype(np.float32, copy=False) if feat.shape[1] == 1: rgb = np.repeat(feat[:, :1], 3, axis=1) elif feat.shape[1] < 3: rgb = np.pad(feat, ((0, 0), (0, 3 - feat.shape[1])), mode="constant") else: rgb = feat[:, :3] rgb = np.clip(rgb, 0.0, 1.0).astype(np.float32, copy=False) amp = np.max(rgb, axis=1).astype(np.float32, copy=False) alpha = np.clip(opacity_scale * amp, min_opacity, max_opacity).astype(np.float32) ply_rgb = np.clip(rgb / alpha[:, None], 0.0, 1.0).astype(np.float32, copy=False) alpha = np.clip(alpha, 1e-6, 1.0 - 1e-6).astype(np.float32, copy=False) opacity_logit = np.log(alpha / (1.0 - alpha)).astype(np.float32, copy=False) f_dc = (ply_rgb - 0.5) / 0.28209479177387814 names = ["x", "y", "z", "nx", "ny", "nz", "f_dc_0", "f_dc_1", "f_dc_2"] names += [f"f_rest_{i}" for i in range(45)] names += ["opacity", "scale_0", "scale_1", "scale_2", "rot_0", "rot_1", "rot_2", "rot_3"] names += ["fi_u", "fi_v", "fi_scale_x_px", "fi_scale_y_px", "fi_rot", "fi_r", "fi_g", "fi_b"] vertices = np.zeros(n_gaussians, dtype=np.dtype([(name, " None: nonlocal last_update if completed_step < total_steps and completed_step - last_update < update_every: return last_update = int(completed_step) frac = min(1.0, max(0.0, completed_step / total_steps)) progress( 0.10 + 0.78 * frac, desc=f"Fitting splats: {completed_step:,}/{total_steps:,} steps", ) elapsed_s, _, _ = _main.train_loop( session.model, total_steps, verbose=False, snapshot_hook=on_step, ) on_step(total_steps, session.model) return session, 0.0, 0.0, float(elapsed_s) def _fit_duration(*args, **kwargs) -> float: """Seconds of ZeroGPU time to reserve for a fit, estimated from its arguments. Mirrors ``fit_image``'s positional signature ``(image, max_side, gaussians, steps, ...)``. Fits are fast, but reserve headroom for large images, high step counts, and first-call CUDA + extension warmup. """ def _arg(idx: int, default: float) -> float: try: return float(args[idx]) except (IndexError, TypeError, ValueError): return default max_side = _arg(1, 1536.0) gaussians = _arg(2, 0.0) steps = _arg(3, 900.0) est = 30.0 + 0.06 * steps + (gaussians / 100000.0) * 40.0 if max_side == 0.0 or max_side > 2048.0: est += 30.0 return float(min(180.0, max(45.0, est))) @spaces.GPU(duration=_fit_duration) def fit_image( image: Image.Image, max_side: int, gaussians: int, steps: int, lr: float, scale: float, loss: str, render_height: int, seed: int, mask_border_padding: bool, mask_edge_clamp: float, ply_opacity_scale: float, ply_min_opacity: float, ply_max_opacity: float, ply_scale_multiplier: float, ply_thickness: float, ply_y_axis: str, progress: gr.Progress = gr.Progress(track_tqdm=False), ): if not torch.cuda.is_available(): raise gr.Error( "FlashImgs fitting needs CUDA. In Space settings, select a GPU hardware tier." ) progress(0.02, desc="Preparing image") image_tensor, resized_image, mask_tensor, size_note = _prepare_image(image, max_side) fit_w, fit_h = resized_image.size mask_source = "alpha" if mask_tensor is not None else "none" if mask_tensor is None and bool(mask_border_padding): mask_tensor = torch.ones((1, fit_h, fit_w), dtype=torch.float32).contiguous() mask_source = "image border" run_dir = OUTPUT_ROOT / uuid.uuid4().hex run_dir.mkdir(parents=True, exist_ok=True) splat_path = run_dir / "flashimgs_fit.splat2d" ply_path = run_dir / "flashimgs_fit.ply" mask_path = run_dir / "mask.npy" flashimgs = _load_flashimgs() session = None psnr = 0.0 ssim = 0.0 train_elapsed_s = 0.0 start = time.time() try: progress(0.1, desc=f"Fitting splats: 0/{max(1, int(steps)):,} steps") session, psnr, ssim, train_elapsed_s = _fit_with_progress( flashimgs, image_tensor, gaussians=gaussians, steps=steps, lr=lr, scale=scale, loss=loss, seed=seed, mask_tensor=mask_tensor, mask_border_padding=mask_border_padding, progress=progress, ) if mask_tensor is not None: progress(0.86, desc="Pruning mask boundary") session.mask_prune() session.materialize_gate() psnr, ssim = session.evaluate() progress(0.88, desc="Rendering") render_at = None if int(render_height) <= 0 else int(render_height) reconstruction = _tensor_to_pil(session.render(height=render_at)) progress(0.94, desc="Exporting") session.export_splat2d(str(splat_path)) mask_state_path = None mask_array = None if mask_tensor is not None: mask_array = mask_tensor.squeeze(0).cpu().numpy().astype(bool) np.save(mask_path, mask_array.astype(np.uint8)) mask_state_path = str(mask_path) xy, scale_px, rot, feat, height, width = _load_splat2d(splat_path) _write_ply_from_splat_arrays( ply_path, xy, scale_px, rot, feat, height, width, mask_array=mask_array, mask_edge_clamp=float(mask_edge_clamp), opacity_scale=float(ply_opacity_scale), min_opacity=float(ply_min_opacity), max_opacity=float(ply_max_opacity), scale_multiplier=float(ply_scale_multiplier), thickness=float(ply_thickness), y_axis=ply_y_axis, ) total_s = time.time() - start n_gaussians = int(session.n_gaussians) metrics = "\n".join( [ f"Device: {_device_name()}", f"Input: {size_note}", f"Fit resolution: {fit_w}x{fit_h}", f"Mask: {mask_source}", f"Mask border padding: {'on' if mask_tensor is not None and mask_border_padding else 'off'}", f"Mask edge clamp: {float(mask_edge_clamp):.2f}x" if mask_tensor is not None else "Mask edge clamp: off", f"Gaussians: {n_gaussians:,}", f"Steps: {int(steps):,}", f"PSNR: {psnr:.2f} dB", f"SSIM: {ssim:.5f}", f"Training time: {train_elapsed_s:.2f} s", f"Total request time: {total_s:.2f} s", ] ) _prune_old_outputs() progress(1.0, desc="Done") return reconstruction, metrics, str(splat_path), str(ply_path), str(ply_path), str(splat_path), str(ply_path), mask_state_path finally: del session gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() with gr.Blocks(title="FlashImgs") as demo: gr.Markdown("# FlashImgs") splat_state = gr.State(value=None) ply_state = gr.State(value=None) mask_state = gr.State(value=None) with gr.Row(elem_id="workspace"): with gr.Column(scale=1, min_width=300, elem_classes=["settings-panel"]): image = gr.Image( label="Image", type="pil", sources=["upload", "clipboard"], image_mode="RGBA", height=240, ) run = gr.Button("Fit", variant="primary") with gr.Accordion("Fit", open=True): max_side = gr.Slider(0, 4096, value=1536, step=64, label="Max side (0 = original)") gaussians = gr.Slider(0, 100000, value=0, step=1000, label="Gaussians") steps = gr.Slider(100, 3200, value=900, step=100, label="Steps") loss = gr.Dropdown(["l2", "l1", "l2+ssim", "l1+ssim"], value="l2", label="Loss") mask_border_padding = gr.Checkbox(value=True, label="Treat image border as mask border") mask_edge_clamp = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Mask edge scale clamp") with gr.Accordion("PLY", open=True): ply_opacity_scale = gr.Slider(0.1, 4.0, value=1.0, step=0.1, label="Opacity scale") ply_min_opacity = gr.Slider(0.001, 0.25, value=0.01, step=0.001, label="Min opacity") ply_max_opacity = gr.Slider(0.05, 0.99, value=0.95, step=0.01, label="Max opacity") ply_scale_multiplier = gr.Slider(0.25, 3.0, value=1.0, step=0.05, label="XY scale multiplier") ply_thickness = gr.Slider(0.00001, 0.002, value=0.0001, step=0.00001, label="Z thickness") ply_y_axis = gr.Dropdown(["down", "up"], value="down", label="PLY Y axis") update_ply = gr.Button("Update PLY Preview") with gr.Accordion("Advanced", open=False): lr = gr.Slider(1.0, 16.0, value=11.0, step=0.5, label="LR") scale = gr.Slider(0.5, 4.0, value=1.5, step=0.1, label="Scale") render_height = gr.Slider(0, 2048, value=0, step=128, label="Render height") seed = gr.Number(value=42, precision=0, label="Seed") with gr.Column(scale=3, min_width=520, elem_classes=["results-panel"]): with gr.Tabs(elem_classes=["preview-tabs"]): with gr.Tab("3DGS"): ply_preview = gr.Model3D( label="3DGS Preview", display_mode="point_cloud", clear_color=(0.03, 0.03, 0.03, 1.0), height=620, ) with gr.Tab("Image"): reconstruction = gr.Image(label="Reconstruction", type="pil", height=620) with gr.Accordion("Details and downloads", open=False): metrics = gr.Textbox(label="Metrics", lines=9) with gr.Row(elem_classes=["download-row"]): splat = gr.File(label="Splat2D") ply = gr.File(label="Edited 3DGS PLY") gr.Markdown( "FlashImgs Space · © 2025 OpsiClear · wrapper licensed " "[AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.html) — source available " "in this Space's **Files** tab. The bundled FlashImgs engine is separately " "licensed (non-commercial); see NOTICE. No warranty.", elem_id="app-footer", ) run.click( fn=fit_image, inputs=[ image, max_side, gaussians, steps, lr, scale, loss, render_height, seed, mask_border_padding, mask_edge_clamp, ply_opacity_scale, ply_min_opacity, ply_max_opacity, ply_scale_multiplier, ply_thickness, ply_y_axis, ], outputs=[reconstruction, metrics, splat, ply, ply_preview, splat_state, ply_state, mask_state], show_progress_on=[ply_preview], ) ply_update_inputs = [ splat_state, ply_state, mask_state, mask_edge_clamp, ply_opacity_scale, ply_min_opacity, ply_max_opacity, ply_scale_multiplier, ply_thickness, ply_y_axis, ] ply_update_outputs = [ply, ply_preview, ply_state] update_ply.click( fn=update_ply_preview, inputs=ply_update_inputs, outputs=ply_update_outputs, show_progress="minimal", show_progress_on=[ply_preview], trigger_mode="always_last", concurrency_limit=1, concurrency_id="ply_update", ) for control in [mask_edge_clamp, ply_opacity_scale, ply_min_opacity, ply_max_opacity, ply_scale_multiplier, ply_thickness]: control.input( fn=update_ply_preview, inputs=ply_update_inputs, outputs=ply_update_outputs, show_progress="hidden", trigger_mode="always_last", concurrency_limit=1, concurrency_id="ply_update", ) for control in [ply_y_axis]: control.change( fn=update_ply_preview, inputs=ply_update_inputs, outputs=ply_update_outputs, show_progress="hidden", trigger_mode="always_last", concurrency_limit=1, concurrency_id="ply_update", ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1, max_size=8).launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", "7860")), css=APP_CSS, )