File size: 5,697 Bytes
a54f9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25c920b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import sys
import logging
import shutil
import tempfile
import zipfile
import io as python_io
import base64
from pathlib import Path

from fastapi import FastAPI, UploadFile, File
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import torch

# Ensure we can import the project package: add top-level 'src' to sys.path
# This file resides at: <repo_root>/src/sharp/web/api_server.py
# Path(__file__).parents[2] == <repo_root>/src
sys.path.append(str(Path(__file__).parents[2]))

from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor
from sharp.utils import io as sharp_io
from sharp.utils.gaussians import save_ply
from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL

logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger("sharp.api")

app = FastAPI()

# CORS - allow HF Spaces frontend to call this API.
# Consider tightening allow_origins to your Space domain for production.
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

predictor: RGBGaussianPredictor | None = None
device: torch.device | None = None


@app.on_event("startup")
async def startup_event():
    global predictor, device
    try:
        device_str = (
            "cuda"
            if torch.cuda.is_available()
            else ("mps" if torch.backends.mps.is_available() else "cpu")
        )
        device = torch.device(device_str)
        LOGGER.info(f"Using device: {device}")

        LOGGER.info("Loading SHARP model state dict...")
        state_dict = torch.hub.load_state_dict_from_url(
            DEFAULT_MODEL_URL, progress=True, map_location=device
        )

        predictor = create_predictor(PredictorParams())
        predictor.load_state_dict(state_dict)
        predictor.eval()
        predictor.to(device)
        LOGGER.info("Model loaded and ready.")
    except Exception as e:
        LOGGER.exception("Failed during startup/model init: %s", e)
        # Leave predictor as None; endpoints will return error until fixed.


@app.get("/health")
async def health():
    return {
        "status": "ok",
        "device": str(device) if device else None,
        "model_loaded": predictor is not None,
    }


@app.post("/predict")
async def predict(files: list[UploadFile] = File(...)):
    """Accept images and return JSON with per-image metadata and PLY as base64."""
    if not predictor:
        return JSONResponse({"error": "Model not loaded"}, status_code=500)

    results = []
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir)

        for file in files:
            try:
                # Persist upload to temp
                file_path = temp_path / file.filename
                with open(file_path, "wb") as buffer:
                    shutil.copyfileobj(file.file, buffer)

                # Load input and run prediction
                image, _, f_px = sharp_io.load_rgb(file_path)
                gaussians = predict_image(predictor, image, f_px, device)

                # Save PLY
                ply_filename = f"{file_path.stem}.ply"
                ply_path = temp_path / ply_filename
                height, width = image.shape[:2]
                save_ply(gaussians, f_px, (height, width), ply_path)

                # Encode PLY to base64 for transport
                with open(ply_path, "rb") as f:
                    ply_data = base64.b64encode(f.read()).decode("utf-8")

                results.append(
                    {
                        "filename": file.filename,
                        "ply_filename": ply_filename,
                        "ply_data": ply_data,
                        "width": width,
                        "height": height,
                        "focal_length": f_px,
                    }
                )
            except Exception as e:
                LOGGER.exception("Error processing %s: %s", file.filename, e)
                results.append({"filename": file.filename, "error": str(e)})

    return {"results": results}


@app.post("/predict/download")
async def predict_download(files: list[UploadFile] = File(...)):
    """Accept images and return a ZIP of generated PLY files."""
    if not predictor:
        return JSONResponse({"error": "Model not loaded"}, status_code=500)

    output_zip = python_io.BytesIO()
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir)
        with zipfile.ZipFile(output_zip, "w") as zf:
            for file in files:
                try:
                    file_path = temp_path / file.filename
                    with open(file_path, "wb") as buffer:
                        shutil.copyfileobj(file.file, buffer)

                    image, _, f_px = sharp_io.load_rgb(file_path)
                    gaussians = predict_image(predictor, image, f_px, device)

                    ply_filename = f"{file_path.stem}.ply"
                    ply_path = temp_path / ply_filename
                    height, width = image.shape[:2]
                    save_ply(gaussians, f_px, (height, width), ply_path)

                    zf.write(ply_path, ply_filename)
                except Exception as e:
                    LOGGER.exception("Error processing %s: %s", file.filename, e)
                    continue

    output_zip.seek(0)
    return StreamingResponse(
        output_zip,
        media_type="application/zip",
        headers={"Content-Disposition": "attachment; filename=gaussians.zip"},
    )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=7860)