Spaces:
Sleeping
Sleeping
File size: 10,214 Bytes
6241844 9993c90 bfc0af1 1143d23 9993c90 6241844 a918efd 9993c90 6241844 a918efd bfc0af1 9993c90 6241844 9993c90 a918efd 9993c90 a918efd 9993c90 d35ef57 9993c90 a918efd 9993c90 a918efd 9993c90 d35ef57 6241844 9993c90 a918efd 9993c90 a918efd 9993c90 1143d23 bfc0af1 68e5ccd 5079cae bfc0af1 68e5ccd 5079cae 796dd6f bfc0af1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 |
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from pydantic import BaseModel,validator
from typing import Optional, List, Literal
import cv2
import numpy as np
import logging
import time
from .utils import (
validate_form,
process_image,
save_image,
load_json,
save_json,
validate_user_and_camera,
extract_metadata,
_bucket_key,
_key_exists,
)
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/predict")
async def predict(
user_id: str = Form(...),
camera_name: str = Form(...),
images: list[UploadFile] = File(...)
):
images = validate_form(user_id, camera_name, images)
validate_user_and_camera(user_id, camera_name)
json_path = _bucket_key(user_id, camera_name, f"{camera_name}_detections.json")
data = load_json(json_path)
new_results = []
for file in images:
raw = await file.read()
metadata = extract_metadata(raw)
nparr = np.frombuffer(raw, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(400, f"Invalid image: {file.filename}")
t0 = time.perf_counter()
detections = process_image(img)
logger.info(f"[{file.filename}] inference: {round((time.perf_counter() - t0) * 1000, 2)}ms")
url = save_image(user_id, camera_name, file.filename, raw)
record = {
"filename": file.filename,
"image_url": url,
"detections": detections,
"metadata": metadata
}
data.append(record)
new_results.append(record)
save_json(json_path, data)
return {
"message": "Images processed successfully",
"camera": camera_name,
"results": new_results
}
# βββββββββββββββββ
# Request Models
# βββββββββββββββββ
# ================================================================
# VALID LABELS β must exactly match what process_images_batch()
# produces from the 3-stage YOLO pipeline:
# Stage 1: deer detected
# Stage 2: Buck β Stage 3: Whitetail | Mule
# Stage 2: Doe
# ================================================================
VALID_LABELS = {
"Deer | Doe ", # trailing space matches pipeline output
"Deer | Buck | White Tail Bucks",
"Deer | Buck | Mule Bucks",
}
VALID_LABELS_DISPLAY = [ # clean version shown in error messages
"Deer | Doe",
"Deer | Buck | White Tail Bucks",
"Deer | Buck | Mule Bucks",
]
def _normalise_label(label: str) -> str:
"""Strip and normalise label so 'Deer | Doe' == 'Deer | Doe '."""
return label.strip()
def _validate_label(label: str) -> str:
"""Raise a clear error if the label is not from the detection pipeline."""
normalised = _normalise_label(label)
# Match against normalised versions of VALID_LABELS
if normalised not in {l.strip() for l in VALID_LABELS}:
raise HTTPException(
status_code=422,
detail=(
f"Invalid label '{label}'. "
f"Must be one of: {VALID_LABELS_DISPLAY}"
)
)
# Return the canonical pipeline form (with trailing space for Doe)
for valid in VALID_LABELS:
if valid.strip() == normalised:
return valid
return label
# βββββββββββββββββ
# Request Models
# βββββββββββββββββ
class DetectionOperation(BaseModel):
action: Literal["add", "update", "delete"]
detection_index: Optional[int] = None
label: Optional[str] = None
bbox: Optional[List[float]] = None # [x1, y1, x2, y2]
@validator("label")
def label_must_be_valid(cls, v):
if v is None:
return v
normalised = v.strip()
valid_normalised = {l.strip() for l in VALID_LABELS}
if normalised not in valid_normalised:
raise ValueError(
f"Invalid label '{v}'. Must be one of: {VALID_LABELS_DISPLAY}"
)
# Return canonical form
for valid in VALID_LABELS:
if valid.strip() == normalised:
return valid
return v
@validator("bbox")
def bbox_must_be_four_values(cls, v):
if v is None:
return v
if len(v) != 4:
raise ValueError("bbox must have exactly 4 values: [x1, y1, x2, y2]")
x1, y1, x2, y2 = v
if x2 <= x1 or y2 <= y1:
raise ValueError("bbox must satisfy x2 > x1 and y2 > y1")
return v
@validator("detection_index")
def index_must_be_non_negative(cls, v):
if v is not None and v < 0:
raise ValueError("detection_index must be >= 0")
return v
class MultiUpdateRequest(BaseModel):
user_id: str
camera_name: str
image_url: str
operations: List[DetectionOperation]
@validator("operations")
def operations_must_not_be_empty(cls, v):
if not v:
raise ValueError("operations list cannot be empty")
return v
@validator("operations", each_item=True)
def validate_operation_fields(cls, op):
if op.action == "add":
if op.label is None:
raise ValueError("'add' operation requires a label")
if op.bbox is None:
raise ValueError("'add' operation requires a bbox")
elif op.action == "update":
if op.detection_index is None:
raise ValueError("'update' operation requires detection_index")
if op.label is None and op.bbox is None:
raise ValueError("'update' operation requires at least label or bbox")
elif op.action == "delete":
if op.detection_index is None:
raise ValueError("'delete' operation requires detection_index")
return op
# βββββββββββββββββ
# Endpoint
# βββββββββββββββββ
@router.post("/modify_detections")
async def modify_detections(req: MultiUpdateRequest):
"""
Add, update, and delete detections (tags) for a given image.
Supports multiple operations in a single request.
Labels must match the detection pipeline format exactly.
"""
# ββ Validate user & camera ββββββββββββββββββββββββββββββββββββ
validate_user_and_camera(req.user_id, req.camera_name)
# ββ Validate detections JSON exists in bucket βββββββββββββββββ
json_key = _bucket_key(req.user_id, req.camera_name, f"{req.camera_name}_detections.json")
if not _key_exists(json_key):
raise HTTPException(status_code=404, detail="Detections file not found")
# ββ Load data from bucket βββββββββββββββββββββββββββββββββββββ
data = load_json(json_key)
# ββ Find image record by filename βββββββββββββββββββββββββββββ
target_filename = req.image_url.split("/")[-1].split("?")[0]
record = None
for item in data:
stored = item.get("image_url", item.get("filename", ""))
stored_filename = stored.split("/")[-1].split("?")[0]
if stored_filename == target_filename:
record = item
break
if record is None:
raise HTTPException(status_code=404, detail="Image not found")
# ββ Ensure detections list exists βββββββββββββββββββββββββββββ
if "detections" not in record or not isinstance(record["detections"], list):
record["detections"] = []
dets = record["detections"]
# ββ Apply operations ββββββββββββββββββββββββββββββββββββββββββ
# Deletes run in reverse index order to avoid index shifting
delete_ops = [op for op in req.operations if op.action == "delete"]
other_ops = [op for op in req.operations if op.action != "delete"]
# DELETE (reverse order to avoid index shifting)
for op in sorted(delete_ops, key=lambda x: x.detection_index or -1, reverse=True):
if op.detection_index >= len(dets):
raise HTTPException(
status_code=400,
detail=f"Invalid delete index {op.detection_index} β only {len(dets)} detection(s) exist"
)
dets.pop(op.detection_index)
# ADD + UPDATE
for op in other_ops:
if op.action == "add":
dets.append({
"label": op.label, # already validated & canonicalised by validator
"confidence": 1.0,
"bbox": op.bbox,
"manually_edited": True
})
elif op.action == "update":
if op.detection_index >= len(dets):
raise HTTPException(
status_code=400,
detail=f"Invalid update index {op.detection_index} β only {len(dets)} detection(s) exist"
)
if op.label is not None:
dets[op.detection_index]["label"] = op.label
if op.bbox is not None:
dets[op.detection_index]["bbox"] = op.bbox
dets[op.detection_index]["manually_edited"] = True
# ββ Save back to bucket βββββββββββββββββββββββββββββββββββββββ
save_json(json_key, data)
logger.info(
"Detections modified | user=%s camera=%s file=%s ops=%d final_count=%d",
req.user_id,
req.camera_name,
target_filename,
len(req.operations),
len(dets)
)
return {
"success": True,
"message": "Detections modified successfully",
"filename": target_filename,
"total_detections": len(dets),
"detections": dets
}
|