lidavidsh commited on
Commit
a54f9b5
·
1 Parent(s): 804c8c9

add gradio ui with seperate frontent/backend

Browse files
pyproject_amd.toml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sharp"
3
+ version = "0.1"
4
+ description = "Inference/Network/Model code for SHARP view synthesis model."
5
+ readme = "README.md"
6
+ dependencies = [
7
+ "click",
8
+ "gsplat",
9
+ "imageio[ffmpeg]",
10
+ "matplotlib",
11
+ "pillow-heif",
12
+ "plyfile",
13
+ "scipy",
14
+ "timm",
15
+ "torch",
16
+ "torchvision",
17
+ ]
18
+
19
+ [project.scripts]
20
+ sharp = "sharp.cli:main_cli"
21
+
22
+ [project.urls]
23
+ Homepage = "https://github.com/apple/ml-sharp"
24
+ Repository = "https://github.com/apple/ml-sharp"
25
+
26
+ [build-system]
27
+ requires = ["setuptools", "setuptools-scm"]
28
+ build-backend = "setuptools.build_meta"
29
+
30
+ [tool.setuptools.packages.find]
31
+ where = ["src"]
32
+
33
+ [tool.pyright]
34
+ include = ["src"]
35
+ exclude = [
36
+ "**/node_modules",
37
+ "**/__pycache__",
38
+ ]
39
+ pythonVersion = "3.12"
40
+
41
+ [tool.pytest.ini_options]
42
+ minversion = "6.0"
43
+ addopts = "-ra -q"
44
+ testpaths = [
45
+ "tests"
46
+ ]
47
+ filterwarnings = [
48
+ "ignore::DeprecationWarning"
49
+ ]
50
+
51
+ [tool.lint.per-file-ignores]
52
+ "__init__.py" = ["F401", "D100", "D104"]
53
+
54
+ [tool.ruff]
55
+ line-length = 100
56
+ lint.select = ["E", "F", "D", "I"]
57
+ lint.ignore = ["D100", "D105",
58
+ # Imperative mood of docstring.
59
+ "D401",
60
+ ]
61
+ extend-exclude = [
62
+ "*external*",
63
+ "third_party",
64
+ ]
65
+ src = ["sharp"]
66
+ target-version = "py39"
67
+
68
+ [tool.ruff.lint.pydocstyle]
69
+ convention = "google"
requirements.txt CHANGED
@@ -1,172 +1,6 @@
1
- # This file was autogenerated by uv via the following command:
2
- # uv pip compile requirements.in -o requirements.txt --universal
3
- -e .
4
- # via -r requirements.in
5
- certifi==2025.8.3
6
- # via requests
7
- charset-normalizer==3.4.3
8
- # via requests
9
- click==8.3.0
10
- # via sharp
11
- colorama==0.4.6 ; sys_platform == 'win32'
12
- # via
13
- # click
14
- # tqdm
15
- contourpy==1.3.3
16
- # via matplotlib
17
- cycler==0.12.1
18
- # via matplotlib
19
- filelock==3.19.1
20
- # via
21
- # huggingface-hub
22
- # torch
23
- fonttools==4.61.0
24
- # via matplotlib
25
- fsspec==2025.9.0
26
- # via
27
- # huggingface-hub
28
- # torch
29
- gsplat==1.5.3
30
- # via sharp
31
- hf-xet==1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
32
- # via huggingface-hub
33
- huggingface-hub==0.35.3
34
- # via timm
35
- idna==3.10
36
- # via requests
37
- imageio==2.37.0
38
- # via sharp
39
- imageio-ffmpeg==0.6.0
40
- # via imageio
41
- jaxtyping==0.3.3
42
- # via gsplat
43
- jinja2==3.1.6
44
- # via torch
45
- kiwisolver==1.4.9
46
- # via matplotlib
47
- markdown-it-py==4.0.0
48
- # via rich
49
- markupsafe==3.0.3
50
- # via jinja2
51
- matplotlib==3.10.6
52
- # via sharp
53
- mdurl==0.1.2
54
- # via markdown-it-py
55
- mpmath==1.3.0
56
- # via sympy
57
- networkx==3.5
58
- # via torch
59
- ninja==1.13.0
60
- # via gsplat
61
- numpy==2.3.3
62
- # via
63
- # contourpy
64
- # gsplat
65
- # imageio
66
- # matplotlib
67
- # plyfile
68
- # scipy
69
- # torchvision
70
- nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
71
- # via
72
- # nvidia-cudnn-cu12
73
- # nvidia-cusolver-cu12
74
- # torch
75
- nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
76
- # via torch
77
- nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
78
- # via torch
79
- nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
80
- # via torch
81
- nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
82
- # via torch
83
- nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
84
- # via torch
85
- nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
86
- # via torch
87
- nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
88
- # via torch
89
- nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
90
- # via torch
91
- nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
92
- # via
93
- # nvidia-cusolver-cu12
94
- # torch
95
- nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
96
- # via torch
97
- nvidia-nccl-cu12==2.27.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
98
- # via torch
99
- nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
100
- # via
101
- # nvidia-cufft-cu12
102
- # nvidia-cusolver-cu12
103
- # nvidia-cusparse-cu12
104
- # torch
105
- nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
106
- # via torch
107
- packaging==25.0
108
- # via
109
- # huggingface-hub
110
- # matplotlib
111
- pillow==11.3.0
112
- # via
113
- # imageio
114
- # matplotlib
115
- # pillow-heif
116
- # torchvision
117
- pillow-heif==1.1.1
118
- # via sharp
119
- plyfile==1.1.2
120
- # via sharp
121
- psutil==7.1.0
122
- # via imageio
123
- pygments==2.19.2
124
- # via rich
125
- pyparsing==3.2.5
126
- # via matplotlib
127
- python-dateutil==2.9.0.post0
128
- # via matplotlib
129
- pyyaml==6.0.3
130
- # via
131
- # huggingface-hub
132
- # timm
133
  requests==2.32.5
134
- # via huggingface-hub
135
- rich==14.1.0
136
- # via gsplat
137
- safetensors==0.6.2
138
- # via timm
139
- scipy==1.16.2
140
- # via sharp
141
- setuptools==80.9.0
142
- # via
143
- # torch
144
- # triton
145
- six==1.17.0
146
- # via python-dateutil
147
- sympy==1.14.0
148
- # via torch
149
- timm==1.0.20
150
- # via sharp
151
- torch==2.8.0
152
- # via
153
- # gsplat
154
- # sharp
155
- # timm
156
- # torchvision
157
- torchvision==0.23.0
158
- # via
159
- # sharp
160
- # timm
161
- tqdm==4.67.1
162
- # via huggingface-hub
163
- triton==3.4.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
164
- # via torch
165
- typing-extensions==4.15.0
166
- # via
167
- # huggingface-hub
168
- # torch
169
- urllib3==2.6.0
170
- # via requests
171
- wadler-lindig==0.1.7
172
- # via jaxtyping
 
1
+ # Front-end requirements for Hugging Face Spaces (Gradio UI)
2
+ # Deploy the Gradio app in src/sharp/web/app.py using these minimal dependencies.
3
+ # Install with: pip install -r requirements.txt
4
+
5
+ gradio==4.44.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  requests==2.32.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements_api.txt ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # This file was autogenerated by uv via the following command:
3
+ # uv pip compile requirements.in -o requirements.txt --universal
4
+ -e .
5
+ # via -r requirements.in
6
+ certifi==2025.8.3
7
+ # via requests
8
+ charset-normalizer==3.4.3
9
+ # via requests
10
+ click==8.3.0
11
+ # via sharp
12
+ colorama==0.4.6 ; sys_platform == 'win32'
13
+ # via
14
+ # click
15
+ # tqdm
16
+ contourpy==1.3.3
17
+ # via matplotlib
18
+ cycler==0.12.1
19
+ # via matplotlib
20
+ filelock==3.19.1
21
+ # via
22
+ # huggingface-hub
23
+ # torch
24
+ fonttools==4.61.0
25
+ # via matplotlib
26
+ fsspec==2025.9.0
27
+ # via
28
+ # huggingface-hub
29
+ # torch
30
+ # gsplat==1.5.3
31
+ # via sharp
32
+ hf-xet==1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
33
+ # via huggingface-hub
34
+ huggingface-hub==0.35.3
35
+ # via timm
36
+ idna==3.10
37
+ # via requests
38
+ imageio==2.37.0
39
+ # via sharp
40
+ imageio-ffmpeg==0.6.0
41
+ # via imageio
42
+ jaxtyping==0.3.3
43
+ # via gsplat
44
+ jinja2==3.1.6
45
+ # via torch
46
+ kiwisolver==1.4.9
47
+ # via matplotlib
48
+ markdown-it-py==4.0.0
49
+ # via rich
50
+ markupsafe==3.0.3
51
+ # via jinja2
52
+ matplotlib==3.10.6
53
+ # via sharp
54
+ mdurl==0.1.2
55
+ # via markdown-it-py
56
+ mpmath==1.3.0
57
+ # via sympy
58
+ networkx==3.5
59
+ # via torch
60
+ ninja==1.13.0
61
+ # via gsplat
62
+ # numpy==2.3.3
63
+ numpy<2
64
+ packaging==25.0
65
+ # via
66
+ # huggingface-hub
67
+ # matplotlib
68
+ pillow==11.3.0
69
+ # via
70
+ # imageio
71
+ # matplotlib
72
+ # pillow-heif
73
+ # torchvision
74
+ pillow-heif==1.1.1
75
+ # via sharp
76
+ plyfile==1.1.2
77
+ # via sharp
78
+ psutil==7.1.0
79
+ # via imageio
80
+ pygments==2.19.2
81
+ # via rich
82
+ pyparsing==3.2.5
83
+ # via matplotlib
84
+ python-dateutil==2.9.0.post0
85
+ # via matplotlib
86
+ pyyaml==6.0.3
87
+ # via
88
+ # huggingface-hub
89
+ # timm
90
+ requests==2.32.5
91
+ # via huggingface-hub
92
+ rich==14.1.0
93
+ # via gsplat
94
+ safetensors==0.6.2
95
+ # via timm
96
+ scipy==1.16.2
97
+ # via sharp
98
+ setuptools==80.9.0
99
+ # via
100
+ # torch
101
+ # triton
102
+ six==1.17.0
103
+ # via python-dateutil
104
+ sympy==1.14.0
105
+ # via torch
106
+ timm==1.0.20
107
+ # via sharp
108
+ tqdm==4.67.1
109
+ # via huggingface-hub
110
+ typing-extensions
111
+ urllib3==2.6.0
112
+ # via requests
113
+ wadler-lindig==0.1.7
114
+ # via jaxtyping
115
+ # Backend API server runtime deps
116
+ fastapi
117
+ uvicorn[standard]
118
+ python-multipart
src/sharp/web/README.md CHANGED
@@ -1,33 +1,101 @@
1
- # Sharp Web Interface
2
 
3
- This is a web interface for the Sharp 3D prediction model.
 
 
4
 
5
- ## Prerequisites
6
 
7
- Make sure you have the `sharp` package installed (see root README).
8
- Install the web dependencies:
9
 
 
 
 
 
 
 
 
 
 
 
10
  ```bash
11
- pip install -r requirements.txt
 
12
  ```
13
 
14
- ## Running the Server
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- Run the following command from the `web` directory:
17
 
 
 
 
18
  ```bash
19
- python app.py
20
  ```
 
21
 
22
- Or using uvicorn directly:
23
 
 
 
 
 
 
 
 
 
 
24
  ```bash
25
- uvicorn app:app --reload --host 0.0.0.0 --port 8000
26
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- ## Usage
29
 
30
- 1. Open your browser and navigate to `http://localhost:8000`.
31
- 2. Drag and drop images or click to select them.
32
- 3. Click "Predict 3D Gaussians".
33
- 4. A zip file containing the resulting `.ply` files will be downloaded automatically.
 
1
+ # SHARP Web: Frontend (Gradio) + Backend (FastAPI)
2
 
3
+ This directory provides a separated deployment:
4
+ - Backend API (`api_server.py`) runs on a GPU cloud with FastAPI
5
+ - Frontend UI (`app.py`) runs on Hugging Face Spaces with Gradio
6
 
7
+ The UI calls the API via HTTP; the API performs model inference and returns PLY results.
8
 
9
+ ## Repository Layout
 
10
 
11
+ - `src/sharp/web/api_server.py` — FastAPI backend hosting inference endpoints
12
+ - `src/sharp/web/app.py` — Gradio frontend calling the backend
13
+ - `requirements_api.txt` — Backend dependencies (GPU cloud)
14
+ - `requirements.txt` — Frontend dependencies (HF Spaces)
15
+
16
+ ## Backend (GPU Cloud)
17
+
18
+ ### Install
19
+
20
+ On your GPU cloud instance:
21
  ```bash
22
+ # From repository root
23
+ pip install -r requirements_api.txt
24
  ```
25
 
26
+ Notes:
27
+ - Ensure CUDA is available if using NVIDIA GPUs. The Torch version in `requirements_api.txt` is compiled for CUDA 12 on Linux.
28
+ - On macOS, MPS (Apple Silicon) may be detected; otherwise CPU fallback is used.
29
+
30
+ ### Run
31
+
32
+ From repository root:
33
+ ```bash
34
+ python src/sharp/web/api_server.py
35
+ ```
36
+ or with Uvicorn:
37
+ ```bash
38
+ uvicorn src.sharp.web.api_server:app --host 0.0.0.0 --port 8000
39
+ ```
40
+
41
+ ### Endpoints
42
+
43
+ - `GET /health` — Basic health check, device info, and model-loaded flag
44
+ - `POST /predict` — Multipart upload of one or more images (`files` field); returns JSON with per-image metadata and PLY contents base64-encoded
45
+ - `POST /predict/download` — Multipart upload of one or more images; returns a ZIP stream containing PLY files
46
+
47
+ CORS is enabled by default to allow calls from the Hugging Face Space. For production, set `allow_origins` to your specific Space domain.
48
 
49
+ ## Frontend (Hugging Face Spaces)
50
 
51
+ ### Install
52
+
53
+ On HF Spaces:
54
  ```bash
55
+ pip install -r requirements.txt
56
  ```
57
+ This installs only Gradio and Requests.
58
 
59
+ ### Configure
60
 
61
+ Set environment variable `API_BASE_URL` in your Space to point to the public backend URL, for example:
62
+ ```
63
+ API_BASE_URL=https://your-api.example.com
64
+ ```
65
+ If running locally for testing, `API_BASE_URL` defaults to `http://localhost:8000`.
66
+
67
+ ### Run
68
+
69
+ Locally:
70
  ```bash
71
+ python src/sharp/web/app.py
72
  ```
73
+ Gradio will start on port `7860` by default (configured to `0.0.0.0` in the script).
74
+
75
+ On HF Spaces, simply setting the Space’s “Run” command to `python src/sharp/web/app.py` is sufficient.
76
+
77
+ ### Usage (Frontend)
78
+
79
+ - Single Image tab: upload one image and click Predict to download its PLY.
80
+ - Batch tab: upload multiple images and click Predict Batch to download a ZIP containing PLY files.
81
+ - The frontend calls the backend `POST /predict` and assembles results for user download.
82
+
83
+ ## Quick Local Test
84
+
85
+ 1. Start backend:
86
+ ```bash
87
+ uvicorn src.sharp.web.api_server:app --host 0.0.0.0 --port 8000
88
+ ```
89
+
90
+ 2. Start frontend (in another terminal):
91
+ ```bash
92
+ API_BASE_URL=http://localhost:8000 python src/sharp/web/app.py
93
+ ```
94
+
95
+ 3. Open the Gradio UI (http://localhost:7860), upload images, and verify outputs.
96
 
97
+ ## Notes & Troubleshooting
98
 
99
+ - If imports like `fastapi` or `gradio` show unresolved in your IDE, ensure the correct environment is selected and dependencies installed via the respective requirements file.
100
+ - Network access from HF Spaces to the GPU API must be allowed; ensure your API endpoint is accessible over HTTPS where possible.
101
+ - For security, consider locking down CORS to your Space origin and adding authentication (e.g., an API key header) if needed.
 
src/sharp/web/api_server.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import shutil
4
+ import tempfile
5
+ import zipfile
6
+ import io as python_io
7
+ import base64
8
+ from pathlib import Path
9
+
10
+ from fastapi import FastAPI, UploadFile, File
11
+ from fastapi.responses import StreamingResponse, JSONResponse
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ import torch
14
+
15
+ # Ensure we can import the project package: add top-level 'src' to sys.path
16
+ # This file resides at: <repo_root>/src/sharp/web/api_server.py
17
+ # Path(__file__).parents[2] == <repo_root>/src
18
+ sys.path.append(str(Path(__file__).parents[2]))
19
+
20
+ from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor
21
+ from sharp.utils import io as sharp_io
22
+ from sharp.utils.gaussians import save_ply
23
+ from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL
24
+
25
+ logging.basicConfig(level=logging.INFO)
26
+ LOGGER = logging.getLogger("sharp.api")
27
+
28
+ app = FastAPI()
29
+
30
+ # CORS - allow HF Spaces frontend to call this API.
31
+ # Consider tightening allow_origins to your Space domain for production.
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"],
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+
40
+ predictor: RGBGaussianPredictor | None = None
41
+ device: torch.device | None = None
42
+
43
+
44
+ @app.on_event("startup")
45
+ async def startup_event():
46
+ global predictor, device
47
+ try:
48
+ device_str = (
49
+ "cuda"
50
+ if torch.cuda.is_available()
51
+ else ("mps" if torch.backends.mps.is_available() else "cpu")
52
+ )
53
+ device = torch.device(device_str)
54
+ LOGGER.info(f"Using device: {device}")
55
+
56
+ LOGGER.info("Loading SHARP model state dict...")
57
+ state_dict = torch.hub.load_state_dict_from_url(
58
+ DEFAULT_MODEL_URL, progress=True, map_location=device
59
+ )
60
+
61
+ predictor = create_predictor(PredictorParams())
62
+ predictor.load_state_dict(state_dict)
63
+ predictor.eval()
64
+ predictor.to(device)
65
+ LOGGER.info("Model loaded and ready.")
66
+ except Exception as e:
67
+ LOGGER.exception("Failed during startup/model init: %s", e)
68
+ # Leave predictor as None; endpoints will return error until fixed.
69
+
70
+
71
+ @app.get("/health")
72
+ async def health():
73
+ return {
74
+ "status": "ok",
75
+ "device": str(device) if device else None,
76
+ "model_loaded": predictor is not None,
77
+ }
78
+
79
+
80
+ @app.post("/predict")
81
+ async def predict(files: list[UploadFile] = File(...)):
82
+ """Accept images and return JSON with per-image metadata and PLY as base64."""
83
+ if not predictor:
84
+ return JSONResponse({"error": "Model not loaded"}, status_code=500)
85
+
86
+ results = []
87
+ with tempfile.TemporaryDirectory() as temp_dir:
88
+ temp_path = Path(temp_dir)
89
+
90
+ for file in files:
91
+ try:
92
+ # Persist upload to temp
93
+ file_path = temp_path / file.filename
94
+ with open(file_path, "wb") as buffer:
95
+ shutil.copyfileobj(file.file, buffer)
96
+
97
+ # Load input and run prediction
98
+ image, _, f_px = sharp_io.load_rgb(file_path)
99
+ gaussians = predict_image(predictor, image, f_px, device)
100
+
101
+ # Save PLY
102
+ ply_filename = f"{file_path.stem}.ply"
103
+ ply_path = temp_path / ply_filename
104
+ height, width = image.shape[:2]
105
+ save_ply(gaussians, f_px, (height, width), ply_path)
106
+
107
+ # Encode PLY to base64 for transport
108
+ with open(ply_path, "rb") as f:
109
+ ply_data = base64.b64encode(f.read()).decode("utf-8")
110
+
111
+ results.append(
112
+ {
113
+ "filename": file.filename,
114
+ "ply_filename": ply_filename,
115
+ "ply_data": ply_data,
116
+ "width": width,
117
+ "height": height,
118
+ "focal_length": f_px,
119
+ }
120
+ )
121
+ except Exception as e:
122
+ LOGGER.exception("Error processing %s: %s", file.filename, e)
123
+ results.append({"filename": file.filename, "error": str(e)})
124
+
125
+ return {"results": results}
126
+
127
+
128
+ @app.post("/predict/download")
129
+ async def predict_download(files: list[UploadFile] = File(...)):
130
+ """Accept images and return a ZIP of generated PLY files."""
131
+ if not predictor:
132
+ return JSONResponse({"error": "Model not loaded"}, status_code=500)
133
+
134
+ output_zip = python_io.BytesIO()
135
+ with tempfile.TemporaryDirectory() as temp_dir:
136
+ temp_path = Path(temp_dir)
137
+ with zipfile.ZipFile(output_zip, "w") as zf:
138
+ for file in files:
139
+ try:
140
+ file_path = temp_path / file.filename
141
+ with open(file_path, "wb") as buffer:
142
+ shutil.copyfileobj(file.file, buffer)
143
+
144
+ image, _, f_px = sharp_io.load_rgb(file_path)
145
+ gaussians = predict_image(predictor, image, f_px, device)
146
+
147
+ ply_filename = f"{file_path.stem}.ply"
148
+ ply_path = temp_path / ply_filename
149
+ height, width = image.shape[:2]
150
+ save_ply(gaussians, f_px, (height, width), ply_path)
151
+
152
+ zf.write(ply_path, ply_filename)
153
+ except Exception as e:
154
+ LOGGER.exception("Error processing %s: %s", file.filename, e)
155
+ continue
156
+
157
+ output_zip.seek(0)
158
+ return StreamingResponse(
159
+ output_zip,
160
+ media_type="application/zip",
161
+ headers={"Content-Disposition": "attachment; filename=gaussians.zip"},
162
+ )
163
+
164
+
165
+ if __name__ == "__main__":
166
+ import uvicorn
167
+
168
+ uvicorn.run(app, host="0.0.0.0", port=8000)
src/sharp/web/app.py CHANGED
@@ -1,184 +1,125 @@
1
- import sys
2
- from pathlib import Path
3
- import logging
4
- import shutil
5
- import tempfile
6
  import zipfile
7
- import io as python_io
8
  import base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- from fastapi import FastAPI, Request, UploadFile, File
11
- from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
12
- from fastapi.staticfiles import StaticFiles
13
- from fastapi.templating import Jinja2Templates
14
- import torch
15
- import numpy as np
16
-
17
- # Add src to path so we can import sharp
18
- sys.path.append(str(Path(__file__).parent.parent / "src"))
19
-
20
- from sharp.models import (
21
- PredictorParams,
22
- RGBGaussianPredictor,
23
- create_predictor,
24
- )
25
- from sharp.utils import io as sharp_io
26
- from sharp.utils.gaussians import save_ply
27
- from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL
28
-
29
- # Configure logging
30
- logging.basicConfig(level=logging.INFO)
31
- LOGGER = logging.getLogger(__name__)
32
-
33
- app = FastAPI()
34
-
35
- # Mount static files if needed (we created the dir)
36
- app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
37
-
38
- templates = Jinja2Templates(directory=Path(__file__).parent / "templates")
39
-
40
- # Global variables for the model
41
- predictor: RGBGaussianPredictor = None
42
- device: torch.device = None
43
-
44
- @app.on_event("startup")
45
- async def startup_event():
46
- global predictor, device
47
-
48
- # Determine device
49
- if torch.cuda.is_available():
50
- device_str = "cuda"
51
- elif torch.mps.is_available():
52
- device_str = "mps"
53
- else:
54
- device_str = "cpu"
55
-
56
- device = torch.device(device_str)
57
- LOGGER.info(f"Using device: {device}")
58
-
59
- # Load model
60
- LOGGER.info("Loading model...")
61
  try:
62
- # Try to load from cache or download
63
- state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True, map_location=device)
64
-
65
- predictor = create_predictor(PredictorParams())
66
- predictor.load_state_dict(state_dict)
67
- predictor.eval()
68
- predictor.to(device)
69
- LOGGER.info("Model loaded successfully.")
70
  except Exception as e:
71
- LOGGER.error(f"Failed to load model: {e}")
72
- raise e
73
-
74
- @app.get("/", response_class=HTMLResponse)
75
- async def read_root(request: Request):
76
- return templates.TemplateResponse("index.html", {"request": request})
77
-
78
- @app.post("/predict")
79
- async def predict(files: list[UploadFile] = File(...)):
80
- """Process images and return PLY data for viewing or download."""
81
- if not predictor:
82
- return JSONResponse({"error": "Model not loaded"}, status_code=500)
83
-
84
- # Create a temporary directory to process files
85
- with tempfile.TemporaryDirectory() as temp_dir:
86
- temp_path = Path(temp_dir)
87
- results = []
88
-
89
- for file in files:
90
- try:
91
- # Save uploaded file
92
- file_path = temp_path / file.filename
93
- with open(file_path, "wb") as buffer:
94
- shutil.copyfileobj(file.file, buffer)
95
-
96
- LOGGER.info(f"Processing {file.filename}")
97
-
98
- # Load image using sharp's IO to get focal length and handle rotation
99
- image, _, f_px = sharp_io.load_rgb(file_path)
100
-
101
- # Run prediction
102
- gaussians = predict_image(predictor, image, f_px, device)
103
-
104
- # Save PLY
105
- ply_filename = f"{file_path.stem}.ply"
106
- ply_path = temp_path / ply_filename
107
-
108
- height, width = image.shape[:2]
109
- save_ply(gaussians, f_px, (height, width), ply_path)
110
-
111
- # Read PLY file and encode as base64
112
- with open(ply_path, "rb") as f:
113
- ply_data = base64.b64encode(f.read()).decode("utf-8")
114
-
115
- results.append({
116
- "filename": file.filename,
117
- "ply_filename": ply_filename,
118
- "ply_data": ply_data,
119
- "width": width,
120
- "height": height,
121
- "focal_length": f_px,
122
- })
123
-
124
- except Exception as e:
125
- LOGGER.error(f"Error processing {file.filename}: {e}")
126
- results.append({
127
- "filename": file.filename,
128
- "error": str(e),
129
- })
130
-
131
- return JSONResponse({"results": results})
132
-
133
-
134
- @app.post("/predict/download")
135
- async def predict_download(files: list[UploadFile] = File(...)):
136
- """Process images and return a ZIP file for download."""
137
- if not predictor:
138
- return HTMLResponse("Model not loaded", status_code=500)
139
-
140
- # Create a temporary directory to process files
141
- with tempfile.TemporaryDirectory() as temp_dir:
142
- temp_path = Path(temp_dir)
143
- output_zip = python_io.BytesIO()
144
-
145
- with zipfile.ZipFile(output_zip, "w") as zf:
146
- for file in files:
147
- try:
148
- # Save uploaded file
149
- file_path = temp_path / file.filename
150
- with open(file_path, "wb") as buffer:
151
- shutil.copyfileobj(file.file, buffer)
152
-
153
- LOGGER.info(f"Processing {file.filename}")
154
-
155
- # Load image using sharp's IO to get focal length and handle rotation
156
- image, _, f_px = sharp_io.load_rgb(file_path)
157
-
158
- # Run prediction
159
- gaussians = predict_image(predictor, image, f_px, device)
160
-
161
- # Save PLY
162
- ply_filename = f"{file_path.stem}.ply"
163
- ply_path = temp_path / ply_filename
164
-
165
- height, width = image.shape[:2]
166
- save_ply(gaussians, f_px, (height, width), ply_path)
167
-
168
- # Add to zip
169
- zf.write(ply_path, ply_filename)
170
-
171
- except Exception as e:
172
- LOGGER.error(f"Error processing {file.filename}: {e}")
173
- continue
174
-
175
- output_zip.seek(0)
176
- return StreamingResponse(
177
- output_zip,
178
- media_type="application/zip",
179
- headers={"Content-Disposition": "attachment; filename=gaussians.zip"}
180
  )
 
 
 
 
181
 
182
  if __name__ == "__main__":
183
- import uvicorn
184
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import os
2
+ import io
 
 
 
3
  import zipfile
 
4
  import base64
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ import requests
9
+ import gradio as gr
10
+
11
+ # Front-end Gradio app that calls the backend FastAPI service hosted on GPU cloud.
12
+ # Configure the backend base URL through environment variable on Hugging Face Spaces.
13
+ # Example: API_BASE_URL = "https://your-api.example.com"
14
+ API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
15
+
16
+
17
+ def _files_payload(images):
18
+ """Prepare multipart/form-data payload for requests.post(files=...)."""
19
+ files = []
20
+ for img in images:
21
+ if img is None:
22
+ continue
23
+ # gr.Image(type="filepath") returns a string path
24
+ if isinstance(img, str):
25
+ path = img
26
+ files.append(("files", (Path(path).name, open(path, "rb"), "image/*")))
27
+ continue
28
+ # gr.File returns objects with a .name attribute (path), or dict-like in some cases
29
+ path = getattr(img, "name", None)
30
+ if path is None and isinstance(img, dict) and "name" in img:
31
+ path = img["name"]
32
+ if path:
33
+ files.append(("files", (Path(path).name, open(path, "rb"), "image/*")))
34
+ return files
35
+
36
+
37
+ def predict_single(image):
38
+ """Call /predict on backend for a single image and return one PLY file to download."""
39
+ if not image:
40
+ return None, "No image provided."
41
+ files = _files_payload([image])
42
+ if not files:
43
+ return None, "Invalid image input."
44
+
45
+ try:
46
+ resp = requests.post(f"{API_BASE_URL}/predict", files=files, timeout=120)
47
+ resp.raise_for_status()
48
+ data = resp.json()
49
+ except Exception as e:
50
+ return None, f"Backend error: {e}"
51
+
52
+ results = data.get("results", [])
53
+ if not results:
54
+ return None, "No result."
55
+ item = results[0]
56
+ if "error" in item:
57
+ return None, item["error"]
58
+
59
+ # Decode base64 PLY to a temporary file
60
+ ply_bytes = base64.b64decode(item["ply_data"])
61
+ with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as tmpf:
62
+ tmpf.write(ply_bytes)
63
+ ply_path = tmpf.name
64
+
65
+ meta = f"{item['ply_filename']} ({item['width']}x{item['height']}), f={item['focal_length']:.2f}"
66
+ return ply_path, meta
67
+
68
+
69
+ def predict_batch(images):
70
+ """Call /predict on backend for multiple images and return a ZIP of PLY files."""
71
+ if not images:
72
+ return None, "No images provided."
73
+ files = _files_payload(images)
74
+ if not files:
75
+ return None, "Invalid inputs."
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  try:
78
+ resp = requests.post(f"{API_BASE_URL}/predict", files=files, timeout=300)
79
+ resp.raise_for_status()
80
+ data = resp.json()
 
 
 
 
 
81
  except Exception as e:
82
+ return None, f"Backend error: {e}"
83
+
84
+ results = data.get("results", [])
85
+ buf = io.BytesIO()
86
+ with zipfile.ZipFile(buf, "w") as zf:
87
+ metas = []
88
+ for item in results:
89
+ if "error" in item:
90
+ metas.append(f"{item.get('filename', '?')}: ERROR {item['error']}")
91
+ continue
92
+ ply_bytes = base64.b64decode(item["ply_data"])
93
+ zf.writestr(item["ply_filename"], ply_bytes)
94
+ metas.append(
95
+ f"{item['filename']} -> {item['ply_filename']} "
96
+ f"({item['width']}x{item['height']}, f={item['focal_length']:.2f})"
97
+ )
98
+ buf.seek(0)
99
+ return buf, "\n".join(metas)
100
+
101
+
102
+ with gr.Blocks(title="SHARP View Synthesis") as demo:
103
+ gr.Markdown(
104
+ "# SHARP View Synthesis\nUpload image(s) to generate 3D Gaussian PLY files via the backend API."
105
+ )
106
+
107
+ with gr.Tab("Single Image"):
108
+ in_img = gr.Image(type="filepath", label="Input Image")
109
+ out_file = gr.File(label="Generated PLY")
110
+ out_info = gr.Textbox(label="Info")
111
+ btn = gr.Button("Predict")
112
+ btn.click(predict_single, inputs=[in_img], outputs=[out_file, out_info])
113
+
114
+ with gr.Tab("Batch"):
115
+ in_imgs = gr.File(
116
+ file_count="multiple", file_types=["image"], label="Input Images"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
+ out_zip = gr.File(label="PLY ZIP")
119
+ out_info2 = gr.Textbox(label="Info")
120
+ btn2 = gr.Button("Predict Batch")
121
+ btn2.click(predict_batch, inputs=[in_imgs], outputs=[out_zip, out_info2])
122
 
123
  if __name__ == "__main__":
124
+ # On Hugging Face Spaces, API_BASE_URL must point to your GPU cloud FastAPI server
125
+ demo.launch(server_name="0.0.0.0", server_port=7860)