| |
| import io |
| import os |
| import numpy as np |
| import cv2 |
| import uvicorn |
| import PIL.Image as Image |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request |
| from pydantic import BaseModel |
| from typing import List |
| from ultralytics import YOLO |
| from weld_tiling import detect_tiled_softnms |
|
|
| MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024)) |
| MAX_SIDE = int(os.getenv("MAX_SIDE", 2000)) |
| os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") |
|
|
|
|
| HIGH_CLASS_NAMES = [ |
| "Valve", "Butterfly Valve", "Flange", "PRV", "Reducer", "Union", "Weld-o-let", "field_sw" |
| ] |
|
|
| LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"] |
|
|
| ALL_CLASS_NAMES = HIGH_CLASS_NAMES + LOW_CLASS_NAMES |
|
|
| |
| |
| |
|
|
| app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0") |
|
|
| model = YOLO("best_7-15-25.pt") |
|
|
| |
| |
| |
|
|
| class PredictResponse(BaseModel): |
| detections: dict |
| bounding_boxes: List[List[float]] |
|
|
| class PredictQuery(BaseModel): |
| image_base64: str |
|
|
| |
| |
| |
| def pil_to_numpy_rgb(img: Image.Image) -> np.ndarray: |
| return np.array(img.convert("RGB")) |
|
|
| def numpy_rgb_to_bgr(img: np.ndarray) -> np.ndarray: |
| return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
| def downscale_if_needed(img_rgb: np.ndarray) -> np.ndarray: |
| h, w = img_rgb.shape[:2] |
| m = max(h, w) |
| if m <= MAX_SIDE: |
| return img_rgb |
| scale = MAX_SIDE / m |
| new_w, new_h = int(w * scale), int(h * scale) |
| return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
| def normalize_prediction(output): |
| weld_counts = {} |
| for cls_pred in output['cls']: |
| weld_key = output['names'][cls_pred] |
| weld_counts[weld_key] = weld_counts.get(weld_key, 0) + 1 |
| return weld_counts |
|
|
| def detect_weld_types(image_bgr: np.ndarray, model) -> dict: |
| out = detect_tiled_softnms( |
| model, image_bgr, |
| tile_size=1024, overlap=0.23, |
| per_tile_conf=0.2, per_tile_iou=0.7, |
| softnms_iou=0.6, softnms_method="hard", softnms_sigma=0.5, |
| final_conf=0.38, device=None, imgsz=1280 |
| ) |
| counts = normalize_prediction(out) |
| return counts, out['xyxy'] |
|
|
| |
| |
| |
| @app.get("/health") |
| def health(): |
| return {"status": "ok"} |
|
|
|
|
| @app.post("/predict", response_model=PredictResponse) |
| async def predict_multipart(file: UploadFile = File(default=None)): |
| if file is None: |
| raise HTTPException(status_code=400, detail="Provide a file via multipart form-data with field name 'file'.") |
|
|
| |
| raw = await file.read() |
| if len(raw) > MAX_UPLOAD_BYTES: |
| raise HTTPException( |
| status_code=413, |
| detail=f"File too large ({len(raw)/1024/1024:.2f} MB). Max is {MAX_UPLOAD_BYTES/1024/1024:.2f} MB." |
| ) |
|
|
| try: |
| img = Image.open(io.BytesIO(raw)) |
| except Exception: |
| raise HTTPException(status_code=400, detail="Invalid image.") |
|
|
| img_rgb = downscale_if_needed(pil_to_numpy_rgb(img)) |
| img_bgr = numpy_rgb_to_bgr(img_rgb) |
| welds, boxes = detect_weld_types(img_bgr, model) |
| |
| boxes_list = boxes.tolist() if isinstance(boxes, np.ndarray) else boxes |
| return PredictResponse(detections=welds, bounding_boxes=boxes_list) |
|
|
| @app.post("/ping") |
| async def ping(): |
| return {"ok": True} |
|
|
| @app.post("/echo") |
| async def echo(req: Request): |
| |
| ct = req.headers.get("content-type", "") |
| return {"ok": True, "content_type": ct} |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) |
|
|