devsdevline commited on
Commit
ff32071
·
verified ·
1 Parent(s): ecaedad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -31
app.py CHANGED
@@ -1,31 +1,37 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from ultralytics import YOLO
3
- from huggingface_hub import hf_hub_download
4
- import io
5
- from PIL import Image
6
-
7
- app = FastAPI()
8
-
9
- # Download your YOLO model from your model repo
10
- model_path = hf_hub_download("devsdevline/soccer-homography-yolo", "best.pt")
11
- model = YOLO(model_path)
12
-
13
- @app.get("/")
14
- def root():
15
- return {"status": "API running", "model": "soccer-homography-yolo"}
16
-
17
- @app.post("/predict")
18
- async def predict(file: UploadFile = File(...)):
19
- # Read uploaded image
20
- image = Image.open(io.BytesIO(await file.read())).convert("RGB")
21
- results = model(image)
22
-
23
- detections = []
24
- for box in results[0].boxes:
25
- detections.append({
26
- "class_id": int(box.cls[0]),
27
- "confidence": float(box.conf[0]),
28
- "bbox": box.xyxy[0].tolist()
29
- })
30
-
31
- return {"detections": detections}
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from ultralytics import YOLO
3
+ from PIL import Image
4
+ import io
5
+ import torch
6
+
7
+ app = FastAPI(title="YOLO API - Football Model")
8
+
9
+ # Load model from local file
10
+ try:
11
+ model = YOLO("best.pt")
12
+ print("✅ Loaded best.pt model")
13
+ except Exception as e:
14
+ print("⚠️ best.pt not found, trying last.pt:", e)
15
+ model = YOLO("last.pt")
16
+
17
+ @app.get("/")
18
+ def home():
19
+ return {"message": "YOLO Football Model API is running!"}
20
+
21
+ @app.post("/predict")
22
+ async def predict(file: UploadFile = File(...)):
23
+ contents = await file.read()
24
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
25
+
26
+ # Run inference
27
+ results = model(image)
28
+
29
+ detections = []
30
+ for box in results[0].boxes:
31
+ detections.append({
32
+ "class": int(box.cls),
33
+ "confidence": float(box.conf),
34
+ "bbox": box.xyxy[0].tolist()
35
+ })
36
+
37
+ return {"detections": detections}