Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import onnxruntime as ort | |
| import shutil | |
| import os | |
| import uuid | |
| import base64 | |
| import time | |
| import json | |
| from fastapi import FastAPI, UploadFile, File, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| # Import the tracker (YOLOX-based) | |
| from bytetrack_yolox import ByteTrackWrapper | |
| # ---------------- CONFIGURATION ---------------- # | |
| YOLO_MODEL_PATH = "best.onnx" | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| os.makedirs("static", exist_ok=True) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # ---------------- YOLO MODEL ---------------- # | |
| class YOLO: | |
| def __init__(self, model_path): | |
| self.session = ort.InferenceSession(model_path) | |
| self.input_name = self.session.get_inputs()[0].name | |
| self.h, self.w = self.session.get_inputs()[0].shape[2:] | |
| self.conf = 0.50 | |
| self.iou = 0.45 | |
| self.classes = [ | |
| "Zebra", "Lion", "Leopard", "Cheetah", "Tiger", "Bear", "Butterfly", | |
| "Canary", "Crocodile", "Bull", "Camel", "Centipede", "Caterpillar", | |
| "Duck", "Squirrel", "Spider", "Ladybug", "Elephant", "Horse", "Fox", | |
| "Tortoise", "Frog", "Kangaroo", "Deer", "Eagle", "Monkey", "Snake", | |
| "Owl", "Swan", "Goat", "Rabbit", "Giraffe", "Goose", "PolarBear", | |
| "Raven", "Hippopotamus", "BrownBear", "Rhinoceros", "Woodpecker", | |
| "Sheep", "Magpie", "Ostrich", "Jaguar", "Hedgehog", "Turkey", | |
| "Raccoon", "Worm", "Harbor", "Panda", "RedPanda", "Otter", "Lynx", | |
| "Scorpion", "Koala" | |
| ] | |
| np.random.seed(42) | |
| # Generate a large palette of random colors for Tracks | |
| self.colors = np.random.randint(0, 255, size=(200, 3)).tolist() | |
| def preprocess(self, img): | |
| h0, w0 = img.shape[:2] | |
| scale = min(self.w / w0, self.h / h0) | |
| nw, nh = int(w0 * scale), int(h0 * scale) | |
| resized = cv2.resize(img, (nw, nh)) | |
| canvas = np.full((self.h, self.w, 3), 114, dtype=np.uint8) | |
| canvas[:nh, :nw] = resized | |
| img = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB) | |
| img = img.transpose(2, 0, 1).astype(np.float32) / 255.0 | |
| img = np.expand_dims(img, 0) | |
| return img, scale | |
| def postprocess(self, output, scale): | |
| preds = output[0][0].transpose() | |
| boxes, scores, ids = [], [], [] | |
| for p in preds: | |
| x,y,w,h = p[:4] | |
| cls_scores = p[4:] | |
| cid = int(np.argmax(cls_scores)) | |
| score = cls_scores[cid] | |
| if score >= self.conf: | |
| x1 = (x - w/2) / scale | |
| y1 = (y - h/2) / scale | |
| x2 = (x + w/2) / scale | |
| y2 = (y + h/2) / scale | |
| boxes.append([float(x1),float(y1),float(x2),float(y2)]) | |
| scores.append(float(score)) | |
| ids.append(cid) | |
| results = [] | |
| idxs = cv2.dnn.NMSBoxes(boxes, scores, self.conf, self.iou) | |
| if len(idxs) > 0: | |
| for i in idxs.flatten(): | |
| results.append({ | |
| "class": self.classes[ids[i]], | |
| "confidence": scores[i], | |
| "box": boxes[i], | |
| "id": ids[i] | |
| }) | |
| return results | |
| def draw(self, img, detections): | |
| for d in detections: | |
| x1,y1,x2,y2 = map(int, d["box"]) | |
| # Use Track ID for color if available, otherwise Class ID | |
| track_id = d.get('track_id') | |
| if track_id is not None: | |
| # Color based on Track ID (consistent color for same object) | |
| color_idx = int(track_id) % len(self.colors) | |
| label = f"{d['class']} #{track_id}" | |
| else: | |
| # Fallback to Class ID | |
| color_idx = int(d["id"]) % len(self.colors) | |
| label = f"{d['class']} ({d['confidence']:.2f})" | |
| color = self.colors[color_idx] | |
| color = (int(color[0]), int(color[1]), int(color[2])) | |
| cv2.rectangle(img, (x1,y1), (x2,y2), color, 3) | |
| # Label background | |
| (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) | |
| cv2.rectangle(img, (x1, y1 - 25), (x1 + w, y1), color, -1) | |
| cv2.putText(img, label, (x1, y1-8), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2) | |
| return img | |
| # Initialize YOLO and Tracker | |
| yolo = YOLO(YOLO_MODEL_PATH) | |
| tracker = ByteTrackWrapper(fps=30, track_thresh=0.5, match_thresh=0.8) | |
| # ---------------- ROUTES ---------------- # | |
| async def detect_image(file: UploadFile = File(...)): | |
| start_t = time.time() | |
| temp = f"temp_{file.filename}" | |
| with open(temp, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| img = cv2.imread(temp) | |
| if img is None: | |
| return "<h2>Error reading image</h2>" | |
| # 1. Inference | |
| tensor, scale = yolo.preprocess(img) | |
| output = yolo.session.run(None, {yolo.input_name: tensor}) | |
| detections = yolo.postprocess(output, scale) | |
| # 2. Tracking | |
| # Even on a static upload, we run the tracker to assign IDs. | |
| tracker.update(detections) | |
| # 3. Draw | |
| img = yolo.draw(img, detections) | |
| name = f"output_{uuid.uuid4().hex}.jpg" | |
| path = f"static/{name}" | |
| cv2.imwrite(path, img) | |
| if os.path.exists(temp): | |
| os.remove(temp) | |
| process_ms = (time.time() - start_t) * 1000 | |
| return f""" | |
| <h2>✅ Detection Result</h2> | |
| <p>⏱️ Processed in {process_ms:.2f}ms</p> | |
| <div style="margin-bottom: 20px;"> | |
| <img src="/static/{name}" width="800" style="border-radius: 10px; border: 2px solid #333;"/> | |
| </div> | |
| <a href="/">⬅ Upload Another</a> | |
| """ | |
| async def detect_frame(request: Request): | |
| start_t = time.time() | |
| data = await request.json() | |
| img_data = data.get("image") | |
| if not img_data: | |
| return JSONResponse({"error": "No image provided"}, status_code=400) | |
| # Decode Image | |
| try: | |
| # Splits 'data:image/jpeg;base64,...' | |
| img_bytes = base64.b64decode(img_data.split(',')[1]) | |
| nparr = np.frombuffer(img_bytes, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| except Exception as e: | |
| return JSONResponse({"error": f"Invalid image data: {str(e)}"}, status_code=400) | |
| # 1. YOLO Inference | |
| tensor, scale = yolo.preprocess(img) | |
| output = yolo.session.run(None, {yolo.input_name: tensor}) | |
| detections = yolo.postprocess(output, scale) | |
| # 2. Update Tracker | |
| # The tracker modifies 'detections' in-place, adding 'track_id' to objects | |
| tracker.update(detections) | |
| # 3. Draw | |
| img = yolo.draw(img, detections) | |
| # Encode back to base64 | |
| _, buffer = cv2.imencode('.jpg', img) | |
| img_base64 = base64.b64encode(buffer).decode('utf-8') | |
| end_t = time.time() | |
| latency_ms = (end_t - start_t) * 1000 | |
| return JSONResponse({ | |
| "image": f"data:image/jpeg;base64,{img_base64}", | |
| "detections": detections, | |
| "latency_ms": f"{latency_ms:.1f}" | |
| }) | |
| def webcam_page(): | |
| if os.path.exists("webcam.html"): | |
| with open("webcam.html", "r", encoding="utf-8") as f: | |
| return f.read() | |
| else: | |
| return "<h1>Error: webcam.html not found. Please create it.</h1>" | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |