zeidImigine's picture
Update app.py
f3fecc9 verified
import torch
import numpy as np
import cv2
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import io
import traceback
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
app = FastAPI(title="Metric3D vit_large β€” Metric Depth API")
DEVICE = "cpu"
INPUT_SIZE = (616, 616)
FOCAL_CANONICAL = 1000.0
logger.info(f"Chargement metric3d_vit_large sur {DEVICE} ...")
model = torch.hub.load(
"yvanyin/metric3d",
"metric3d_vit_large",
pretrain=True,
trust_repo=True,
)
model = model.to(DEVICE).eval()
logger.info("ModΓ¨le prΓͺt.")
# ──────────────────────────────────────────────
def preprocess(img_np: np.ndarray):
H_orig, W_orig = img_np.shape[:2]
focal_px = float(max(H_orig, W_orig)) * 1.2
scale = min(INPUT_SIZE[0] / H_orig, INPUT_SIZE[1] / W_orig)
H_new = int(H_orig * scale)
W_new = int(W_orig * scale)
img_resized = cv2.resize(img_np, (W_new, H_new), interpolation=cv2.INTER_LINEAR)
pad_h = INPUT_SIZE[0] - H_new
pad_w = INPUT_SIZE[1] - W_new
ph, pw = pad_h // 2, pad_w // 2
img_padded = np.pad(
img_resized,
((ph, pad_h - ph), (pw, pad_w - pw), (0, 0)),
mode="constant",
constant_values=128,
)
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
img_norm = (img_padded.astype(np.float32) - mean) / std
tensor = torch.tensor(img_norm).permute(2, 0, 1).unsqueeze(0)
pad_info = (ph, pad_h - ph, pw, pad_w - pw)
return tensor, scale, pad_info, H_orig, W_orig, focal_px
def postprocess(pred_depth_raw, scale, pad_info, H_orig, W_orig, focal_px):
ph, ph2, pw, pw2 = pad_info
depth_np = pred_depth_raw.squeeze().cpu().numpy()
H_crop = depth_np.shape[0] - ph - ph2
W_crop = depth_np.shape[1] - pw - pw2
depth_crop = depth_np[ph: ph + H_crop, pw: pw + W_crop]
depth_resized = cv2.resize(depth_crop, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
depth_metric = depth_resized * (FOCAL_CANONICAL / focal_px)
return depth_metric
# ──────────────────────────────────────────────
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# 1) Lecture image
logger.info(f"Fichier reΓ§u : {file.filename} | content_type : {file.content_type}")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"Content-type invalide : {file.content_type}")
contents = await file.read()
logger.info(f"Taille fichier : {len(contents)} octets")
try:
pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
logger.info(f"Image PIL : {pil_img.size}")
except Exception as e:
logger.error(f"Erreur lecture image : {e}")
raise HTTPException(status_code=400, detail=f"Impossible de lire l'image : {e}")
img_np = np.array(pil_img)
logger.info(f"img_np shape : {img_np.shape} dtype : {img_np.dtype}")
# 2) PrΓ©traitement
try:
tensor, scale, pad_info, H_orig, W_orig, focal_px = preprocess(img_np)
logger.info(f"Tensor shape : {tensor.shape} | focal_px : {focal_px:.1f} | scale : {scale:.4f}")
except Exception as e:
logger.error(f"Erreur prΓ©traitement :\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Erreur prΓ©traitement : {e}")
# 3) InfΓ©rence
try:
logger.info("Lancement infΓ©rence ...")
with torch.no_grad():
pred_depth, confidence, output_dict = model.inference({"input": tensor})
logger.info(f"pred_depth shape : {pred_depth.shape} | min : {pred_depth.min():.3f} | max : {pred_depth.max():.3f}")
except Exception as e:
logger.error(f"Erreur infΓ©rence :\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Erreur infΓ©rence : {e}")
# 4) Post-traitement
try:
depth_map = postprocess(pred_depth, scale, pad_info, H_orig, W_orig, focal_px)
logger.info(f"depth_map shape : {depth_map.shape} | min : {depth_map.min():.3f} | max : {depth_map.max():.3f}")
except Exception as e:
logger.error(f"Erreur post-traitement :\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Erreur post-traitement : {e}")
# 5) RΓ©sultats
closest_distance = float(np.min(depth_map))
cy, cx = H_orig // 2, W_orig // 2
center_distance = float(depth_map[cy, cx])
logger.info(f"closest_distance : {closest_distance:.3f} m | center_distance : {center_distance:.3f} m")
return JSONResponse(content={
"closest_distance": closest_distance,
"center_distance": center_distance,
})
@app.get("/")
def root():
return {"status": "ok", "model": "metric3d_vit_large", "device": DEVICE}