Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import logging | |
| import os | |
| import re | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| from uuid import uuid4 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor | |
| from sharp.utils import io | |
| from sharp.utils.gaussians import Gaussians3D, save_ply, unproject_gaussians | |
| LOGGER = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| WEIGHTS_REPO_ID = os.getenv("SHARP_WEIGHTS_REPO_ID", "IdlecloudX/ml-sharp-weights") | |
| CHECKPOINT_FILENAME = os.getenv("SHARP_CHECKPOINT_FILENAME", "sharp_2572gikvuh.pt") | |
| OUTPUT_DIR = Path(os.getenv("SHARP_OUTPUT_DIR", "outputs")) | |
| INTERNAL_SHAPE = (1536, 1536) | |
| def get_runtime_device() -> torch.device: | |
| """选择 SHARP 推理使用的运行设备。 | |
| Args: | |
| 无。 | |
| Returns: | |
| torch.device: ZeroGPU/真实 CUDA 环境返回 cuda,本地烟测环境无 CUDA 时返回 cpu。 | |
| """ | |
| if os.getenv("SPACE_ID") or torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| DEVICE = get_runtime_device() | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| def sanitize_stem(stem: str) -> str: | |
| """清理上传文件名,生成可安全写入输出目录的文件名前缀。 | |
| Args: | |
| stem: 原始文件名去除扩展名后的文本。 | |
| Returns: | |
| str: 仅包含字母、数字、点、下划线和短横线的文件名前缀。 | |
| """ | |
| normalized = re.sub(r"[^A-Za-z0-9._-]+", "_", stem).strip("._-") | |
| return normalized[:64] or "sharp_scene" | |
| def resolve_checkpoint_path() -> Path: | |
| """从 Hugging Face Hub 缓存中解析 SHARP checkpoint 路径。 | |
| Args: | |
| 无。 | |
| Returns: | |
| Path: 已下载或已预加载的 checkpoint 本地路径。 | |
| """ | |
| checkpoint_path = hf_hub_download( | |
| repo_id=WEIGHTS_REPO_ID, | |
| filename=CHECKPOINT_FILENAME, | |
| repo_type="model", | |
| ) | |
| return Path(checkpoint_path) | |
| def load_predictor() -> RGBGaussianPredictor: | |
| """加载 Apple SHARP 权重并初始化 Gaussian predictor。 | |
| Args: | |
| 无。 | |
| Returns: | |
| RGBGaussianPredictor: 已切换为 eval 模式并移动到目标设备的预测模型。 | |
| """ | |
| checkpoint_path = resolve_checkpoint_path() | |
| LOGGER.info("Loading SHARP checkpoint from %s", checkpoint_path) | |
| # 先在 CPU 反序列化权重,避免下载和反序列化阶段占用 ZeroGPU 真实显存。 | |
| state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) | |
| predictor = create_predictor(PredictorParams()) | |
| predictor.load_state_dict(state_dict) | |
| predictor.eval() | |
| # ZeroGPU 文档建议模型在模块加载阶段移动到 cuda,由运行时接管真实 GPU 分配。 | |
| predictor.to(DEVICE) | |
| return predictor | |
| def predict_image( | |
| predictor: RGBGaussianPredictor, | |
| image: np.ndarray, | |
| f_px: float, | |
| device: torch.device, | |
| ) -> Gaussians3D: | |
| """将单张 RGB 图片转换为 3D Gaussian 表示。 | |
| Args: | |
| predictor: 已加载权重的 SHARP Gaussian predictor。 | |
| image: RGB 图像数组,形状为 HxWx3。 | |
| f_px: 由 EXIF 或默认参数推导出的像素焦距。 | |
| device: 执行张量推理的设备。 | |
| Returns: | |
| Gaussians3D: 已从 NDC 空间还原到度量空间的 3D Gaussian 数据。 | |
| """ | |
| image_pt = torch.from_numpy(image.copy()).float().to(device).permute(2, 0, 1) / 255.0 | |
| _, height, width = image_pt.shape | |
| disparity_factor = torch.tensor([f_px / width], dtype=torch.float32, device=device) | |
| # SHARP 官方实现固定使用 1536x1536 作为网络内部输入分辨率。 | |
| image_resized_pt = F.interpolate( | |
| image_pt[None], | |
| size=(INTERNAL_SHAPE[1], INTERNAL_SHAPE[0]), | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| # 网络输出位于 NDC 空间,后续需要结合相机内参还原到度量空间。 | |
| gaussians_ndc = predictor(image_resized_pt, disparity_factor) | |
| intrinsics = torch.tensor( | |
| [ | |
| [f_px, 0, width / 2, 0], | |
| [0, f_px, height / 2, 0], | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1], | |
| ], | |
| dtype=torch.float32, | |
| device=device, | |
| ) | |
| intrinsics_resized = intrinsics.clone() | |
| intrinsics_resized[0] *= INTERNAL_SHAPE[0] / width | |
| intrinsics_resized[1] *= INTERNAL_SHAPE[1] / height | |
| # 与 upstream CLI 保持一致:导出前把 NDC Gaussian 变换到 metric 3D 空间。 | |
| return unproject_gaussians( | |
| gaussians_ndc, | |
| torch.eye(4, device=device), | |
| intrinsics_resized, | |
| INTERNAL_SHAPE, | |
| ) | |
| def save_uploaded_image_as_ply(image_path: str, predictor: RGBGaussianPredictor) -> tuple[Path, float]: | |
| """读取用户上传图片,运行 SHARP,并保存为 3DGS PLY 文件。 | |
| Args: | |
| image_path: Gradio 上传图片的本地临时文件路径。 | |
| predictor: 已加载权重的 SHARP Gaussian predictor。 | |
| Returns: | |
| tuple[Path, float]: 输出 PLY 路径和本次处理耗时秒数。 | |
| """ | |
| start_time = time.perf_counter() | |
| input_path = Path(image_path) | |
| # io.load_rgb 会处理 EXIF 方向、HEIC 以及无焦距 EXIF 时的默认焦距回退。 | |
| image, _, f_px = io.load_rgb(input_path) | |
| height, width = image.shape[:2] | |
| gaussians = predict_image(predictor, image, f_px, DEVICE) | |
| output_name = f"{sanitize_stem(input_path.stem)}_{uuid4().hex[:10]}.ply" | |
| output_path = OUTPUT_DIR / output_name | |
| # 保存格式沿用 Apple SHARP,包含顶点属性、内参、图像尺寸和颜色空间元数据。 | |
| save_ply(gaussians, f_px, (height, width), output_path) | |
| elapsed_seconds = time.perf_counter() - start_time | |
| return output_path, elapsed_seconds | |
| MODEL_LOAD_ERROR: str | None = None | |
| PREDICTOR: RGBGaussianPredictor | None = None | |
| try: | |
| if os.getenv("SHARP_SKIP_MODEL_LOAD") == "1": | |
| LOGGER.warning("Skipping SHARP model load because SHARP_SKIP_MODEL_LOAD=1.") | |
| else: | |
| PREDICTOR = load_predictor() | |
| except Exception: | |
| MODEL_LOAD_ERROR = traceback.format_exc(limit=8) | |
| LOGGER.exception("Failed to load SHARP model.") | |
| def generate_ply(image_path: str | None) -> tuple[str | None, str]: | |
| """Gradio 事件函数:把上传图片转换为可下载的 3DGS PLY 文件。 | |
| Args: | |
| image_path: Gradio Image 组件传入的本地图片路径。 | |
| Returns: | |
| tuple[str | None, str]: PLY 文件路径和面向用户展示的状态文本。 | |
| """ | |
| if image_path is None: | |
| return None, "请先上传一张 JPEG、PNG 或 HEIC 图片。" | |
| if PREDICTOR is None: | |
| detail = MODEL_LOAD_ERROR or "模型尚未加载,且没有捕获到详细异常。" | |
| return None, f"SHARP 模型加载失败,无法执行推理。\n\n```text\n{detail}\n```" | |
| try: | |
| output_path, elapsed_seconds = save_uploaded_image_as_ply(image_path, PREDICTOR) | |
| except Exception: | |
| detail = traceback.format_exc(limit=8) | |
| LOGGER.exception("Failed to generate PLY.") | |
| return None, f"生成失败。请确认上传的是有效图片文件。\n\n```text\n{detail}\n```" | |
| file_size_mb = output_path.stat().st_size / (1024 * 1024) | |
| status = ( | |
| f"生成完成:`{output_path.name}`\n\n" | |
| f"- 耗时:{elapsed_seconds:.2f} 秒\n" | |
| f"- 文件大小:{file_size_mb:.2f} MB\n" | |
| "- 输出格式:3D Gaussian Splatting `.ply`\n" | |
| "- 注:基于开源协议,该Apple SHARP模型权重仅限科学研究等非商业用途。" | |
| ) | |
| return str(output_path), status | |
| with gr.Blocks(title="Apple SHARP ZeroGPU") as demo: | |
| gr.Markdown( | |
| """ | |
| # Apple SHARP ZeroGPU | |
| Upload one image and generate a downloadable 3D Gaussian Splatting `.ply` file. | |
| This Space is a research demo for Apple SHARP. The model weights are licensed | |
| for scientific research and non-commercial use only. The output is a 3DGS file, | |
| not a mesh or GLB model. | |
| """ | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image( | |
| label="Input image", | |
| sources=["upload"], | |
| type="filepath", | |
| image_mode="RGB", | |
| ) | |
| with gr.Column(): | |
| output_file = gr.File(label="Generated 3DGS PLY") | |
| status_output = gr.Markdown(label="Status") | |
| run_button = gr.Button("Generate PLY", variant="primary") | |
| run_button.click( | |
| fn=generate_ply, | |
| inputs=image_input, | |
| outputs=[output_file, status_output], | |
| concurrency_limit=1, | |
| show_progress="full", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch() | |