ais-api / app.py
csmith715's picture
Adding bounding box data to output
2e46b64
# import base64
import io
import os
import numpy as np
import cv2
import uvicorn
import PIL.Image as Image
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from pydantic import BaseModel
from typing import List
from ultralytics import YOLO
from weld_tiling import detect_tiled_softnms
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024)) # 8 MB default
MAX_SIDE = int(os.getenv("MAX_SIDE", 2000)) # downscale largest side to this
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") # silence config-dir warning
HIGH_CLASS_NAMES = [
"Valve", "Butterfly Valve", "Flange", "PRV", "Reducer", "Union", "Weld-o-let", "field_sw"
]
LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"]
ALL_CLASS_NAMES = HIGH_CLASS_NAMES + LOW_CLASS_NAMES
# -----------------------------
# App setup
# -----------------------------
app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0")
model = YOLO("best_7-15-25.pt")
# -----------------------------
# Schemas
# -----------------------------
class PredictResponse(BaseModel):
detections: dict
bounding_boxes: List[List[float]]
class PredictQuery(BaseModel):
image_base64: str
# -----------------------------
# Utils
# -----------------------------
def pil_to_numpy_rgb(img: Image.Image) -> np.ndarray:
return np.array(img.convert("RGB"))
def numpy_rgb_to_bgr(img: np.ndarray) -> np.ndarray:
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def downscale_if_needed(img_rgb: np.ndarray) -> np.ndarray:
h, w = img_rgb.shape[:2]
m = max(h, w)
if m <= MAX_SIDE:
return img_rgb
scale = MAX_SIDE / m
new_w, new_h = int(w * scale), int(h * scale)
return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
def normalize_prediction(output):
weld_counts = {}
for cls_pred in output['cls']:
weld_key = output['names'][cls_pred]
weld_counts[weld_key] = weld_counts.get(weld_key, 0) + 1
return weld_counts
def detect_weld_types(image_bgr: np.ndarray, model) -> dict:
out = detect_tiled_softnms(
model, image_bgr,
tile_size=1024, overlap=0.23,
per_tile_conf=0.2, per_tile_iou=0.7,
softnms_iou=0.6, softnms_method="hard", softnms_sigma=0.5,
final_conf=0.38, device=None, imgsz=1280
)
counts = normalize_prediction(out)
return counts, out['xyxy']
# -----------------------------
# Endpoints
# -----------------------------
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictResponse)
async def predict_multipart(file: UploadFile = File(default=None)):
if file is None:
raise HTTPException(status_code=400, detail="Provide a file via multipart form-data with field name 'file'.")
# Fast size check (if client sent Content-Length FastAPI won't expose easily, so read carefully)
raw = await file.read()
if len(raw) > MAX_UPLOAD_BYTES:
raise HTTPException(
status_code=413,
detail=f"File too large ({len(raw)/1024/1024:.2f} MB). Max is {MAX_UPLOAD_BYTES/1024/1024:.2f} MB."
)
try:
img = Image.open(io.BytesIO(raw))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image.")
img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
img_bgr = numpy_rgb_to_bgr(img_rgb)
welds, boxes = detect_weld_types(img_bgr, model)
# Convert numpy array to list of lists for JSON serialization
boxes_list = boxes.tolist() if isinstance(boxes, np.ndarray) else boxes
return PredictResponse(detections=welds, bounding_boxes=boxes_list)
@app.post("/ping")
async def ping():
return {"ok": True}
@app.post("/echo")
async def echo(req: Request):
# echoes JSON or form-data keys without reading big bodies
ct = req.headers.get("content-type", "")
return {"ok": True, "content_type": ct}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)