from fastapi import APIRouter, UploadFile, File, Form, HTTPException from pydantic import BaseModel,validator from typing import Optional, List, Literal import cv2 import numpy as np import logging import time from .utils import ( validate_form, process_image, save_image, load_json, save_json, validate_user_and_camera, extract_metadata, _bucket_key, _key_exists, ) router = APIRouter() logger = logging.getLogger(__name__) @router.post("/predict") async def predict( user_id: str = Form(...), camera_name: str = Form(...), images: list[UploadFile] = File(...) ): images = validate_form(user_id, camera_name, images) validate_user_and_camera(user_id, camera_name) json_path = _bucket_key(user_id, camera_name, f"{camera_name}_detections.json") data = load_json(json_path) new_results = [] for file in images: raw = await file.read() metadata = extract_metadata(raw) nparr = np.frombuffer(raw, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: raise HTTPException(400, f"Invalid image: {file.filename}") t0 = time.perf_counter() detections = process_image(img) logger.info(f"[{file.filename}] inference: {round((time.perf_counter() - t0) * 1000, 2)}ms") url = save_image(user_id, camera_name, file.filename, raw) record = { "filename": file.filename, "image_url": url, "detections": detections, "metadata": metadata } data.append(record) new_results.append(record) save_json(json_path, data) return { "message": "Images processed successfully", "camera": camera_name, "results": new_results } # ───────────────── # Request Models # ───────────────── # ================================================================ # VALID LABELS — must exactly match what process_images_batch() # produces from the 3-stage YOLO pipeline: # Stage 1: deer detected # Stage 2: Buck → Stage 3: Whitetail | Mule # Stage 2: Doe # ================================================================ VALID_LABELS = { "Deer | Doe ", # trailing space matches pipeline output "Deer | Buck | White Tail Bucks", "Deer | Buck | Mule Bucks", } VALID_LABELS_DISPLAY = [ # clean version shown in error messages "Deer | Doe", "Deer | Buck | White Tail Bucks", "Deer | Buck | Mule Bucks", ] def _normalise_label(label: str) -> str: """Strip and normalise label so 'Deer | Doe' == 'Deer | Doe '.""" return label.strip() def _validate_label(label: str) -> str: """Raise a clear error if the label is not from the detection pipeline.""" normalised = _normalise_label(label) # Match against normalised versions of VALID_LABELS if normalised not in {l.strip() for l in VALID_LABELS}: raise HTTPException( status_code=422, detail=( f"Invalid label '{label}'. " f"Must be one of: {VALID_LABELS_DISPLAY}" ) ) # Return the canonical pipeline form (with trailing space for Doe) for valid in VALID_LABELS: if valid.strip() == normalised: return valid return label # ───────────────── # Request Models # ───────────────── class DetectionOperation(BaseModel): action: Literal["add", "update", "delete"] detection_index: Optional[int] = None label: Optional[str] = None bbox: Optional[List[float]] = None # [x1, y1, x2, y2] @validator("label") def label_must_be_valid(cls, v): if v is None: return v normalised = v.strip() valid_normalised = {l.strip() for l in VALID_LABELS} if normalised not in valid_normalised: raise ValueError( f"Invalid label '{v}'. Must be one of: {VALID_LABELS_DISPLAY}" ) # Return canonical form for valid in VALID_LABELS: if valid.strip() == normalised: return valid return v @validator("bbox") def bbox_must_be_four_values(cls, v): if v is None: return v if len(v) != 4: raise ValueError("bbox must have exactly 4 values: [x1, y1, x2, y2]") x1, y1, x2, y2 = v if x2 <= x1 or y2 <= y1: raise ValueError("bbox must satisfy x2 > x1 and y2 > y1") return v @validator("detection_index") def index_must_be_non_negative(cls, v): if v is not None and v < 0: raise ValueError("detection_index must be >= 0") return v class MultiUpdateRequest(BaseModel): user_id: str camera_name: str image_url: str operations: List[DetectionOperation] @validator("operations") def operations_must_not_be_empty(cls, v): if not v: raise ValueError("operations list cannot be empty") return v @validator("operations", each_item=True) def validate_operation_fields(cls, op): if op.action == "add": if op.label is None: raise ValueError("'add' operation requires a label") if op.bbox is None: raise ValueError("'add' operation requires a bbox") elif op.action == "update": if op.detection_index is None: raise ValueError("'update' operation requires detection_index") if op.label is None and op.bbox is None: raise ValueError("'update' operation requires at least label or bbox") elif op.action == "delete": if op.detection_index is None: raise ValueError("'delete' operation requires detection_index") return op # ───────────────── # Endpoint # ───────────────── @router.post("/modify_detections") async def modify_detections(req: MultiUpdateRequest): """ Add, update, and delete detections (tags) for a given image. Supports multiple operations in a single request. Labels must match the detection pipeline format exactly. """ # ── Validate user & camera ──────────────────────────────────── validate_user_and_camera(req.user_id, req.camera_name) # ── Validate detections JSON exists in bucket ───────────────── json_key = _bucket_key(req.user_id, req.camera_name, f"{req.camera_name}_detections.json") if not _key_exists(json_key): raise HTTPException(status_code=404, detail="Detections file not found") # ── Load data from bucket ───────────────────────────────────── data = load_json(json_key) # ── Find image record by filename ───────────────────────────── target_filename = req.image_url.split("/")[-1].split("?")[0] record = None for item in data: stored = item.get("image_url", item.get("filename", "")) stored_filename = stored.split("/")[-1].split("?")[0] if stored_filename == target_filename: record = item break if record is None: raise HTTPException(status_code=404, detail="Image not found") # ── Ensure detections list exists ───────────────────────────── if "detections" not in record or not isinstance(record["detections"], list): record["detections"] = [] dets = record["detections"] # ── Apply operations ────────────────────────────────────────── # Deletes run in reverse index order to avoid index shifting delete_ops = [op for op in req.operations if op.action == "delete"] other_ops = [op for op in req.operations if op.action != "delete"] # DELETE (reverse order to avoid index shifting) for op in sorted(delete_ops, key=lambda x: x.detection_index or -1, reverse=True): if op.detection_index >= len(dets): raise HTTPException( status_code=400, detail=f"Invalid delete index {op.detection_index} — only {len(dets)} detection(s) exist" ) dets.pop(op.detection_index) # ADD + UPDATE for op in other_ops: if op.action == "add": dets.append({ "label": op.label, # already validated & canonicalised by validator "confidence": 1.0, "bbox": op.bbox, "manually_edited": True }) elif op.action == "update": if op.detection_index >= len(dets): raise HTTPException( status_code=400, detail=f"Invalid update index {op.detection_index} — only {len(dets)} detection(s) exist" ) if op.label is not None: dets[op.detection_index]["label"] = op.label if op.bbox is not None: dets[op.detection_index]["bbox"] = op.bbox dets[op.detection_index]["manually_edited"] = True # ── Save back to bucket ─────────────────────────────────────── save_json(json_key, data) logger.info( "Detections modified | user=%s camera=%s file=%s ops=%d final_count=%d", req.user_id, req.camera_name, target_filename, len(req.operations), len(dets) ) return { "success": True, "message": "Detections modified successfully", "filename": target_filename, "total_detections": len(dets), "detections": dets }