# app.py import io import logging import traceback import time from fastapi import FastAPI, UploadFile, File, HTTPException, Request from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from starlette.responses import RedirectResponse from PIL import Image import numpy as np import depth_texture_mask # make sure depth_texture_mask.py is at repo root logger = logging.getLogger("uvicorn.error") app = FastAPI(title="Depth & Structural Masking API with UI") # Mount static folder for CSS/JS/images app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") @app.on_event("startup") async def startup_event(): try: logger.info("Initializing MiDaS model...") # This will initialize the heavy model once depth_texture_mask.init_midas() logger.info("MiDaS initialized.") except Exception as e: logger.exception("Error initializing MiDaS: %s", e) @app.get("/", response_class=HTMLResponse) async def index(request: Request): return templates.TemplateResponse("index.html", {"request": request}) def pil_image_from_uploadfile(upload_file: UploadFile) -> Image.Image: contents = upload_file.file.read() img = Image.open(io.BytesIO(contents)).convert("RGB") upload_file.file.close() return img def numpy_from_pil(pil_img: Image.Image) -> np.ndarray: return np.asarray(pil_img) def pil_from_mask_array(mask: np.ndarray) -> Image.Image: arr = mask.copy() if np.issubdtype(arr.dtype, np.floating): if arr.max() <= 1.0: arr = (arr * 255.0).astype("uint8") else: arr = np.clip(arr, 0, 255).astype("uint8") else: arr = np.clip(arr, 0, 255).astype("uint8") if arr.ndim == 3 and arr.shape[2] == 3: arr = (0.2989 * arr[...,0] + 0.5870 * arr[...,1] + 0.1140 * arr[...,2]).astype("uint8") return Image.fromarray(arr, mode="L") @app.post("/mask/", response_class=StreamingResponse) async def generate_mask_endpoint(file: UploadFile = File(...)): """ Accept an image file and return a PNG mask. Adds a response header 'X-Inference-Time-ms' with inference time in milliseconds. """ try: if not file.content_type.startswith("image/"): raise HTTPException(status_code=415, detail="Unsupported file type.") # read + convert pil_img = pil_image_from_uploadfile(file) input_np = numpy_from_pil(pil_img) # Call model & measure time start = time.perf_counter() mask = depth_texture_mask.generate_texture_depth_mask(input_np, mask_only=True) end = time.perf_counter() infer_ms = int((end - start) * 1000) if mask is None: raise HTTPException(status_code=500, detail="Mask generation failed.") mask_pil = pil_from_mask_array(mask) buf = io.BytesIO() mask_pil.save(buf, format="PNG") buf.seek(0) headers = {"X-Inference-Time-ms": str(infer_ms)} return StreamingResponse(buf, media_type="image/png", headers=headers) except HTTPException: raise except Exception as e: logger.error("Error in /mask/: %s", e) logger.debug(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @app.get("/ui") async def redirect_ui(): return RedirectResponse("/")