paint_defect_detector / src\api.py
therealestcoder's picture
Upload src\api.py with huggingface_hub
a30f9a1 verified
"""REST API детекции дефектов окраски кузова (по ТЗ АвтоВАЗа, таблица 3).
Эндпоинты:
POST /predict — приём фото детали (multipart), VIN — параметром формы;
возвращает JSON с дефектами, координатами и base64-визуализацией.
GET /defects/{vin} — последние результаты по VIN (in-memory история).
GET /health — проверка состояния сервиса.
Запуск:
uvicorn src.api:app --host 0.0.0.0 --port 8080
"""
from __future__ import annotations
import base64
import io
import time
from collections import defaultdict, deque
from datetime import datetime
from typing import Any
import cv2
import numpy as np
import torch
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from . import config as C
from .infer import load_model, predict_image, render_visualization
app = FastAPI(
title="Paint Defect Detection API",
version="1.0.0",
description="Система автоматической детекции дефектов лакокрасочного покрытия "
"(крыша, капот, багажник). Соответствует требованиям ТЗ АвтоВАЗ.",
)
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model = None # ленивая загрузка
_history: dict[str, deque] = defaultdict(lambda: deque(maxlen=20))
_STATIC_DIR = C.ROOT / "static"
if _STATIC_DIR.exists():
app.mount("/static", StaticFiles(directory=str(_STATIC_DIR)), name="static")
@app.get("/", include_in_schema=False)
def index():
"""Веб-интерфейс оператора (одностраничное приложение)."""
idx = _STATIC_DIR / "index.html"
if idx.exists():
return FileResponse(str(idx))
raise HTTPException(status_code=404, detail="UI not built")
def _ensure_model():
global _model
if _model is None:
_model = load_model(device=_device)
return _model
class DefectBox(BaseModel):
x: int; y: int; w: int; h: int
confidence: float
mean_prob: float
class PredictResponse(BaseModel):
vin: str
timestamp: str
is_defect: bool
defect_count: int
defect_ratio: float
max_prob: float
boxes: list[DefectBox]
panel_size: dict[str, int]
visualization_base64: str = Field(description="JPEG, base64-encoded, для отображения на ТВ-панели")
elapsed_ms: int
@app.get("/health")
def health() -> dict[str, Any]:
return {
"status": "ok",
"device": str(_device),
"model_loaded": _model is not None,
"checkpoint": str(C.CHECKPOINTS / "best.pt"),
}
@app.post("/predict", response_model=PredictResponse)
async def predict(
file: UploadFile = File(..., description="Фото детали кузова"),
vin: str = Form(..., description="VIN автомобиля"),
part: str = Form("unknown", description="Деталь: roof|hood|trunk"),
threshold: float = Form(C.DEFECT_THRESHOLD),
) -> PredictResponse:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Ожидался image/*")
raw = await file.read()
arr = np.frombuffer(raw, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if bgr is None:
raise HTTPException(status_code=400, detail="Не удалось декодировать изображение")
model = _ensure_model()
t0 = time.time()
result = predict_image(bgr, model, _device, threshold=threshold)
elapsed_ms = int((time.time() - t0) * 1000)
vis = render_visualization(result)
ok, buf = cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 88])
vis_b64 = base64.b64encode(buf.tobytes()).decode("ascii") if ok else ""
response = PredictResponse(
vin=vin,
timestamp=datetime.utcnow().isoformat() + "Z",
is_defect=result["is_defect"],
defect_count=len(result["boxes"]),
defect_ratio=result["defect_ratio"],
max_prob=result["max_prob"],
boxes=[DefectBox(**b) for b in result["boxes"]],
panel_size=result["panel_size"],
visualization_base64=vis_b64,
elapsed_ms=elapsed_ms,
)
_history[vin].append({"part": part, "ts": response.timestamp,
"is_defect": response.is_defect,
"defect_count": response.defect_count})
return response
@app.get("/defects/{vin}")
def defects_by_vin(vin: str) -> dict[str, Any]:
return {"vin": vin, "results": list(_history.get(vin, []))}
def main():
import uvicorn
uvicorn.run("src.api:app", host=C.API_HOST, port=C.API_PORT, reload=False)
if __name__ == "__main__":
main()