import gradio as gr import torch from PIL import Image, ImageDraw import numpy as np import json import base64 import io import os import secrets from dotenv import load_dotenv from megadetector.detection import run_detector # Load environment variables for configuration load_dotenv() # Access token configuration # You can set a fixed token in your Space's environment variables # or generate a random one on startup (less secure) API_TOKEN = os.getenv("API_TOKEN") if not API_TOKEN: # Generate a random token if not provided - will change on restart! API_TOKEN = secrets.token_hex(16) print(f"Generated API token: {API_TOKEN}") print("IMPORTANT: This token will change if the space restarts!") print("Set a permanent token in the Space's environment variables.") def validate_token(token): """Validate the provided access token""" return token == API_TOKEN model = run_detector.load_detector('MDV5A') # CVAT categories - customize based on your model's classes CATEGORIES = [ {"id": 1, "name": "animal"}, {"id": 2, "name": "person"}, {"id": 3, "name": "vehicle"}, # Add all categories your model supports ] def process_predictions(outputs, image, confidence_threshold=0.5): # Process the model outputs to match CVAT format results = [] iw, ih = image.size for det in outputs['detections']: # Convert from [x, y, w, h] to [x1, y1, x2, y2] x, y, w, h = det['bbox'] bbox = [x * iw, y * ih, (x + w) * iw, (y + h) * ih] score = det['conf'] if score < confidence_threshold: continue # Convert to 0-indexed classes to match YOLOS label = int(det['category']) - 1 category_id = int(label) category_name = CATEGORIES[category_id]["name"] result = { "confidence": float(score), "label": category_name, "points": [bbox[0], bbox[1], bbox[2], bbox[3]], "type": "rectangle" } results.append(result) return results def predict(image_data, token=None): """Main prediction function for API endpoint Args: image_data: The image to be processed token: Access token for authentication """ try: # Validate access token if token is None or not validate_token(token): return {"error": "Authentication failed. Invalid or missing access token."} # Handle various image input formats if isinstance(image_data, Image.Image): image = image_data elif isinstance(image_data, str) and image_data.startswith("data:image"): image_data = image_data.split(",")[1] image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)) elif isinstance(image_data, np.ndarray): image = Image.fromarray(image_data) else: image = Image.open(image_data) # Process image with model outputs = model.generate_detections_one_image(image) # Process predictions results = process_predictions(outputs, image) # Return results in CVAT-compatible format return {"results": results} except Exception as e: return {"error": str(e)} # Create Gradio interface for testing def gradio_interface(image): # For the demo interface, we'll automatically pass the token results = predict(image, API_TOKEN) # Draw bounding boxes on image for visualization img_draw = image.copy() draw = ImageDraw.Draw(img_draw) for obj in results.get("results", []): box = obj["points"] draw.rectangle([box[0], box[1], box[2], box[3]], outline="red", width=3) draw.text((box[0], box[1]), f"{obj['label']} {obj['confidence']:.2f}", fill="red") return img_draw, json.dumps(results, indent=2) # API endpoint for CVAT app = gr.Interface( fn=predict, inputs=[ gr.Image(type="filepath"), gr.Textbox(label="Access Token", type="password") ], outputs="json", title="Object Detection API for CVAT", description=f"Upload an image to get object detection predictions in CVAT-compatible format. Requires access token.", flagging_mode="never", ) # UI for testing demo = gr.Interface( fn=gradio_interface, inputs=gr.Image(type="pil"), outputs=[ gr.Image(type="pil", label="Detection Result"), gr.JSON(label="JSON Output") ], title="Object Detection Demo", description="Test your object detection model with this interface", flagging_mode="never", ) # Combine both interfaces combined_demo = gr.TabbedInterface( [app, demo], ["API Endpoint", "Testing Interface"] ) if __name__ == "__main__": combined_demo.launch()