Logistikon commited on
Commit
c04146c
·
1 Parent(s): 5c09794

Init space

Browse files
Files changed (2) hide show
  1. app.py +126 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ import numpy as np
5
+ import json
6
+ import base64
7
+ import io
8
+ from megadetector.detection import run_detector
9
+
10
+ model = run_detector.load_detector('MDV5A')
11
+
12
+ # CVAT categories - customize based on your model's classes
13
+ CATEGORIES = [
14
+ {"id": 1, "name": "animal"},
15
+ {"id": 2, "name": "person"},
16
+ {"id": 3, "name": "vehicle"},
17
+ # Add all categories your model supports
18
+ ]
19
+
20
+ def process_predictions(outputs, image, confidence_threshold=0.5):
21
+ # Process the model outputs to match CVAT format
22
+ results = []
23
+
24
+ iw, ih = image.size
25
+
26
+ for det in outputs['detections']:
27
+ # Convert from [x, y, w, h] to [x1, y1, x2, y2]
28
+ x, y, w, h = det['bbox']
29
+ bbox = [x * iw, y * ih, (x + w) * iw, (y + h) * ih]
30
+ score = det['conf']
31
+ if score < confidence_threshold:
32
+ continue
33
+ # Convert to 0-indexed classes to match YOLOS
34
+ label = int(det['category']) - 1
35
+
36
+ category_id = int(label)
37
+ category_name = CATEGORIES[category_id]["name"]
38
+
39
+ result = {
40
+ "confidence": float(score),
41
+ "label": category_name,
42
+ "points": [bbox[0], bbox[1], bbox[2], bbox[3]],
43
+ "type": "rectangle"
44
+ }
45
+ results.append(result)
46
+
47
+ return results
48
+
49
+ def predict(image_data):
50
+ try:
51
+ # Decode base64 image if provided in that format
52
+
53
+ if isinstance(image_data, Image.Image):
54
+ image = image_data
55
+ elif isinstance(image_data, str) and image_data.startswith("data:image"):
56
+ image_data = image_data.split(",")[1]
57
+ image_bytes = base64.b64decode(image_data)
58
+ image = Image.open(io.BytesIO(image_bytes))
59
+ elif isinstance(image_data, np.ndarray):
60
+ image = Image.fromarray(image_data)
61
+ else:
62
+ image = Image.open(image_data)
63
+
64
+ # Process image with model
65
+
66
+ outputs = model.generate_detections_one_image(image)
67
+
68
+ # Process predictions
69
+ results = process_predictions(outputs, image)
70
+
71
+ # Return results in CVAT-compatible format
72
+ return {"results": results}
73
+
74
+ except Exception as e:
75
+ return {"error": str(e)}
76
+
77
+ # Create Gradio interface for testing
78
+ def gradio_interface(image):
79
+ results = predict(image)
80
+
81
+ # Draw bounding boxes on image for visualization
82
+ img_draw = image.copy()
83
+ draw = ImageDraw.Draw(img_draw)
84
+
85
+ for obj in results.get("results", []):
86
+ box = obj["points"]
87
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline="red", width=3)
88
+ draw.text((box[0], box[1]), f"{obj['label']} {obj['confidence']:.2f}", fill="red")
89
+
90
+ return img_draw, json.dumps(results, indent=2)
91
+
92
+ # Two interfaces:
93
+ # 1. A REST API endpoint for CVAT
94
+ # 2. A user interface for testing
95
+
96
+ # REST API for CVAT
97
+ app = gr.Interface(
98
+ fn=predict,
99
+ inputs=gr.Image(type="filepath"),
100
+ outputs="json",
101
+ title="Object Detection API for CVAT",
102
+ description="Upload an image to get object detection predictions in CVAT-compatible format",
103
+ flagging_mode="never",
104
+ )
105
+
106
+ # UI for testing
107
+ demo = gr.Interface(
108
+ fn=gradio_interface,
109
+ inputs=gr.Image(type="pil"),
110
+ outputs=[
111
+ gr.Image(type="pil", label="Detection Result"),
112
+ gr.JSON(label="JSON Output")
113
+ ],
114
+ title="Object Detection Demo",
115
+ description="Test your object detection model with this interface",
116
+ flagging_mode="never",
117
+ )
118
+
119
+ # Combine both interfaces
120
+ combined_demo = gr.TabbedInterface(
121
+ [app, demo],
122
+ ["API Endpoint", "Testing Interface"]
123
+ )
124
+
125
+ if __name__ == "__main__":
126
+ combined_demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ megadetector
2
+ pillow
3
+ gradio
4
+ numpy