sahl-express / main.py
arij155's picture
Update main.py
ca8962c verified
import os
import cv2
import numpy as np
import requests
import torch
import firebase_admin
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from ultralytics import YOLO
from firebase_admin import credentials, firestore
# --- Setup & Environment ---
def _trust_all(*args, **kwargs):
pass
torch.hub._check_repo_is_trusted = _trust_all
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['YOLO_CONFIG_DIR'] = '/tmp/ultralytics_config'
app = FastAPI()
# --- Model Loading ---
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"🚀 Running on: {device}")
try:
yolo_model = YOLO('best.pt')
print("✅ YOLOv8 Loaded")
except Exception as e:
print(f"❌ YOLO Load Error: {e}")
yolo_model = None
try:
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
midas.to(device)
midas.eval()
transform = midas_transforms.small_transform
print("✅ MiDaS Loaded")
except Exception as e:
print(f"❌ MiDaS Load Failed: {e}")
midas = None
# --- Firebase Initialization ---
try:
if not firebase_admin._apps:
cred = credentials.Certificate("serviceAccount.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
print("✅ Firebase Connected")
except Exception as e:
print(f"⚠️ Firebase Error: {e}")
db = None
# --- Constants ---
REFERENCE_SIZES = {
'id_card': 8.56,
'id_cards': 8.56,
'1dinar_coin': 2.8, # ADD THIS LINE (Matches your data.yaml exactly)
'reference_coin': 2.8,
'coin': 2.8,
'a4_paper': 21.0,
'reference_paper': 21.0
}
class ImageRequest(BaseModel):
image_url: str
delivery_id: str
# --- Core Logic ---
def get_depth_map(img):
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_batch = transform(img_rgb).to(device)
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
return prediction.cpu().numpy()
def perform_3d_measurement(image_url: str, delivery_id: str):
if db is None: return
doc_ref = db.collection("orders").document(delivery_id)
try:
# 1. Download Image
resp = requests.get(image_url, timeout=15)
if resp.status_code != 200:
doc_ref.update({"status": "Failed", "reason": "Erreur téléchargement image"})
return
img_array = np.asarray(bytearray(resp.content), dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
# 2. YOLO Detection
yolo_results = yolo_model.predict(source=img, conf=0.25)[0]
detected_labels = [yolo_results.names[int(b.cls[0])].lower() for b in yolo_results.boxes]
print(f"DEBUG: Labels found: {detected_labels}")
depth_map = get_depth_map(img)
pixel_cm_ratio = None
pkg_mask_points = None
pkg_w_px, pkg_h_px = None, None
# 3. Find Reference Object
for i, box in enumerate(yolo_results.boxes):
label = yolo_results.names[int(box.cls[0])].lower()
if label in REFERENCE_SIZES:
x1, y1, x2, y2 = box.xyxy[0].tolist()
pixel_cm_ratio = (x2 - x1) / REFERENCE_SIZES[label]
print(f"✅ Found Reference: {label}")
break
if not pixel_cm_ratio:
doc_ref.update({"status": "Failed", "reason": "Objet de référence (carte/pièce) non détecté."})
return
# 4. Find Package (Smart Matching for 'package', 'package-box', etc.)
for i, box in enumerate(yolo_results.boxes):
label = yolo_results.names[int(box.cls[0])].lower()
if 'pack' in label or 'box' in label or '0 0 0' in label:
# Priority: Segmentation Mask
if yolo_results.masks is not None:
pkg_mask_points = yolo_results.masks.xy[i]
rect = cv2.minAreaRect(pkg_mask_points.astype(np.int32))
(_, _), (w, h), _ = rect
pkg_w_px, pkg_h_px = w, h
print("✅ Found Package Mask")
# Fallback: Bounding Box
else:
x1, y1, x2, y2 = box.xyxy[0].tolist()
pkg_w_px = x2 - x1
pkg_h_px = y2 - y1
pkg_mask_points = np.array([[x1,y1], [x2,y1], [x2,y2], [x1,y2]])
print("⚠️ Found Package Box (Fallback)")
break
if pkg_w_px is None:
doc_ref.update({"status": "Failed", "reason": "Colis non détecté."})
return
# 5. Depth & Volume Calculation
mask_img = np.zeros(depth_map.shape, dtype=np.uint8)
cv2.fillPoly(mask_img, [pkg_mask_points.astype(np.int32)], 1)
pkg_depth_val = np.median(depth_map[mask_img == 1])
kernel = np.ones((20, 20), np.uint8)
dilated = cv2.dilate(mask_img, kernel, iterations=1)
ground_depth_val = np.median(depth_map[(dilated - mask_img) == 1])
depth_delta = abs(ground_depth_val - pkg_depth_val)
# CRITICAL FIX: Convert to standard Python float for Firestore
real_h = float(round((depth_delta / pixel_cm_ratio) * 0.5, 1))
real_w = float(round(pkg_w_px / pixel_cm_ratio, 1))
real_l = float(round(pkg_h_px / pixel_cm_ratio, 1))
if real_h < 2.0: real_h = 5.0
volume = float(round(real_w * real_l * real_h, 2))
# 6. Final Update
doc_ref.update({
"volume_cm3": volume,
"dimensions": f"{real_l}x{real_w}x{real_h} cm",
"status": "Measured_3D",
"processedAt": firestore.SERVER_TIMESTAMP
})
print(f"✅ Success: {delivery_id} -> {volume} cm3")
except Exception as e:
print(f"❌ Error: {str(e)}")
doc_ref.update({"status": "Failed", "reason": str(e)})
# --- Endpoints ---
@app.get("/")
def home():
return {"message": "Sahl Express AI v2.1 is Online"}
@app.post("/measure")
async def measure_endpoint(request: ImageRequest, background_tasks: BackgroundTasks):
background_tasks.add_task(perform_3d_measurement, request.image_url, request.delivery_id)
return {"status": "processing", "id": request.delivery_id}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)