# import base64 import io import os import numpy as np import cv2 import uvicorn import PIL.Image as Image from fastapi import FastAPI, UploadFile, File, HTTPException, Request from pydantic import BaseModel from typing import List from ultralytics import YOLO from weld_tiling import detect_tiled_softnms MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024)) # 8 MB default MAX_SIDE = int(os.getenv("MAX_SIDE", 2000)) # downscale largest side to this os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") # silence config-dir warning HIGH_CLASS_NAMES = [ "Valve", "Butterfly Valve", "Flange", "PRV", "Reducer", "Union", "Weld-o-let", "field_sw" ] LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"] ALL_CLASS_NAMES = HIGH_CLASS_NAMES + LOW_CLASS_NAMES # ----------------------------- # App setup # ----------------------------- app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0") model = YOLO("best_7-15-25.pt") # ----------------------------- # Schemas # ----------------------------- class PredictResponse(BaseModel): detections: dict bounding_boxes: List[List[float]] class PredictQuery(BaseModel): image_base64: str # ----------------------------- # Utils # ----------------------------- def pil_to_numpy_rgb(img: Image.Image) -> np.ndarray: return np.array(img.convert("RGB")) def numpy_rgb_to_bgr(img: np.ndarray) -> np.ndarray: return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) def downscale_if_needed(img_rgb: np.ndarray) -> np.ndarray: h, w = img_rgb.shape[:2] m = max(h, w) if m <= MAX_SIDE: return img_rgb scale = MAX_SIDE / m new_w, new_h = int(w * scale), int(h * scale) return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA) def normalize_prediction(output): weld_counts = {} for cls_pred in output['cls']: weld_key = output['names'][cls_pred] weld_counts[weld_key] = weld_counts.get(weld_key, 0) + 1 return weld_counts def detect_weld_types(image_bgr: np.ndarray, model) -> dict: out = detect_tiled_softnms( model, image_bgr, tile_size=1024, overlap=0.23, per_tile_conf=0.2, per_tile_iou=0.7, softnms_iou=0.6, softnms_method="hard", softnms_sigma=0.5, final_conf=0.38, device=None, imgsz=1280 ) counts = normalize_prediction(out) return counts, out['xyxy'] # ----------------------------- # Endpoints # ----------------------------- @app.get("/health") def health(): return {"status": "ok"} @app.post("/predict", response_model=PredictResponse) async def predict_multipart(file: UploadFile = File(default=None)): if file is None: raise HTTPException(status_code=400, detail="Provide a file via multipart form-data with field name 'file'.") # Fast size check (if client sent Content-Length FastAPI won't expose easily, so read carefully) raw = await file.read() if len(raw) > MAX_UPLOAD_BYTES: raise HTTPException( status_code=413, detail=f"File too large ({len(raw)/1024/1024:.2f} MB). Max is {MAX_UPLOAD_BYTES/1024/1024:.2f} MB." ) try: img = Image.open(io.BytesIO(raw)) except Exception: raise HTTPException(status_code=400, detail="Invalid image.") img_rgb = downscale_if_needed(pil_to_numpy_rgb(img)) img_bgr = numpy_rgb_to_bgr(img_rgb) welds, boxes = detect_weld_types(img_bgr, model) # Convert numpy array to list of lists for JSON serialization boxes_list = boxes.tolist() if isinstance(boxes, np.ndarray) else boxes return PredictResponse(detections=welds, bounding_boxes=boxes_list) @app.post("/ping") async def ping(): return {"ok": True} @app.post("/echo") async def echo(req: Request): # echoes JSON or form-data keys without reading big bodies ct = req.headers.get("content-type", "") return {"ok": True, "content_type": ct} if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=7860)