devsdevline commited on
Commit
992bcd9
·
verified ·
1 Parent(s): 9cd40e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -36
app.py CHANGED
@@ -1,52 +1,61 @@
1
- from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
  from ultralytics import YOLO
4
- import os
5
- import shutil
 
 
 
6
 
7
  app = FastAPI()
8
 
9
- # Load both YOLO models
10
- FIELD_MODEL_PATH = "model/field/best.pt"
11
- PLAYER_MODEL_PATH = "model/player/last.pt"
12
 
13
- field_model = YOLO(FIELD_MODEL_PATH)
14
- player_model = YOLO(PLAYER_MODEL_PATH)
15
 
 
 
 
16
 
17
- def run_detection(model, file_path):
18
- results = model(file_path)
19
- detections = []
20
- for r in results:
21
- for box in r.boxes:
22
- detections.append({
23
- "class": int(box.cls),
24
- "confidence": float(box.conf),
25
- "bbox": box.xyxy[0].tolist()
26
- })
27
- return detections
28
 
 
 
 
 
29
 
30
- @app.post("/predict/field")
31
- async def predict_field(file: UploadFile = File(...)):
32
- temp_path = f"/tmp/{file.filename}"
33
- with open(temp_path, "wb") as buffer:
34
- shutil.copyfileobj(file.file, buffer)
35
- detections = run_detection(field_model, temp_path)
36
- os.remove(temp_path)
37
- return JSONResponse({"model": "field", "detections": detections})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  @app.post("/predict/player")
41
  async def predict_player(file: UploadFile = File(...)):
42
- temp_path = f"/tmp/{file.filename}"
43
- with open(temp_path, "wb") as buffer:
44
- shutil.copyfileobj(file.file, buffer)
45
- detections = run_detection(player_model, temp_path)
46
- os.remove(temp_path)
47
- return JSONResponse({"model": "player", "detections": detections})
48
 
49
 
50
- @app.get("/")
51
- def home():
52
- return JSONResponse({"message": "Server running. Use /predict/player or /predict/field endpoints."})
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
  from fastapi.responses import JSONResponse
3
  from ultralytics import YOLO
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import base64
7
+ import cv2
8
+ import numpy as np
9
 
10
  app = FastAPI()
11
 
12
+ # Load models once at startup
13
+ player_model = YOLO("models/player/best.pt")
14
+ field_model = YOLO("models/field/best.pt")
15
 
 
 
16
 
17
+ @app.get("/")
18
+ def home():
19
+ return {"message": "Server running ✅ Use /predict/player or /predict/field"}
20
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def process_image(file, model):
23
+ # Load uploaded image
24
+ image = Image.open(file.file).convert("RGB")
25
+ image_np = np.array(image)
26
 
27
+ # Run inference
28
+ results = model(image_np)
29
+
30
+ # Draw detections
31
+ annotated_frame = results[0].plot()
32
+
33
+ # Convert annotated image to bytes
34
+ _, buffer = cv2.imencode(".jpg", annotated_frame)
35
+ img_bytes = buffer.tobytes()
36
+
37
+ # Encode image as base64 to include in JSON response
38
+ img_base64 = base64.b64encode(img_bytes).decode("utf-8")
39
+
40
+ # Prepare JSON result
41
+ detections = []
42
+ for box in results[0].boxes:
43
+ detections.append({
44
+ "class": int(box.cls[0]),
45
+ "confidence": float(box.conf[0]),
46
+ "bbox": [float(x) for x in box.xyxy[0].tolist()]
47
+ })
48
+
49
+ return {"detections": detections, "image_base64": img_base64}
50
 
51
 
52
  @app.post("/predict/player")
53
  async def predict_player(file: UploadFile = File(...)):
54
+ result = process_image(file, player_model)
55
+ return JSONResponse(result)
 
 
 
 
56
 
57
 
58
+ @app.post("/predict/field")
59
+ async def predict_field(file: UploadFile = File(...)):
60
+ result = process_image(file, field_model)
61
+ return JSONResponse(result)