Draco15628 commited on
Commit
aa8c1b9
·
verified ·
1 Parent(s): 083a0c5

Update rasp.py

Browse files
Files changed (1) hide show
  1. rasp.py +76 -0
rasp.py CHANGED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import cv2
3
+ import numpy as np
4
+ import gradio as gr
5
+
6
+ # Load the pre-trained MobileNet SSD model
7
+ model = tf.saved_model.load("http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_fpnlite_320x320/saved_model")
8
+
9
+ # Define the label map for the MobileNet SSD model
10
+ category_index = {
11
+ 1: {'id': 1, 'name': 'person'},
12
+ 2: {'id': 2, 'name': 'bicycle'},
13
+ 3: {'id': 3, 'name': 'car'},
14
+ # Add more label mappings as needed
15
+ }
16
+
17
+ # Function to detect objects in the image
18
+ def detect_objects(image):
19
+ # Preprocess the image
20
+ input_tensor = tf.convert_to_tensor(image)
21
+ input_tensor = input_tensor[tf.newaxis,...]
22
+
23
+ # Run the model and get detections
24
+ detections = model(input_tensor)
25
+
26
+ # Process detections and draw bounding boxes
27
+ num_detections = int(detections.pop('num_detections'))
28
+ detections = {key: value[0, :num_detections].numpy() for key, value in detections.items()}
29
+ detection_classes = detections['detection_classes'].astype(np.int64)
30
+ detection_boxes = detections['detection_boxes']
31
+ detection_scores = detections['detection_scores']
32
+
33
+ # Draw boxes on the image
34
+ for i in range(num_detections):
35
+ if detection_scores[i] > 0.5: # Only consider confident detections
36
+ class_name = category_index.get(detection_classes[i], {'name': 'N/A'})['name']
37
+ box = detection_boxes[i]
38
+ height, width, _ = image.shape
39
+ ymin, xmin, ymax, xmax = box
40
+ (startX, startY, endX, endY) = (int(xmin * width), int(ymin * height), int(xmax * width), int(ymax * height))
41
+
42
+ # Draw bounding box and label
43
+ cv2.rectangle(image, (startX, startY), (endX, endY), (0, 255, 0), 2)
44
+ label = f"{class_name}: {detection_scores[i]:.2f}"
45
+ cv2.putText(image, label, (startX, startY - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
46
+
47
+ return image
48
+
49
+ # Function to handle the image from file upload or path
50
+ def gradio_interface(image):
51
+ if isinstance(image, str): # Check if it's a path string
52
+ image = cv2.imread(image)
53
+ else:
54
+ # Convert PIL image (Gradio) to OpenCV format (numpy array)
55
+ image = np.array(image)
56
+
57
+ # Convert to RGB format
58
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
59
+
60
+ # Detect objects in the image
61
+ detected_image = detect_objects(image_rgb)
62
+
63
+ # Convert back to BGR for display in Gradio (OpenCV uses BGR)
64
+ detected_image_bgr = cv2.cvtColor(detected_image, cv2.COLOR_RGB2BGR)
65
+
66
+ return detected_image_bgr
67
+
68
+ # Create Gradio app with image input (supports path or upload)
69
+ iface = gr.Interface(fn=gradio_interface,
70
+ inputs=gr.inputs.Image(type="filepath"), # Use "filepath" to allow local path or upload
71
+ outputs="image",
72
+ title="Object Detection with Bounding Boxes",
73
+ description="Upload an image or provide a file path to detect objects.")
74
+
75
+ # Launch the Gradio app (for local or Hugging Face)
76
+ iface.launch()