Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from sahi import AutoDetectionModel | |
| from sahi.predict import get_sliced_prediction | |
| from sahi.utils.cv import read_image_as_pil | |
| import cv2 | |
| import tempfile | |
| import os | |
| # Set page config | |
| st.set_page_config(page_title="AI Powered Ship Detection using SAR", layout="wide") | |
| def load_model(): | |
| """Load the YOLO model""" | |
| model = AutoDetectionModel.from_pretrained( | |
| model_type='yolov8', | |
| model_path='yolov10x.pt', | |
| confidence_threshold=0.5, | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| return model | |
| def process_image(image_path, model, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio): | |
| """Process image using SAHI and YOLO""" | |
| # Read image | |
| image = read_image_as_pil(image_path) | |
| # Get predictions | |
| result = get_sliced_prediction( | |
| image, | |
| model, | |
| slice_height=slice_height, | |
| slice_width=slice_width, | |
| overlap_height_ratio=overlap_height_ratio, | |
| overlap_width_ratio=overlap_width_ratio | |
| ) | |
| # Convert PIL image to numpy array | |
| image_np = np.array(image) | |
| # Draw predictions | |
| for prediction in result.object_prediction_list: | |
| bbox = prediction.bbox | |
| category_name = prediction.category.name | |
| score = prediction.score.value | |
| # Draw rectangle | |
| cv2.rectangle( | |
| image_np, | |
| (int(bbox.minx), int(bbox.miny)), | |
| (int(bbox.maxx), int(bbox.maxy)), | |
| (0, 255, 0), | |
| 2 | |
| ) | |
| # Draw label | |
| label = f"{category_name}: {score:.2f}" | |
| cv2.putText( | |
| image_np, | |
| label, | |
| (int(bbox.minx), int(bbox.miny) - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (0, 255, 0), | |
| 2 | |
| ) | |
| return image_np, result.object_prediction_list | |
| def main(): | |
| st.title("AI Powered Ship Detection using SAR") | |
| # Sidebar configuration | |
| st.sidebar.header("SAHI Configuration") | |
| # Slicing parameters | |
| st.sidebar.subheader("Slicing Parameters") | |
| slice_height = st.sidebar.slider( | |
| "Slice Height", | |
| min_value=100, | |
| max_value=1024, | |
| value=512, | |
| step=64, | |
| help="Height of each slice in pixels" | |
| ) | |
| slice_width = st.sidebar.slider( | |
| "Slice Width", | |
| min_value=100, | |
| max_value=1024, | |
| value=512, | |
| step=64, | |
| help="Width of each slice in pixels" | |
| ) | |
| # Overlap parameters | |
| st.sidebar.subheader("Overlap Parameters") | |
| overlap_height_ratio = st.sidebar.slider( | |
| "Height Overlap Ratio", | |
| min_value=0.1, | |
| max_value=0.9, | |
| value=0.5, | |
| step=0.1, | |
| help="Overlap ratio between consecutive slices in height" | |
| ) | |
| overlap_width_ratio = st.sidebar.slider( | |
| "Width Overlap Ratio", | |
| min_value=0.1, | |
| max_value=0.9, | |
| value=0.5, | |
| step=0.1, | |
| help="Overlap ratio between consecutive slices in width" | |
| ) | |
| # Add information about parameters | |
| st.sidebar.markdown(""" | |
| ### Parameter Guide | |
| - **Slice Size**: Larger slices process more context at once but use more memory | |
| - **Overlap Ratio**: Higher overlap may catch objects at slice boundaries but increases processing time | |
| """) | |
| # Load model | |
| def get_model(): | |
| return load_model() | |
| model = get_model() | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Create columns for before/after comparison | |
| col1, col2 = st.columns(2) | |
| # Save uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| tmp_path = tmp_file.name | |
| # Display original image | |
| original_image = Image.open(uploaded_file) | |
| col1.header("Original Image") | |
| col1.image(original_image, use_column_width=True) | |
| # Process image and display results | |
| with st.spinner('Processing image...'): | |
| processed_image, predictions = process_image( | |
| tmp_path, | |
| model, | |
| slice_height, | |
| slice_width, | |
| overlap_height_ratio, | |
| overlap_width_ratio | |
| ) | |
| col2.header("Detected Objects") | |
| col2.image(processed_image, use_column_width=True) | |
| # Display detection results | |
| st.header("Detection Results") | |
| # Create a container for detection results | |
| with st.container(): | |
| # Display total count | |
| total_ships = len(predictions) | |
| st.markdown(f"**Total Objects Detected:** {total_ships}") | |
| # Display individual detections | |
| for pred in predictions: | |
| st.write(f"- Found {pred.category.name} with confidence {pred.score.value:.2f}") | |
| # Clean up temporary file | |
| os.unlink(tmp_path) | |
| if __name__ == "__main__": | |
| main() |