Spaces:
Running
Running
| import os | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| import torch | |
| import firebase_admin | |
| from fastapi import FastAPI, BackgroundTasks | |
| from pydantic import BaseModel | |
| from ultralytics import YOLO | |
| from firebase_admin import credentials, firestore | |
| # --- Setup & Environment --- | |
| def _trust_all(*args, **kwargs): | |
| pass | |
| torch.hub._check_repo_is_trusted = _trust_all | |
| os.environ['TORCH_HOME'] = '/tmp/torch_cache' | |
| os.environ['YOLO_CONFIG_DIR'] = '/tmp/ultralytics_config' | |
| app = FastAPI() | |
| # --- Model Loading --- | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| print(f"🚀 Running on: {device}") | |
| try: | |
| yolo_model = YOLO('best.pt') | |
| print("✅ YOLOv8 Loaded") | |
| except Exception as e: | |
| print(f"❌ YOLO Load Error: {e}") | |
| yolo_model = None | |
| try: | |
| midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True) | |
| midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True) | |
| midas.to(device) | |
| midas.eval() | |
| transform = midas_transforms.small_transform | |
| print("✅ MiDaS Loaded") | |
| except Exception as e: | |
| print(f"❌ MiDaS Load Failed: {e}") | |
| midas = None | |
| # --- Firebase Initialization --- | |
| try: | |
| if not firebase_admin._apps: | |
| cred = credentials.Certificate("serviceAccount.json") | |
| firebase_admin.initialize_app(cred) | |
| db = firestore.client() | |
| print("✅ Firebase Connected") | |
| except Exception as e: | |
| print(f"⚠️ Firebase Error: {e}") | |
| db = None | |
| # --- Constants --- | |
| REFERENCE_SIZES = { | |
| 'id_card': 8.56, | |
| 'id_cards': 8.56, | |
| '1dinar_coin': 2.8, # ADD THIS LINE (Matches your data.yaml exactly) | |
| 'reference_coin': 2.8, | |
| 'coin': 2.8, | |
| 'a4_paper': 21.0, | |
| 'reference_paper': 21.0 | |
| } | |
| class ImageRequest(BaseModel): | |
| image_url: str | |
| delivery_id: str | |
| # --- Core Logic --- | |
| def get_depth_map(img): | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| input_batch = transform(img_rgb).to(device) | |
| with torch.no_grad(): | |
| prediction = midas(input_batch) | |
| prediction = torch.nn.functional.interpolate( | |
| prediction.unsqueeze(1), | |
| size=img.shape[:2], | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze() | |
| return prediction.cpu().numpy() | |
| def perform_3d_measurement(image_url: str, delivery_id: str): | |
| if db is None: return | |
| doc_ref = db.collection("orders").document(delivery_id) | |
| try: | |
| # 1. Download Image | |
| resp = requests.get(image_url, timeout=15) | |
| if resp.status_code != 200: | |
| doc_ref.update({"status": "Failed", "reason": "Erreur téléchargement image"}) | |
| return | |
| img_array = np.asarray(bytearray(resp.content), dtype=np.uint8) | |
| img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | |
| # 2. YOLO Detection | |
| yolo_results = yolo_model.predict(source=img, conf=0.25)[0] | |
| detected_labels = [yolo_results.names[int(b.cls[0])].lower() for b in yolo_results.boxes] | |
| print(f"DEBUG: Labels found: {detected_labels}") | |
| depth_map = get_depth_map(img) | |
| pixel_cm_ratio = None | |
| pkg_mask_points = None | |
| pkg_w_px, pkg_h_px = None, None | |
| # 3. Find Reference Object | |
| for i, box in enumerate(yolo_results.boxes): | |
| label = yolo_results.names[int(box.cls[0])].lower() | |
| if label in REFERENCE_SIZES: | |
| x1, y1, x2, y2 = box.xyxy[0].tolist() | |
| pixel_cm_ratio = (x2 - x1) / REFERENCE_SIZES[label] | |
| print(f"✅ Found Reference: {label}") | |
| break | |
| if not pixel_cm_ratio: | |
| doc_ref.update({"status": "Failed", "reason": "Objet de référence (carte/pièce) non détecté."}) | |
| return | |
| # 4. Find Package (Smart Matching for 'package', 'package-box', etc.) | |
| for i, box in enumerate(yolo_results.boxes): | |
| label = yolo_results.names[int(box.cls[0])].lower() | |
| if 'pack' in label or 'box' in label or '0 0 0' in label: | |
| # Priority: Segmentation Mask | |
| if yolo_results.masks is not None: | |
| pkg_mask_points = yolo_results.masks.xy[i] | |
| rect = cv2.minAreaRect(pkg_mask_points.astype(np.int32)) | |
| (_, _), (w, h), _ = rect | |
| pkg_w_px, pkg_h_px = w, h | |
| print("✅ Found Package Mask") | |
| # Fallback: Bounding Box | |
| else: | |
| x1, y1, x2, y2 = box.xyxy[0].tolist() | |
| pkg_w_px = x2 - x1 | |
| pkg_h_px = y2 - y1 | |
| pkg_mask_points = np.array([[x1,y1], [x2,y1], [x2,y2], [x1,y2]]) | |
| print("⚠️ Found Package Box (Fallback)") | |
| break | |
| if pkg_w_px is None: | |
| doc_ref.update({"status": "Failed", "reason": "Colis non détecté."}) | |
| return | |
| # 5. Depth & Volume Calculation | |
| mask_img = np.zeros(depth_map.shape, dtype=np.uint8) | |
| cv2.fillPoly(mask_img, [pkg_mask_points.astype(np.int32)], 1) | |
| pkg_depth_val = np.median(depth_map[mask_img == 1]) | |
| kernel = np.ones((20, 20), np.uint8) | |
| dilated = cv2.dilate(mask_img, kernel, iterations=1) | |
| ground_depth_val = np.median(depth_map[(dilated - mask_img) == 1]) | |
| depth_delta = abs(ground_depth_val - pkg_depth_val) | |
| # CRITICAL FIX: Convert to standard Python float for Firestore | |
| real_h = float(round((depth_delta / pixel_cm_ratio) * 0.5, 1)) | |
| real_w = float(round(pkg_w_px / pixel_cm_ratio, 1)) | |
| real_l = float(round(pkg_h_px / pixel_cm_ratio, 1)) | |
| if real_h < 2.0: real_h = 5.0 | |
| volume = float(round(real_w * real_l * real_h, 2)) | |
| # 6. Final Update | |
| doc_ref.update({ | |
| "volume_cm3": volume, | |
| "dimensions": f"{real_l}x{real_w}x{real_h} cm", | |
| "status": "Measured_3D", | |
| "processedAt": firestore.SERVER_TIMESTAMP | |
| }) | |
| print(f"✅ Success: {delivery_id} -> {volume} cm3") | |
| except Exception as e: | |
| print(f"❌ Error: {str(e)}") | |
| doc_ref.update({"status": "Failed", "reason": str(e)}) | |
| # --- Endpoints --- | |
| def home(): | |
| return {"message": "Sahl Express AI v2.1 is Online"} | |
| async def measure_endpoint(request: ImageRequest, background_tasks: BackgroundTasks): | |
| background_tasks.add_task(perform_3d_measurement, request.image_url, request.delivery_id) | |
| return {"status": "processing", "id": request.delivery_id} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |