File size: 4,003 Bytes
f224d0a
9415326
f224d0a
9415326
 
 
 
efce4f7
9415326
2e46b64
9415326
311720b
9415326
f224d0a
 
 
 
9415326
 
 
 
 
 
 
8021aca
 
9415326
 
 
 
 
 
8021aca
9415326
 
 
 
 
 
 
2e46b64
9415326
 
a1864b3
9415326
 
 
 
 
 
 
 
 
 
f224d0a
 
 
 
 
 
 
 
 
8021aca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e46b64
9415326
 
 
 
 
 
 
 
 
 
f224d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e46b64
 
 
 
f224d0a
efce4f7
 
 
 
 
 
 
 
 
 
9415326
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# import base64
import io
import os
import numpy as np
import cv2
import uvicorn
import PIL.Image as Image
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from pydantic import BaseModel
from typing import List
from ultralytics import YOLO
from weld_tiling import detect_tiled_softnms

MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024))  # 8 MB default
MAX_SIDE = int(os.getenv("MAX_SIDE", 2000))  # downscale largest side to this
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")  # silence config-dir warning


HIGH_CLASS_NAMES = [
    "Valve", "Butterfly Valve", "Flange", "PRV", "Reducer", "Union", "Weld-o-let", "field_sw"
]

LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"]

ALL_CLASS_NAMES = HIGH_CLASS_NAMES + LOW_CLASS_NAMES

# -----------------------------
# App setup
# -----------------------------

app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0")

model = YOLO("best_7-15-25.pt")

# -----------------------------
# Schemas
# -----------------------------

class PredictResponse(BaseModel):
    detections: dict
    bounding_boxes: List[List[float]]

class PredictQuery(BaseModel):
    image_base64: str

# -----------------------------
# Utils
# -----------------------------
def pil_to_numpy_rgb(img: Image.Image) -> np.ndarray:
    return np.array(img.convert("RGB"))

def numpy_rgb_to_bgr(img: np.ndarray) -> np.ndarray:
    return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

def downscale_if_needed(img_rgb: np.ndarray) -> np.ndarray:
    h, w = img_rgb.shape[:2]
    m = max(h, w)
    if m <= MAX_SIDE:
        return img_rgb
    scale = MAX_SIDE / m
    new_w, new_h = int(w * scale), int(h * scale)
    return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)

def normalize_prediction(output):
    weld_counts = {}
    for cls_pred in output['cls']:
        weld_key = output['names'][cls_pred]
        weld_counts[weld_key] = weld_counts.get(weld_key, 0) + 1
    return weld_counts

def detect_weld_types(image_bgr: np.ndarray, model) -> dict:
    out = detect_tiled_softnms(
        model, image_bgr,
        tile_size=1024, overlap=0.23,
        per_tile_conf=0.2, per_tile_iou=0.7,
        softnms_iou=0.6, softnms_method="hard", softnms_sigma=0.5,
        final_conf=0.38, device=None, imgsz=1280
    )
    counts = normalize_prediction(out)
    return counts, out['xyxy']

# -----------------------------
# Endpoints
# -----------------------------
@app.get("/health")
def health():
    return {"status": "ok"}


@app.post("/predict", response_model=PredictResponse)
async def predict_multipart(file: UploadFile = File(default=None)):
    if file is None:
        raise HTTPException(status_code=400, detail="Provide a file via multipart form-data with field name 'file'.")

    # Fast size check (if client sent Content-Length FastAPI won't expose easily, so read carefully)
    raw = await file.read()
    if len(raw) > MAX_UPLOAD_BYTES:
        raise HTTPException(
            status_code=413,
            detail=f"File too large ({len(raw)/1024/1024:.2f} MB). Max is {MAX_UPLOAD_BYTES/1024/1024:.2f} MB."
        )

    try:
        img = Image.open(io.BytesIO(raw))
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid image.")

    img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
    img_bgr = numpy_rgb_to_bgr(img_rgb)
    welds, boxes = detect_weld_types(img_bgr, model)
    # Convert numpy array to list of lists for JSON serialization
    boxes_list = boxes.tolist() if isinstance(boxes, np.ndarray) else boxes
    return PredictResponse(detections=welds, bounding_boxes=boxes_list)

@app.post("/ping")
async def ping():
    return {"ok": True}

@app.post("/echo")
async def echo(req: Request):
    # echoes JSON or form-data keys without reading big bodies
    ct = req.headers.get("content-type", "")
    return {"ok": True, "content_type": ct}


if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860)