File size: 4,996 Bytes
c04146c
 
 
 
 
 
 
8ba1b51
 
 
c04146c
 
8ba1b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c04146c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba1b51
 
 
 
 
 
 
c04146c
8ba1b51
 
 
 
 
c04146c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba1b51
 
c04146c
 
 
 
 
 
 
 
 
 
 
 
8ba1b51
c04146c
 
8ba1b51
 
 
 
c04146c
 
8ba1b51
c04146c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import torch
from PIL import Image, ImageDraw
import numpy as np
import json
import base64
import io
import os
import secrets
from dotenv import load_dotenv
from megadetector.detection import run_detector

# Load environment variables for configuration
load_dotenv()

# Access token configuration
# You can set a fixed token in your Space's environment variables
# or generate a random one on startup (less secure)
API_TOKEN = os.getenv("API_TOKEN")
if not API_TOKEN:
    # Generate a random token if not provided - will change on restart!
    API_TOKEN = secrets.token_hex(16)
    print(f"Generated API token: {API_TOKEN}")
    print("IMPORTANT: This token will change if the space restarts!")
    print("Set a permanent token in the Space's environment variables.")

def validate_token(token):
    """Validate the provided access token"""
    return token == API_TOKEN

model = run_detector.load_detector('MDV5A')

# CVAT categories - customize based on your model's classes
CATEGORIES = [
    {"id": 1, "name": "animal"},
    {"id": 2, "name": "person"},
    {"id": 3, "name": "vehicle"},
    # Add all categories your model supports
]

def process_predictions(outputs, image, confidence_threshold=0.5):
    # Process the model outputs to match CVAT format
    results = []

    iw, ih = image.size

    for det in outputs['detections']:
        # Convert from [x, y, w, h] to [x1, y1, x2, y2]
        x, y, w, h = det['bbox']
        bbox = [x * iw, y * ih, (x + w) * iw, (y + h) * ih]
        score = det['conf']
        if score < confidence_threshold:
            continue
        # Convert to 0-indexed classes to match YOLOS
        label = int(det['category']) - 1
    
        category_id = int(label)
        category_name = CATEGORIES[category_id]["name"]
        
        result = {
            "confidence": float(score),
            "label": category_name,
            "points": [bbox[0], bbox[1], bbox[2], bbox[3]],
            "type": "rectangle"
        }
        results.append(result)

    return results

def predict(image_data, token=None):
    """Main prediction function for API endpoint

    

    Args:

        image_data: The image to be processed

        token: Access token for authentication

    """
    try:
        # Validate access token
        if token is None or not validate_token(token):
            return {"error": "Authentication failed. Invalid or missing access token."}
            
        # Handle various image input formats
        if isinstance(image_data, Image.Image):
            image = image_data
        elif isinstance(image_data, str) and image_data.startswith("data:image"):
            image_data = image_data.split(",")[1]
            image_bytes = base64.b64decode(image_data)
            image = Image.open(io.BytesIO(image_bytes))
        elif isinstance(image_data, np.ndarray):
            image = Image.fromarray(image_data)
        else:
            image = Image.open(image_data)
            
        # Process image with model
        outputs = model.generate_detections_one_image(image)
        
        # Process predictions
        results = process_predictions(outputs, image)
        
        # Return results in CVAT-compatible format
        return {"results": results}
    
    except Exception as e:
        return {"error": str(e)}

# Create Gradio interface for testing
def gradio_interface(image):
    # For the demo interface, we'll automatically pass the token
    results = predict(image, API_TOKEN)
    
    # Draw bounding boxes on image for visualization
    img_draw = image.copy()
    draw = ImageDraw.Draw(img_draw)
    
    for obj in results.get("results", []):
        box = obj["points"]
        draw.rectangle([box[0], box[1], box[2], box[3]], outline="red", width=3)
        draw.text((box[0], box[1]), f"{obj['label']} {obj['confidence']:.2f}", fill="red")
    
    return img_draw, json.dumps(results, indent=2)

# API endpoint for CVAT
app = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="filepath"),
        gr.Textbox(label="Access Token", type="password")
    ],
    outputs="json",
    title="Object Detection API for CVAT",
    description=f"Upload an image to get object detection predictions in CVAT-compatible format. Requires access token.",
    flagging_mode="never",
)

# UI for testing
demo = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detection Result"),
        gr.JSON(label="JSON Output")
    ],
    title="Object Detection Demo",
    description="Test your object detection model with this interface",
    flagging_mode="never",
)

# Combine both interfaces
combined_demo = gr.TabbedInterface(
    [app, demo],
    ["API Endpoint", "Testing Interface"]
)

if __name__ == "__main__":
    combined_demo.launch()