Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| import json | |
| import base64 | |
| import io | |
| import os | |
| import secrets | |
| from dotenv import load_dotenv | |
| from megadetector.detection import run_detector | |
| # Load environment variables for configuration | |
| load_dotenv() | |
| # Access token configuration | |
| # You can set a fixed token in your Space's environment variables | |
| # or generate a random one on startup (less secure) | |
| API_TOKEN = os.getenv("API_TOKEN") | |
| if not API_TOKEN: | |
| # Generate a random token if not provided - will change on restart! | |
| API_TOKEN = secrets.token_hex(16) | |
| print(f"Generated API token: {API_TOKEN}") | |
| print("IMPORTANT: This token will change if the space restarts!") | |
| print("Set a permanent token in the Space's environment variables.") | |
| def validate_token(token): | |
| """Validate the provided access token""" | |
| return token == API_TOKEN | |
| model = run_detector.load_detector('MDV5A') | |
| # CVAT categories - customize based on your model's classes | |
| CATEGORIES = [ | |
| {"id": 1, "name": "animal"}, | |
| {"id": 2, "name": "person"}, | |
| {"id": 3, "name": "vehicle"}, | |
| # Add all categories your model supports | |
| ] | |
| def process_predictions(outputs, image, confidence_threshold=0.5): | |
| # Process the model outputs to match CVAT format | |
| results = [] | |
| iw, ih = image.size | |
| for det in outputs['detections']: | |
| # Convert from [x, y, w, h] to [x1, y1, x2, y2] | |
| x, y, w, h = det['bbox'] | |
| bbox = [x * iw, y * ih, (x + w) * iw, (y + h) * ih] | |
| score = det['conf'] | |
| if score < confidence_threshold: | |
| continue | |
| # Convert to 0-indexed classes to match YOLOS | |
| label = int(det['category']) - 1 | |
| category_id = int(label) | |
| category_name = CATEGORIES[category_id]["name"] | |
| result = { | |
| "confidence": float(score), | |
| "label": category_name, | |
| "points": [bbox[0], bbox[1], bbox[2], bbox[3]], | |
| "type": "rectangle" | |
| } | |
| results.append(result) | |
| return results | |
| def predict(image_data, token=None): | |
| """Main prediction function for API endpoint | |
| Args: | |
| image_data: The image to be processed | |
| token: Access token for authentication | |
| """ | |
| try: | |
| # Validate access token | |
| if token is None or not validate_token(token): | |
| return {"error": "Authentication failed. Invalid or missing access token."} | |
| # Handle various image input formats | |
| if isinstance(image_data, Image.Image): | |
| image = image_data | |
| elif isinstance(image_data, str) and image_data.startswith("data:image"): | |
| image_data = image_data.split(",")[1] | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| elif isinstance(image_data, np.ndarray): | |
| image = Image.fromarray(image_data) | |
| else: | |
| image = Image.open(image_data) | |
| # Process image with model | |
| outputs = model.generate_detections_one_image(image) | |
| # Process predictions | |
| results = process_predictions(outputs, image) | |
| # Return results in CVAT-compatible format | |
| return {"results": results} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Create Gradio interface for testing | |
| def gradio_interface(image): | |
| # For the demo interface, we'll automatically pass the token | |
| results = predict(image, API_TOKEN) | |
| # Draw bounding boxes on image for visualization | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| for obj in results.get("results", []): | |
| box = obj["points"] | |
| draw.rectangle([box[0], box[1], box[2], box[3]], outline="red", width=3) | |
| draw.text((box[0], box[1]), f"{obj['label']} {obj['confidence']:.2f}", fill="red") | |
| return img_draw, json.dumps(results, indent=2) | |
| # API endpoint for CVAT | |
| app = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="filepath"), | |
| gr.Textbox(label="Access Token", type="password") | |
| ], | |
| outputs="json", | |
| title="Object Detection API for CVAT", | |
| description=f"Upload an image to get object detection predictions in CVAT-compatible format. Requires access token.", | |
| flagging_mode="never", | |
| ) | |
| # UI for testing | |
| demo = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Detection Result"), | |
| gr.JSON(label="JSON Output") | |
| ], | |
| title="Object Detection Demo", | |
| description="Test your object detection model with this interface", | |
| flagging_mode="never", | |
| ) | |
| # Combine both interfaces | |
| combined_demo = gr.TabbedInterface( | |
| [app, demo], | |
| ["API Endpoint", "Testing Interface"] | |
| ) | |
| if __name__ == "__main__": | |
| combined_demo.launch() |