|
|
import sys |
|
|
import logging |
|
|
import shutil |
|
|
import tempfile |
|
|
import zipfile |
|
|
import io as python_io |
|
|
import base64 |
|
|
from pathlib import Path |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parents[2])) |
|
|
|
|
|
from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor |
|
|
from sharp.utils import io as sharp_io |
|
|
from sharp.utils.gaussians import save_ply |
|
|
from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
LOGGER = logging.getLogger("sharp.api") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
predictor: RGBGaussianPredictor | None = None |
|
|
device: torch.device | None = None |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
global predictor, device |
|
|
try: |
|
|
device_str = ( |
|
|
"cuda" |
|
|
if torch.cuda.is_available() |
|
|
else ("mps" if torch.backends.mps.is_available() else "cpu") |
|
|
) |
|
|
device = torch.device(device_str) |
|
|
LOGGER.info(f"Using device: {device}") |
|
|
|
|
|
LOGGER.info("Loading SHARP model state dict...") |
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
|
DEFAULT_MODEL_URL, progress=True, map_location=device |
|
|
) |
|
|
|
|
|
predictor = create_predictor(PredictorParams()) |
|
|
predictor.load_state_dict(state_dict) |
|
|
predictor.eval() |
|
|
predictor.to(device) |
|
|
LOGGER.info("Model loaded and ready.") |
|
|
except Exception as e: |
|
|
LOGGER.exception("Failed during startup/model init: %s", e) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return { |
|
|
"status": "ok", |
|
|
"device": str(device) if device else None, |
|
|
"model_loaded": predictor is not None, |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(files: list[UploadFile] = File(...)): |
|
|
"""Accept images and return JSON with per-image metadata and PLY as base64.""" |
|
|
if not predictor: |
|
|
return JSONResponse({"error": "Model not loaded"}, status_code=500) |
|
|
|
|
|
results = [] |
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
temp_path = Path(temp_dir) |
|
|
|
|
|
for file in files: |
|
|
try: |
|
|
|
|
|
file_path = temp_path / file.filename |
|
|
with open(file_path, "wb") as buffer: |
|
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
|
|
|
|
|
image, _, f_px = sharp_io.load_rgb(file_path) |
|
|
gaussians = predict_image(predictor, image, f_px, device) |
|
|
|
|
|
|
|
|
ply_filename = f"{file_path.stem}.ply" |
|
|
ply_path = temp_path / ply_filename |
|
|
height, width = image.shape[:2] |
|
|
save_ply(gaussians, f_px, (height, width), ply_path) |
|
|
|
|
|
|
|
|
with open(ply_path, "rb") as f: |
|
|
ply_data = base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
|
results.append( |
|
|
{ |
|
|
"filename": file.filename, |
|
|
"ply_filename": ply_filename, |
|
|
"ply_data": ply_data, |
|
|
"width": width, |
|
|
"height": height, |
|
|
"focal_length": f_px, |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
LOGGER.exception("Error processing %s: %s", file.filename, e) |
|
|
results.append({"filename": file.filename, "error": str(e)}) |
|
|
|
|
|
return {"results": results} |
|
|
|
|
|
|
|
|
@app.post("/predict/download") |
|
|
async def predict_download(files: list[UploadFile] = File(...)): |
|
|
"""Accept images and return a ZIP of generated PLY files.""" |
|
|
if not predictor: |
|
|
return JSONResponse({"error": "Model not loaded"}, status_code=500) |
|
|
|
|
|
output_zip = python_io.BytesIO() |
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
temp_path = Path(temp_dir) |
|
|
with zipfile.ZipFile(output_zip, "w") as zf: |
|
|
for file in files: |
|
|
try: |
|
|
file_path = temp_path / file.filename |
|
|
with open(file_path, "wb") as buffer: |
|
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
|
|
image, _, f_px = sharp_io.load_rgb(file_path) |
|
|
gaussians = predict_image(predictor, image, f_px, device) |
|
|
|
|
|
ply_filename = f"{file_path.stem}.ply" |
|
|
ply_path = temp_path / ply_filename |
|
|
height, width = image.shape[:2] |
|
|
save_ply(gaussians, f_px, (height, width), ply_path) |
|
|
|
|
|
zf.write(ply_path, ply_filename) |
|
|
except Exception as e: |
|
|
LOGGER.exception("Error processing %s: %s", file.filename, e) |
|
|
continue |
|
|
|
|
|
output_zip.seek(0) |
|
|
return StreamingResponse( |
|
|
output_zip, |
|
|
media_type="application/zip", |
|
|
headers={"Content-Disposition": "attachment; filename=gaussians.zip"}, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|