import os import cv2 import numpy as np import requests import torch import firebase_admin from fastapi import FastAPI, BackgroundTasks from pydantic import BaseModel from ultralytics import YOLO from firebase_admin import credentials, firestore # --- Setup & Environment --- def _trust_all(*args, **kwargs): pass torch.hub._check_repo_is_trusted = _trust_all os.environ['TORCH_HOME'] = '/tmp/torch_cache' os.environ['YOLO_CONFIG_DIR'] = '/tmp/ultralytics_config' app = FastAPI() # --- Model Loading --- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") print(f"🚀 Running on: {device}") try: yolo_model = YOLO('best.pt') print("✅ YOLOv8 Loaded") except Exception as e: print(f"❌ YOLO Load Error: {e}") yolo_model = None try: midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True) midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True) midas.to(device) midas.eval() transform = midas_transforms.small_transform print("✅ MiDaS Loaded") except Exception as e: print(f"❌ MiDaS Load Failed: {e}") midas = None # --- Firebase Initialization --- try: if not firebase_admin._apps: cred = credentials.Certificate("serviceAccount.json") firebase_admin.initialize_app(cred) db = firestore.client() print("✅ Firebase Connected") except Exception as e: print(f"⚠️ Firebase Error: {e}") db = None # --- Constants --- REFERENCE_SIZES = { 'id_card': 8.56, 'id_cards': 8.56, '1dinar_coin': 2.8, # ADD THIS LINE (Matches your data.yaml exactly) 'reference_coin': 2.8, 'coin': 2.8, 'a4_paper': 21.0, 'reference_paper': 21.0 } class ImageRequest(BaseModel): image_url: str delivery_id: str # --- Core Logic --- def get_depth_map(img): img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) input_batch = transform(img_rgb).to(device) with torch.no_grad(): prediction = midas(input_batch) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=img.shape[:2], mode="bicubic", align_corners=False, ).squeeze() return prediction.cpu().numpy() def perform_3d_measurement(image_url: str, delivery_id: str): if db is None: return doc_ref = db.collection("orders").document(delivery_id) try: # 1. Download Image resp = requests.get(image_url, timeout=15) if resp.status_code != 200: doc_ref.update({"status": "Failed", "reason": "Erreur téléchargement image"}) return img_array = np.asarray(bytearray(resp.content), dtype=np.uint8) img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) # 2. YOLO Detection yolo_results = yolo_model.predict(source=img, conf=0.25)[0] detected_labels = [yolo_results.names[int(b.cls[0])].lower() for b in yolo_results.boxes] print(f"DEBUG: Labels found: {detected_labels}") depth_map = get_depth_map(img) pixel_cm_ratio = None pkg_mask_points = None pkg_w_px, pkg_h_px = None, None # 3. Find Reference Object for i, box in enumerate(yolo_results.boxes): label = yolo_results.names[int(box.cls[0])].lower() if label in REFERENCE_SIZES: x1, y1, x2, y2 = box.xyxy[0].tolist() pixel_cm_ratio = (x2 - x1) / REFERENCE_SIZES[label] print(f"✅ Found Reference: {label}") break if not pixel_cm_ratio: doc_ref.update({"status": "Failed", "reason": "Objet de référence (carte/pièce) non détecté."}) return # 4. Find Package (Smart Matching for 'package', 'package-box', etc.) for i, box in enumerate(yolo_results.boxes): label = yolo_results.names[int(box.cls[0])].lower() if 'pack' in label or 'box' in label or '0 0 0' in label: # Priority: Segmentation Mask if yolo_results.masks is not None: pkg_mask_points = yolo_results.masks.xy[i] rect = cv2.minAreaRect(pkg_mask_points.astype(np.int32)) (_, _), (w, h), _ = rect pkg_w_px, pkg_h_px = w, h print("✅ Found Package Mask") # Fallback: Bounding Box else: x1, y1, x2, y2 = box.xyxy[0].tolist() pkg_w_px = x2 - x1 pkg_h_px = y2 - y1 pkg_mask_points = np.array([[x1,y1], [x2,y1], [x2,y2], [x1,y2]]) print("⚠️ Found Package Box (Fallback)") break if pkg_w_px is None: doc_ref.update({"status": "Failed", "reason": "Colis non détecté."}) return # 5. Depth & Volume Calculation mask_img = np.zeros(depth_map.shape, dtype=np.uint8) cv2.fillPoly(mask_img, [pkg_mask_points.astype(np.int32)], 1) pkg_depth_val = np.median(depth_map[mask_img == 1]) kernel = np.ones((20, 20), np.uint8) dilated = cv2.dilate(mask_img, kernel, iterations=1) ground_depth_val = np.median(depth_map[(dilated - mask_img) == 1]) depth_delta = abs(ground_depth_val - pkg_depth_val) # CRITICAL FIX: Convert to standard Python float for Firestore real_h = float(round((depth_delta / pixel_cm_ratio) * 0.5, 1)) real_w = float(round(pkg_w_px / pixel_cm_ratio, 1)) real_l = float(round(pkg_h_px / pixel_cm_ratio, 1)) if real_h < 2.0: real_h = 5.0 volume = float(round(real_w * real_l * real_h, 2)) # 6. Final Update doc_ref.update({ "volume_cm3": volume, "dimensions": f"{real_l}x{real_w}x{real_h} cm", "status": "Measured_3D", "processedAt": firestore.SERVER_TIMESTAMP }) print(f"✅ Success: {delivery_id} -> {volume} cm3") except Exception as e: print(f"❌ Error: {str(e)}") doc_ref.update({"status": "Failed", "reason": str(e)}) # --- Endpoints --- @app.get("/") def home(): return {"message": "Sahl Express AI v2.1 is Online"} @app.post("/measure") async def measure_endpoint(request: ImageRequest, background_tasks: BackgroundTasks): background_tasks.add_task(perform_3d_measurement, request.image_url, request.delivery_id) return {"status": "processing", "id": request.delivery_id} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)