File size: 4,003 Bytes
f224d0a 9415326 f224d0a 9415326 efce4f7 9415326 2e46b64 9415326 311720b 9415326 f224d0a 9415326 8021aca 9415326 8021aca 9415326 2e46b64 9415326 a1864b3 9415326 f224d0a 8021aca 2e46b64 9415326 f224d0a 2e46b64 f224d0a efce4f7 9415326 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | # import base64
import io
import os
import numpy as np
import cv2
import uvicorn
import PIL.Image as Image
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from pydantic import BaseModel
from typing import List
from ultralytics import YOLO
from weld_tiling import detect_tiled_softnms
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024)) # 8 MB default
MAX_SIDE = int(os.getenv("MAX_SIDE", 2000)) # downscale largest side to this
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") # silence config-dir warning
HIGH_CLASS_NAMES = [
"Valve", "Butterfly Valve", "Flange", "PRV", "Reducer", "Union", "Weld-o-let", "field_sw"
]
LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"]
ALL_CLASS_NAMES = HIGH_CLASS_NAMES + LOW_CLASS_NAMES
# -----------------------------
# App setup
# -----------------------------
app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0")
model = YOLO("best_7-15-25.pt")
# -----------------------------
# Schemas
# -----------------------------
class PredictResponse(BaseModel):
detections: dict
bounding_boxes: List[List[float]]
class PredictQuery(BaseModel):
image_base64: str
# -----------------------------
# Utils
# -----------------------------
def pil_to_numpy_rgb(img: Image.Image) -> np.ndarray:
return np.array(img.convert("RGB"))
def numpy_rgb_to_bgr(img: np.ndarray) -> np.ndarray:
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def downscale_if_needed(img_rgb: np.ndarray) -> np.ndarray:
h, w = img_rgb.shape[:2]
m = max(h, w)
if m <= MAX_SIDE:
return img_rgb
scale = MAX_SIDE / m
new_w, new_h = int(w * scale), int(h * scale)
return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
def normalize_prediction(output):
weld_counts = {}
for cls_pred in output['cls']:
weld_key = output['names'][cls_pred]
weld_counts[weld_key] = weld_counts.get(weld_key, 0) + 1
return weld_counts
def detect_weld_types(image_bgr: np.ndarray, model) -> dict:
out = detect_tiled_softnms(
model, image_bgr,
tile_size=1024, overlap=0.23,
per_tile_conf=0.2, per_tile_iou=0.7,
softnms_iou=0.6, softnms_method="hard", softnms_sigma=0.5,
final_conf=0.38, device=None, imgsz=1280
)
counts = normalize_prediction(out)
return counts, out['xyxy']
# -----------------------------
# Endpoints
# -----------------------------
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictResponse)
async def predict_multipart(file: UploadFile = File(default=None)):
if file is None:
raise HTTPException(status_code=400, detail="Provide a file via multipart form-data with field name 'file'.")
# Fast size check (if client sent Content-Length FastAPI won't expose easily, so read carefully)
raw = await file.read()
if len(raw) > MAX_UPLOAD_BYTES:
raise HTTPException(
status_code=413,
detail=f"File too large ({len(raw)/1024/1024:.2f} MB). Max is {MAX_UPLOAD_BYTES/1024/1024:.2f} MB."
)
try:
img = Image.open(io.BytesIO(raw))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image.")
img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
img_bgr = numpy_rgb_to_bgr(img_rgb)
welds, boxes = detect_weld_types(img_bgr, model)
# Convert numpy array to list of lists for JSON serialization
boxes_list = boxes.tolist() if isinstance(boxes, np.ndarray) else boxes
return PredictResponse(detections=welds, bounding_boxes=boxes_list)
@app.post("/ping")
async def ping():
return {"ok": True}
@app.post("/echo")
async def echo(req: Request):
# echoes JSON or form-data keys without reading big bodies
ct = req.headers.get("content-type", "")
return {"ok": True, "content_type": ct}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|