Spaces:
Sleeping
Sleeping
File size: 8,044 Bytes
39dce41 0c98179 39dce41 0c98179 39dce41 0c98179 39dce41 0c98179 39dce41 0c98179 39dce41 0c98179 39dce41 0c98179 39dce41 0c98179 14d5ca7 0c98179 14d5ca7 0c98179 43ab3a0 0c98179 43ab3a0 0c98179 14d5ca7 43ab3a0 0c98179 14d5ca7 0c98179 9440d11 3502260 9440d11 3502260 0c98179 f1c06ad 3502260 feb8679 1be6fac 0c98179 3502260 43ab3a0 403af64 3502260 403af64 3502260 403af64 3502260 9440d11 3502260 3e59cdd 3502260 3e59cdd 6d8b763 8c9d4c0 39dce41 0c98179 39dce41 0c98179 39dce41 0c98179 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import streamlit as st
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from tensorflow.keras.models import load_model
from PIL import Image
import io
# Set up Streamlit page
st.set_page_config(page_title="Object Detection and Classification App", page_icon="🖼️", layout="wide")
# Load models
@st.cache_resource
def load_models():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
object_detection_model = torch.load("fasterrcnn_resnet50_fpn_270824.pth", map_location=device)
object_detection_model.to(device)
object_detection_model.eval()
classification_model = load_model('resnet50_6000_2.h5')
return object_detection_model, classification_model, device
object_detection_model, classification_model, device = load_models()
# Helper functions
def preprocess_image(image, target_size=(256, 256)):
img = image.resize(target_size)
img_array = np.array(img).astype('float32') / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
def classify_image(image):
processed_image = preprocess_image(image)
prediction = classification_model.predict(processed_image)
predicted_class = np.argmax(prediction, axis=1)[0]
class_labels = ['fail', 'pass']
return class_labels[predicted_class]
def convert_png_to_jpg(image):
if image.format == 'PNG':
rgb_im = image.convert('RGB')
img_byte_arr = io.BytesIO()
rgb_im.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
return Image.open(io.BytesIO(img_byte_arr))
return image
def resize_to_square(image):
h, w = image.shape[:2]
# Determine the shorter side
shorter_side = min(h, w)
# Crop to create a square
if h > w: # portrait image
start = (h - w) // 2
cropped = image[start:start+w, :]
else: # landscape or square image
start = (w - h) // 2
cropped = image[:, start:start+h]
return cropped
def perform_object_detection(image):
original_size = image.size
target_size = (256, 256)
frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA)
frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32)
frame_rgb /= 255.0
frame_rgb = frame_rgb.transpose(2, 0, 1)
frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device)
with torch.no_grad():
outputs = object_detection_model(frame_rgb)
boxes = outputs[0]['boxes'].cpu().detach().numpy().astype(np.int32)
labels = outputs[0]['labels'].cpu().detach().numpy().astype(np.int32)
scores = outputs[0]['scores']
result_image = frame_resized.copy()
cropped_images = [] # List to hold multiple cropped images
for i in range(len(boxes)):
if scores[i] >= 0.75:
x1, y1, x2, y2 = boxes[i]
if (int(labels[i])-1) == 1 or (int(labels[i])-1) == 0:
color = (0, 0, 255)
label_text = 'Flame stone surface'
else:
st.info("Không nhìn thấy bề mặt đá đốt")
continue # Skip objects that aren't of interest
# Crop the detected region from the original image
original_h, original_w = original_size[::-1]
scale_h, scale_w = original_h / target_size[0], original_w / target_size[1]
x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h)
x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h)
cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig]
# Resize the cropped image to a square while maintaining resolution
resized_crop = resize_to_square(cropped_image)
cropped_images.append(resized_crop)
# Draw bounding boxes on the result image
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3)
cv2.putText(result_image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return Image.fromarray(result_image), cropped_images
# Main app
def main():
st.title('🖼️ Object Detection and Classification App')
st.write("Upload an image for object detection and classification.")
tab1, tab2 = st.tabs(["🖼️ OB and BC", "BC"])
with tab1:
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader_1")
col1, col2 = st.columns(2)
if uploaded_file is not None:
image = Image.open(uploaded_file)
image = convert_png_to_jpg(image)
col1.image(image, caption='Uploaded Image')
with st.spinner('Processing...'):
# Perform object detection and get cropped images
detection_result, cropped_images = perform_object_detection(image)
col2.image(detection_result, caption='Object Detection Result')
# If cropped images are detected, classify each
if cropped_images is not None and len(cropped_images) > 0:
st.subheader("Cropped Images and Classification Results")
# Lặp qua tất cả các ảnh đã cắt
for idx, cropped_image in enumerate(cropped_images):
cropped_image_pil = Image.fromarray(cropped_image)
classification_result = classify_image(cropped_image_pil)
# Tạo hai cột cho mỗi ảnh đã cắt và kết quả phân loại của nó
img_col, result_col = st.columns([1, 2])
with img_col:
st.image(cropped_image_pil, caption=f'Cropped Image {idx + 1}', use_column_width=True)
with result_col:
if classification_result == 'pass':
st.success(f"Classification: {classification_result.upper()}")
else:
st.error(f"Classification: {classification_result.upper()}")
else:
st.warning("No object detected with a confidence of 0.75 or higher.")
with tab2:
st.header('Image Classification')
uploaded_file_2 = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader_2")
if uploaded_file_2 is not None:
image = convert_png_to_jpg(Image.open(uploaded_file_2))
col1, col2 = st.columns(2)
with col1:
st.image(image, caption='Uploaded Image', use_column_width=True)
with col2:
with st.spinner('Classifying...'):
classification_result = classify_image(image)
if classification_result == 'pass':
st.success(f"Classification: {classification_result.upper()}")
else:
st.error(f"Classification: {classification_result.upper()}")
# Sidebar and footer
st.sidebar.header("About")
st.sidebar.info(
"This app performs both object detection and image classification. "
"Upload an image to see the results!"
)
st.markdown(
"""
<style>
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: #0E1117;
color: #FAFAFA;
text-align: center;
padding: 10px;
font-size: 12px;
}
</style>
<div class="footer">
Developed by Tran Thanh Son | © 2024 Object Detection and Classification App
</div>
""",
unsafe_allow_html=True
)
if __name__ == "__main__":
main()
|