import streamlit as st from PIL import Image import numpy as np import cv2 import io from model import AircraftDetector from train import train_model st.set_page_config(page_title="Aircraft Detection System", layout="wide") st.title("Aircraft Detection from Aerial Imagery") # Sidebar for model selection and training with st.sidebar: st.header("Model Options") use_pretrained = st.checkbox("Use Pre-Trained Model", value=True) if st.button("Train New Model"): with st.spinner("Training model... This may take a while."): success = train_model() if success: st.success("Training completed!") else: st.error("Training failed. Using pre-trained model.") # Load model @st.cache_resource def load_model(pretrained): return AircraftDetector(use_pretrained=pretrained) detector = load_model(use_pretrained) # File uploader uploaded_file = st.file_uploader("Upload an aerial image", type=['jpg', 'jpeg', 'png']) if uploaded_file: image_bytes = uploaded_file.read() image = Image.open(io.BytesIO(image_bytes)) image_np = np.array(image) col1, col2 = st.columns(2) with col1: st.subheader("Original Image") st.image(image, use_column_width=True) detections = detector.model(image_np)[0].boxes.data.tolist() result_img = image_np.copy() for r in detections: x1, y1, x2, y2, score, class_id = r if score > 0.25: label = f"Class {int(class_id)} ({score:.2f})" cv2.rectangle(result_img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) cv2.putText(result_img, label, (int(x1), int(y1)-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) with col2: st.subheader("Detection Results") st.image(result_img, use_column_width=True)