from __future__ import annotations import logging import os import re import time import traceback from pathlib import Path from uuid import uuid4 import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor from sharp.utils import io from sharp.utils.gaussians import Gaussians3D, save_ply, unproject_gaussians LOGGER = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) WEIGHTS_REPO_ID = os.getenv("SHARP_WEIGHTS_REPO_ID", "IdlecloudX/ml-sharp-weights") CHECKPOINT_FILENAME = os.getenv("SHARP_CHECKPOINT_FILENAME", "sharp_2572gikvuh.pt") OUTPUT_DIR = Path(os.getenv("SHARP_OUTPUT_DIR", "outputs")) INTERNAL_SHAPE = (1536, 1536) def get_runtime_device() -> torch.device: """选择 SHARP 推理使用的运行设备。 Args: 无。 Returns: torch.device: ZeroGPU/真实 CUDA 环境返回 cuda,本地烟测环境无 CUDA 时返回 cpu。 """ if os.getenv("SPACE_ID") or torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") DEVICE = get_runtime_device() OUTPUT_DIR.mkdir(parents=True, exist_ok=True) def sanitize_stem(stem: str) -> str: """清理上传文件名,生成可安全写入输出目录的文件名前缀。 Args: stem: 原始文件名去除扩展名后的文本。 Returns: str: 仅包含字母、数字、点、下划线和短横线的文件名前缀。 """ normalized = re.sub(r"[^A-Za-z0-9._-]+", "_", stem).strip("._-") return normalized[:64] or "sharp_scene" def resolve_checkpoint_path() -> Path: """从 Hugging Face Hub 缓存中解析 SHARP checkpoint 路径。 Args: 无。 Returns: Path: 已下载或已预加载的 checkpoint 本地路径。 """ checkpoint_path = hf_hub_download( repo_id=WEIGHTS_REPO_ID, filename=CHECKPOINT_FILENAME, repo_type="model", ) return Path(checkpoint_path) def load_predictor() -> RGBGaussianPredictor: """加载 Apple SHARP 权重并初始化 Gaussian predictor。 Args: 无。 Returns: RGBGaussianPredictor: 已切换为 eval 模式并移动到目标设备的预测模型。 """ checkpoint_path = resolve_checkpoint_path() LOGGER.info("Loading SHARP checkpoint from %s", checkpoint_path) # 先在 CPU 反序列化权重,避免下载和反序列化阶段占用 ZeroGPU 真实显存。 state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) predictor = create_predictor(PredictorParams()) predictor.load_state_dict(state_dict) predictor.eval() # ZeroGPU 文档建议模型在模块加载阶段移动到 cuda,由运行时接管真实 GPU 分配。 predictor.to(DEVICE) return predictor @torch.no_grad() def predict_image( predictor: RGBGaussianPredictor, image: np.ndarray, f_px: float, device: torch.device, ) -> Gaussians3D: """将单张 RGB 图片转换为 3D Gaussian 表示。 Args: predictor: 已加载权重的 SHARP Gaussian predictor。 image: RGB 图像数组,形状为 HxWx3。 f_px: 由 EXIF 或默认参数推导出的像素焦距。 device: 执行张量推理的设备。 Returns: Gaussians3D: 已从 NDC 空间还原到度量空间的 3D Gaussian 数据。 """ image_pt = torch.from_numpy(image.copy()).float().to(device).permute(2, 0, 1) / 255.0 _, height, width = image_pt.shape disparity_factor = torch.tensor([f_px / width], dtype=torch.float32, device=device) # SHARP 官方实现固定使用 1536x1536 作为网络内部输入分辨率。 image_resized_pt = F.interpolate( image_pt[None], size=(INTERNAL_SHAPE[1], INTERNAL_SHAPE[0]), mode="bilinear", align_corners=True, ) # 网络输出位于 NDC 空间,后续需要结合相机内参还原到度量空间。 gaussians_ndc = predictor(image_resized_pt, disparity_factor) intrinsics = torch.tensor( [ [f_px, 0, width / 2, 0], [0, f_px, height / 2, 0], [0, 0, 1, 0], [0, 0, 0, 1], ], dtype=torch.float32, device=device, ) intrinsics_resized = intrinsics.clone() intrinsics_resized[0] *= INTERNAL_SHAPE[0] / width intrinsics_resized[1] *= INTERNAL_SHAPE[1] / height # 与 upstream CLI 保持一致:导出前把 NDC Gaussian 变换到 metric 3D 空间。 return unproject_gaussians( gaussians_ndc, torch.eye(4, device=device), intrinsics_resized, INTERNAL_SHAPE, ) def save_uploaded_image_as_ply(image_path: str, predictor: RGBGaussianPredictor) -> tuple[Path, float]: """读取用户上传图片,运行 SHARP,并保存为 3DGS PLY 文件。 Args: image_path: Gradio 上传图片的本地临时文件路径。 predictor: 已加载权重的 SHARP Gaussian predictor。 Returns: tuple[Path, float]: 输出 PLY 路径和本次处理耗时秒数。 """ start_time = time.perf_counter() input_path = Path(image_path) # io.load_rgb 会处理 EXIF 方向、HEIC 以及无焦距 EXIF 时的默认焦距回退。 image, _, f_px = io.load_rgb(input_path) height, width = image.shape[:2] gaussians = predict_image(predictor, image, f_px, DEVICE) output_name = f"{sanitize_stem(input_path.stem)}_{uuid4().hex[:10]}.ply" output_path = OUTPUT_DIR / output_name # 保存格式沿用 Apple SHARP,包含顶点属性、内参、图像尺寸和颜色空间元数据。 save_ply(gaussians, f_px, (height, width), output_path) elapsed_seconds = time.perf_counter() - start_time return output_path, elapsed_seconds MODEL_LOAD_ERROR: str | None = None PREDICTOR: RGBGaussianPredictor | None = None try: if os.getenv("SHARP_SKIP_MODEL_LOAD") == "1": LOGGER.warning("Skipping SHARP model load because SHARP_SKIP_MODEL_LOAD=1.") else: PREDICTOR = load_predictor() except Exception: MODEL_LOAD_ERROR = traceback.format_exc(limit=8) LOGGER.exception("Failed to load SHARP model.") @spaces.GPU(duration=60, size="large") def generate_ply(image_path: str | None) -> tuple[str | None, str]: """Gradio 事件函数:把上传图片转换为可下载的 3DGS PLY 文件。 Args: image_path: Gradio Image 组件传入的本地图片路径。 Returns: tuple[str | None, str]: PLY 文件路径和面向用户展示的状态文本。 """ if image_path is None: return None, "请先上传一张 JPEG、PNG 或 HEIC 图片。" if PREDICTOR is None: detail = MODEL_LOAD_ERROR or "模型尚未加载,且没有捕获到详细异常。" return None, f"SHARP 模型加载失败,无法执行推理。\n\n```text\n{detail}\n```" try: output_path, elapsed_seconds = save_uploaded_image_as_ply(image_path, PREDICTOR) except Exception: detail = traceback.format_exc(limit=8) LOGGER.exception("Failed to generate PLY.") return None, f"生成失败。请确认上传的是有效图片文件。\n\n```text\n{detail}\n```" file_size_mb = output_path.stat().st_size / (1024 * 1024) status = ( f"生成完成:`{output_path.name}`\n\n" f"- 耗时:{elapsed_seconds:.2f} 秒\n" f"- 文件大小:{file_size_mb:.2f} MB\n" "- 输出格式:3D Gaussian Splatting `.ply`\n" "- 注:基于开源协议,该Apple SHARP模型权重仅限科学研究等非商业用途。" ) return str(output_path), status with gr.Blocks(title="Apple SHARP ZeroGPU") as demo: gr.Markdown( """ # Apple SHARP ZeroGPU Upload one image and generate a downloadable 3D Gaussian Splatting `.ply` file. This Space is a research demo for Apple SHARP. The model weights are licensed for scientific research and non-commercial use only. The output is a 3DGS file, not a mesh or GLB model. """ ) with gr.Row(): image_input = gr.Image( label="Input image", sources=["upload"], type="filepath", image_mode="RGB", ) with gr.Column(): output_file = gr.File(label="Generated 3DGS PLY") status_output = gr.Markdown(label="Status") run_button = gr.Button("Generate PLY", variant="primary") run_button.click( fn=generate_ply, inputs=image_input, outputs=[output_file, status_output], concurrency_limit=1, show_progress="full", ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch()