add gradio ui with seperate frontent/backend
Browse files- pyproject_amd.toml +69 -0
- requirements.txt +5 -171
- requirements_api.txt +118 -0
- src/sharp/web/README.md +84 -16
- src/sharp/web/api_server.py +168 -0
- src/sharp/web/app.py +117 -176
pyproject_amd.toml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sharp"
|
| 3 |
+
version = "0.1"
|
| 4 |
+
description = "Inference/Network/Model code for SHARP view synthesis model."
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"click",
|
| 8 |
+
"gsplat",
|
| 9 |
+
"imageio[ffmpeg]",
|
| 10 |
+
"matplotlib",
|
| 11 |
+
"pillow-heif",
|
| 12 |
+
"plyfile",
|
| 13 |
+
"scipy",
|
| 14 |
+
"timm",
|
| 15 |
+
"torch",
|
| 16 |
+
"torchvision",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.scripts]
|
| 20 |
+
sharp = "sharp.cli:main_cli"
|
| 21 |
+
|
| 22 |
+
[project.urls]
|
| 23 |
+
Homepage = "https://github.com/apple/ml-sharp"
|
| 24 |
+
Repository = "https://github.com/apple/ml-sharp"
|
| 25 |
+
|
| 26 |
+
[build-system]
|
| 27 |
+
requires = ["setuptools", "setuptools-scm"]
|
| 28 |
+
build-backend = "setuptools.build_meta"
|
| 29 |
+
|
| 30 |
+
[tool.setuptools.packages.find]
|
| 31 |
+
where = ["src"]
|
| 32 |
+
|
| 33 |
+
[tool.pyright]
|
| 34 |
+
include = ["src"]
|
| 35 |
+
exclude = [
|
| 36 |
+
"**/node_modules",
|
| 37 |
+
"**/__pycache__",
|
| 38 |
+
]
|
| 39 |
+
pythonVersion = "3.12"
|
| 40 |
+
|
| 41 |
+
[tool.pytest.ini_options]
|
| 42 |
+
minversion = "6.0"
|
| 43 |
+
addopts = "-ra -q"
|
| 44 |
+
testpaths = [
|
| 45 |
+
"tests"
|
| 46 |
+
]
|
| 47 |
+
filterwarnings = [
|
| 48 |
+
"ignore::DeprecationWarning"
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
[tool.lint.per-file-ignores]
|
| 52 |
+
"__init__.py" = ["F401", "D100", "D104"]
|
| 53 |
+
|
| 54 |
+
[tool.ruff]
|
| 55 |
+
line-length = 100
|
| 56 |
+
lint.select = ["E", "F", "D", "I"]
|
| 57 |
+
lint.ignore = ["D100", "D105",
|
| 58 |
+
# Imperative mood of docstring.
|
| 59 |
+
"D401",
|
| 60 |
+
]
|
| 61 |
+
extend-exclude = [
|
| 62 |
+
"*external*",
|
| 63 |
+
"third_party",
|
| 64 |
+
]
|
| 65 |
+
src = ["sharp"]
|
| 66 |
+
target-version = "py39"
|
| 67 |
+
|
| 68 |
+
[tool.ruff.lint.pydocstyle]
|
| 69 |
+
convention = "google"
|
requirements.txt
CHANGED
|
@@ -1,172 +1,6 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
| 3 |
-
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
# via requests
|
| 7 |
-
charset-normalizer==3.4.3
|
| 8 |
-
# via requests
|
| 9 |
-
click==8.3.0
|
| 10 |
-
# via sharp
|
| 11 |
-
colorama==0.4.6 ; sys_platform == 'win32'
|
| 12 |
-
# via
|
| 13 |
-
# click
|
| 14 |
-
# tqdm
|
| 15 |
-
contourpy==1.3.3
|
| 16 |
-
# via matplotlib
|
| 17 |
-
cycler==0.12.1
|
| 18 |
-
# via matplotlib
|
| 19 |
-
filelock==3.19.1
|
| 20 |
-
# via
|
| 21 |
-
# huggingface-hub
|
| 22 |
-
# torch
|
| 23 |
-
fonttools==4.61.0
|
| 24 |
-
# via matplotlib
|
| 25 |
-
fsspec==2025.9.0
|
| 26 |
-
# via
|
| 27 |
-
# huggingface-hub
|
| 28 |
-
# torch
|
| 29 |
-
gsplat==1.5.3
|
| 30 |
-
# via sharp
|
| 31 |
-
hf-xet==1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 32 |
-
# via huggingface-hub
|
| 33 |
-
huggingface-hub==0.35.3
|
| 34 |
-
# via timm
|
| 35 |
-
idna==3.10
|
| 36 |
-
# via requests
|
| 37 |
-
imageio==2.37.0
|
| 38 |
-
# via sharp
|
| 39 |
-
imageio-ffmpeg==0.6.0
|
| 40 |
-
# via imageio
|
| 41 |
-
jaxtyping==0.3.3
|
| 42 |
-
# via gsplat
|
| 43 |
-
jinja2==3.1.6
|
| 44 |
-
# via torch
|
| 45 |
-
kiwisolver==1.4.9
|
| 46 |
-
# via matplotlib
|
| 47 |
-
markdown-it-py==4.0.0
|
| 48 |
-
# via rich
|
| 49 |
-
markupsafe==3.0.3
|
| 50 |
-
# via jinja2
|
| 51 |
-
matplotlib==3.10.6
|
| 52 |
-
# via sharp
|
| 53 |
-
mdurl==0.1.2
|
| 54 |
-
# via markdown-it-py
|
| 55 |
-
mpmath==1.3.0
|
| 56 |
-
# via sympy
|
| 57 |
-
networkx==3.5
|
| 58 |
-
# via torch
|
| 59 |
-
ninja==1.13.0
|
| 60 |
-
# via gsplat
|
| 61 |
-
numpy==2.3.3
|
| 62 |
-
# via
|
| 63 |
-
# contourpy
|
| 64 |
-
# gsplat
|
| 65 |
-
# imageio
|
| 66 |
-
# matplotlib
|
| 67 |
-
# plyfile
|
| 68 |
-
# scipy
|
| 69 |
-
# torchvision
|
| 70 |
-
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 71 |
-
# via
|
| 72 |
-
# nvidia-cudnn-cu12
|
| 73 |
-
# nvidia-cusolver-cu12
|
| 74 |
-
# torch
|
| 75 |
-
nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 76 |
-
# via torch
|
| 77 |
-
nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 78 |
-
# via torch
|
| 79 |
-
nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 80 |
-
# via torch
|
| 81 |
-
nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 82 |
-
# via torch
|
| 83 |
-
nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 84 |
-
# via torch
|
| 85 |
-
nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 86 |
-
# via torch
|
| 87 |
-
nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 88 |
-
# via torch
|
| 89 |
-
nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 90 |
-
# via torch
|
| 91 |
-
nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 92 |
-
# via
|
| 93 |
-
# nvidia-cusolver-cu12
|
| 94 |
-
# torch
|
| 95 |
-
nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 96 |
-
# via torch
|
| 97 |
-
nvidia-nccl-cu12==2.27.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 98 |
-
# via torch
|
| 99 |
-
nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 100 |
-
# via
|
| 101 |
-
# nvidia-cufft-cu12
|
| 102 |
-
# nvidia-cusolver-cu12
|
| 103 |
-
# nvidia-cusparse-cu12
|
| 104 |
-
# torch
|
| 105 |
-
nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 106 |
-
# via torch
|
| 107 |
-
packaging==25.0
|
| 108 |
-
# via
|
| 109 |
-
# huggingface-hub
|
| 110 |
-
# matplotlib
|
| 111 |
-
pillow==11.3.0
|
| 112 |
-
# via
|
| 113 |
-
# imageio
|
| 114 |
-
# matplotlib
|
| 115 |
-
# pillow-heif
|
| 116 |
-
# torchvision
|
| 117 |
-
pillow-heif==1.1.1
|
| 118 |
-
# via sharp
|
| 119 |
-
plyfile==1.1.2
|
| 120 |
-
# via sharp
|
| 121 |
-
psutil==7.1.0
|
| 122 |
-
# via imageio
|
| 123 |
-
pygments==2.19.2
|
| 124 |
-
# via rich
|
| 125 |
-
pyparsing==3.2.5
|
| 126 |
-
# via matplotlib
|
| 127 |
-
python-dateutil==2.9.0.post0
|
| 128 |
-
# via matplotlib
|
| 129 |
-
pyyaml==6.0.3
|
| 130 |
-
# via
|
| 131 |
-
# huggingface-hub
|
| 132 |
-
# timm
|
| 133 |
requests==2.32.5
|
| 134 |
-
# via huggingface-hub
|
| 135 |
-
rich==14.1.0
|
| 136 |
-
# via gsplat
|
| 137 |
-
safetensors==0.6.2
|
| 138 |
-
# via timm
|
| 139 |
-
scipy==1.16.2
|
| 140 |
-
# via sharp
|
| 141 |
-
setuptools==80.9.0
|
| 142 |
-
# via
|
| 143 |
-
# torch
|
| 144 |
-
# triton
|
| 145 |
-
six==1.17.0
|
| 146 |
-
# via python-dateutil
|
| 147 |
-
sympy==1.14.0
|
| 148 |
-
# via torch
|
| 149 |
-
timm==1.0.20
|
| 150 |
-
# via sharp
|
| 151 |
-
torch==2.8.0
|
| 152 |
-
# via
|
| 153 |
-
# gsplat
|
| 154 |
-
# sharp
|
| 155 |
-
# timm
|
| 156 |
-
# torchvision
|
| 157 |
-
torchvision==0.23.0
|
| 158 |
-
# via
|
| 159 |
-
# sharp
|
| 160 |
-
# timm
|
| 161 |
-
tqdm==4.67.1
|
| 162 |
-
# via huggingface-hub
|
| 163 |
-
triton==3.4.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 164 |
-
# via torch
|
| 165 |
-
typing-extensions==4.15.0
|
| 166 |
-
# via
|
| 167 |
-
# huggingface-hub
|
| 168 |
-
# torch
|
| 169 |
-
urllib3==2.6.0
|
| 170 |
-
# via requests
|
| 171 |
-
wadler-lindig==0.1.7
|
| 172 |
-
# via jaxtyping
|
|
|
|
| 1 |
+
# Front-end requirements for Hugging Face Spaces (Gradio UI)
|
| 2 |
+
# Deploy the Gradio app in src/sharp/web/app.py using these minimal dependencies.
|
| 3 |
+
# Install with: pip install -r requirements.txt
|
| 4 |
+
|
| 5 |
+
gradio==4.44.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
requests==2.32.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements_api.txt
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# This file was autogenerated by uv via the following command:
|
| 3 |
+
# uv pip compile requirements.in -o requirements.txt --universal
|
| 4 |
+
-e .
|
| 5 |
+
# via -r requirements.in
|
| 6 |
+
certifi==2025.8.3
|
| 7 |
+
# via requests
|
| 8 |
+
charset-normalizer==3.4.3
|
| 9 |
+
# via requests
|
| 10 |
+
click==8.3.0
|
| 11 |
+
# via sharp
|
| 12 |
+
colorama==0.4.6 ; sys_platform == 'win32'
|
| 13 |
+
# via
|
| 14 |
+
# click
|
| 15 |
+
# tqdm
|
| 16 |
+
contourpy==1.3.3
|
| 17 |
+
# via matplotlib
|
| 18 |
+
cycler==0.12.1
|
| 19 |
+
# via matplotlib
|
| 20 |
+
filelock==3.19.1
|
| 21 |
+
# via
|
| 22 |
+
# huggingface-hub
|
| 23 |
+
# torch
|
| 24 |
+
fonttools==4.61.0
|
| 25 |
+
# via matplotlib
|
| 26 |
+
fsspec==2025.9.0
|
| 27 |
+
# via
|
| 28 |
+
# huggingface-hub
|
| 29 |
+
# torch
|
| 30 |
+
# gsplat==1.5.3
|
| 31 |
+
# via sharp
|
| 32 |
+
hf-xet==1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 33 |
+
# via huggingface-hub
|
| 34 |
+
huggingface-hub==0.35.3
|
| 35 |
+
# via timm
|
| 36 |
+
idna==3.10
|
| 37 |
+
# via requests
|
| 38 |
+
imageio==2.37.0
|
| 39 |
+
# via sharp
|
| 40 |
+
imageio-ffmpeg==0.6.0
|
| 41 |
+
# via imageio
|
| 42 |
+
jaxtyping==0.3.3
|
| 43 |
+
# via gsplat
|
| 44 |
+
jinja2==3.1.6
|
| 45 |
+
# via torch
|
| 46 |
+
kiwisolver==1.4.9
|
| 47 |
+
# via matplotlib
|
| 48 |
+
markdown-it-py==4.0.0
|
| 49 |
+
# via rich
|
| 50 |
+
markupsafe==3.0.3
|
| 51 |
+
# via jinja2
|
| 52 |
+
matplotlib==3.10.6
|
| 53 |
+
# via sharp
|
| 54 |
+
mdurl==0.1.2
|
| 55 |
+
# via markdown-it-py
|
| 56 |
+
mpmath==1.3.0
|
| 57 |
+
# via sympy
|
| 58 |
+
networkx==3.5
|
| 59 |
+
# via torch
|
| 60 |
+
ninja==1.13.0
|
| 61 |
+
# via gsplat
|
| 62 |
+
# numpy==2.3.3
|
| 63 |
+
numpy<2
|
| 64 |
+
packaging==25.0
|
| 65 |
+
# via
|
| 66 |
+
# huggingface-hub
|
| 67 |
+
# matplotlib
|
| 68 |
+
pillow==11.3.0
|
| 69 |
+
# via
|
| 70 |
+
# imageio
|
| 71 |
+
# matplotlib
|
| 72 |
+
# pillow-heif
|
| 73 |
+
# torchvision
|
| 74 |
+
pillow-heif==1.1.1
|
| 75 |
+
# via sharp
|
| 76 |
+
plyfile==1.1.2
|
| 77 |
+
# via sharp
|
| 78 |
+
psutil==7.1.0
|
| 79 |
+
# via imageio
|
| 80 |
+
pygments==2.19.2
|
| 81 |
+
# via rich
|
| 82 |
+
pyparsing==3.2.5
|
| 83 |
+
# via matplotlib
|
| 84 |
+
python-dateutil==2.9.0.post0
|
| 85 |
+
# via matplotlib
|
| 86 |
+
pyyaml==6.0.3
|
| 87 |
+
# via
|
| 88 |
+
# huggingface-hub
|
| 89 |
+
# timm
|
| 90 |
+
requests==2.32.5
|
| 91 |
+
# via huggingface-hub
|
| 92 |
+
rich==14.1.0
|
| 93 |
+
# via gsplat
|
| 94 |
+
safetensors==0.6.2
|
| 95 |
+
# via timm
|
| 96 |
+
scipy==1.16.2
|
| 97 |
+
# via sharp
|
| 98 |
+
setuptools==80.9.0
|
| 99 |
+
# via
|
| 100 |
+
# torch
|
| 101 |
+
# triton
|
| 102 |
+
six==1.17.0
|
| 103 |
+
# via python-dateutil
|
| 104 |
+
sympy==1.14.0
|
| 105 |
+
# via torch
|
| 106 |
+
timm==1.0.20
|
| 107 |
+
# via sharp
|
| 108 |
+
tqdm==4.67.1
|
| 109 |
+
# via huggingface-hub
|
| 110 |
+
typing-extensions
|
| 111 |
+
urllib3==2.6.0
|
| 112 |
+
# via requests
|
| 113 |
+
wadler-lindig==0.1.7
|
| 114 |
+
# via jaxtyping
|
| 115 |
+
# Backend API server runtime deps
|
| 116 |
+
fastapi
|
| 117 |
+
uvicorn[standard]
|
| 118 |
+
python-multipart
|
src/sharp/web/README.md
CHANGED
|
@@ -1,33 +1,101 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
-
This
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
-
Install the web dependencies:
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
```bash
|
| 11 |
-
|
|
|
|
| 12 |
```
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
```bash
|
| 19 |
-
|
| 20 |
```
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
```bash
|
| 25 |
-
|
| 26 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
##
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
4. A zip file containing the resulting `.ply` files will be downloaded automatically.
|
|
|
|
| 1 |
+
# SHARP Web: Frontend (Gradio) + Backend (FastAPI)
|
| 2 |
|
| 3 |
+
This directory provides a separated deployment:
|
| 4 |
+
- Backend API (`api_server.py`) runs on a GPU cloud with FastAPI
|
| 5 |
+
- Frontend UI (`app.py`) runs on Hugging Face Spaces with Gradio
|
| 6 |
|
| 7 |
+
The UI calls the API via HTTP; the API performs model inference and returns PLY results.
|
| 8 |
|
| 9 |
+
## Repository Layout
|
|
|
|
| 10 |
|
| 11 |
+
- `src/sharp/web/api_server.py` — FastAPI backend hosting inference endpoints
|
| 12 |
+
- `src/sharp/web/app.py` — Gradio frontend calling the backend
|
| 13 |
+
- `requirements_api.txt` — Backend dependencies (GPU cloud)
|
| 14 |
+
- `requirements.txt` — Frontend dependencies (HF Spaces)
|
| 15 |
+
|
| 16 |
+
## Backend (GPU Cloud)
|
| 17 |
+
|
| 18 |
+
### Install
|
| 19 |
+
|
| 20 |
+
On your GPU cloud instance:
|
| 21 |
```bash
|
| 22 |
+
# From repository root
|
| 23 |
+
pip install -r requirements_api.txt
|
| 24 |
```
|
| 25 |
|
| 26 |
+
Notes:
|
| 27 |
+
- Ensure CUDA is available if using NVIDIA GPUs. The Torch version in `requirements_api.txt` is compiled for CUDA 12 on Linux.
|
| 28 |
+
- On macOS, MPS (Apple Silicon) may be detected; otherwise CPU fallback is used.
|
| 29 |
+
|
| 30 |
+
### Run
|
| 31 |
+
|
| 32 |
+
From repository root:
|
| 33 |
+
```bash
|
| 34 |
+
python src/sharp/web/api_server.py
|
| 35 |
+
```
|
| 36 |
+
or with Uvicorn:
|
| 37 |
+
```bash
|
| 38 |
+
uvicorn src.sharp.web.api_server:app --host 0.0.0.0 --port 8000
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Endpoints
|
| 42 |
+
|
| 43 |
+
- `GET /health` — Basic health check, device info, and model-loaded flag
|
| 44 |
+
- `POST /predict` — Multipart upload of one or more images (`files` field); returns JSON with per-image metadata and PLY contents base64-encoded
|
| 45 |
+
- `POST /predict/download` — Multipart upload of one or more images; returns a ZIP stream containing PLY files
|
| 46 |
+
|
| 47 |
+
CORS is enabled by default to allow calls from the Hugging Face Space. For production, set `allow_origins` to your specific Space domain.
|
| 48 |
|
| 49 |
+
## Frontend (Hugging Face Spaces)
|
| 50 |
|
| 51 |
+
### Install
|
| 52 |
+
|
| 53 |
+
On HF Spaces:
|
| 54 |
```bash
|
| 55 |
+
pip install -r requirements.txt
|
| 56 |
```
|
| 57 |
+
This installs only Gradio and Requests.
|
| 58 |
|
| 59 |
+
### Configure
|
| 60 |
|
| 61 |
+
Set environment variable `API_BASE_URL` in your Space to point to the public backend URL, for example:
|
| 62 |
+
```
|
| 63 |
+
API_BASE_URL=https://your-api.example.com
|
| 64 |
+
```
|
| 65 |
+
If running locally for testing, `API_BASE_URL` defaults to `http://localhost:8000`.
|
| 66 |
+
|
| 67 |
+
### Run
|
| 68 |
+
|
| 69 |
+
Locally:
|
| 70 |
```bash
|
| 71 |
+
python src/sharp/web/app.py
|
| 72 |
```
|
| 73 |
+
Gradio will start on port `7860` by default (configured to `0.0.0.0` in the script).
|
| 74 |
+
|
| 75 |
+
On HF Spaces, simply setting the Space’s “Run” command to `python src/sharp/web/app.py` is sufficient.
|
| 76 |
+
|
| 77 |
+
### Usage (Frontend)
|
| 78 |
+
|
| 79 |
+
- Single Image tab: upload one image and click Predict to download its PLY.
|
| 80 |
+
- Batch tab: upload multiple images and click Predict Batch to download a ZIP containing PLY files.
|
| 81 |
+
- The frontend calls the backend `POST /predict` and assembles results for user download.
|
| 82 |
+
|
| 83 |
+
## Quick Local Test
|
| 84 |
+
|
| 85 |
+
1. Start backend:
|
| 86 |
+
```bash
|
| 87 |
+
uvicorn src.sharp.web.api_server:app --host 0.0.0.0 --port 8000
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
2. Start frontend (in another terminal):
|
| 91 |
+
```bash
|
| 92 |
+
API_BASE_URL=http://localhost:8000 python src/sharp/web/app.py
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
3. Open the Gradio UI (http://localhost:7860), upload images, and verify outputs.
|
| 96 |
|
| 97 |
+
## Notes & Troubleshooting
|
| 98 |
|
| 99 |
+
- If imports like `fastapi` or `gradio` show unresolved in your IDE, ensure the correct environment is selected and dependencies installed via the respective requirements file.
|
| 100 |
+
- Network access from HF Spaces to the GPU API must be allowed; ensure your API endpoint is accessible over HTTPS where possible.
|
| 101 |
+
- For security, consider locking down CORS to your Space origin and adding authentication (e.g., an API key header) if needed.
|
|
|
src/sharp/web/api_server.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import logging
|
| 3 |
+
import shutil
|
| 4 |
+
import tempfile
|
| 5 |
+
import zipfile
|
| 6 |
+
import io as python_io
|
| 7 |
+
import base64
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from fastapi import FastAPI, UploadFile, File
|
| 11 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
| 12 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# Ensure we can import the project package: add top-level 'src' to sys.path
|
| 16 |
+
# This file resides at: <repo_root>/src/sharp/web/api_server.py
|
| 17 |
+
# Path(__file__).parents[2] == <repo_root>/src
|
| 18 |
+
sys.path.append(str(Path(__file__).parents[2]))
|
| 19 |
+
|
| 20 |
+
from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor
|
| 21 |
+
from sharp.utils import io as sharp_io
|
| 22 |
+
from sharp.utils.gaussians import save_ply
|
| 23 |
+
from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
LOGGER = logging.getLogger("sharp.api")
|
| 27 |
+
|
| 28 |
+
app = FastAPI()
|
| 29 |
+
|
| 30 |
+
# CORS - allow HF Spaces frontend to call this API.
|
| 31 |
+
# Consider tightening allow_origins to your Space domain for production.
|
| 32 |
+
app.add_middleware(
|
| 33 |
+
CORSMiddleware,
|
| 34 |
+
allow_origins=["*"],
|
| 35 |
+
allow_credentials=True,
|
| 36 |
+
allow_methods=["*"],
|
| 37 |
+
allow_headers=["*"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
predictor: RGBGaussianPredictor | None = None
|
| 41 |
+
device: torch.device | None = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.on_event("startup")
|
| 45 |
+
async def startup_event():
|
| 46 |
+
global predictor, device
|
| 47 |
+
try:
|
| 48 |
+
device_str = (
|
| 49 |
+
"cuda"
|
| 50 |
+
if torch.cuda.is_available()
|
| 51 |
+
else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 52 |
+
)
|
| 53 |
+
device = torch.device(device_str)
|
| 54 |
+
LOGGER.info(f"Using device: {device}")
|
| 55 |
+
|
| 56 |
+
LOGGER.info("Loading SHARP model state dict...")
|
| 57 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 58 |
+
DEFAULT_MODEL_URL, progress=True, map_location=device
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
predictor = create_predictor(PredictorParams())
|
| 62 |
+
predictor.load_state_dict(state_dict)
|
| 63 |
+
predictor.eval()
|
| 64 |
+
predictor.to(device)
|
| 65 |
+
LOGGER.info("Model loaded and ready.")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
LOGGER.exception("Failed during startup/model init: %s", e)
|
| 68 |
+
# Leave predictor as None; endpoints will return error until fixed.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@app.get("/health")
|
| 72 |
+
async def health():
|
| 73 |
+
return {
|
| 74 |
+
"status": "ok",
|
| 75 |
+
"device": str(device) if device else None,
|
| 76 |
+
"model_loaded": predictor is not None,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@app.post("/predict")
|
| 81 |
+
async def predict(files: list[UploadFile] = File(...)):
|
| 82 |
+
"""Accept images and return JSON with per-image metadata and PLY as base64."""
|
| 83 |
+
if not predictor:
|
| 84 |
+
return JSONResponse({"error": "Model not loaded"}, status_code=500)
|
| 85 |
+
|
| 86 |
+
results = []
|
| 87 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 88 |
+
temp_path = Path(temp_dir)
|
| 89 |
+
|
| 90 |
+
for file in files:
|
| 91 |
+
try:
|
| 92 |
+
# Persist upload to temp
|
| 93 |
+
file_path = temp_path / file.filename
|
| 94 |
+
with open(file_path, "wb") as buffer:
|
| 95 |
+
shutil.copyfileobj(file.file, buffer)
|
| 96 |
+
|
| 97 |
+
# Load input and run prediction
|
| 98 |
+
image, _, f_px = sharp_io.load_rgb(file_path)
|
| 99 |
+
gaussians = predict_image(predictor, image, f_px, device)
|
| 100 |
+
|
| 101 |
+
# Save PLY
|
| 102 |
+
ply_filename = f"{file_path.stem}.ply"
|
| 103 |
+
ply_path = temp_path / ply_filename
|
| 104 |
+
height, width = image.shape[:2]
|
| 105 |
+
save_ply(gaussians, f_px, (height, width), ply_path)
|
| 106 |
+
|
| 107 |
+
# Encode PLY to base64 for transport
|
| 108 |
+
with open(ply_path, "rb") as f:
|
| 109 |
+
ply_data = base64.b64encode(f.read()).decode("utf-8")
|
| 110 |
+
|
| 111 |
+
results.append(
|
| 112 |
+
{
|
| 113 |
+
"filename": file.filename,
|
| 114 |
+
"ply_filename": ply_filename,
|
| 115 |
+
"ply_data": ply_data,
|
| 116 |
+
"width": width,
|
| 117 |
+
"height": height,
|
| 118 |
+
"focal_length": f_px,
|
| 119 |
+
}
|
| 120 |
+
)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
LOGGER.exception("Error processing %s: %s", file.filename, e)
|
| 123 |
+
results.append({"filename": file.filename, "error": str(e)})
|
| 124 |
+
|
| 125 |
+
return {"results": results}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@app.post("/predict/download")
|
| 129 |
+
async def predict_download(files: list[UploadFile] = File(...)):
|
| 130 |
+
"""Accept images and return a ZIP of generated PLY files."""
|
| 131 |
+
if not predictor:
|
| 132 |
+
return JSONResponse({"error": "Model not loaded"}, status_code=500)
|
| 133 |
+
|
| 134 |
+
output_zip = python_io.BytesIO()
|
| 135 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 136 |
+
temp_path = Path(temp_dir)
|
| 137 |
+
with zipfile.ZipFile(output_zip, "w") as zf:
|
| 138 |
+
for file in files:
|
| 139 |
+
try:
|
| 140 |
+
file_path = temp_path / file.filename
|
| 141 |
+
with open(file_path, "wb") as buffer:
|
| 142 |
+
shutil.copyfileobj(file.file, buffer)
|
| 143 |
+
|
| 144 |
+
image, _, f_px = sharp_io.load_rgb(file_path)
|
| 145 |
+
gaussians = predict_image(predictor, image, f_px, device)
|
| 146 |
+
|
| 147 |
+
ply_filename = f"{file_path.stem}.ply"
|
| 148 |
+
ply_path = temp_path / ply_filename
|
| 149 |
+
height, width = image.shape[:2]
|
| 150 |
+
save_ply(gaussians, f_px, (height, width), ply_path)
|
| 151 |
+
|
| 152 |
+
zf.write(ply_path, ply_filename)
|
| 153 |
+
except Exception as e:
|
| 154 |
+
LOGGER.exception("Error processing %s: %s", file.filename, e)
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
output_zip.seek(0)
|
| 158 |
+
return StreamingResponse(
|
| 159 |
+
output_zip,
|
| 160 |
+
media_type="application/zip",
|
| 161 |
+
headers={"Content-Disposition": "attachment; filename=gaussians.zip"},
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
import uvicorn
|
| 167 |
+
|
| 168 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
src/sharp/web/app.py
CHANGED
|
@@ -1,184 +1,125 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
import shutil
|
| 5 |
-
import tempfile
|
| 6 |
import zipfile
|
| 7 |
-
import io as python_io
|
| 8 |
import base64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
from fastapi import FastAPI, Request, UploadFile, File
|
| 11 |
-
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
|
| 12 |
-
from fastapi.staticfiles import StaticFiles
|
| 13 |
-
from fastapi.templating import Jinja2Templates
|
| 14 |
-
import torch
|
| 15 |
-
import numpy as np
|
| 16 |
-
|
| 17 |
-
# Add src to path so we can import sharp
|
| 18 |
-
sys.path.append(str(Path(__file__).parent.parent / "src"))
|
| 19 |
-
|
| 20 |
-
from sharp.models import (
|
| 21 |
-
PredictorParams,
|
| 22 |
-
RGBGaussianPredictor,
|
| 23 |
-
create_predictor,
|
| 24 |
-
)
|
| 25 |
-
from sharp.utils import io as sharp_io
|
| 26 |
-
from sharp.utils.gaussians import save_ply
|
| 27 |
-
from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL
|
| 28 |
-
|
| 29 |
-
# Configure logging
|
| 30 |
-
logging.basicConfig(level=logging.INFO)
|
| 31 |
-
LOGGER = logging.getLogger(__name__)
|
| 32 |
-
|
| 33 |
-
app = FastAPI()
|
| 34 |
-
|
| 35 |
-
# Mount static files if needed (we created the dir)
|
| 36 |
-
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
|
| 37 |
-
|
| 38 |
-
templates = Jinja2Templates(directory=Path(__file__).parent / "templates")
|
| 39 |
-
|
| 40 |
-
# Global variables for the model
|
| 41 |
-
predictor: RGBGaussianPredictor = None
|
| 42 |
-
device: torch.device = None
|
| 43 |
-
|
| 44 |
-
@app.on_event("startup")
|
| 45 |
-
async def startup_event():
|
| 46 |
-
global predictor, device
|
| 47 |
-
|
| 48 |
-
# Determine device
|
| 49 |
-
if torch.cuda.is_available():
|
| 50 |
-
device_str = "cuda"
|
| 51 |
-
elif torch.mps.is_available():
|
| 52 |
-
device_str = "mps"
|
| 53 |
-
else:
|
| 54 |
-
device_str = "cpu"
|
| 55 |
-
|
| 56 |
-
device = torch.device(device_str)
|
| 57 |
-
LOGGER.info(f"Using device: {device}")
|
| 58 |
-
|
| 59 |
-
# Load model
|
| 60 |
-
LOGGER.info("Loading model...")
|
| 61 |
try:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
predictor = create_predictor(PredictorParams())
|
| 66 |
-
predictor.load_state_dict(state_dict)
|
| 67 |
-
predictor.eval()
|
| 68 |
-
predictor.to(device)
|
| 69 |
-
LOGGER.info("Model loaded successfully.")
|
| 70 |
except Exception as e:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
ply_path = temp_path / ply_filename
|
| 107 |
-
|
| 108 |
-
height, width = image.shape[:2]
|
| 109 |
-
save_ply(gaussians, f_px, (height, width), ply_path)
|
| 110 |
-
|
| 111 |
-
# Read PLY file and encode as base64
|
| 112 |
-
with open(ply_path, "rb") as f:
|
| 113 |
-
ply_data = base64.b64encode(f.read()).decode("utf-8")
|
| 114 |
-
|
| 115 |
-
results.append({
|
| 116 |
-
"filename": file.filename,
|
| 117 |
-
"ply_filename": ply_filename,
|
| 118 |
-
"ply_data": ply_data,
|
| 119 |
-
"width": width,
|
| 120 |
-
"height": height,
|
| 121 |
-
"focal_length": f_px,
|
| 122 |
-
})
|
| 123 |
-
|
| 124 |
-
except Exception as e:
|
| 125 |
-
LOGGER.error(f"Error processing {file.filename}: {e}")
|
| 126 |
-
results.append({
|
| 127 |
-
"filename": file.filename,
|
| 128 |
-
"error": str(e),
|
| 129 |
-
})
|
| 130 |
-
|
| 131 |
-
return JSONResponse({"results": results})
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
@app.post("/predict/download")
|
| 135 |
-
async def predict_download(files: list[UploadFile] = File(...)):
|
| 136 |
-
"""Process images and return a ZIP file for download."""
|
| 137 |
-
if not predictor:
|
| 138 |
-
return HTMLResponse("Model not loaded", status_code=500)
|
| 139 |
-
|
| 140 |
-
# Create a temporary directory to process files
|
| 141 |
-
with tempfile.TemporaryDirectory() as temp_dir:
|
| 142 |
-
temp_path = Path(temp_dir)
|
| 143 |
-
output_zip = python_io.BytesIO()
|
| 144 |
-
|
| 145 |
-
with zipfile.ZipFile(output_zip, "w") as zf:
|
| 146 |
-
for file in files:
|
| 147 |
-
try:
|
| 148 |
-
# Save uploaded file
|
| 149 |
-
file_path = temp_path / file.filename
|
| 150 |
-
with open(file_path, "wb") as buffer:
|
| 151 |
-
shutil.copyfileobj(file.file, buffer)
|
| 152 |
-
|
| 153 |
-
LOGGER.info(f"Processing {file.filename}")
|
| 154 |
-
|
| 155 |
-
# Load image using sharp's IO to get focal length and handle rotation
|
| 156 |
-
image, _, f_px = sharp_io.load_rgb(file_path)
|
| 157 |
-
|
| 158 |
-
# Run prediction
|
| 159 |
-
gaussians = predict_image(predictor, image, f_px, device)
|
| 160 |
-
|
| 161 |
-
# Save PLY
|
| 162 |
-
ply_filename = f"{file_path.stem}.ply"
|
| 163 |
-
ply_path = temp_path / ply_filename
|
| 164 |
-
|
| 165 |
-
height, width = image.shape[:2]
|
| 166 |
-
save_ply(gaussians, f_px, (height, width), ply_path)
|
| 167 |
-
|
| 168 |
-
# Add to zip
|
| 169 |
-
zf.write(ply_path, ply_filename)
|
| 170 |
-
|
| 171 |
-
except Exception as e:
|
| 172 |
-
LOGGER.error(f"Error processing {file.filename}: {e}")
|
| 173 |
-
continue
|
| 174 |
-
|
| 175 |
-
output_zip.seek(0)
|
| 176 |
-
return StreamingResponse(
|
| 177 |
-
output_zip,
|
| 178 |
-
media_type="application/zip",
|
| 179 |
-
headers={"Content-Disposition": "attachment; filename=gaussians.zip"}
|
| 180 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
if __name__ == "__main__":
|
| 183 |
-
|
| 184 |
-
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
|
|
|
|
|
|
|
|
|
| 3 |
import zipfile
|
|
|
|
| 4 |
import base64
|
| 5 |
+
import tempfile
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
import gradio as gr
|
| 10 |
+
|
| 11 |
+
# Front-end Gradio app that calls the backend FastAPI service hosted on GPU cloud.
|
| 12 |
+
# Configure the backend base URL through environment variable on Hugging Face Spaces.
|
| 13 |
+
# Example: API_BASE_URL = "https://your-api.example.com"
|
| 14 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _files_payload(images):
|
| 18 |
+
"""Prepare multipart/form-data payload for requests.post(files=...)."""
|
| 19 |
+
files = []
|
| 20 |
+
for img in images:
|
| 21 |
+
if img is None:
|
| 22 |
+
continue
|
| 23 |
+
# gr.Image(type="filepath") returns a string path
|
| 24 |
+
if isinstance(img, str):
|
| 25 |
+
path = img
|
| 26 |
+
files.append(("files", (Path(path).name, open(path, "rb"), "image/*")))
|
| 27 |
+
continue
|
| 28 |
+
# gr.File returns objects with a .name attribute (path), or dict-like in some cases
|
| 29 |
+
path = getattr(img, "name", None)
|
| 30 |
+
if path is None and isinstance(img, dict) and "name" in img:
|
| 31 |
+
path = img["name"]
|
| 32 |
+
if path:
|
| 33 |
+
files.append(("files", (Path(path).name, open(path, "rb"), "image/*")))
|
| 34 |
+
return files
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def predict_single(image):
|
| 38 |
+
"""Call /predict on backend for a single image and return one PLY file to download."""
|
| 39 |
+
if not image:
|
| 40 |
+
return None, "No image provided."
|
| 41 |
+
files = _files_payload([image])
|
| 42 |
+
if not files:
|
| 43 |
+
return None, "Invalid image input."
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
resp = requests.post(f"{API_BASE_URL}/predict", files=files, timeout=120)
|
| 47 |
+
resp.raise_for_status()
|
| 48 |
+
data = resp.json()
|
| 49 |
+
except Exception as e:
|
| 50 |
+
return None, f"Backend error: {e}"
|
| 51 |
+
|
| 52 |
+
results = data.get("results", [])
|
| 53 |
+
if not results:
|
| 54 |
+
return None, "No result."
|
| 55 |
+
item = results[0]
|
| 56 |
+
if "error" in item:
|
| 57 |
+
return None, item["error"]
|
| 58 |
+
|
| 59 |
+
# Decode base64 PLY to a temporary file
|
| 60 |
+
ply_bytes = base64.b64decode(item["ply_data"])
|
| 61 |
+
with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as tmpf:
|
| 62 |
+
tmpf.write(ply_bytes)
|
| 63 |
+
ply_path = tmpf.name
|
| 64 |
+
|
| 65 |
+
meta = f"{item['ply_filename']} ({item['width']}x{item['height']}), f={item['focal_length']:.2f}"
|
| 66 |
+
return ply_path, meta
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def predict_batch(images):
|
| 70 |
+
"""Call /predict on backend for multiple images and return a ZIP of PLY files."""
|
| 71 |
+
if not images:
|
| 72 |
+
return None, "No images provided."
|
| 73 |
+
files = _files_payload(images)
|
| 74 |
+
if not files:
|
| 75 |
+
return None, "Invalid inputs."
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
try:
|
| 78 |
+
resp = requests.post(f"{API_BASE_URL}/predict", files=files, timeout=300)
|
| 79 |
+
resp.raise_for_status()
|
| 80 |
+
data = resp.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
except Exception as e:
|
| 82 |
+
return None, f"Backend error: {e}"
|
| 83 |
+
|
| 84 |
+
results = data.get("results", [])
|
| 85 |
+
buf = io.BytesIO()
|
| 86 |
+
with zipfile.ZipFile(buf, "w") as zf:
|
| 87 |
+
metas = []
|
| 88 |
+
for item in results:
|
| 89 |
+
if "error" in item:
|
| 90 |
+
metas.append(f"{item.get('filename', '?')}: ERROR {item['error']}")
|
| 91 |
+
continue
|
| 92 |
+
ply_bytes = base64.b64decode(item["ply_data"])
|
| 93 |
+
zf.writestr(item["ply_filename"], ply_bytes)
|
| 94 |
+
metas.append(
|
| 95 |
+
f"{item['filename']} -> {item['ply_filename']} "
|
| 96 |
+
f"({item['width']}x{item['height']}, f={item['focal_length']:.2f})"
|
| 97 |
+
)
|
| 98 |
+
buf.seek(0)
|
| 99 |
+
return buf, "\n".join(metas)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
with gr.Blocks(title="SHARP View Synthesis") as demo:
|
| 103 |
+
gr.Markdown(
|
| 104 |
+
"# SHARP View Synthesis\nUpload image(s) to generate 3D Gaussian PLY files via the backend API."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
with gr.Tab("Single Image"):
|
| 108 |
+
in_img = gr.Image(type="filepath", label="Input Image")
|
| 109 |
+
out_file = gr.File(label="Generated PLY")
|
| 110 |
+
out_info = gr.Textbox(label="Info")
|
| 111 |
+
btn = gr.Button("Predict")
|
| 112 |
+
btn.click(predict_single, inputs=[in_img], outputs=[out_file, out_info])
|
| 113 |
+
|
| 114 |
+
with gr.Tab("Batch"):
|
| 115 |
+
in_imgs = gr.File(
|
| 116 |
+
file_count="multiple", file_types=["image"], label="Input Images"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
)
|
| 118 |
+
out_zip = gr.File(label="PLY ZIP")
|
| 119 |
+
out_info2 = gr.Textbox(label="Info")
|
| 120 |
+
btn2 = gr.Button("Predict Batch")
|
| 121 |
+
btn2.click(predict_batch, inputs=[in_imgs], outputs=[out_zip, out_info2])
|
| 122 |
|
| 123 |
if __name__ == "__main__":
|
| 124 |
+
# On Hugging Face Spaces, API_BASE_URL must point to your GPU cloud FastAPI server
|
| 125 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|