File size: 5,697 Bytes
a54f9b5 25c920b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import sys
import logging
import shutil
import tempfile
import zipfile
import io as python_io
import base64
from pathlib import Path
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import torch
# Ensure we can import the project package: add top-level 'src' to sys.path
# This file resides at: <repo_root>/src/sharp/web/api_server.py
# Path(__file__).parents[2] == <repo_root>/src
sys.path.append(str(Path(__file__).parents[2]))
from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor
from sharp.utils import io as sharp_io
from sharp.utils.gaussians import save_ply
from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger("sharp.api")
app = FastAPI()
# CORS - allow HF Spaces frontend to call this API.
# Consider tightening allow_origins to your Space domain for production.
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
predictor: RGBGaussianPredictor | None = None
device: torch.device | None = None
@app.on_event("startup")
async def startup_event():
global predictor, device
try:
device_str = (
"cuda"
if torch.cuda.is_available()
else ("mps" if torch.backends.mps.is_available() else "cpu")
)
device = torch.device(device_str)
LOGGER.info(f"Using device: {device}")
LOGGER.info("Loading SHARP model state dict...")
state_dict = torch.hub.load_state_dict_from_url(
DEFAULT_MODEL_URL, progress=True, map_location=device
)
predictor = create_predictor(PredictorParams())
predictor.load_state_dict(state_dict)
predictor.eval()
predictor.to(device)
LOGGER.info("Model loaded and ready.")
except Exception as e:
LOGGER.exception("Failed during startup/model init: %s", e)
# Leave predictor as None; endpoints will return error until fixed.
@app.get("/health")
async def health():
return {
"status": "ok",
"device": str(device) if device else None,
"model_loaded": predictor is not None,
}
@app.post("/predict")
async def predict(files: list[UploadFile] = File(...)):
"""Accept images and return JSON with per-image metadata and PLY as base64."""
if not predictor:
return JSONResponse({"error": "Model not loaded"}, status_code=500)
results = []
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
for file in files:
try:
# Persist upload to temp
file_path = temp_path / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Load input and run prediction
image, _, f_px = sharp_io.load_rgb(file_path)
gaussians = predict_image(predictor, image, f_px, device)
# Save PLY
ply_filename = f"{file_path.stem}.ply"
ply_path = temp_path / ply_filename
height, width = image.shape[:2]
save_ply(gaussians, f_px, (height, width), ply_path)
# Encode PLY to base64 for transport
with open(ply_path, "rb") as f:
ply_data = base64.b64encode(f.read()).decode("utf-8")
results.append(
{
"filename": file.filename,
"ply_filename": ply_filename,
"ply_data": ply_data,
"width": width,
"height": height,
"focal_length": f_px,
}
)
except Exception as e:
LOGGER.exception("Error processing %s: %s", file.filename, e)
results.append({"filename": file.filename, "error": str(e)})
return {"results": results}
@app.post("/predict/download")
async def predict_download(files: list[UploadFile] = File(...)):
"""Accept images and return a ZIP of generated PLY files."""
if not predictor:
return JSONResponse({"error": "Model not loaded"}, status_code=500)
output_zip = python_io.BytesIO()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
with zipfile.ZipFile(output_zip, "w") as zf:
for file in files:
try:
file_path = temp_path / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
image, _, f_px = sharp_io.load_rgb(file_path)
gaussians = predict_image(predictor, image, f_px, device)
ply_filename = f"{file_path.stem}.ply"
ply_path = temp_path / ply_filename
height, width = image.shape[:2]
save_ply(gaussians, f_px, (height, width), ply_path)
zf.write(ply_path, ply_filename)
except Exception as e:
LOGGER.exception("Error processing %s: %s", file.filename, e)
continue
output_zip.seek(0)
return StreamingResponse(
output_zip,
media_type="application/zip",
headers={"Content-Disposition": "attachment; filename=gaussians.zip"},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|