import streamlit as st import cv2 import numpy as np import torch import torchvision.transforms as transforms from tensorflow.keras.models import load_model from PIL import Image import io # Set up Streamlit page st.set_page_config(page_title="Object Detection and Classification App", page_icon="🖼️", layout="wide") # Load models @st.cache_resource def load_models(): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') object_detection_model = torch.load("fasterrcnn_resnet50_fpn_270824.pth", map_location=device) object_detection_model.to(device) object_detection_model.eval() classification_model = load_model('resnet50_6000_2.h5') return object_detection_model, classification_model, device object_detection_model, classification_model, device = load_models() # Helper functions def preprocess_image(image, target_size=(256, 256)): img = image.resize(target_size) img_array = np.array(img).astype('float32') / 255.0 img_array = np.expand_dims(img_array, axis=0) return img_array def classify_image(image): processed_image = preprocess_image(image) prediction = classification_model.predict(processed_image) predicted_class = np.argmax(prediction, axis=1)[0] class_labels = ['fail', 'pass'] return class_labels[predicted_class] def convert_png_to_jpg(image): if image.format == 'PNG': rgb_im = image.convert('RGB') img_byte_arr = io.BytesIO() rgb_im.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() return Image.open(io.BytesIO(img_byte_arr)) return image def resize_to_square(image): h, w = image.shape[:2] # Determine the shorter side shorter_side = min(h, w) # Crop to create a square if h > w: # portrait image start = (h - w) // 2 cropped = image[start:start+w, :] else: # landscape or square image start = (w - h) // 2 cropped = image[:, start:start+h] return cropped def perform_object_detection(image): original_size = image.size target_size = (256, 256) frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA) frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32) frame_rgb /= 255.0 frame_rgb = frame_rgb.transpose(2, 0, 1) frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device) with torch.no_grad(): outputs = object_detection_model(frame_rgb) boxes = outputs[0]['boxes'].cpu().detach().numpy().astype(np.int32) labels = outputs[0]['labels'].cpu().detach().numpy().astype(np.int32) scores = outputs[0]['scores'] result_image = frame_resized.copy() cropped_images = [] # List to hold multiple cropped images for i in range(len(boxes)): if scores[i] >= 0.75: x1, y1, x2, y2 = boxes[i] if (int(labels[i])-1) == 1 or (int(labels[i])-1) == 0: color = (0, 0, 255) label_text = 'Flame stone surface' else: st.info("Không nhìn thấy bề mặt đá đốt") continue # Skip objects that aren't of interest # Crop the detected region from the original image original_h, original_w = original_size[::-1] scale_h, scale_w = original_h / target_size[0], original_w / target_size[1] x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h) x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h) cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig] # Resize the cropped image to a square while maintaining resolution resized_crop = resize_to_square(cropped_image) cropped_images.append(resized_crop) # Draw bounding boxes on the result image cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3) cv2.putText(result_image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) return Image.fromarray(result_image), cropped_images # Main app def main(): st.title('🖼️ Object Detection and Classification App') st.write("Upload an image for object detection and classification.") tab1, tab2 = st.tabs(["🖼️ OB and BC", "BC"]) with tab1: uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader_1") col1, col2 = st.columns(2) if uploaded_file is not None: image = Image.open(uploaded_file) image = convert_png_to_jpg(image) col1.image(image, caption='Uploaded Image') with st.spinner('Processing...'): # Perform object detection and get cropped images detection_result, cropped_images = perform_object_detection(image) col2.image(detection_result, caption='Object Detection Result') # If cropped images are detected, classify each if cropped_images is not None and len(cropped_images) > 0: st.subheader("Cropped Images and Classification Results") # Lặp qua tất cả các ảnh đã cắt for idx, cropped_image in enumerate(cropped_images): cropped_image_pil = Image.fromarray(cropped_image) classification_result = classify_image(cropped_image_pil) # Tạo hai cột cho mỗi ảnh đã cắt và kết quả phân loại của nó img_col, result_col = st.columns([1, 2]) with img_col: st.image(cropped_image_pil, caption=f'Cropped Image {idx + 1}', use_column_width=True) with result_col: if classification_result == 'pass': st.success(f"Classification: {classification_result.upper()}") else: st.error(f"Classification: {classification_result.upper()}") else: st.warning("No object detected with a confidence of 0.75 or higher.") with tab2: st.header('Image Classification') uploaded_file_2 = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader_2") if uploaded_file_2 is not None: image = convert_png_to_jpg(Image.open(uploaded_file_2)) col1, col2 = st.columns(2) with col1: st.image(image, caption='Uploaded Image', use_column_width=True) with col2: with st.spinner('Classifying...'): classification_result = classify_image(image) if classification_result == 'pass': st.success(f"Classification: {classification_result.upper()}") else: st.error(f"Classification: {classification_result.upper()}") # Sidebar and footer st.sidebar.header("About") st.sidebar.info( "This app performs both object detection and image classification. " "Upload an image to see the results!" ) st.markdown( """ """, unsafe_allow_html=True ) if __name__ == "__main__": main()