MegaDetector / app.py
Logistikon's picture
Add token authorisation
8ba1b51
import gradio as gr
import torch
from PIL import Image, ImageDraw
import numpy as np
import json
import base64
import io
import os
import secrets
from dotenv import load_dotenv
from megadetector.detection import run_detector
# Load environment variables for configuration
load_dotenv()
# Access token configuration
# You can set a fixed token in your Space's environment variables
# or generate a random one on startup (less secure)
API_TOKEN = os.getenv("API_TOKEN")
if not API_TOKEN:
# Generate a random token if not provided - will change on restart!
API_TOKEN = secrets.token_hex(16)
print(f"Generated API token: {API_TOKEN}")
print("IMPORTANT: This token will change if the space restarts!")
print("Set a permanent token in the Space's environment variables.")
def validate_token(token):
"""Validate the provided access token"""
return token == API_TOKEN
model = run_detector.load_detector('MDV5A')
# CVAT categories - customize based on your model's classes
CATEGORIES = [
{"id": 1, "name": "animal"},
{"id": 2, "name": "person"},
{"id": 3, "name": "vehicle"},
# Add all categories your model supports
]
def process_predictions(outputs, image, confidence_threshold=0.5):
# Process the model outputs to match CVAT format
results = []
iw, ih = image.size
for det in outputs['detections']:
# Convert from [x, y, w, h] to [x1, y1, x2, y2]
x, y, w, h = det['bbox']
bbox = [x * iw, y * ih, (x + w) * iw, (y + h) * ih]
score = det['conf']
if score < confidence_threshold:
continue
# Convert to 0-indexed classes to match YOLOS
label = int(det['category']) - 1
category_id = int(label)
category_name = CATEGORIES[category_id]["name"]
result = {
"confidence": float(score),
"label": category_name,
"points": [bbox[0], bbox[1], bbox[2], bbox[3]],
"type": "rectangle"
}
results.append(result)
return results
def predict(image_data, token=None):
"""Main prediction function for API endpoint
Args:
image_data: The image to be processed
token: Access token for authentication
"""
try:
# Validate access token
if token is None or not validate_token(token):
return {"error": "Authentication failed. Invalid or missing access token."}
# Handle various image input formats
if isinstance(image_data, Image.Image):
image = image_data
elif isinstance(image_data, str) and image_data.startswith("data:image"):
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
elif isinstance(image_data, np.ndarray):
image = Image.fromarray(image_data)
else:
image = Image.open(image_data)
# Process image with model
outputs = model.generate_detections_one_image(image)
# Process predictions
results = process_predictions(outputs, image)
# Return results in CVAT-compatible format
return {"results": results}
except Exception as e:
return {"error": str(e)}
# Create Gradio interface for testing
def gradio_interface(image):
# For the demo interface, we'll automatically pass the token
results = predict(image, API_TOKEN)
# Draw bounding boxes on image for visualization
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
for obj in results.get("results", []):
box = obj["points"]
draw.rectangle([box[0], box[1], box[2], box[3]], outline="red", width=3)
draw.text((box[0], box[1]), f"{obj['label']} {obj['confidence']:.2f}", fill="red")
return img_draw, json.dumps(results, indent=2)
# API endpoint for CVAT
app = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="filepath"),
gr.Textbox(label="Access Token", type="password")
],
outputs="json",
title="Object Detection API for CVAT",
description=f"Upload an image to get object detection predictions in CVAT-compatible format. Requires access token.",
flagging_mode="never",
)
# UI for testing
demo = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detection Result"),
gr.JSON(label="JSON Output")
],
title="Object Detection Demo",
description="Test your object detection model with this interface",
flagging_mode="never",
)
# Combine both interfaces
combined_demo = gr.TabbedInterface(
[app, demo],
["API Endpoint", "Testing Interface"]
)
if __name__ == "__main__":
combined_demo.launch()