File size: 3,125 Bytes
f3f6f5d
 
4bce717
f3f6f5d
 
 
 
 
 
 
 
 
 
4bce717
 
f3f6f5d
 
 
 
 
 
 
4bce717
 
 
f3f6f5d
 
 
 
4bce717
 
 
 
 
 
 
f3f6f5d
4bce717
f3f6f5d
 
 
 
 
 
 
 
 
 
 
 
 
 
4bce717
 
 
 
f3f6f5d
 
 
 
 
 
4bce717
f3f6f5d
4bce717
 
 
 
 
 
 
 
 
f3f6f5d
 
 
 
 
 
 
 
 
4bce717
f3f6f5d
 
 
4bce717
 
 
 
 
 
 
 
 
f3f6f5d
 
 
 
 
4bce717
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""FastAPI server for aerial car detection."""

import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator

import cv2
import numpy as np
from fastapi import FastAPI, File, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

from server.detect import (
    MODEL_CLASSES,
    MODEL_PATHS,
    annotate_image,
    image_to_data_uri,
    load_model,
    run_detection,
)
from server.heatmap import generate_heatmap

logger = logging.getLogger(__name__)

_sessions: dict = {}


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    global _sessions
    for key, path in MODEL_PATHS.items():
        if path.exists():
            _sessions[key] = load_model(path)
            logger.info("Loaded model %s from %s", key, path)
        else:
            logger.warning("Model %s not found at %s, skipping", key, path)
    yield
    _sessions.clear()


app = FastAPI(title="Parking Car Detection", lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/health")
async def health() -> dict:
    return {
        "status": "ok",
        "models_loaded": list(_sessions.keys()),
    }


@app.post("/detect")
async def detect(
    file: UploadFile = File(...),
    threshold: float = Query(0.5, ge=0.0, le=1.0),
    model: str = Query("cars"),
) -> JSONResponse:
    if model not in _sessions:
        return JSONResponse(
            status_code=400,
            content={"error": f"Model '{model}' not loaded. Available: {list(_sessions.keys())}"},
        )

    session = _sessions[model]
    class_names = MODEL_CLASSES.get(model, ["object"])

    contents = await file.read()
    arr = np.frombuffer(contents, dtype=np.uint8)
    image = cv2.imdecode(arr, cv2.IMREAD_COLOR)
    if image is None:
        return JSONResponse(
            status_code=400,
            content={"error": "Could not decode image"},
        )

    detections = run_detection(session, image, threshold, class_names)

    annotated = annotate_image(image, detections)

    # For spot model, heatmap shows only occupied spots
    if model == "spots":
        heatmap_dets = [d for d in detections if d.get("class_name") == "occupied"]
    else:
        heatmap_dets = detections
    heatmap = generate_heatmap(image, heatmap_dets)

    response: dict = {
        "model": model,
        "car_count": len(detections),
        "detections": detections,
        "annotated_image": image_to_data_uri(annotated),
        "heatmap_image": image_to_data_uri(heatmap),
    }

    if model == "spots":
        empty = sum(1 for d in detections if d.get("class_name") == "empty")
        occupied = sum(1 for d in detections if d.get("class_name") == "occupied")
        total = empty + occupied
        response["occupancy"] = {
            "empty_count": empty,
            "occupied_count": occupied,
            "total_spots": total,
            "occupancy_rate": round(occupied / total, 3) if total > 0 else 0.0,
        }

    return response