Buck_Tracker / api /detection.py
codewithRiz's picture
retagging images
68e5ccd
raw
history blame
10.2 kB
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from pydantic import BaseModel,validator
from typing import Optional, List, Literal
import cv2
import numpy as np
import logging
import time
from .utils import (
validate_form,
process_image,
save_image,
load_json,
save_json,
validate_user_and_camera,
extract_metadata,
_bucket_key,
_key_exists,
)
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/predict")
async def predict(
user_id: str = Form(...),
camera_name: str = Form(...),
images: list[UploadFile] = File(...)
):
images = validate_form(user_id, camera_name, images)
validate_user_and_camera(user_id, camera_name)
json_path = _bucket_key(user_id, camera_name, f"{camera_name}_detections.json")
data = load_json(json_path)
new_results = []
for file in images:
raw = await file.read()
metadata = extract_metadata(raw)
nparr = np.frombuffer(raw, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(400, f"Invalid image: {file.filename}")
t0 = time.perf_counter()
detections = process_image(img)
logger.info(f"[{file.filename}] inference: {round((time.perf_counter() - t0) * 1000, 2)}ms")
url = save_image(user_id, camera_name, file.filename, raw)
record = {
"filename": file.filename,
"image_url": url,
"detections": detections,
"metadata": metadata
}
data.append(record)
new_results.append(record)
save_json(json_path, data)
return {
"message": "Images processed successfully",
"camera": camera_name,
"results": new_results
}
# ─────────────────
# Request Models
# ─────────────────
# ================================================================
# VALID LABELS β€” must exactly match what process_images_batch()
# produces from the 3-stage YOLO pipeline:
# Stage 1: deer detected
# Stage 2: Buck β†’ Stage 3: Whitetail | Mule
# Stage 2: Doe
# ================================================================
VALID_LABELS = {
"Deer | Doe ", # trailing space matches pipeline output
"Deer | Buck | White Tail Bucks",
"Deer | Buck | Mule Bucks",
}
VALID_LABELS_DISPLAY = [ # clean version shown in error messages
"Deer | Doe",
"Deer | Buck | White Tail Bucks",
"Deer | Buck | Mule Bucks",
]
def _normalise_label(label: str) -> str:
"""Strip and normalise label so 'Deer | Doe' == 'Deer | Doe '."""
return label.strip()
def _validate_label(label: str) -> str:
"""Raise a clear error if the label is not from the detection pipeline."""
normalised = _normalise_label(label)
# Match against normalised versions of VALID_LABELS
if normalised not in {l.strip() for l in VALID_LABELS}:
raise HTTPException(
status_code=422,
detail=(
f"Invalid label '{label}'. "
f"Must be one of: {VALID_LABELS_DISPLAY}"
)
)
# Return the canonical pipeline form (with trailing space for Doe)
for valid in VALID_LABELS:
if valid.strip() == normalised:
return valid
return label
# ─────────────────
# Request Models
# ─────────────────
class DetectionOperation(BaseModel):
action: Literal["add", "update", "delete"]
detection_index: Optional[int] = None
label: Optional[str] = None
bbox: Optional[List[float]] = None # [x1, y1, x2, y2]
@validator("label")
def label_must_be_valid(cls, v):
if v is None:
return v
normalised = v.strip()
valid_normalised = {l.strip() for l in VALID_LABELS}
if normalised not in valid_normalised:
raise ValueError(
f"Invalid label '{v}'. Must be one of: {VALID_LABELS_DISPLAY}"
)
# Return canonical form
for valid in VALID_LABELS:
if valid.strip() == normalised:
return valid
return v
@validator("bbox")
def bbox_must_be_four_values(cls, v):
if v is None:
return v
if len(v) != 4:
raise ValueError("bbox must have exactly 4 values: [x1, y1, x2, y2]")
x1, y1, x2, y2 = v
if x2 <= x1 or y2 <= y1:
raise ValueError("bbox must satisfy x2 > x1 and y2 > y1")
return v
@validator("detection_index")
def index_must_be_non_negative(cls, v):
if v is not None and v < 0:
raise ValueError("detection_index must be >= 0")
return v
class MultiUpdateRequest(BaseModel):
user_id: str
camera_name: str
image_url: str
operations: List[DetectionOperation]
@validator("operations")
def operations_must_not_be_empty(cls, v):
if not v:
raise ValueError("operations list cannot be empty")
return v
@validator("operations", each_item=True)
def validate_operation_fields(cls, op):
if op.action == "add":
if op.label is None:
raise ValueError("'add' operation requires a label")
if op.bbox is None:
raise ValueError("'add' operation requires a bbox")
elif op.action == "update":
if op.detection_index is None:
raise ValueError("'update' operation requires detection_index")
if op.label is None and op.bbox is None:
raise ValueError("'update' operation requires at least label or bbox")
elif op.action == "delete":
if op.detection_index is None:
raise ValueError("'delete' operation requires detection_index")
return op
# ─────────────────
# Endpoint
# ─────────────────
@router.post("/modify_detections")
async def modify_detections(req: MultiUpdateRequest):
"""
Add, update, and delete detections (tags) for a given image.
Supports multiple operations in a single request.
Labels must match the detection pipeline format exactly.
"""
# ── Validate user & camera ────────────────────────────────────
validate_user_and_camera(req.user_id, req.camera_name)
# ── Validate detections JSON exists in bucket ─────────────────
json_key = _bucket_key(req.user_id, req.camera_name, f"{req.camera_name}_detections.json")
if not _key_exists(json_key):
raise HTTPException(status_code=404, detail="Detections file not found")
# ── Load data from bucket ─────────────────────────────────────
data = load_json(json_key)
# ── Find image record by filename ─────────────────────────────
target_filename = req.image_url.split("/")[-1].split("?")[0]
record = None
for item in data:
stored = item.get("image_url", item.get("filename", ""))
stored_filename = stored.split("/")[-1].split("?")[0]
if stored_filename == target_filename:
record = item
break
if record is None:
raise HTTPException(status_code=404, detail="Image not found")
# ── Ensure detections list exists ─────────────────────────────
if "detections" not in record or not isinstance(record["detections"], list):
record["detections"] = []
dets = record["detections"]
# ── Apply operations ──────────────────────────────────────────
# Deletes run in reverse index order to avoid index shifting
delete_ops = [op for op in req.operations if op.action == "delete"]
other_ops = [op for op in req.operations if op.action != "delete"]
# DELETE (reverse order to avoid index shifting)
for op in sorted(delete_ops, key=lambda x: x.detection_index or -1, reverse=True):
if op.detection_index >= len(dets):
raise HTTPException(
status_code=400,
detail=f"Invalid delete index {op.detection_index} β€” only {len(dets)} detection(s) exist"
)
dets.pop(op.detection_index)
# ADD + UPDATE
for op in other_ops:
if op.action == "add":
dets.append({
"label": op.label, # already validated & canonicalised by validator
"confidence": 1.0,
"bbox": op.bbox,
"manually_edited": True
})
elif op.action == "update":
if op.detection_index >= len(dets):
raise HTTPException(
status_code=400,
detail=f"Invalid update index {op.detection_index} β€” only {len(dets)} detection(s) exist"
)
if op.label is not None:
dets[op.detection_index]["label"] = op.label
if op.bbox is not None:
dets[op.detection_index]["bbox"] = op.bbox
dets[op.detection_index]["manually_edited"] = True
# ── Save back to bucket ───────────────────────────────────────
save_json(json_key, data)
logger.info(
"Detections modified | user=%s camera=%s file=%s ops=%d final_count=%d",
req.user_id,
req.camera_name,
target_filename,
len(req.operations),
len(dets)
)
return {
"success": True,
"message": "Detections modified successfully",
"filename": target_filename,
"total_detections": len(dets),
"detections": dets
}