File size: 3,616 Bytes
ceeef4a
 
 
 
1376b12
 
 
 
 
ceeef4a
1376b12
 
ceeef4a
1376b12
 
ceeef4a
1376b12
 
ceeef4a
1376b12
 
 
 
 
ceeef4a
 
1376b12
ceeef4a
1376b12
 
 
 
 
 
 
 
 
 
ceeef4a
1376b12
 
ceeef4a
1376b12
ceeef4a
 
1376b12
ceeef4a
1376b12
 
 
 
 
ceeef4a
1376b12
 
 
 
 
 
 
 
 
ceeef4a
1376b12
 
 
 
 
 
 
 
 
 
 
ceeef4a
1376b12
 
 
 
 
ceeef4a
1376b12
ceeef4a
1376b12
 
 
 
ceeef4a
 
1376b12
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import io
import traceback
import logging
import torch
from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ──────────────────────────────────────────────
app = FastAPI(title="Depth Pro (Apple) β€” Metric Depth API")

DEVICE = "cpu"
MODEL_ID = "apple/DepthPro-hf"

logger.info(f"Chargement {MODEL_ID} sur {DEVICE} ...")
processor = DepthProImageProcessorFast.from_pretrained(MODEL_ID)
model = DepthProForDepthEstimation.from_pretrained(MODEL_ID)
model = model.to(DEVICE).eval()
logger.info("ModΓ¨le Depth Pro prΓͺt.")


# ──────────────────────────────────────────────
@app.post("/predict")
async def predict(file: UploadFile = File(...)):

    # 1) Lecture image
    logger.info(f"Fichier reΓ§u : {file.filename} | type : {file.content_type}")
    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail=f"Content-type invalide : {file.content_type}")

    contents = await file.read()
    logger.info(f"Taille : {len(contents)} octets")

    try:
        pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
        logger.info(f"Image : {pil_img.size[0]}x{pil_img.size[1]}")
    except Exception as e:
        logger.error(traceback.format_exc())
        raise HTTPException(status_code=400, detail=f"Impossible de lire l'image : {e}")

    # 2) PrΓ©traitement
    try:
        inputs = processor(images=pil_img, return_tensors="pt").to(DEVICE)
        logger.info(f"Inputs prΓͺts : {list(inputs.keys())}")
    except Exception as e:
        logger.error(traceback.format_exc())
        raise HTTPException(status_code=500, detail=f"Erreur prΓ©traitement : {e}")

    # 3) InfΓ©rence
    try:
        logger.info("InfΓ©rence en cours ...")
        with torch.no_grad():
            outputs = model(**inputs)
        logger.info("InfΓ©rence terminΓ©e.")
    except Exception as e:
        logger.error(traceback.format_exc())
        raise HTTPException(status_code=500, detail=f"Erreur infΓ©rence : {e}")

    # 4) Post-traitement → profondeur en mètres à la résolution originale
    try:
        post = processor.post_process_depth_estimation(
            outputs,
            target_sizes=[(pil_img.height, pil_img.width)],
        )
        depth_map = post[0]["predicted_depth"].squeeze().cpu().numpy()  # [H, W] float32, mètres
        logger.info(f"depth_map shape={depth_map.shape} min={depth_map.min():.3f} max={depth_map.max():.3f}")
    except Exception as e:
        logger.error(traceback.format_exc())
        raise HTTPException(status_code=500, detail=f"Erreur post-traitement : {e}")

    # 5) RΓ©sultats
    H, W = depth_map.shape
    closest_distance = float(np.min(depth_map))
    cy, cx = H // 2, W // 2
    center_distance = float(depth_map[cy, cx])

    logger.info(f"closest={closest_distance:.3f}m | center={center_distance:.3f}m")

    return JSONResponse(content={
        "closest_distance": closest_distance,
        "center_distance":  center_distance,
    })


# ──────────────────────────────────────────────
@app.get("/")
def root():
    return {"status": "ok", "model": MODEL_ID, "device": DEVICE}