File size: 3,572 Bytes
d13d7e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# app.py
import io
import logging
import traceback
import time

from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from starlette.responses import RedirectResponse
from PIL import Image
import numpy as np

import depth_texture_mask  # make sure depth_texture_mask.py is at repo root

logger = logging.getLogger("uvicorn.error")
app = FastAPI(title="Depth & Structural Masking API with UI")

# Mount static folder for CSS/JS/images
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

@app.on_event("startup")
async def startup_event():
    try:
        logger.info("Initializing MiDaS model...")
        # This will initialize the heavy model once
        depth_texture_mask.init_midas()
        logger.info("MiDaS initialized.")
    except Exception as e:
        logger.exception("Error initializing MiDaS: %s", e)

@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

def pil_image_from_uploadfile(upload_file: UploadFile) -> Image.Image:
    contents = upload_file.file.read()
    img = Image.open(io.BytesIO(contents)).convert("RGB")
    upload_file.file.close()
    return img

def numpy_from_pil(pil_img: Image.Image) -> np.ndarray:
    return np.asarray(pil_img)

def pil_from_mask_array(mask: np.ndarray) -> Image.Image:
    arr = mask.copy()
    if np.issubdtype(arr.dtype, np.floating):
        if arr.max() <= 1.0:
            arr = (arr * 255.0).astype("uint8")
        else:
            arr = np.clip(arr, 0, 255).astype("uint8")
    else:
        arr = np.clip(arr, 0, 255).astype("uint8")
    if arr.ndim == 3 and arr.shape[2] == 3:
        arr = (0.2989 * arr[...,0] + 0.5870 * arr[...,1] + 0.1140 * arr[...,2]).astype("uint8")
    return Image.fromarray(arr, mode="L")

@app.post("/mask/", response_class=StreamingResponse)
async def generate_mask_endpoint(file: UploadFile = File(...)):
    """

    Accept an image file and return a PNG mask.

    Adds a response header 'X-Inference-Time-ms' with inference time in milliseconds.

    """
    try:
        if not file.content_type.startswith("image/"):
            raise HTTPException(status_code=415, detail="Unsupported file type.")
        # read + convert
        pil_img = pil_image_from_uploadfile(file)
        input_np = numpy_from_pil(pil_img)

        # Call model & measure time
        start = time.perf_counter()
        mask = depth_texture_mask.generate_texture_depth_mask(input_np, mask_only=True)
        end = time.perf_counter()
        infer_ms = int((end - start) * 1000)

        if mask is None:
            raise HTTPException(status_code=500, detail="Mask generation failed.")

        mask_pil = pil_from_mask_array(mask)
        buf = io.BytesIO()
        mask_pil.save(buf, format="PNG")
        buf.seek(0)

        headers = {"X-Inference-Time-ms": str(infer_ms)}
        return StreamingResponse(buf, media_type="image/png", headers=headers)
    except HTTPException:
        raise
    except Exception as e:
        logger.error("Error in /mask/: %s", e)
        logger.debug(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/ui")
async def redirect_ui():
    return RedirectResponse("/")