tracking / main.py
shivamsshhiivvaamm's picture
Upload 3 files
28cfaab verified
import cv2
import numpy as np
import onnxruntime as ort
import shutil
import os
import uuid
import base64
import time
import json
from fastapi import FastAPI, UploadFile, File, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
# Import the tracker (YOLOX-based)
from bytetrack_yolox import ByteTrackWrapper
# ---------------- CONFIGURATION ---------------- #
YOLO_MODEL_PATH = "best.onnx"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
os.makedirs("static", exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
# ---------------- YOLO MODEL ---------------- #
class YOLO:
def __init__(self, model_path):
self.session = ort.InferenceSession(model_path)
self.input_name = self.session.get_inputs()[0].name
self.h, self.w = self.session.get_inputs()[0].shape[2:]
self.conf = 0.50
self.iou = 0.45
self.classes = [
"Zebra", "Lion", "Leopard", "Cheetah", "Tiger", "Bear", "Butterfly",
"Canary", "Crocodile", "Bull", "Camel", "Centipede", "Caterpillar",
"Duck", "Squirrel", "Spider", "Ladybug", "Elephant", "Horse", "Fox",
"Tortoise", "Frog", "Kangaroo", "Deer", "Eagle", "Monkey", "Snake",
"Owl", "Swan", "Goat", "Rabbit", "Giraffe", "Goose", "PolarBear",
"Raven", "Hippopotamus", "BrownBear", "Rhinoceros", "Woodpecker",
"Sheep", "Magpie", "Ostrich", "Jaguar", "Hedgehog", "Turkey",
"Raccoon", "Worm", "Harbor", "Panda", "RedPanda", "Otter", "Lynx",
"Scorpion", "Koala"
]
np.random.seed(42)
# Generate a large palette of random colors for Tracks
self.colors = np.random.randint(0, 255, size=(200, 3)).tolist()
def preprocess(self, img):
h0, w0 = img.shape[:2]
scale = min(self.w / w0, self.h / h0)
nw, nh = int(w0 * scale), int(h0 * scale)
resized = cv2.resize(img, (nw, nh))
canvas = np.full((self.h, self.w, 3), 114, dtype=np.uint8)
canvas[:nh, :nw] = resized
img = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
img = np.expand_dims(img, 0)
return img, scale
def postprocess(self, output, scale):
preds = output[0][0].transpose()
boxes, scores, ids = [], [], []
for p in preds:
x,y,w,h = p[:4]
cls_scores = p[4:]
cid = int(np.argmax(cls_scores))
score = cls_scores[cid]
if score >= self.conf:
x1 = (x - w/2) / scale
y1 = (y - h/2) / scale
x2 = (x + w/2) / scale
y2 = (y + h/2) / scale
boxes.append([float(x1),float(y1),float(x2),float(y2)])
scores.append(float(score))
ids.append(cid)
results = []
idxs = cv2.dnn.NMSBoxes(boxes, scores, self.conf, self.iou)
if len(idxs) > 0:
for i in idxs.flatten():
results.append({
"class": self.classes[ids[i]],
"confidence": scores[i],
"box": boxes[i],
"id": ids[i]
})
return results
def draw(self, img, detections):
for d in detections:
x1,y1,x2,y2 = map(int, d["box"])
# Use Track ID for color if available, otherwise Class ID
track_id = d.get('track_id')
if track_id is not None:
# Color based on Track ID (consistent color for same object)
color_idx = int(track_id) % len(self.colors)
label = f"{d['class']} #{track_id}"
else:
# Fallback to Class ID
color_idx = int(d["id"]) % len(self.colors)
label = f"{d['class']} ({d['confidence']:.2f})"
color = self.colors[color_idx]
color = (int(color[0]), int(color[1]), int(color[2]))
cv2.rectangle(img, (x1,y1), (x2,y2), color, 3)
# Label background
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
cv2.rectangle(img, (x1, y1 - 25), (x1 + w, y1), color, -1)
cv2.putText(img, label, (x1, y1-8), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
return img
# Initialize YOLO and Tracker
yolo = YOLO(YOLO_MODEL_PATH)
tracker = ByteTrackWrapper(fps=30, track_thresh=0.5, match_thresh=0.8)
# ---------------- ROUTES ---------------- #
@app.post("/detect", response_class=HTMLResponse)
async def detect_image(file: UploadFile = File(...)):
start_t = time.time()
temp = f"temp_{file.filename}"
with open(temp, "wb") as f:
shutil.copyfileobj(file.file, f)
img = cv2.imread(temp)
if img is None:
return "<h2>Error reading image</h2>"
# 1. Inference
tensor, scale = yolo.preprocess(img)
output = yolo.session.run(None, {yolo.input_name: tensor})
detections = yolo.postprocess(output, scale)
# 2. Tracking
# Even on a static upload, we run the tracker to assign IDs.
tracker.update(detections)
# 3. Draw
img = yolo.draw(img, detections)
name = f"output_{uuid.uuid4().hex}.jpg"
path = f"static/{name}"
cv2.imwrite(path, img)
if os.path.exists(temp):
os.remove(temp)
process_ms = (time.time() - start_t) * 1000
return f"""
<h2>✅ Detection Result</h2>
<p>⏱️ Processed in {process_ms:.2f}ms</p>
<div style="margin-bottom: 20px;">
<img src="/static/{name}" width="800" style="border-radius: 10px; border: 2px solid #333;"/>
</div>
<a href="/">⬅ Upload Another</a>
"""
@app.post("/detect-frame")
async def detect_frame(request: Request):
start_t = time.time()
data = await request.json()
img_data = data.get("image")
if not img_data:
return JSONResponse({"error": "No image provided"}, status_code=400)
# Decode Image
try:
# Splits 'data:image/jpeg;base64,...'
img_bytes = base64.b64decode(img_data.split(',')[1])
nparr = np.frombuffer(img_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
except Exception as e:
return JSONResponse({"error": f"Invalid image data: {str(e)}"}, status_code=400)
# 1. YOLO Inference
tensor, scale = yolo.preprocess(img)
output = yolo.session.run(None, {yolo.input_name: tensor})
detections = yolo.postprocess(output, scale)
# 2. Update Tracker
# The tracker modifies 'detections' in-place, adding 'track_id' to objects
tracker.update(detections)
# 3. Draw
img = yolo.draw(img, detections)
# Encode back to base64
_, buffer = cv2.imencode('.jpg', img)
img_base64 = base64.b64encode(buffer).decode('utf-8')
end_t = time.time()
latency_ms = (end_t - start_t) * 1000
return JSONResponse({
"image": f"data:image/jpeg;base64,{img_base64}",
"detections": detections,
"latency_ms": f"{latency_ms:.1f}"
})
@app.get("/", response_class=HTMLResponse)
def webcam_page():
if os.path.exists("webcam.html"):
with open("webcam.html", "r", encoding="utf-8") as f:
return f.read()
else:
return "<h1>Error: webcam.html not found. Please create it.</h1>"
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)