web / main.py
pestdetectionai's picture
Sync from GitHub via hub-sync
a7d0f38 verified
from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Query
from fastapi.responses import JSONResponse, FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
from pathlib import Path
from datetime import datetime, timedelta
from collections import Counter, defaultdict
from dotenv import load_dotenv
import uuid
import cv2
import numpy as np
import os
import base64
import json
import firebase_admin
from firebase_admin import credentials, db
import cloudinary
import cloudinary.uploader
# =========================================================
# ENV
# =========================================================
load_dotenv()
# =========================================================
# CONFIG
# =========================================================
BASE_DIR = Path(__file__).resolve().parent
MODEL_DIR = BASE_DIR / "models"
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "best.pt").strip()
MODEL_PATH = Path(os.getenv("MODEL_PATH", str(MODEL_DIR / MODEL_FILENAME))).resolve()
HF_MODEL_REPO = os.getenv(
"HF_MODEL_REPO",
"underdogquality/yolo11s-pest-detection"
).strip()
HF_MODEL_FILE = os.getenv(
"HF_MODEL_FILE",
MODEL_FILENAME
).strip()
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() or None
AUTO_DOWNLOAD_MODEL = os.getenv(
"AUTO_DOWNLOAD_MODEL",
"true"
).strip().lower() in {"1", "true", "yes", "on"}
UPLOAD_DIR = BASE_DIR / "uploads"
RESULT_DIR = BASE_DIR / "results"
DEBUG_DIR = BASE_DIR / "debug"
WEB_DIR = BASE_DIR / "web"
MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
RESULT_DIR.mkdir(parents=True, exist_ok=True)
DEBUG_DIR.mkdir(parents=True, exist_ok=True)
WEB_DIR.mkdir(parents=True, exist_ok=True)
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
APP_PUBLIC_BASE_URL = os.getenv("APP_PUBLIC_BASE_URL", "").strip()
FIREBASE_DATABASE_URL = os.getenv("FIREBASE_DATABASE_URL", "").strip()
FIREBASE_SERVICE_ACCOUNT_PATH = os.getenv("FIREBASE_SERVICE_ACCOUNT_PATH", "").strip()
FIREBASE_SERVICE_ACCOUNT_JSON_B64 = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON_B64", "").strip()
FIREBASE_LOGS_PATH = "/api/analyze/logs"
CLOUDINARY_CLOUD_NAME = os.getenv("CLOUDINARY_CLOUD_NAME", "").strip()
CLOUDINARY_API_KEY = os.getenv("CLOUDINARY_API_KEY", "").strip()
CLOUDINARY_API_SECRET = os.getenv("CLOUDINARY_API_SECRET", "").strip()
CLOUDINARY_FOLDER = os.getenv("CLOUDINARY_FOLDER", "smart-pest-detection").strip()
YOLO_CONFIDENCE = 0.08
YOLO_IOU = 0.40
YOLO_IMAGE_SIZE = 1280
MAX_YOLO_BOX_AREA_RATIO = 0.10
LOW_CONF_LARGE_BOX_CONF = 0.20
LOW_CONF_LARGE_BOX_AREA_RATIO = 0.040
VISUAL_COUNTER_ENABLED = True
# This is the main duplicate rule:
# if the fallback pest is almost inside a YOLO pest, remove fallback duplicate.
UNKNOWN_OVERLAP_WITH_YOLO = 0.45
FINAL_NMS_IOU = 0.10
GREEN = (0, 255, 0)
ORANGE = (0, 165, 255)
BLACK = (0, 0, 0)
# =========================================================
# APP INIT
# =========================================================
app = FastAPI(
title="Smart Pest Trap Detection API",
description="YOLO pest identification + visual pest counter + Cloudinary storage + Firebase logs + static website.",
version="11.0.0"
)
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
app.mount("/debug", StaticFiles(directory=str(DEBUG_DIR)), name="debug")
app.mount("/web", StaticFiles(directory=str(WEB_DIR)), name="web")
# =========================================================
# GLOBALS
# =========================================================
model = None
firebase_ready = False
cloudinary_ready = False
# =========================================================
# STARTUP
# =========================================================
@app.on_event("startup")
def startup():
load_model()
init_firebase()
init_cloudinary()
def ensure_model_available():
MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
if MODEL_PATH.exists() and MODEL_PATH.stat().st_size > 0:
print("====================================")
print("[MODEL] Local model found")
print(f"[MODEL] Path: {MODEL_PATH}")
print("====================================")
return
if not AUTO_DOWNLOAD_MODEL:
raise RuntimeError(
f"Model not found: {MODEL_PATH}\n"
"AUTO_DOWNLOAD_MODEL=false, so boot download is disabled."
)
print("====================================")
print("[MODEL] Local model not found")
print("[MODEL] Downloading model from Hugging Face...")
print(f"[MODEL] Repo: {HF_MODEL_REPO}")
print(f"[MODEL] File: {HF_MODEL_FILE}")
print(f"[MODEL] Save to: {MODEL_PATH.parent}")
print("====================================")
try:
downloaded_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=HF_MODEL_FILE,
local_dir=str(MODEL_PATH.parent),
local_dir_use_symlinks=False,
token=HF_TOKEN
)
downloaded_path = Path(downloaded_path).resolve()
if downloaded_path != MODEL_PATH and downloaded_path.exists():
MODEL_PATH.write_bytes(downloaded_path.read_bytes())
if not MODEL_PATH.exists() or MODEL_PATH.stat().st_size <= 0:
raise RuntimeError(f"Downloaded model is missing or empty: {MODEL_PATH}")
print("====================================")
print("[MODEL] Download complete")
print(f"[MODEL] Path: {MODEL_PATH}")
print("====================================")
except Exception as e:
raise RuntimeError(
f"Model download failed.\n"
f"Repo: {HF_MODEL_REPO}\n"
f"File: {HF_MODEL_FILE}\n"
f"Target: {MODEL_PATH}\n"
f"Error: {e}"
)
def load_model():
global model
ensure_model_available()
print("====================================")
print("[MODEL] Loading pest detection model")
print(f"[MODEL] Path: {MODEL_PATH}")
print("====================================")
model = YOLO(str(MODEL_PATH))
print("[MODEL] Loaded successfully")
def init_firebase():
global firebase_ready
if not FIREBASE_DATABASE_URL:
print("[FIREBASE] Disabled: FIREBASE_DATABASE_URL is missing")
firebase_ready = False
return
try:
if firebase_admin._apps:
firebase_ready = True
print("[FIREBASE] Already initialized")
return
if FIREBASE_SERVICE_ACCOUNT_JSON_B64:
decoded = base64.b64decode(FIREBASE_SERVICE_ACCOUNT_JSON_B64).decode("utf-8")
service_account_info = json.loads(decoded)
cred = credentials.Certificate(service_account_info)
print("[FIREBASE] Using FIREBASE_SERVICE_ACCOUNT_JSON_B64")
elif FIREBASE_SERVICE_ACCOUNT_PATH:
service_account_path = Path(FIREBASE_SERVICE_ACCOUNT_PATH)
if not service_account_path.is_absolute():
service_account_path = BASE_DIR / service_account_path
if not service_account_path.exists():
print(f"[FIREBASE] Service account file not found: {service_account_path}")
firebase_ready = False
return
cred = credentials.Certificate(str(service_account_path))
print(f"[FIREBASE] Using service account file: {service_account_path}")
else:
print("[FIREBASE] Disabled: service account is missing")
firebase_ready = False
return
firebase_admin.initialize_app(
cred,
{
"databaseURL": FIREBASE_DATABASE_URL
}
)
firebase_ready = True
print("[FIREBASE] Initialized successfully")
except Exception as e:
firebase_ready = False
print(f"[FIREBASE] Init failed: {e}")
def init_cloudinary():
global cloudinary_ready
if not CLOUDINARY_CLOUD_NAME or not CLOUDINARY_API_KEY or not CLOUDINARY_API_SECRET:
cloudinary_ready = False
print("[CLOUDINARY] Disabled: missing CLOUDINARY_CLOUD_NAME/API_KEY/API_SECRET")
return
try:
cloudinary.config(
cloud_name=CLOUDINARY_CLOUD_NAME,
api_key=CLOUDINARY_API_KEY,
api_secret=CLOUDINARY_API_SECRET,
secure=True
)
cloudinary_ready = True
print("[CLOUDINARY] Initialized successfully")
except Exception as e:
cloudinary_ready = False
print(f"[CLOUDINARY] Init failed: {e}")
# =========================================================
# BASIC HELPERS
# =========================================================
def now_dt():
return datetime.now()
def now_string():
return now_dt().strftime("%Y-%m-%d %H:%M")
def now_iso():
return now_dt().isoformat(timespec="seconds")
def now_timestamp_ms():
return int(now_dt().timestamp() * 1000)
def get_base_url(request: Request):
if APP_PUBLIC_BASE_URL:
return APP_PUBLIC_BASE_URL.rstrip("/")
return str(request.base_url).rstrip("/")
def validate_image_file(file: UploadFile):
filename = file.filename or ""
ext = Path(filename).suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Invalid image type. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
)
return ext
async def save_upload(file: UploadFile, ext: str) -> Path:
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
unique_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex}{ext}"
save_path = UPLOAD_DIR / unique_name
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="Uploaded image is empty.")
with open(save_path, "wb") as buffer:
buffer.write(content)
if not save_path.exists() or save_path.stat().st_size <= 0:
raise HTTPException(
status_code=500,
detail=f"Upload save failed: {save_path}"
)
return save_path
def safe_label(label: str):
return label.replace("_", " ").strip()
def get_box(det):
b = det["box"]
return [float(b["x1"]), float(b["y1"]), float(b["x2"]), float(b["y2"])]
def box_area(box):
x1, y1, x2, y2 = box
return max(0.0, x2 - x1) * max(0.0, y2 - y1)
def clamp_box(box, width, height):
x1, y1, x2, y2 = box
x1 = max(0, min(float(x1), width - 1))
y1 = max(0, min(float(y1), height - 1))
x2 = max(0, min(float(x2), width - 1))
y2 = max(0, min(float(y2), height - 1))
if x2 < x1:
x1, x2 = x2, x1
if y2 < y1:
y1, y2 = y2, y1
return [x1, y1, x2, y2]
def expand_box(box, pad, width, height):
x1, y1, x2, y2 = box
return clamp_box(
[x1 - pad, y1 - pad, x2 + pad, y2 + pad],
width,
height
)
def iou(box_a, box_b):
ax1, ay1, ax2, ay2 = box_a
bx1, by1, bx2, by2 = box_b
ix1 = max(ax1, bx1)
iy1 = max(ay1, by1)
ix2 = min(ax2, bx2)
iy2 = min(ay2, by2)
iw = max(0.0, ix2 - ix1)
ih = max(0.0, iy2 - iy1)
inter = iw * ih
union = box_area(box_a) + box_area(box_b) - inter
if union <= 0:
return 0.0
return inter / union
def overlap_ratio_small(box_a, box_b):
ax1, ay1, ax2, ay2 = box_a
bx1, by1, bx2, by2 = box_b
ix1 = max(ax1, bx1)
iy1 = max(ay1, by1)
ix2 = min(ax2, bx2)
iy2 = min(ay2, by2)
iw = max(0.0, ix2 - ix1)
ih = max(0.0, iy2 - iy1)
inter = iw * ih
smaller = min(box_area(box_a), box_area(box_b))
if smaller <= 0:
return 0.0
return inter / smaller
def nms_detections(detections, iou_threshold=0.10, class_aware=False):
if not detections:
return []
detections = sorted(detections, key=lambda d: float(d.get("confidence", 0)), reverse=True)
kept = []
while detections:
best = detections.pop(0)
kept.append(best)
remaining = []
for det in detections:
overlap = iou(get_box(best), get_box(det))
if class_aware:
if best["type"] == det["type"] and overlap > iou_threshold:
continue
else:
if overlap > iou_threshold:
continue
remaining.append(det)
detections = remaining
return kept
# =========================================================
# CONFIDENCE HELPERS
# =========================================================
def compute_avg_confidence(item: dict) -> float:
detections = item.get("detections", []) or []
valid = []
for det in detections:
try:
conf = float(det.get("confidence", 0) or 0)
if conf > 0:
valid.append(conf)
except Exception:
pass
if not valid:
return 0.0
return round(sum(valid) / len(valid), 4)
def compute_top_confidence(item: dict) -> float:
detections = item.get("detections", []) or []
valid = []
for det in detections:
try:
conf = float(det.get("confidence", 0) or 0)
if conf > 0:
valid.append(conf)
except Exception:
pass
if not valid:
return 0.0
return round(max(valid), 4)
# =========================================================
# CLOUDINARY HELPERS
# =========================================================
def upload_image_to_cloudinary(file_path: Path, folder_name: str, public_id_prefix: str):
if not cloudinary_ready:
print("[CLOUDINARY] Skipped: cloudinary_ready=False")
return None
if not file_path:
print("[CLOUDINARY] Skipped: file_path is None")
return None
file_path = Path(file_path)
if not file_path.exists():
print(f"[CLOUDINARY] Skipped: file does not exist: {file_path}")
return None
if file_path.stat().st_size <= 0:
print(f"[CLOUDINARY] Skipped: file is empty: {file_path}")
return None
try:
public_id = f"{public_id_prefix}_{file_path.stem}"
result = cloudinary.uploader.upload(
str(file_path.resolve()),
folder=f"{CLOUDINARY_FOLDER}/{folder_name}",
public_id=public_id,
resource_type="image",
overwrite=True
)
return {
"secure_url": result.get("secure_url"),
"url": result.get("url"),
"public_id": result.get("public_id"),
"asset_id": result.get("asset_id"),
"format": result.get("format"),
"bytes": result.get("bytes"),
"width": result.get("width"),
"height": result.get("height")
}
except Exception as e:
print(f"[CLOUDINARY] Upload failed for {file_path}: {e}")
return None
def upload_analysis_images_to_cloudinary(uploaded_path: Path, result_image_path: Path, debug_mask_path: Path | None):
timestamp_folder = datetime.now().strftime("%Y/%m/%d")
original_cloud = upload_image_to_cloudinary(
uploaded_path,
f"{timestamp_folder}/original",
"original"
)
annotated_cloud = upload_image_to_cloudinary(
result_image_path,
f"{timestamp_folder}/annotated",
"annotated"
)
debug_cloud = None
if debug_mask_path:
debug_cloud = upload_image_to_cloudinary(
debug_mask_path,
f"{timestamp_folder}/debug",
"debug"
)
return original_cloud, annotated_cloud, debug_cloud
# =========================================================
# IMAGE PREPROCESSING
# =========================================================
def remove_colored_markup_if_present(image):
"""
Removes red user marks and green previous YOLO boxes/labels if a marked image is re-uploaded.
Clean original frames are still best.
"""
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
lower_red1 = np.array([0, 70, 70])
upper_red1 = np.array([14, 255, 255])
lower_red2 = np.array([165, 70, 70])
upper_red2 = np.array([180, 255, 255])
red_mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
red_mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
red_mask = cv2.bitwise_or(red_mask1, red_mask2)
lower_green = np.array([35, 50, 50])
upper_green = np.array([95, 255, 255])
green_mask = cv2.inRange(hsv, lower_green, upper_green)
mask = cv2.bitwise_or(red_mask, green_mask)
if cv2.countNonZero(mask) < 50:
return image
mask = cv2.dilate(mask, np.ones((3, 3), np.uint8), iterations=1)
return cv2.inpaint(image, mask, 5, cv2.INPAINT_TELEA)
def enhance_image(image):
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(
clipLimit=3.5,
tileGridSize=(8, 8)
)
l2 = clahe.apply(l)
lab2 = cv2.merge((l2, a, b))
enhanced = cv2.cvtColor(lab2, cv2.COLOR_LAB2BGR)
blur = cv2.GaussianBlur(enhanced, (0, 0), 1.0)
sharp = cv2.addWeighted(enhanced, 1.8, blur, -0.8, 0)
return sharp
def find_trap_floor_crop(image):
h, w = image.shape[:2]
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (7, 7), 0)
_, thresh = cv2.threshold(
blur,
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU
)
contours, _ = cv2.findContours(
thresh,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
)
if contours:
contours = sorted(contours, key=cv2.contourArea, reverse=True)
for cnt in contours[:10]:
x, y, cw, ch = cv2.boundingRect(cnt)
area = cw * ch
image_area = w * h
if area > image_area * 0.16 and cw > w * 0.25 and ch > h * 0.25:
pad = 2
x1 = max(0, x - pad)
y1 = max(0, y - pad)
x2 = min(w, x + cw + pad)
y2 = min(h, y + ch + pad)
return image[y1:y2, x1:x2].copy(), x1, y1
x1 = int(w * 0.16)
y1 = int(h * 0.24)
x2 = int(w * 0.82)
y2 = int(h * 0.82)
return image[y1:y2, x1:x2].copy(), x1, y1
# =========================================================
# YOLO DETECTION
# =========================================================
def yolo_predict(image, offset_x=0, offset_y=0, source_name="image"):
results = model.predict(
source=image,
imgsz=YOLO_IMAGE_SIZE,
conf=YOLO_CONFIDENCE,
iou=YOLO_IOU,
verbose=False
)
detections = []
if not results:
return detections
result = results[0]
if result.boxes is None or len(result.boxes) == 0:
return detections
names = result.names
for box in result.boxes:
cls_id = int(box.cls[0].item())
confidence = float(box.conf[0].item())
label = safe_label(names.get(cls_id, str(cls_id)))
xyxy = box.xyxy[0].cpu().numpy().astype(float)
x1, y1, x2, y2 = xyxy.tolist()
detections.append({
"type": label,
"confidence": round(confidence, 4),
"source": source_name,
"box": {
"x1": round(x1 + offset_x, 2),
"y1": round(y1 + offset_y, 2),
"x2": round(x2 + offset_x, 2),
"y2": round(y2 + offset_y, 2)
}
})
return detections
def yolo_tiled(image, offset_x=0, offset_y=0):
detections = []
h, w = image.shape[:2]
tile_size = 448
overlap = 220
step = tile_size - overlap
y_positions = list(range(0, max(1, h - tile_size + 1), step))
x_positions = list(range(0, max(1, w - tile_size + 1), step))
if not y_positions:
y_positions = [0]
if not x_positions:
x_positions = [0]
last_y = max(0, h - tile_size)
last_x = max(0, w - tile_size)
if y_positions[-1] != last_y:
y_positions.append(last_y)
if x_positions[-1] != last_x:
x_positions.append(last_x)
for y in y_positions:
for x in x_positions:
tile = image[y:y + tile_size, x:x + tile_size].copy()
tile_detections = yolo_predict(
tile,
offset_x=offset_x + x,
offset_y=offset_y + y,
source_name="tile"
)
detections.extend(tile_detections)
return detections
def filter_bad_yolo_boxes(detections, image_width, image_height):
filtered = []
image_area = image_width * image_height
for det in detections:
box = get_box(det)
area_ratio = box_area(box) / max(image_area, 1)
conf = det["confidence"]
bw = box[2] - box[0]
bh = box[3] - box[1]
if bw < 4 or bh < 4:
continue
if area_ratio > MAX_YOLO_BOX_AREA_RATIO:
continue
if conf < LOW_CONF_LARGE_BOX_CONF and area_ratio > LOW_CONF_LARGE_BOX_AREA_RATIO:
continue
filtered.append(det)
return filtered
# =========================================================
# HARD VISUAL COUNTER
# =========================================================
def hard_visual_counter(original, floor_crop, floor_x, floor_y):
"""
This is the important part.
It counts visible dark pests using a hard visual threshold from the trap floor.
For your sample image, it should find 4 dark components.
YOLO will classify one, and this fallback will add the other 3 as unknown_pest.
"""
detections = []
crop_h, crop_w = floor_crop.shape[:2]
original_h, original_w = original.shape[:2]
gray = cv2.cvtColor(floor_crop, cv2.COLOR_BGR2GRAY)
# Hard threshold is intentional.
# Your missed pests are visibly dark. Dynamic threshold from the floor tends to include stains.
# This threshold catches the 4 visible dark pest bodies/wings in the sample.
threshold_value = int(os.getenv("VISUAL_DARK_THRESHOLD", "112"))
mask = cv2.threshold(
gray,
threshold_value,
255,
cv2.THRESH_BINARY_INV
)[1]
# Remove crop edge artifacts.
border = 8
mask[:border, :] = 0
mask[-border:, :] = 0
mask[:, :border] = 0
mask[:, -border:] = 0
# Clean and merge wings/body.
k2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, k2, iterations=1)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, k5, iterations=1)
contours, _ = cv2.findContours(
mask,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
)
for cnt in contours:
x, y, bw, bh = cv2.boundingRect(cnt)
contour_area = cv2.contourArea(cnt)
box_area_local = bw * bh
if box_area_local < 120:
continue
if box_area_local > 6000:
continue
if bw < 7 or bh < 7:
continue
aspect = bw / max(bh, 1)
if aspect < 0.20 or aspect > 3.80:
continue
roi = gray[y:y + bh, x:x + bw]
if roi.size == 0:
continue
contrast = float(np.std(roi))
darkness = 255.0 - float(np.mean(roi))
edges = cv2.Canny(roi, 10, 70)
edge_density = cv2.countNonZero(edges) / max(1, bw * bh)
# Real pests have strong darkness/contrast or wing/body edges.
# Soft stains should fail this.
if contrast < 18 and edge_density < 0.045:
continue
if darkness < 45 and contrast < 25:
continue
x1 = floor_x + x
y1 = floor_y + y
x2 = floor_x + x + bw
y2 = floor_y + y + bh
x1, y1, x2, y2 = expand_box(
[x1, y1, x2, y2],
pad=5,
width=original_w,
height=original_h
)
score = 0.18
if contrast >= 25:
score += 0.10
if contrast >= 35:
score += 0.10
if edge_density >= 0.08:
score += 0.08
if darkness >= 90:
score += 0.08
if box_area_local >= 400:
score += 0.04
score = min(score, 0.60)
detections.append({
"type": "unknown_pest",
"confidence": round(score, 4),
"source": "hard_visual_counter",
"debug": {
"threshold": threshold_value,
"local_x": int(x),
"local_y": int(y),
"local_w": int(bw),
"local_h": int(bh),
"box_area": int(box_area_local),
"contour_area": round(float(contour_area), 2),
"contrast": round(contrast, 2),
"darkness": round(darkness, 2),
"edge_density": round(edge_density, 4)
},
"box": {
"x1": round(x1, 2),
"y1": round(y1, 2),
"x2": round(x2, 2),
"y2": round(y2, 2)
}
})
detections = nms_detections(
detections,
iou_threshold=0.08,
class_aware=False
)
return detections, mask
def remove_unknown_duplicates(yolo_detections, unknown_detections):
"""
Remove unknown only if it is basically the same pest as a YOLO detection.
"""
cleaned = []
for unknown in unknown_detections:
ub = get_box(unknown)
duplicate = False
for known in yolo_detections:
kb = get_box(known)
if overlap_ratio_small(ub, kb) >= UNKNOWN_OVERLAP_WITH_YOLO:
duplicate = True
break
if not duplicate:
cleaned.append(unknown)
return cleaned
# =========================================================
# FULL DETECTION PIPELINE
# =========================================================
def run_detection_pipeline(image_path: Path):
image_path = Path(image_path)
if not image_path.exists():
raise HTTPException(
status_code=500,
detail=f"Uploaded image file does not exist before processing: {image_path}"
)
original = cv2.imread(str(image_path))
if original is None:
raise HTTPException(status_code=400, detail="Unable to read uploaded image.")
original = remove_colored_markup_if_present(original)
h, w = original.shape[:2]
floor_crop, floor_x, floor_y = find_trap_floor_crop(original)
enhanced_floor = enhance_image(floor_crop)
yolo_detections = []
yolo_detections.extend(
yolo_predict(
original,
offset_x=0,
offset_y=0,
source_name="original"
)
)
yolo_detections.extend(
yolo_predict(
floor_crop,
offset_x=floor_x,
offset_y=floor_y,
source_name="trap_floor"
)
)
yolo_detections.extend(
yolo_predict(
enhanced_floor,
offset_x=floor_x,
offset_y=floor_y,
source_name="enhanced_floor"
)
)
yolo_detections.extend(
yolo_tiled(
enhanced_floor,
offset_x=floor_x,
offset_y=floor_y
)
)
fixed_yolo = []
for det in yolo_detections:
b = det["box"]
x1, y1, x2, y2 = clamp_box(
[b["x1"], b["y1"], b["x2"], b["y2"]],
width=w,
height=h
)
det["box"] = {
"x1": round(x1, 2),
"y1": round(y1, 2),
"x2": round(x2, 2),
"y2": round(y2, 2)
}
fixed_yolo.append(det)
yolo_detections = filter_bad_yolo_boxes(
fixed_yolo,
image_width=w,
image_height=h
)
yolo_detections = nms_detections(
yolo_detections,
iou_threshold=0.20,
class_aware=True
)
visual_detections, visual_mask = hard_visual_counter(
original,
floor_crop,
floor_x,
floor_y
)
visual_detections = remove_unknown_duplicates(
yolo_detections,
visual_detections
)
final_detections = []
final_detections.extend(yolo_detections)
final_detections.extend(visual_detections)
final_detections = nms_detections(
final_detections,
iou_threshold=FINAL_NMS_IOU,
class_aware=False
)
return final_detections, original, visual_mask
# =========================================================
# DRAWING AND RESPONSE HELPERS
# =========================================================
def draw_annotated_image(original_image, image_path: Path, detections):
RESULT_DIR.mkdir(parents=True, exist_ok=True)
image = original_image.copy()
for det in detections:
box = det["box"]
label = det["type"]
confidence = det["confidence"]
x1 = int(box["x1"])
y1 = int(box["y1"])
x2 = int(box["x2"])
y2 = int(box["y2"])
color = ORANGE if label == "unknown_pest" else GREEN
text = f"{label} {confidence:.2f}"
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.55
thickness = 2
text_size, _ = cv2.getTextSize(text, font, font_scale, thickness)
text_w, text_h = text_size
label_y1 = max(y1 - text_h - 10, 0)
label_y2 = max(y1, text_h + 12)
cv2.rectangle(
image,
(x1, label_y1),
(min(x1 + text_w + 8, image.shape[1] - 1), label_y2),
color,
-1
)
cv2.putText(
image,
text,
(x1 + 4, max(y1 - 6, text_h + 4)),
font,
font_scale,
BLACK,
thickness,
cv2.LINE_AA
)
output_name = f"result_{image_path.stem}.jpg"
output_path = RESULT_DIR / output_name
success = cv2.imwrite(str(output_path), image)
if not success or not output_path.exists():
raise HTTPException(
status_code=500,
detail=f"Failed to save annotated image: {output_path}"
)
return output_path
def save_debug_mask(image_path: Path, mask):
if mask is None:
return None
DEBUG_DIR.mkdir(parents=True, exist_ok=True)
output_name = f"mask_{image_path.stem}.jpg"
output_path = DEBUG_DIR / output_name
success = cv2.imwrite(str(output_path), mask)
if not success or not output_path.exists():
print(f"[DEBUG] Failed to save debug mask: {output_path}")
return None
return output_path
def build_summary(detections):
counts = Counter(det["type"] for det in detections)
return [
{
"type": pest_type,
"count": count
}
for pest_type, count in sorted(counts.items())
]
def get_local_image_urls(request: Request, uploaded_path: Path, result_image_path: Path, debug_mask_path: Path | None):
base_url = get_base_url(request)
original_image_url = f"{base_url}/uploads/{uploaded_path.name}"
annotated_image_url = f"{base_url}/results/{result_image_path.name}"
debug_mask_url = None
if debug_mask_path is not None:
debug_mask_url = f"{base_url}/debug/{debug_mask_path.name}"
return original_image_url, annotated_image_url, debug_mask_url
# =========================================================
# FIREBASE LOG FUNCTIONS
# =========================================================
def firebase_logs_ref():
return db.reference(FIREBASE_LOGS_PATH)
def save_analysis_log_to_firebase(log_payload: dict):
if not firebase_ready:
return None
ref = firebase_logs_ref().push()
log_id = ref.key
log_payload["id"] = log_id
log_payload["firebase_saved"] = True
log_payload["firebase_path"] = f"{FIREBASE_LOGS_PATH}/{log_id}"
ref.set(log_payload)
return log_id
def get_all_logs_from_firebase():
if not firebase_ready:
raise HTTPException(
status_code=503,
detail="Firebase is not initialized. Check FIREBASE_DATABASE_URL and service account."
)
raw = firebase_logs_ref().get()
if not raw:
return []
logs = []
for key, value in raw.items():
if not isinstance(value, dict):
continue
item = value
item["id"] = value.get("id", key)
logs.append(item)
return logs
def get_log_from_firebase(log_id: str):
if not firebase_ready:
raise HTTPException(
status_code=503,
detail="Firebase is not initialized. Check FIREBASE_DATABASE_URL and service account."
)
item = firebase_logs_ref().child(log_id).get()
if not item:
raise HTTPException(status_code=404, detail="Log not found")
item["id"] = item.get("id", log_id)
return item
def parse_date_filter(value: str | None, end_of_day=False):
if not value:
return None
try:
if len(value) == 10:
parsed = datetime.strptime(value, "%Y-%m-%d")
if end_of_day:
parsed = parsed.replace(hour=23, minute=59, second=59, microsecond=999000)
return parsed
return datetime.fromisoformat(value)
except Exception:
raise HTTPException(
status_code=400,
detail=f"Invalid date format: {value}. Use YYYY-MM-DD or ISO datetime."
)
def filter_logs(
logs,
pest_type=None,
date_from=None,
date_to=None,
min_total=None,
max_total=None,
search=None
):
date_from_dt = parse_date_filter(date_from, end_of_day=False)
date_to_dt = parse_date_filter(date_to, end_of_day=True)
filtered = []
for item in logs:
total = int(item.get("total", 0) or 0)
if min_total is not None and total < min_total:
continue
if max_total is not None and total > max_total:
continue
timestamp_ms = item.get("timestamp_ms")
if timestamp_ms:
item_dt = datetime.fromtimestamp(int(timestamp_ms) / 1000)
else:
item_dt = None
if date_from_dt and item_dt and item_dt < date_from_dt:
continue
if date_to_dt and item_dt and item_dt > date_to_dt:
continue
data = item.get("data", [])
detections = item.get("detections", [])
if pest_type:
wanted = pest_type.lower().strip()
found_type = False
for row in data:
if str(row.get("type", "")).lower().strip() == wanted:
found_type = True
break
for det in detections:
if str(det.get("type", "")).lower().strip() == wanted:
found_type = True
break
if not found_type:
continue
if search:
s = search.lower().strip()
haystack = json.dumps(item, ensure_ascii=False).lower()
if s not in haystack:
continue
filtered.append(item)
return filtered
def paginate_items(items, page, page_size):
if page <= 0:
page = 1
if page_size <= 0:
page_size = 10
if page_size > 100:
page_size = 100
total_items = len(items)
total_pages = max(1, (total_items + page_size - 1) // page_size)
if page > total_pages:
page_items = []
else:
start = (page - 1) * page_size
end = start + page_size
page_items = items[start:end]
return {
"page": page,
"page_size": page_size,
"total_items": total_items,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_prev": page > 1,
"items": page_items
}
def compact_log_item(item):
return {
"id": item.get("id"),
"datatime": item.get("datatime"),
"timestamp_ms": item.get("timestamp_ms"),
"total": item.get("total", 0),
"data": item.get("data", []),
"avg_confidence": item.get("avg_confidence", compute_avg_confidence(item)),
"top_confidence": item.get("top_confidence", compute_top_confidence(item)),
"annotated_image": item.get("annotated_image"),
"original_image": item.get("original_image"),
"debug_mask": item.get("debug_mask"),
"cloudinary": item.get("cloudinary", {})
}
# =========================================================
# DASHBOARD HELPERS
# =========================================================
def build_dashboard_data(logs):
logs = sorted(logs, key=lambda x: int(x.get("timestamp_ms", 0) or 0), reverse=True)
today = datetime.now().date()
seven_days_ago = datetime.now() - timedelta(days=6)
total_logs = len(logs)
total_pests = sum(int(item.get("total", 0) or 0) for item in logs)
today_logs = []
last_7_days_logs = []
pest_counter = Counter()
daily_counter = defaultdict(int)
hourly_today_counter = defaultdict(int)
for item in logs:
timestamp_ms = item.get("timestamp_ms")
if timestamp_ms:
item_dt = datetime.fromtimestamp(int(timestamp_ms) / 1000)
else:
item_dt = None
item_total = int(item.get("total", 0) or 0)
for row in item.get("data", []):
pest_counter[row.get("type", "unknown")] += int(row.get("count", 0) or 0)
if item_dt:
day_key = item_dt.strftime("%Y-%m-%d")
daily_counter[day_key] += item_total
if item_dt.date() == today:
today_logs.append(item)
hour_key = item_dt.strftime("%H:00")
hourly_today_counter[hour_key] += item_total
if item_dt >= seven_days_ago:
last_7_days_logs.append(item)
today_pests = sum(int(item.get("total", 0) or 0) for item in today_logs)
top_pests = [
{
"type": pest_type,
"count": count
}
for pest_type, count in pest_counter.most_common(10)
]
daily_chart = []
for i in range(6, -1, -1):
day = datetime.now() - timedelta(days=i)
key = day.strftime("%Y-%m-%d")
daily_chart.append(
{
"date": key,
"total": daily_counter.get(key, 0)
}
)
hourly_chart = []
for hour in range(24):
key = f"{hour:02d}:00"
hourly_chart.append(
{
"hour": key,
"total": hourly_today_counter.get(key, 0)
}
)
latest_log = logs[0] if logs else None
recent_logs = [compact_log_item(item) for item in logs[:10]]
return {
"summary": {
"total_logs": total_logs,
"total_pests": total_pests,
"today_logs": len(today_logs),
"today_pests": today_pests,
"last_7_days_logs": len(last_7_days_logs),
"last_7_days_pests": sum(int(item.get("total", 0) or 0) for item in last_7_days_logs),
"top_pests": top_pests
},
"chart": {
"daily_last_7_days": daily_chart,
"hourly_today": hourly_chart
},
"live_camera_stream": {
"latest": compact_log_item(latest_log) if latest_log else None,
"polling_route": "/api/live/latest",
"note": "Use latest.annotated_image as the latest processed camera frame. Frontend can poll every 1 to 3 seconds."
},
"logs": recent_logs
}
# =========================================================
# WEB UI ROUTES
# =========================================================
@app.get("/", include_in_schema=False)
def web_root():
return RedirectResponse(url="/ui")
@app.get("/ui", include_in_schema=False)
def ui_dashboard():
index_path = WEB_DIR / "index.html"
if not index_path.exists():
raise HTTPException(
status_code=404,
detail=f"Missing web file: {index_path}"
)
return FileResponse(index_path)
@app.get("/ui/logs", include_in_schema=False)
def ui_logs():
logs_path = WEB_DIR / "logs.html"
if not logs_path.exists():
raise HTTPException(
status_code=404,
detail=f"Missing web file: {logs_path}"
)
return FileResponse(logs_path)
@app.get("/ui/logs/{log_id}", include_in_schema=False)
def ui_log_detail(log_id: str):
detail_path = WEB_DIR / "detail.html"
if not detail_path.exists():
raise HTTPException(
status_code=404,
detail=f"Missing web file: {detail_path}"
)
return FileResponse(detail_path)
@app.get("/ui/access", include_in_schema=False)
def ui_access():
access_path = WEB_DIR / "access.html"
if not access_path.exists():
raise HTTPException(
status_code=404,
detail=f"Missing web file: {access_path}"
)
return FileResponse(access_path)
# =========================================================
# API ROUTES
# =========================================================
@app.get("/api/status")
def api_status():
return {
"message": "Smart Pest Trap Detection API is running",
"ui_route": "/ui",
"analyze_route": "/api/analyze",
"logs_route": "/api/logs",
"log_detail_route": "/api/logs/{id}",
"dashboard_route": "/api/dashboard",
"live_latest_route": "/api/live/latest",
"firebase_ready": firebase_ready,
"cloudinary_ready": cloudinary_ready,
"model_path": str(MODEL_PATH),
"model_exists": MODEL_PATH.exists(),
"auto_download_model": AUTO_DOWNLOAD_MODEL,
"hf_model_repo": HF_MODEL_REPO,
"hf_model_file": HF_MODEL_FILE,
"firebase_logs_path": FIREBASE_LOGS_PATH,
"cloudinary_folder": CLOUDINARY_FOLDER,
"visual_dark_threshold": int(os.getenv("VISUAL_DARK_THRESHOLD", "112")),
"field_name": "image"
}
@app.post("/api/analyze")
async def analyze_pest(request: Request, image: UploadFile = File(...)):
try:
ext = validate_image_file(image)
uploaded_path = await save_upload(image, ext)
detections, processed_original, visual_mask = run_detection_pipeline(uploaded_path)
result_image_path = draw_annotated_image(
processed_original,
uploaded_path,
detections
)
debug_mask_path = save_debug_mask(
uploaded_path,
visual_mask
)
data = build_summary(detections)
total = len(detections)
avg_confidence = compute_avg_confidence({"detections": detections})
top_confidence = compute_top_confidence({"detections": detections})
local_original_url, local_annotated_url, local_debug_url = get_local_image_urls(
request,
uploaded_path,
result_image_path,
debug_mask_path
)
original_cloud, annotated_cloud, debug_cloud = upload_analysis_images_to_cloudinary(
uploaded_path,
result_image_path,
debug_mask_path
)
original_image_url = original_cloud.get("secure_url") if original_cloud else local_original_url
annotated_image_url = annotated_cloud.get("secure_url") if annotated_cloud else local_annotated_url
debug_mask_url = debug_cloud.get("secure_url") if debug_cloud else local_debug_url
cloudinary_saved = bool(original_cloud and annotated_cloud)
log_payload = {
"id": None,
"datatime": now_string(),
"created_at": now_iso(),
"timestamp_ms": now_timestamp_ms(),
"data": data,
"total": total,
"detections": detections,
"avg_confidence": avg_confidence,
"top_confidence": top_confidence,
"original_image": original_image_url,
"annotated_image": annotated_image_url,
"debug_mask": debug_mask_url,
"local_images": {
"original_image": local_original_url,
"annotated_image": local_annotated_url,
"debug_mask": local_debug_url
},
"image_files": {
"original_filename": uploaded_path.name,
"annotated_filename": result_image_path.name,
"debug_mask_filename": debug_mask_path.name if debug_mask_path else None
},
"cloudinary_saved": cloudinary_saved,
"cloudinary": {
"original": original_cloud,
"annotated": annotated_cloud,
"debug_mask": debug_cloud
},
"firebase_saved": False,
"firebase_path": None,
"note": "Green boxes are YOLO identified pests. Orange boxes are hard visual counter detections that YOLO could not classify."
}
log_id = save_analysis_log_to_firebase(log_payload)
response = {
"datatime": log_payload["datatime"],
"id": log_id,
"data": data,
"total": total,
"detections": detections,
"avg_confidence": avg_confidence,
"top_confidence": top_confidence,
"original_image": original_image_url,
"annotated_image": annotated_image_url,
"debug_mask": debug_mask_url,
"cloudinary_saved": cloudinary_saved,
"firebase_saved": bool(log_id),
"firebase_path": f"{FIREBASE_LOGS_PATH}/{log_id}" if log_id else None,
"cloudinary": {
"original": original_cloud,
"annotated": annotated_cloud,
"debug_mask": debug_cloud
}
}
return JSONResponse(content=response)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Analysis failed: {str(e)}"
)
@app.get("/api/logs")
def list_logs(
page: int = Query(1, description="Page number. If page=0, it becomes page=1."),
page_size: int = Query(10, description="Items per page. Max 100."),
pest_type: str | None = Query(None, description="Filter by pest type, example: unknown_pest"),
date_from: str | None = Query(None, description="YYYY-MM-DD or ISO datetime"),
date_to: str | None = Query(None, description="YYYY-MM-DD or ISO datetime"),
min_total: int | None = Query(None),
max_total: int | None = Query(None),
search: str | None = Query(None),
sort: str = Query("desc", description="desc or asc")
):
logs = get_all_logs_from_firebase()
logs = filter_logs(
logs,
pest_type=pest_type,
date_from=date_from,
date_to=date_to,
min_total=min_total,
max_total=max_total,
search=search
)
reverse = sort.lower() != "asc"
logs = sorted(
logs,
key=lambda x: int(x.get("timestamp_ms", 0) or 0),
reverse=reverse
)
logs = [compact_log_item(item) for item in logs]
result = paginate_items(logs, page, page_size)
return {
"success": True,
"filters": {
"pest_type": pest_type,
"date_from": date_from,
"date_to": date_to,
"min_total": min_total,
"max_total": max_total,
"search": search,
"sort": sort
},
**result
}
@app.get("/api/logs/{log_id}")
def get_log_detail(log_id: str):
item = get_log_from_firebase(log_id)
return {
"success": True,
"data": item
}
@app.get("/api/dashboard")
@app.get("/api/dashboard/")
def dashboard():
logs = get_all_logs_from_firebase()
dashboard_data = build_dashboard_data(logs)
return {
"success": True,
"datatime": now_string(),
**dashboard_data
}
@app.get("/api/live/latest")
def live_latest():
logs = get_all_logs_from_firebase()
if not logs:
return {
"success": True,
"latest": None
}
logs = sorted(
logs,
key=lambda x: int(x.get("timestamp_ms", 0) or 0),
reverse=True
)
return {
"success": True,
"latest": compact_log_item(logs[0])
}