File size: 5,327 Bytes
5a10009
 
 
 
 
 
 
 
 
 
 
 
9339350
5a10009
 
 
 
 
 
9339350
5a10009
 
 
 
6eea11b
5a10009
 
 
 
 
 
 
 
6eea11b
 
 
 
5a10009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295f539
5a10009
6eea11b
 
 
 
 
 
 
261bd14
6eea11b
 
 
 
 
 
 
 
261bd14
6eea11b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a10009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eea11b
 
 
 
 
 
 
 
5a10009
 
 
 
 
 
6eea11b
 
 
 
 
 
 
 
 
 
5a10009
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import streamlit as st
import torch
from PIL import Image
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import read_image_as_pil
import cv2
import tempfile
import os

# Set page config
st.set_page_config(page_title="AI Powered Ship Detection using SAR", layout="wide")

def load_model():
    """Load the YOLO model"""
    model = AutoDetectionModel.from_pretrained(
        model_type='yolov8',
        model_path='yolov10x.pt',
        confidence_threshold=0.5,
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    return model

def process_image(image_path, model, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio):
    """Process image using SAHI and YOLO"""
    # Read image
    image = read_image_as_pil(image_path)
    
    # Get predictions
    result = get_sliced_prediction(
        image,
        model,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio
    )
    
    # Convert PIL image to numpy array
    image_np = np.array(image)
    
    # Draw predictions
    for prediction in result.object_prediction_list:
        bbox = prediction.bbox
        category_name = prediction.category.name
        score = prediction.score.value
        
        # Draw rectangle
        cv2.rectangle(
            image_np,
            (int(bbox.minx), int(bbox.miny)),
            (int(bbox.maxx), int(bbox.maxy)),
            (0, 255, 0),
            2
        )
        
        # Draw label
        label = f"{category_name}: {score:.2f}"
        cv2.putText(
            image_np,
            label,
            (int(bbox.minx), int(bbox.miny) - 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (0, 255, 0),
            2
        )
    
    return image_np, result.object_prediction_list

def main():
    st.title("AI Powered Ship Detection using SAR")
    
    # Sidebar configuration
    st.sidebar.header("SAHI Configuration")
    
    # Slicing parameters
    st.sidebar.subheader("Slicing Parameters")
    slice_height = st.sidebar.slider(
        "Slice Height",
        min_value=100,
        max_value=1024,
        value=512,
        step=64,
        help="Height of each slice in pixels"
    )
    
    slice_width = st.sidebar.slider(
        "Slice Width",
        min_value=100,
        max_value=1024,
        value=512,
        step=64,
        help="Width of each slice in pixels"
    )
    
    # Overlap parameters
    st.sidebar.subheader("Overlap Parameters")
    overlap_height_ratio = st.sidebar.slider(
        "Height Overlap Ratio",
        min_value=0.1,
        max_value=0.9,
        value=0.5,
        step=0.1,
        help="Overlap ratio between consecutive slices in height"
    )
    
    overlap_width_ratio = st.sidebar.slider(
        "Width Overlap Ratio",
        min_value=0.1,
        max_value=0.9,
        value=0.5,
        step=0.1,
        help="Overlap ratio between consecutive slices in width"
    )
    
    # Add information about parameters
    st.sidebar.markdown("""
    ### Parameter Guide
    - **Slice Size**: Larger slices process more context at once but use more memory
    - **Overlap Ratio**: Higher overlap may catch objects at slice boundaries but increases processing time
    """)
    
    # Load model
    @st.cache_resource
    def get_model():
        return load_model()
    
    model = get_model()
    
    # File uploader
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    
    if uploaded_file is not None:
        # Create columns for before/after comparison
        col1, col2 = st.columns(2)
        
        # Save uploaded file temporarily
        with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
            tmp_file.write(uploaded_file.getvalue())
            tmp_path = tmp_file.name
        
        # Display original image
        original_image = Image.open(uploaded_file)
        col1.header("Original Image")
        col1.image(original_image, use_column_width=True)
        
        # Process image and display results
        with st.spinner('Processing image...'):
            processed_image, predictions = process_image(
                tmp_path,
                model,
                slice_height,
                slice_width,
                overlap_height_ratio,
                overlap_width_ratio
            )
            
            col2.header("Detected Objects")
            col2.image(processed_image, use_column_width=True)
            
            # Display detection results
            st.header("Detection Results")
            
            # Create a container for detection results
            with st.container():
                # Display total count
                total_ships = len(predictions)
                st.markdown(f"**Total Objects Detected:** {total_ships}")
                
                # Display individual detections
                for pred in predictions:
                    st.write(f"- Found {pred.category.name} with confidence {pred.score.value:.2f}")
        
        # Clean up temporary file
        os.unlink(tmp_path)

if __name__ == "__main__":
    main()