shaheerawan3's picture
Update app.py
9382e9a verified
import streamlit as st
from PIL import Image
import numpy as np
import cv2
import io
from model import AircraftDetector
from train import train_model
st.set_page_config(page_title="Aircraft Detection System", layout="wide")
st.title("Aircraft Detection from Aerial Imagery")
# Sidebar for model selection and training
with st.sidebar:
st.header("Model Options")
use_pretrained = st.checkbox("Use Pre-Trained Model", value=True)
if st.button("Train New Model"):
with st.spinner("Training model... This may take a while."):
success = train_model()
if success:
st.success("Training completed!")
else:
st.error("Training failed. Using pre-trained model.")
# Load model
@st.cache_resource
def load_model(pretrained):
return AircraftDetector(use_pretrained=pretrained)
detector = load_model(use_pretrained)
# File uploader
uploaded_file = st.file_uploader("Upload an aerial image", type=['jpg', 'jpeg', 'png'])
if uploaded_file:
image_bytes = uploaded_file.read()
image = Image.open(io.BytesIO(image_bytes))
image_np = np.array(image)
col1, col2 = st.columns(2)
with col1:
st.subheader("Original Image")
st.image(image, use_column_width=True)
detections = detector.model(image_np)[0].boxes.data.tolist()
result_img = image_np.copy()
for r in detections:
x1, y1, x2, y2, score, class_id = r
if score > 0.25:
label = f"Class {int(class_id)} ({score:.2f})"
cv2.rectangle(result_img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
cv2.putText(result_img, label, (int(x1), int(y1)-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
with col2:
st.subheader("Detection Results")
st.image(result_img, use_column_width=True)