zeidImigine commited on
Commit
1376b12
·
verified ·
1 Parent(s): ea22bc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -80
app.py CHANGED
@@ -1,103 +1,93 @@
1
- import io
2
- import base64
3
  import numpy as np
4
  from PIL import Image
5
- import torch
6
- import uvicorn
7
  from fastapi import FastAPI, File, UploadFile, HTTPException
8
  from fastapi.responses import JSONResponse
9
- from contextlib import asynccontextmanager
10
- from depth_pro import create_model_and_transforms
11
- from depth_pro.depth_pro import DepthProConfig
12
-
13
- CHECKPOINT_PATH = "/app/ml-depth-pro/checkpoints/depth_pro.pt"
14
-
15
- model = None
16
- transform = None
17
-
18
-
19
- @asynccontextmanager
20
- async def lifespan(app: FastAPI):
21
- global model, transform
22
- print("Chargement du modèle DepthPro...")
23
- config = DepthProConfig(
24
- patch_encoder_preset="dinov2l16_384",
25
- image_encoder_preset="dinov2l16_384",
26
- checkpoint_uri=CHECKPOINT_PATH,
27
- decoder_features=256,
28
- use_fov_head=True,
29
- fov_encoder_preset="dinov2l16_384",
30
- )
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- precision = torch.half if torch.cuda.is_available() else torch.float32
33
- model, transform = create_model_and_transforms(
34
- config=config,
35
- device=device,
36
- precision=precision,
37
- )
38
- model.eval()
39
- print(f"Modèle chargé sur {device}.")
40
- yield
41
-
42
-
43
- app = FastAPI(
44
- title="DepthPro API",
45
- description="API de prédiction de profondeur monoculaire avec Apple DepthPro",
46
- version="1.0.0",
47
- lifespan=lifespan,
48
- )
49
 
 
 
50
 
51
- @app.get("/")
52
- def root():
53
- return {"message": "DepthPro API is running", "docs": "/docs"}
54
 
 
 
55
 
56
- @app.get("/health")
57
- def health():
58
- return {
59
- "status": "ok",
60
- "device": "cuda" if torch.cuda.is_available() else "cpu",
61
- "model_loaded": model is not None,
62
- }
63
 
64
 
 
65
  @app.post("/predict")
66
- async def predict_depth(file: UploadFile = File(...)):
 
 
 
 
 
 
 
 
 
67
  try:
68
- contents = await file.read()
69
- image = Image.open(io.BytesIO(contents)).convert("RGB")
70
  except Exception as e:
 
71
  raise HTTPException(status_code=400, detail=f"Impossible de lire l'image : {e}")
72
 
 
73
  try:
74
- image_tensor = transform(image)
75
- with torch.no_grad():
76
- prediction = model.infer(image_tensor)
 
 
77
 
78
- depth = prediction["depth"].squeeze().cpu().numpy()
79
- focal_length_px = float(prediction["focallength_px"].cpu())
 
 
 
 
 
 
 
80
 
81
- depth_min = float(depth.min())
82
- depth_max = float(depth.max())
83
- depth_normalized = ((depth - depth_min) / (depth_max - depth_min + 1e-8) * 65535).astype(np.uint16)
 
 
 
 
 
 
 
 
84
 
85
- depth_img = Image.fromarray(depth_normalized, mode="I;16")
86
- buf = io.BytesIO()
87
- depth_img.save(buf, format="PNG")
88
- depth_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
 
89
 
90
- return JSONResponse({
91
- "depth_map_base64": depth_b64,
92
- "focal_length_px": focal_length_px,
93
- "depth_min_meters": depth_min,
94
- "depth_max_meters": depth_max,
95
- "image_size": {"width": image.width, "height": image.height},
96
- })
97
 
98
- except Exception as e:
99
- raise HTTPException(status_code=500, detail=f"Erreur lors de l'inférence : {e}")
 
 
100
 
101
 
102
- if __name__ == "__main__":
103
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
 
 
 
 
 
1
  import numpy as np
2
  from PIL import Image
 
 
3
  from fastapi import FastAPI, File, UploadFile, HTTPException
4
  from fastapi.responses import JSONResponse
5
+ import io
6
+ import traceback
7
+ import logging
8
+ import torch
9
+ from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
+ # ──────────────────────────────────────────────
15
+ app = FastAPI(title="Depth Pro (Apple) — Metric Depth API")
 
16
 
17
+ DEVICE = "cpu"
18
+ MODEL_ID = "apple/DepthPro-hf"
19
 
20
+ logger.info(f"Chargement {MODEL_ID} sur {DEVICE} ...")
21
+ processor = DepthProImageProcessorFast.from_pretrained(MODEL_ID)
22
+ model = DepthProForDepthEstimation.from_pretrained(MODEL_ID)
23
+ model = model.to(DEVICE).eval()
24
+ logger.info("Modèle Depth Pro prêt.")
 
 
25
 
26
 
27
+ # ──────────────────────────────────────────────
28
  @app.post("/predict")
29
+ async def predict(file: UploadFile = File(...)):
30
+
31
+ # 1) Lecture image
32
+ logger.info(f"Fichier reçu : {file.filename} | type : {file.content_type}")
33
+ if not file.content_type.startswith("image/"):
34
+ raise HTTPException(status_code=400, detail=f"Content-type invalide : {file.content_type}")
35
+
36
+ contents = await file.read()
37
+ logger.info(f"Taille : {len(contents)} octets")
38
+
39
  try:
40
+ pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
41
+ logger.info(f"Image : {pil_img.size[0]}x{pil_img.size[1]}")
42
  except Exception as e:
43
+ logger.error(traceback.format_exc())
44
  raise HTTPException(status_code=400, detail=f"Impossible de lire l'image : {e}")
45
 
46
+ # 2) Prétraitement
47
  try:
48
+ inputs = processor(images=pil_img, return_tensors="pt").to(DEVICE)
49
+ logger.info(f"Inputs prêts : {list(inputs.keys())}")
50
+ except Exception as e:
51
+ logger.error(traceback.format_exc())
52
+ raise HTTPException(status_code=500, detail=f"Erreur prétraitement : {e}")
53
 
54
+ # 3) Inférence
55
+ try:
56
+ logger.info("Inférence en cours ...")
57
+ with torch.no_grad():
58
+ outputs = model(**inputs)
59
+ logger.info("Inférence terminée.")
60
+ except Exception as e:
61
+ logger.error(traceback.format_exc())
62
+ raise HTTPException(status_code=500, detail=f"Erreur inférence : {e}")
63
 
64
+ # 4) Post-traitement → profondeur en mètres à la résolution originale
65
+ try:
66
+ post = processor.post_process_depth_estimation(
67
+ outputs,
68
+ target_sizes=[(pil_img.height, pil_img.width)],
69
+ )
70
+ depth_map = post[0]["predicted_depth"].squeeze().cpu().numpy() # [H, W] float32, mètres
71
+ logger.info(f"depth_map shape={depth_map.shape} min={depth_map.min():.3f} max={depth_map.max():.3f}")
72
+ except Exception as e:
73
+ logger.error(traceback.format_exc())
74
+ raise HTTPException(status_code=500, detail=f"Erreur post-traitement : {e}")
75
 
76
+ # 5) Résultats
77
+ H, W = depth_map.shape
78
+ closest_distance = float(np.min(depth_map))
79
+ cy, cx = H // 2, W // 2
80
+ center_distance = float(depth_map[cy, cx])
81
 
82
+ logger.info(f"closest={closest_distance:.3f}m | center={center_distance:.3f}m")
 
 
 
 
 
 
83
 
84
+ return JSONResponse(content={
85
+ "closest_distance": closest_distance,
86
+ "center_distance": center_distance,
87
+ })
88
 
89
 
90
+ # ──────────────────────────────────────────────
91
+ @app.get("/")
92
+ def root():
93
+ return {"status": "ok", "model": MODEL_ID, "device": DEVICE}