stoneclass / stoneapp.py
SonFox2920's picture
Update stoneapp.py
3e59cdd verified
import streamlit as st
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from tensorflow.keras.models import load_model
from PIL import Image
import io
# Set up Streamlit page
st.set_page_config(page_title="Object Detection and Classification App", page_icon="🖼️", layout="wide")
# Load models
@st.cache_resource
def load_models():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
object_detection_model = torch.load("fasterrcnn_resnet50_fpn_270824.pth", map_location=device)
object_detection_model.to(device)
object_detection_model.eval()
classification_model = load_model('resnet50_6000_2.h5')
return object_detection_model, classification_model, device
object_detection_model, classification_model, device = load_models()
# Helper functions
def preprocess_image(image, target_size=(256, 256)):
img = image.resize(target_size)
img_array = np.array(img).astype('float32') / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
def classify_image(image):
processed_image = preprocess_image(image)
prediction = classification_model.predict(processed_image)
predicted_class = np.argmax(prediction, axis=1)[0]
class_labels = ['fail', 'pass']
return class_labels[predicted_class]
def convert_png_to_jpg(image):
if image.format == 'PNG':
rgb_im = image.convert('RGB')
img_byte_arr = io.BytesIO()
rgb_im.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
return Image.open(io.BytesIO(img_byte_arr))
return image
def resize_to_square(image):
h, w = image.shape[:2]
# Determine the shorter side
shorter_side = min(h, w)
# Crop to create a square
if h > w: # portrait image
start = (h - w) // 2
cropped = image[start:start+w, :]
else: # landscape or square image
start = (w - h) // 2
cropped = image[:, start:start+h]
return cropped
def perform_object_detection(image):
original_size = image.size
target_size = (256, 256)
frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA)
frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32)
frame_rgb /= 255.0
frame_rgb = frame_rgb.transpose(2, 0, 1)
frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device)
with torch.no_grad():
outputs = object_detection_model(frame_rgb)
boxes = outputs[0]['boxes'].cpu().detach().numpy().astype(np.int32)
labels = outputs[0]['labels'].cpu().detach().numpy().astype(np.int32)
scores = outputs[0]['scores']
result_image = frame_resized.copy()
cropped_images = [] # List to hold multiple cropped images
for i in range(len(boxes)):
if scores[i] >= 0.75:
x1, y1, x2, y2 = boxes[i]
if (int(labels[i])-1) == 1 or (int(labels[i])-1) == 0:
color = (0, 0, 255)
label_text = 'Flame stone surface'
else:
st.info("Không nhìn thấy bề mặt đá đốt")
continue # Skip objects that aren't of interest
# Crop the detected region from the original image
original_h, original_w = original_size[::-1]
scale_h, scale_w = original_h / target_size[0], original_w / target_size[1]
x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h)
x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h)
cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig]
# Resize the cropped image to a square while maintaining resolution
resized_crop = resize_to_square(cropped_image)
cropped_images.append(resized_crop)
# Draw bounding boxes on the result image
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3)
cv2.putText(result_image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return Image.fromarray(result_image), cropped_images
# Main app
def main():
st.title('🖼️ Object Detection and Classification App')
st.write("Upload an image for object detection and classification.")
tab1, tab2 = st.tabs(["🖼️ OB and BC", "BC"])
with tab1:
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader_1")
col1, col2 = st.columns(2)
if uploaded_file is not None:
image = Image.open(uploaded_file)
image = convert_png_to_jpg(image)
col1.image(image, caption='Uploaded Image')
with st.spinner('Processing...'):
# Perform object detection and get cropped images
detection_result, cropped_images = perform_object_detection(image)
col2.image(detection_result, caption='Object Detection Result')
# If cropped images are detected, classify each
if cropped_images is not None and len(cropped_images) > 0:
st.subheader("Cropped Images and Classification Results")
# Lặp qua tất cả các ảnh đã cắt
for idx, cropped_image in enumerate(cropped_images):
cropped_image_pil = Image.fromarray(cropped_image)
classification_result = classify_image(cropped_image_pil)
# Tạo hai cột cho mỗi ảnh đã cắt và kết quả phân loại của nó
img_col, result_col = st.columns([1, 2])
with img_col:
st.image(cropped_image_pil, caption=f'Cropped Image {idx + 1}', use_column_width=True)
with result_col:
if classification_result == 'pass':
st.success(f"Classification: {classification_result.upper()}")
else:
st.error(f"Classification: {classification_result.upper()}")
else:
st.warning("No object detected with a confidence of 0.75 or higher.")
with tab2:
st.header('Image Classification')
uploaded_file_2 = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader_2")
if uploaded_file_2 is not None:
image = convert_png_to_jpg(Image.open(uploaded_file_2))
col1, col2 = st.columns(2)
with col1:
st.image(image, caption='Uploaded Image', use_column_width=True)
with col2:
with st.spinner('Classifying...'):
classification_result = classify_image(image)
if classification_result == 'pass':
st.success(f"Classification: {classification_result.upper()}")
else:
st.error(f"Classification: {classification_result.upper()}")
# Sidebar and footer
st.sidebar.header("About")
st.sidebar.info(
"This app performs both object detection and image classification. "
"Upload an image to see the results!"
)
st.markdown(
"""
<style>
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: #0E1117;
color: #FAFAFA;
text-align: center;
padding: 10px;
font-size: 12px;
}
</style>
<div class="footer">
Developed by Tran Thanh Son | © 2024 Object Detection and Classification App
</div>
""",
unsafe_allow_html=True
)
if __name__ == "__main__":
main()