Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from tensorflow.keras.models import load_model | |
| from PIL import Image | |
| import io | |
| # Set up Streamlit page | |
| st.set_page_config(page_title="Object Detection and Classification App", page_icon="🖼️", layout="wide") | |
| # Load models | |
| def load_models(): | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| object_detection_model = torch.load("fasterrcnn_resnet50_fpn_270824.pth", map_location=device) | |
| object_detection_model.to(device) | |
| object_detection_model.eval() | |
| classification_model = load_model('resnet50_6000_2.h5') | |
| return object_detection_model, classification_model, device | |
| object_detection_model, classification_model, device = load_models() | |
| # Helper functions | |
| def preprocess_image(image, target_size=(256, 256)): | |
| img = image.resize(target_size) | |
| img_array = np.array(img).astype('float32') / 255.0 | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| def classify_image(image): | |
| processed_image = preprocess_image(image) | |
| prediction = classification_model.predict(processed_image) | |
| predicted_class = np.argmax(prediction, axis=1)[0] | |
| class_labels = ['fail', 'pass'] | |
| return class_labels[predicted_class] | |
| def convert_png_to_jpg(image): | |
| if image.format == 'PNG': | |
| rgb_im = image.convert('RGB') | |
| img_byte_arr = io.BytesIO() | |
| rgb_im.save(img_byte_arr, format='JPEG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| return Image.open(io.BytesIO(img_byte_arr)) | |
| return image | |
| def resize_to_square(image): | |
| h, w = image.shape[:2] | |
| # Determine the shorter side | |
| shorter_side = min(h, w) | |
| # Crop to create a square | |
| if h > w: # portrait image | |
| start = (h - w) // 2 | |
| cropped = image[start:start+w, :] | |
| else: # landscape or square image | |
| start = (w - h) // 2 | |
| cropped = image[:, start:start+h] | |
| return cropped | |
| def perform_object_detection(image): | |
| original_size = image.size | |
| target_size = (256, 256) | |
| frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA) | |
| frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32) | |
| frame_rgb /= 255.0 | |
| frame_rgb = frame_rgb.transpose(2, 0, 1) | |
| frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = object_detection_model(frame_rgb) | |
| boxes = outputs[0]['boxes'].cpu().detach().numpy().astype(np.int32) | |
| labels = outputs[0]['labels'].cpu().detach().numpy().astype(np.int32) | |
| scores = outputs[0]['scores'] | |
| result_image = frame_resized.copy() | |
| cropped_images = [] # List to hold multiple cropped images | |
| for i in range(len(boxes)): | |
| if scores[i] >= 0.75: | |
| x1, y1, x2, y2 = boxes[i] | |
| if (int(labels[i])-1) == 1 or (int(labels[i])-1) == 0: | |
| color = (0, 0, 255) | |
| label_text = 'Flame stone surface' | |
| else: | |
| st.info("Không nhìn thấy bề mặt đá đốt") | |
| continue # Skip objects that aren't of interest | |
| # Crop the detected region from the original image | |
| original_h, original_w = original_size[::-1] | |
| scale_h, scale_w = original_h / target_size[0], original_w / target_size[1] | |
| x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h) | |
| x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h) | |
| cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig] | |
| # Resize the cropped image to a square while maintaining resolution | |
| resized_crop = resize_to_square(cropped_image) | |
| cropped_images.append(resized_crop) | |
| # Draw bounding boxes on the result image | |
| cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3) | |
| cv2.putText(result_image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| return Image.fromarray(result_image), cropped_images | |
| # Main app | |
| def main(): | |
| st.title('🖼️ Object Detection and Classification App') | |
| st.write("Upload an image for object detection and classification.") | |
| tab1, tab2 = st.tabs(["🖼️ OB and BC", "BC"]) | |
| with tab1: | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader_1") | |
| col1, col2 = st.columns(2) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| image = convert_png_to_jpg(image) | |
| col1.image(image, caption='Uploaded Image') | |
| with st.spinner('Processing...'): | |
| # Perform object detection and get cropped images | |
| detection_result, cropped_images = perform_object_detection(image) | |
| col2.image(detection_result, caption='Object Detection Result') | |
| # If cropped images are detected, classify each | |
| if cropped_images is not None and len(cropped_images) > 0: | |
| st.subheader("Cropped Images and Classification Results") | |
| # Lặp qua tất cả các ảnh đã cắt | |
| for idx, cropped_image in enumerate(cropped_images): | |
| cropped_image_pil = Image.fromarray(cropped_image) | |
| classification_result = classify_image(cropped_image_pil) | |
| # Tạo hai cột cho mỗi ảnh đã cắt và kết quả phân loại của nó | |
| img_col, result_col = st.columns([1, 2]) | |
| with img_col: | |
| st.image(cropped_image_pil, caption=f'Cropped Image {idx + 1}', use_column_width=True) | |
| with result_col: | |
| if classification_result == 'pass': | |
| st.success(f"Classification: {classification_result.upper()}") | |
| else: | |
| st.error(f"Classification: {classification_result.upper()}") | |
| else: | |
| st.warning("No object detected with a confidence of 0.75 or higher.") | |
| with tab2: | |
| st.header('Image Classification') | |
| uploaded_file_2 = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader_2") | |
| if uploaded_file_2 is not None: | |
| image = convert_png_to_jpg(Image.open(uploaded_file_2)) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption='Uploaded Image', use_column_width=True) | |
| with col2: | |
| with st.spinner('Classifying...'): | |
| classification_result = classify_image(image) | |
| if classification_result == 'pass': | |
| st.success(f"Classification: {classification_result.upper()}") | |
| else: | |
| st.error(f"Classification: {classification_result.upper()}") | |
| # Sidebar and footer | |
| st.sidebar.header("About") | |
| st.sidebar.info( | |
| "This app performs both object detection and image classification. " | |
| "Upload an image to see the results!" | |
| ) | |
| st.markdown( | |
| """ | |
| <style> | |
| .footer { | |
| position: fixed; | |
| left: 0; | |
| bottom: 0; | |
| width: 100%; | |
| background-color: #0E1117; | |
| color: #FAFAFA; | |
| text-align: center; | |
| padding: 10px; | |
| font-size: 12px; | |
| } | |
| </style> | |
| <div class="footer"> | |
| Developed by Tran Thanh Son | © 2024 Object Detection and Classification App | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| if __name__ == "__main__": | |
| main() | |