csmith715 commited on
Commit
2e46b64
·
1 Parent(s): c17d855

Adding bounding box data to output

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -7,6 +7,7 @@ import uvicorn
7
  import PIL.Image as Image
8
  from fastapi import FastAPI, UploadFile, File, HTTPException, Request
9
  from pydantic import BaseModel
 
10
  from ultralytics import YOLO
11
  from weld_tiling import detect_tiled_softnms
12
 
@@ -37,6 +38,7 @@ model = YOLO("best_7-15-25.pt")
37
 
38
  class PredictResponse(BaseModel):
39
  detections: dict
 
40
 
41
  class PredictQuery(BaseModel):
42
  image_base64: str
@@ -75,7 +77,7 @@ def detect_weld_types(image_bgr: np.ndarray, model) -> dict:
75
  final_conf=0.38, device=None, imgsz=1280
76
  )
77
  counts = normalize_prediction(out)
78
- return counts
79
 
80
  # -----------------------------
81
  # Endpoints
@@ -105,8 +107,10 @@ async def predict_multipart(file: UploadFile = File(default=None)):
105
 
106
  img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
107
  img_bgr = numpy_rgb_to_bgr(img_rgb)
108
- welds = detect_weld_types(img_bgr, model)
109
- return PredictResponse(detections=welds)
 
 
110
 
111
  @app.post("/ping")
112
  async def ping():
 
7
  import PIL.Image as Image
8
  from fastapi import FastAPI, UploadFile, File, HTTPException, Request
9
  from pydantic import BaseModel
10
+ from typing import List
11
  from ultralytics import YOLO
12
  from weld_tiling import detect_tiled_softnms
13
 
 
38
 
39
  class PredictResponse(BaseModel):
40
  detections: dict
41
+ bounding_boxes: List[List[float]]
42
 
43
  class PredictQuery(BaseModel):
44
  image_base64: str
 
77
  final_conf=0.38, device=None, imgsz=1280
78
  )
79
  counts = normalize_prediction(out)
80
+ return counts, out['xyxy']
81
 
82
  # -----------------------------
83
  # Endpoints
 
107
 
108
  img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
109
  img_bgr = numpy_rgb_to_bgr(img_rgb)
110
+ welds, boxes = detect_weld_types(img_bgr, model)
111
+ # Convert numpy array to list of lists for JSON serialization
112
+ boxes_list = boxes.tolist() if isinstance(boxes, np.ndarray) else boxes
113
+ return PredictResponse(detections=welds, bounding_boxes=boxes_list)
114
 
115
  @app.post("/ping")
116
  async def ping():