Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import io | |
| import traceback | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="Metric3D vit_large β Metric Depth API") | |
| DEVICE = "cpu" | |
| INPUT_SIZE = (616, 616) | |
| FOCAL_CANONICAL = 1000.0 | |
| logger.info(f"Chargement metric3d_vit_large sur {DEVICE} ...") | |
| model = torch.hub.load( | |
| "yvanyin/metric3d", | |
| "metric3d_vit_large", | |
| pretrain=True, | |
| trust_repo=True, | |
| ) | |
| model = model.to(DEVICE).eval() | |
| logger.info("ModΓ¨le prΓͺt.") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(img_np: np.ndarray): | |
| H_orig, W_orig = img_np.shape[:2] | |
| focal_px = float(max(H_orig, W_orig)) * 1.2 | |
| scale = min(INPUT_SIZE[0] / H_orig, INPUT_SIZE[1] / W_orig) | |
| H_new = int(H_orig * scale) | |
| W_new = int(W_orig * scale) | |
| img_resized = cv2.resize(img_np, (W_new, H_new), interpolation=cv2.INTER_LINEAR) | |
| pad_h = INPUT_SIZE[0] - H_new | |
| pad_w = INPUT_SIZE[1] - W_new | |
| ph, pw = pad_h // 2, pad_w // 2 | |
| img_padded = np.pad( | |
| img_resized, | |
| ((ph, pad_h - ph), (pw, pad_w - pw), (0, 0)), | |
| mode="constant", | |
| constant_values=128, | |
| ) | |
| mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) | |
| std = np.array([58.395, 57.12, 57.375], dtype=np.float32) | |
| img_norm = (img_padded.astype(np.float32) - mean) / std | |
| tensor = torch.tensor(img_norm).permute(2, 0, 1).unsqueeze(0) | |
| pad_info = (ph, pad_h - ph, pw, pad_w - pw) | |
| return tensor, scale, pad_info, H_orig, W_orig, focal_px | |
| def postprocess(pred_depth_raw, scale, pad_info, H_orig, W_orig, focal_px): | |
| ph, ph2, pw, pw2 = pad_info | |
| depth_np = pred_depth_raw.squeeze().cpu().numpy() | |
| H_crop = depth_np.shape[0] - ph - ph2 | |
| W_crop = depth_np.shape[1] - pw - pw2 | |
| depth_crop = depth_np[ph: ph + H_crop, pw: pw + W_crop] | |
| depth_resized = cv2.resize(depth_crop, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR) | |
| depth_metric = depth_resized * (FOCAL_CANONICAL / focal_px) | |
| return depth_metric | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def predict(file: UploadFile = File(...)): | |
| # 1) Lecture image | |
| logger.info(f"Fichier reΓ§u : {file.filename} | content_type : {file.content_type}") | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail=f"Content-type invalide : {file.content_type}") | |
| contents = await file.read() | |
| logger.info(f"Taille fichier : {len(contents)} octets") | |
| try: | |
| pil_img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| logger.info(f"Image PIL : {pil_img.size}") | |
| except Exception as e: | |
| logger.error(f"Erreur lecture image : {e}") | |
| raise HTTPException(status_code=400, detail=f"Impossible de lire l'image : {e}") | |
| img_np = np.array(pil_img) | |
| logger.info(f"img_np shape : {img_np.shape} dtype : {img_np.dtype}") | |
| # 2) PrΓ©traitement | |
| try: | |
| tensor, scale, pad_info, H_orig, W_orig, focal_px = preprocess(img_np) | |
| logger.info(f"Tensor shape : {tensor.shape} | focal_px : {focal_px:.1f} | scale : {scale:.4f}") | |
| except Exception as e: | |
| logger.error(f"Erreur prΓ©traitement :\n{traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Erreur prΓ©traitement : {e}") | |
| # 3) InfΓ©rence | |
| try: | |
| logger.info("Lancement infΓ©rence ...") | |
| with torch.no_grad(): | |
| pred_depth, confidence, output_dict = model.inference({"input": tensor}) | |
| logger.info(f"pred_depth shape : {pred_depth.shape} | min : {pred_depth.min():.3f} | max : {pred_depth.max():.3f}") | |
| except Exception as e: | |
| logger.error(f"Erreur infΓ©rence :\n{traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Erreur infΓ©rence : {e}") | |
| # 4) Post-traitement | |
| try: | |
| depth_map = postprocess(pred_depth, scale, pad_info, H_orig, W_orig, focal_px) | |
| logger.info(f"depth_map shape : {depth_map.shape} | min : {depth_map.min():.3f} | max : {depth_map.max():.3f}") | |
| except Exception as e: | |
| logger.error(f"Erreur post-traitement :\n{traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Erreur post-traitement : {e}") | |
| # 5) RΓ©sultats | |
| closest_distance = float(np.min(depth_map)) | |
| cy, cx = H_orig // 2, W_orig // 2 | |
| center_distance = float(depth_map[cy, cx]) | |
| logger.info(f"closest_distance : {closest_distance:.3f} m | center_distance : {center_distance:.3f} m") | |
| return JSONResponse(content={ | |
| "closest_distance": closest_distance, | |
| "center_distance": center_distance, | |
| }) | |
| def root(): | |
| return {"status": "ok", "model": "metric3d_vit_large", "device": DEVICE} |