Spaces:
Sleeping
Sleeping
File size: 5,327 Bytes
5a10009 9339350 5a10009 9339350 5a10009 6eea11b 5a10009 6eea11b 5a10009 295f539 5a10009 6eea11b 261bd14 6eea11b 261bd14 6eea11b 5a10009 6eea11b 5a10009 6eea11b 5a10009 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import streamlit as st
import torch
from PIL import Image
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import read_image_as_pil
import cv2
import tempfile
import os
# Set page config
st.set_page_config(page_title="AI Powered Ship Detection using SAR", layout="wide")
def load_model():
"""Load the YOLO model"""
model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path='yolov10x.pt',
confidence_threshold=0.5,
device="cuda" if torch.cuda.is_available() else "cpu"
)
return model
def process_image(image_path, model, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio):
"""Process image using SAHI and YOLO"""
# Read image
image = read_image_as_pil(image_path)
# Get predictions
result = get_sliced_prediction(
image,
model,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio
)
# Convert PIL image to numpy array
image_np = np.array(image)
# Draw predictions
for prediction in result.object_prediction_list:
bbox = prediction.bbox
category_name = prediction.category.name
score = prediction.score.value
# Draw rectangle
cv2.rectangle(
image_np,
(int(bbox.minx), int(bbox.miny)),
(int(bbox.maxx), int(bbox.maxy)),
(0, 255, 0),
2
)
# Draw label
label = f"{category_name}: {score:.2f}"
cv2.putText(
image_np,
label,
(int(bbox.minx), int(bbox.miny) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0),
2
)
return image_np, result.object_prediction_list
def main():
st.title("AI Powered Ship Detection using SAR")
# Sidebar configuration
st.sidebar.header("SAHI Configuration")
# Slicing parameters
st.sidebar.subheader("Slicing Parameters")
slice_height = st.sidebar.slider(
"Slice Height",
min_value=100,
max_value=1024,
value=512,
step=64,
help="Height of each slice in pixels"
)
slice_width = st.sidebar.slider(
"Slice Width",
min_value=100,
max_value=1024,
value=512,
step=64,
help="Width of each slice in pixels"
)
# Overlap parameters
st.sidebar.subheader("Overlap Parameters")
overlap_height_ratio = st.sidebar.slider(
"Height Overlap Ratio",
min_value=0.1,
max_value=0.9,
value=0.5,
step=0.1,
help="Overlap ratio between consecutive slices in height"
)
overlap_width_ratio = st.sidebar.slider(
"Width Overlap Ratio",
min_value=0.1,
max_value=0.9,
value=0.5,
step=0.1,
help="Overlap ratio between consecutive slices in width"
)
# Add information about parameters
st.sidebar.markdown("""
### Parameter Guide
- **Slice Size**: Larger slices process more context at once but use more memory
- **Overlap Ratio**: Higher overlap may catch objects at slice boundaries but increases processing time
""")
# Load model
@st.cache_resource
def get_model():
return load_model()
model = get_model()
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Create columns for before/after comparison
col1, col2 = st.columns(2)
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_path = tmp_file.name
# Display original image
original_image = Image.open(uploaded_file)
col1.header("Original Image")
col1.image(original_image, use_column_width=True)
# Process image and display results
with st.spinner('Processing image...'):
processed_image, predictions = process_image(
tmp_path,
model,
slice_height,
slice_width,
overlap_height_ratio,
overlap_width_ratio
)
col2.header("Detected Objects")
col2.image(processed_image, use_column_width=True)
# Display detection results
st.header("Detection Results")
# Create a container for detection results
with st.container():
# Display total count
total_ships = len(predictions)
st.markdown(f"**Total Objects Detected:** {total_ships}")
# Display individual detections
for pred in predictions:
st.write(f"- Found {pred.category.name} with confidence {pred.score.value:.2f}")
# Clean up temporary file
os.unlink(tmp_path)
if __name__ == "__main__":
main() |