arij155's picture
Update main.py
5e44d5f verified
import os
import cv2
import numpy as np
import requests
import torch
import firebase_admin
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from ultralytics import YOLO
from firebase_admin import credentials, firestore
# --- 0. THE "TOTAL BYPASS" SECURITY OVERRIDE ---
import torch.hub
# This function replaces the one that asks "(y/N)" in your logs.
# It tells the system "Yes, I trust every repository" automatically.
def dummy_check_dependencies(*args, **kwargs):
return True
# Apply the bypass before anything else loads
torch.hub.check_dependencies = dummy_check_dependencies
# Set writable directories for Hugging Face
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['YOLO_CONFIG_DIR'] = '/tmp/ultralytics_config'
# --- 1. INITIALIZE MODELS ---
app = FastAPI()
@app.get("/")
def home():
return {"status": "Sahl Express AI is Online", "region": "Tunisia"}
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("πŸš€ Starting Sahl Express Engine...")
# Load YOLOv8 (2D Segmentation)
try:
yolo_model = YOLO('best.pt')
print("βœ… YOLOv8 Loaded")
except Exception as e:
print(f"❌ YOLO Load Error: {e}")
# Load MiDaS (Depth Estimation)
try:
print("πŸ“₯ Loading MiDaS (Security Bypass Active)...")
# We load the model. With our override above, it will skip the (y/N) prompt.
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
midas.to(device)
midas.eval()
transform = midas_transforms.small_transform
print("βœ… MiDaS Loaded Successfully!")
except Exception as e:
print(f"❌ MiDaS Load Failed: {e}")
# --- 2. FIREBASE SETUP ---
try:
# Ensure serviceAccount.json is uploaded to your HF Space Files
cred = credentials.Certificate("serviceAccount.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
print("βœ… Firebase Connected")
except Exception as e:
print(f"⚠️ Firebase Error: {e}")
# Tunisian Reference Constants (cm)
REFERENCE_SIZES = {
'reference_card': 8.56, # ID Card
'reference_paper': 21.0, # A4 Paper
'reference_coin': 2.8 # 1 Dinar
}
class ImageRequest(BaseModel):
image_url: str
delivery_id: str
def get_depth_map(img):
""" Converts an image to a relative depth map """
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_batch = transform(img_rgb).to(device)
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
return prediction.cpu().numpy()
def perform_3d_measurement(image_url: str, delivery_id: str):
try:
# 1. Download Image
resp = requests.get(image_url)
img_array = np.asarray(bytearray(resp.content), dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
# 2. Run AI Models
yolo_results = yolo_model.predict(source=img, conf=0.4)[0]
depth_map = get_depth_map(img)
pixel_cm_ratio = None
pkg_mask = None
pkg_w_px, pkg_h_px = None, None
# 3. Calibration (Find ID Card/Coin)
for i, box in enumerate(yolo_results.boxes):
label = yolo_results.names[int(box.cls[0])]
if label in REFERENCE_SIZES:
x1, y1, x2, y2 = box.xyxy[0].tolist()
pixel_cm_ratio = (x2 - x1) / REFERENCE_SIZES[label]
break
# 4. Identification (Find Package)
for i, box in enumerate(yolo_results.boxes):
label = yolo_results.names[int(box.cls[0])]
if label == 'package' and yolo_results.masks is not None:
pkg_mask = yolo_results.masks.xy[i]
rect = cv2.minAreaRect(pkg_mask.astype(np.int32))
(_, _), (w, h), _ = rect
pkg_w_px, pkg_h_px = w, h
break
# 5. Calculate 3D Volume
if pixel_cm_ratio and pkg_w_px is not None:
mask_img = np.zeros(depth_map.shape, dtype=np.uint8)
cv2.fillPoly(mask_img, [pkg_mask.astype(np.int32)], 1)
pkg_depth_val = np.median(depth_map[mask_img == 1])
# Estimate ground depth by looking at the area outside the package
kernel = np.ones((30,30), np.uint8)
dilated = cv2.dilate(mask_img, kernel, iterations=2)
ground_depth_val = np.median(depth_map[(dilated - mask_img) == 1])
# Height calculation (Tuning constant 0.5 is for MiDaS relative scaling)
depth_delta = abs(ground_depth_val - pkg_depth_val)
real_h = round((depth_delta / pixel_cm_ratio) * 0.5, 1)
real_w = round(pkg_w_px / pixel_cm_ratio, 1)
real_l = round(pkg_h_px / pixel_cm_ratio, 1)
if real_h < 0.5: real_h = 1.0
volume = round(real_w * real_l * real_h, 2)
# 6. Update Firebase
db.collection("orders").document(delivery_id).update({
"volume_cm3": volume,
"dimensions": f"{real_l}x{real_w}x{real_h} cm",
"status": "Measured_3D"
})
print(f"πŸ“¦ Success: {delivery_id} | Vol: {volume}cm3")
except Exception as e:
print(f"❌ Measurement Error: {e}")
@app.post("/measure")
async def measure_endpoint(request: ImageRequest, background_tasks: BackgroundTasks):
background_tasks.add_task(perform_3d_measurement, request.image_url, request.delivery_id)
return {"status": "processing"}
if __name__ == "__main__":
import uvicorn
# Hugging Face Spaces use port 7860
uvicorn.run(app, host="0.0.0.0", port=7860)