devsdevline commited on
Commit
86a5c6a
·
verified ·
1 Parent(s): 8b0eaf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -45
app.py CHANGED
@@ -1,54 +1,47 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
- from fastapi.responses import StreamingResponse
4
  from ultralytics import YOLO
5
- from PIL import Image
6
- import io
7
 
8
- app = FastAPI(title="YOLO API - Football Model")
9
 
10
- # Load model
11
- try:
12
- model = YOLO("best.pt")
13
- print("✅ Loaded best.pt")
14
- except:
15
- model = YOLO("last.pt")
16
- print("⚠️ Loaded last.pt")
17
 
18
- @app.get("/")
19
- def home():
20
- return {"message": "YOLO Football Model API is running!"}
21
 
22
 
23
- @app.post("/predict")
24
- async def predict(file: UploadFile = File(...)):
25
- contents = await file.read()
26
- image = Image.open(io.BytesIO(contents)).convert("RGB")
27
-
28
- # Run inference
29
- results = model(image)
30
-
31
- # Get JSON detections
32
  detections = []
33
- for box in results[0].boxes:
34
- detections.append({
35
- "class": int(box.cls),
36
- "confidence": float(box.conf),
37
- "bbox": box.xyxy[0].tolist()
38
- })
39
-
40
- # Save image with bounding boxes drawn by YOLO
41
- annotated_image = results[0].plot() # numpy array (BGR)
42
- annotated_image = Image.fromarray(annotated_image[..., ::-1]) # convert BGR→RGB
43
-
44
- # Convert to bytes for response
45
- img_bytes = io.BytesIO()
46
- annotated_image.save(img_bytes, format="JPEG")
47
- img_bytes.seek(0)
48
-
49
- # Return both JSON and image
50
- return StreamingResponse(
51
- img_bytes,
52
- media_type="image/jpeg",
53
- headers={"detections": str(detections)}
54
- )
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import JSONResponse
 
3
  from ultralytics import YOLO
4
+ import os
5
+ import shutil
6
 
7
+ app = FastAPI()
8
 
9
+ # Load both YOLO models
10
+ FIELD_MODEL_PATH = "models/field/best.pt"
11
+ PLAYER_MODEL_PATH = "models/player/last.pt"
 
 
 
 
12
 
13
+ field_model = YOLO(FIELD_MODEL_PATH)
14
+ player_model = YOLO(PLAYER_MODEL_PATH)
 
15
 
16
 
17
+ def run_detection(model, file_path):
18
+ results = model(file_path)
 
 
 
 
 
 
 
19
  detections = []
20
+ for r in results:
21
+ for box in r.boxes:
22
+ detections.append({
23
+ "class": int(box.cls),
24
+ "confidence": float(box.conf),
25
+ "bbox": box.xyxy[0].tolist()
26
+ })
27
+ return detections
28
+
29
+
30
+ @app.post("/predict/field")
31
+ async def predict_field(file: UploadFile = File(...)):
32
+ temp_path = f"/tmp/{file.filename}"
33
+ with open(temp_path, "wb") as buffer:
34
+ shutil.copyfileobj(file.file, buffer)
35
+ detections = run_detection(field_model, temp_path)
36
+ os.remove(temp_path)
37
+ return JSONResponse({"model": "field", "detections": detections})
38
+
39
+
40
+ @app.post("/predict/player")
41
+ async def predict_player(file: UploadFile = File(...)):
42
+ temp_path = f"/tmp/{file.filename}"
43
+ with open(temp_path, "wb") as buffer:
44
+ shutil.copyfileobj(file.file, buffer)
45
+ detections = run_detection(player_model, temp_path)
46
+ os.remove(temp_path)
47
+ return JSONResponse({"model": "player", "detections": detections})