stevafernandes commited on
Commit
d46c0bc
·
verified ·
1 Parent(s): 337d8c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import keras_cv
5
+ import keras
6
+
7
+ # COCO class labels (80 classes)
8
+ COCO_CLASSES = [
9
+ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
10
+ "truck", "boat", "traffic light", "fire hydrant", "stop sign",
11
+ "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
12
+ "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
13
+ "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
14
+ "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
15
+ "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
16
+ "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
17
+ "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
18
+ "couch", "potted plant", "bed", "dining table", "toilet", "tv",
19
+ "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
20
+ "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
21
+ "scissors", "teddy bear", "hair drier", "toothbrush",
22
+ ]
23
+
24
+ # Color palette for bounding boxes
25
+ COLORS = [
26
+ "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7",
27
+ "#DDA0DD", "#98D8C8", "#F7DC6F", "#BB8FCE", "#85C1E9",
28
+ "#F8C471", "#82E0AA", "#F1948A", "#AED6F1", "#D7BDE2",
29
+ ]
30
+
31
+
32
+ def load_model():
33
+ """Load pretrained YOLOv8 model from KerasCV."""
34
+ model = keras_cv.models.YOLOV8Detector.from_preset(
35
+ "yolo_v8_m_pascalvoc",
36
+ bounding_box_format="xyxy",
37
+ )
38
+ return model
39
+
40
+
41
+ print("Loading model...")
42
+ model = load_model()
43
+ print("Model loaded!")
44
+
45
+
46
+ def detect_objects(image, confidence_threshold=0.5):
47
+ """Run object detection on a single image."""
48
+ if image is None:
49
+ return None
50
+
51
+ orig_image = Image.fromarray(image)
52
+ orig_w, orig_h = orig_image.size
53
+
54
+ # Resize for model input
55
+ input_size = 640
56
+ resized = orig_image.resize((input_size, input_size))
57
+ img_array = np.array(resized, dtype="float32")
58
+ input_batch = np.expand_dims(img_array, axis=0)
59
+
60
+ # Run prediction
61
+ predictions = model.predict(input_batch)
62
+
63
+ boxes = predictions["boxes"][0]
64
+ classes = predictions["classes"][0]
65
+ confidence = predictions["confidence"][0]
66
+
67
+ # Convert to numpy if needed
68
+ if hasattr(boxes, "numpy"):
69
+ boxes = boxes.numpy()
70
+ classes = classes.numpy()
71
+ confidence = confidence.numpy()
72
+
73
+ # Draw results on original image
74
+ draw = ImageDraw.Draw(orig_image)
75
+
76
+ try:
77
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
78
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 13)
79
+ except OSError:
80
+ font = ImageFont.load_default()
81
+ small_font = font
82
+
83
+ detections_found = 0
84
+
85
+ for i in range(len(boxes)):
86
+ score = float(confidence[i])
87
+ if score < confidence_threshold:
88
+ continue
89
+
90
+ cls_id = int(classes[i])
91
+ if cls_id < 0 or cls_id >= len(COCO_CLASSES):
92
+ label = f"class_{cls_id}"
93
+ else:
94
+ label = COCO_CLASSES[cls_id]
95
+
96
+ # Scale boxes from resized coords back to original image
97
+ x1 = float(boxes[i][0]) * orig_w / input_size
98
+ y1 = float(boxes[i][1]) * orig_h / input_size
99
+ x2 = float(boxes[i][2]) * orig_w / input_size
100
+ y2 = float(boxes[i][3]) * orig_h / input_size
101
+
102
+ color = COLORS[cls_id % len(COLORS)]
103
+
104
+ # Draw bounding box
105
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
106
+
107
+ # Draw label background + text
108
+ text = f"{label} {score:.0%}"
109
+ bbox = draw.textbbox((x1, y1), text, font=font)
110
+ text_w = bbox[2] - bbox[0]
111
+ text_h = bbox[3] - bbox[1]
112
+ draw.rectangle([x1, y1 - text_h - 6, x1 + text_w + 8, y1], fill=color)
113
+ draw.text((x1 + 4, y1 - text_h - 4), text, fill="white", font=font)
114
+
115
+ detections_found += 1
116
+
117
+ status = f"Found {detections_found} object(s)" if detections_found else "No objects detected"
118
+ return orig_image, status
119
+
120
+
121
+ # Build the Gradio interface
122
+ with gr.Blocks(title="Keras Object Detection") as demo:
123
+ gr.Markdown("# Object Detection with KerasCV YOLOv8")
124
+ gr.Markdown("Upload an image to detect objects using a pretrained YOLOv8 model.")
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ input_image = gr.Image(label="Upload Image", type="numpy")
129
+ threshold = gr.Slider(
130
+ minimum=0.1,
131
+ maximum=0.95,
132
+ value=0.5,
133
+ step=0.05,
134
+ label="Confidence Threshold",
135
+ )
136
+ run_btn = gr.Button("Detect Objects", variant="primary")
137
+ with gr.Column():
138
+ output_image = gr.Image(label="Detections")
139
+ status_text = gr.Textbox(label="Status", interactive=False)
140
+
141
+ run_btn.click(
142
+ fn=detect_objects,
143
+ inputs=[input_image, threshold],
144
+ outputs=[output_image, status_text],
145
+ )
146
+
147
+ demo.launch()