Spaces:
Running
Running
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import io | |
| import traceback | |
| import logging | |
| import torch | |
| from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="Depth Pro (Apple) β Metric Depth API") | |
| DEVICE = "cpu" | |
| MODEL_ID = "apple/DepthPro-hf" | |
| logger.info(f"Chargement {MODEL_ID} sur {DEVICE} ...") | |
| processor = DepthProImageProcessorFast.from_pretrained(MODEL_ID) | |
| model = DepthProForDepthEstimation.from_pretrained(MODEL_ID) | |
| model = model.to(DEVICE).eval() | |
| logger.info("ModΓ¨le Depth Pro prΓͺt.") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def predict(file: UploadFile = File(...)): | |
| # 1) Lecture image | |
| logger.info(f"Fichier reΓ§u : {file.filename} | type : {file.content_type}") | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail=f"Content-type invalide : {file.content_type}") | |
| contents = await file.read() | |
| logger.info(f"Taille : {len(contents)} octets") | |
| try: | |
| pil_img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| logger.info(f"Image : {pil_img.size[0]}x{pil_img.size[1]}") | |
| except Exception as e: | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=400, detail=f"Impossible de lire l'image : {e}") | |
| # 2) PrΓ©traitement | |
| try: | |
| inputs = processor(images=pil_img, return_tensors="pt").to(DEVICE) | |
| logger.info(f"Inputs prΓͺts : {list(inputs.keys())}") | |
| except Exception as e: | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=f"Erreur prΓ©traitement : {e}") | |
| # 3) InfΓ©rence | |
| try: | |
| logger.info("InfΓ©rence en cours ...") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logger.info("InfΓ©rence terminΓ©e.") | |
| except Exception as e: | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=f"Erreur infΓ©rence : {e}") | |
| # 4) Post-traitement β profondeur en mΓ¨tres Γ la rΓ©solution originale | |
| try: | |
| post = processor.post_process_depth_estimation( | |
| outputs, | |
| target_sizes=[(pil_img.height, pil_img.width)], | |
| ) | |
| depth_map = post[0]["predicted_depth"].squeeze().cpu().numpy() # [H, W] float32, mètres | |
| logger.info(f"depth_map shape={depth_map.shape} min={depth_map.min():.3f} max={depth_map.max():.3f}") | |
| except Exception as e: | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=f"Erreur post-traitement : {e}") | |
| # 5) RΓ©sultats | |
| H, W = depth_map.shape | |
| closest_distance = float(np.min(depth_map)) | |
| cy, cx = H // 2, W // 2 | |
| center_distance = float(depth_map[cy, cx]) | |
| logger.info(f"closest={closest_distance:.3f}m | center={center_distance:.3f}m") | |
| return JSONResponse(content={ | |
| "closest_distance": closest_distance, | |
| "center_distance": center_distance, | |
| }) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return {"status": "ok", "model": MODEL_ID, "device": DEVICE} |