Spaces:
Sleeping
Sleeping
added object detection to the space UI
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import scipy
|
|
| 6 |
from PIL import Image
|
| 7 |
import torch.nn as nn
|
| 8 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
|
| 9 |
-
from
|
| 10 |
|
| 11 |
def load_caption_model(blip2=False, instructblip=True):
|
| 12 |
|
|
@@ -65,3 +65,56 @@ if st.button("Get Answer"):
|
|
| 65 |
st.write(answer)
|
| 66 |
else:
|
| 67 |
st.write("Please upload an image and enter a question.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
import torch.nn as nn
|
| 8 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
|
| 9 |
+
from my_model.object_detection import ObjectDetector
|
| 10 |
|
| 11 |
def load_caption_model(blip2=False, instructblip=True):
|
| 12 |
|
|
|
|
| 65 |
st.write(answer)
|
| 66 |
else:
|
| 67 |
st.write("Please upload an image and enter a question.")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Object Detection
|
| 75 |
+
|
| 76 |
+
# Object Detection UI in the sidebar
|
| 77 |
+
st.sidebar.title("Object Detection")
|
| 78 |
+
# Dropdown to select the model
|
| 79 |
+
detect_model = st.sidebar.selectbox("Choose a model for object detection:", ["detic", "yolov5"])
|
| 80 |
+
# Slider for threshold with default values based on the model
|
| 81 |
+
threshold = st.sidebar.slider("Select Detection Threshold", 0.1, 0.9, 0.2 if detect_model == "yolov5" else 0.4)
|
| 82 |
+
# Button to trigger object detection
|
| 83 |
+
detect_button = st.sidebar.button("Detect Objects")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def perform_object_detection(image, model_name, threshold):
|
| 87 |
+
"""
|
| 88 |
+
Perform object detection on the given image using the specified model and threshold.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
image (PIL.Image): The image on which to perform object detection.
|
| 92 |
+
model_name (str): The name of the object detection model to use.
|
| 93 |
+
threshold (float): The threshold for object detection.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
PIL.Image, str: The image with drawn bounding boxes and a string of detected objects.
|
| 97 |
+
"""
|
| 98 |
+
# Initialize the ObjectDetector
|
| 99 |
+
detector = ObjectDetector()
|
| 100 |
+
# Load the specified model
|
| 101 |
+
detector.load_model(model_name)
|
| 102 |
+
# Perform object detection
|
| 103 |
+
processed_image, detected_objects = detector.detect_objects(image, threshold)
|
| 104 |
+
return processed_image, detected_objects
|
| 105 |
+
|
| 106 |
+
# Check if the 'Detect Objects' button was clicked
|
| 107 |
+
if detect_button:
|
| 108 |
+
if image is not None:
|
| 109 |
+
# Open the uploaded image
|
| 110 |
+
image = Image.open(image)
|
| 111 |
+
# Display the original image
|
| 112 |
+
st.image(image, use_column_width=True, caption="Original Image")
|
| 113 |
+
# Perform object detection
|
| 114 |
+
processed_image, detected_objects = perform_object_detection(image, detect_model, threshold)
|
| 115 |
+
# Display the image with detected objects
|
| 116 |
+
st.image(processed_image, use_column_width=True, caption="Image with Detected Objects")
|
| 117 |
+
# Display the detected objects
|
| 118 |
+
st.write(detected_objects)
|
| 119 |
+
else:
|
| 120 |
+
st.write("Please upload an image for object detection.")
|