DepthPro-Api / app.py
zeidImigine's picture
Update app.py
1376b12 verified
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import io
import traceback
import logging
import torch
from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
app = FastAPI(title="Depth Pro (Apple) β€” Metric Depth API")
DEVICE = "cpu"
MODEL_ID = "apple/DepthPro-hf"
logger.info(f"Chargement {MODEL_ID} sur {DEVICE} ...")
processor = DepthProImageProcessorFast.from_pretrained(MODEL_ID)
model = DepthProForDepthEstimation.from_pretrained(MODEL_ID)
model = model.to(DEVICE).eval()
logger.info("ModΓ¨le Depth Pro prΓͺt.")
# ──────────────────────────────────────────────
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# 1) Lecture image
logger.info(f"Fichier reΓ§u : {file.filename} | type : {file.content_type}")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"Content-type invalide : {file.content_type}")
contents = await file.read()
logger.info(f"Taille : {len(contents)} octets")
try:
pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
logger.info(f"Image : {pil_img.size[0]}x{pil_img.size[1]}")
except Exception as e:
logger.error(traceback.format_exc())
raise HTTPException(status_code=400, detail=f"Impossible de lire l'image : {e}")
# 2) PrΓ©traitement
try:
inputs = processor(images=pil_img, return_tensors="pt").to(DEVICE)
logger.info(f"Inputs prΓͺts : {list(inputs.keys())}")
except Exception as e:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Erreur prΓ©traitement : {e}")
# 3) InfΓ©rence
try:
logger.info("InfΓ©rence en cours ...")
with torch.no_grad():
outputs = model(**inputs)
logger.info("InfΓ©rence terminΓ©e.")
except Exception as e:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Erreur infΓ©rence : {e}")
# 4) Post-traitement → profondeur en mètres à la résolution originale
try:
post = processor.post_process_depth_estimation(
outputs,
target_sizes=[(pil_img.height, pil_img.width)],
)
depth_map = post[0]["predicted_depth"].squeeze().cpu().numpy() # [H, W] float32, mètres
logger.info(f"depth_map shape={depth_map.shape} min={depth_map.min():.3f} max={depth_map.max():.3f}")
except Exception as e:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Erreur post-traitement : {e}")
# 5) RΓ©sultats
H, W = depth_map.shape
closest_distance = float(np.min(depth_map))
cy, cx = H // 2, W // 2
center_distance = float(depth_map[cy, cx])
logger.info(f"closest={closest_distance:.3f}m | center={center_distance:.3f}m")
return JSONResponse(content={
"closest_distance": closest_distance,
"center_distance": center_distance,
})
# ──────────────────────────────────────────────
@app.get("/")
def root():
return {"status": "ok", "model": MODEL_ID, "device": DEVICE}