File size: 1,868 Bytes
a21b7eb
 
9382e9a
 
a21b7eb
9382e9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a21b7eb
9382e9a
 
a21b7eb
9382e9a
 
 
 
6dc5f6d
9382e9a
6dc5f6d
9382e9a
 
 
6dc5f6d
9382e9a
a21b7eb
9382e9a
 
 
 
 
 
 
 
8eb7f25
9382e9a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import streamlit as st
from PIL import Image
import numpy as np
import cv2
import io
from model import AircraftDetector
from train import train_model

st.set_page_config(page_title="Aircraft Detection System", layout="wide")
st.title("Aircraft Detection from Aerial Imagery")

# Sidebar for model selection and training
with st.sidebar:
    st.header("Model Options")
    use_pretrained = st.checkbox("Use Pre-Trained Model", value=True)

    if st.button("Train New Model"):
        with st.spinner("Training model... This may take a while."):
            success = train_model()
            if success:
                st.success("Training completed!")
            else:
                st.error("Training failed. Using pre-trained model.")

# Load model
@st.cache_resource
def load_model(pretrained):
    return AircraftDetector(use_pretrained=pretrained)

detector = load_model(use_pretrained)

# File uploader
uploaded_file = st.file_uploader("Upload an aerial image", type=['jpg', 'jpeg', 'png'])

if uploaded_file:
    image_bytes = uploaded_file.read()
    image = Image.open(io.BytesIO(image_bytes))
    image_np = np.array(image)

    col1, col2 = st.columns(2)

    with col1:
        st.subheader("Original Image")
        st.image(image, use_column_width=True)

    detections = detector.model(image_np)[0].boxes.data.tolist()

    result_img = image_np.copy()
    for r in detections:
        x1, y1, x2, y2, score, class_id = r
        if score > 0.25:
            label = f"Class {int(class_id)} ({score:.2f})"
            cv2.rectangle(result_img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
            cv2.putText(result_img, label, (int(x1), int(y1)-10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    with col2:
        st.subheader("Detection Results")
        st.image(result_img, use_column_width=True)