SAR_detection / app.py
hd-hg's picture
Update app.py
7532185 verified
import gradio as gr
from PIL import Image
import numpy as np
import supervision as sv
from rfdetr import RFDETRMedium
# ----------------------------
# Model Initialization
# ----------------------------
model = RFDETRMedium(pretrain_weights="checkpoint_best_total.pth")
CLASSES = ["A220", "A320-321", "A330", "ARJ21", "Boeing737", "Boeing787", "Other"]
def predict(image_array, threshold):
if image_array is None:
return None, "SYSTEM: No Input Detected"
# Robust Conversion to 8-bit RGB
if image_array.dtype != np.uint8:
image_array = (image_array / image_array.max() * 255).astype(np.uint8)
image_pil = Image.fromarray(image_array).convert("RGB")
detections = model.predict(image_pil, threshold=threshold)
# REFINEMENT: Smaller, fixed-size annotations for high-res imagery
# Instead of fully dynamic, we cap them for a "cleaner" look
thickness = max(1, int(sv.calculate_optimal_line_thickness(resolution_wh=image_pil.size) * 0.5))
text_scale = max(0.3, sv.calculate_optimal_text_scale(resolution_wh=image_pil.size) * 0.4)
# Cyber Blue & Steel Palette (Less "Orangy")
color = sv.ColorPalette.from_hex(["#00E5FF", "#00B0FF", "#2979FF"])
bbox_annotator = sv.BoxAnnotator(color=color, thickness=thickness)
label_annotator = sv.LabelAnnotator(
color=color,
text_color=sv.Color.BLACK,
text_scale=text_scale,
text_padding=2
)
labels = [
f"{CLASSES[class_id] if class_id < len(CLASSES) else 'Obj'} {confidence:.2f}"
for class_id, confidence in zip(detections.class_id, detections.confidence)
]
annotated_image = image_pil.copy()
annotated_image = bbox_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections, labels=labels)
return annotated_image, f"LOG: {len(detections)} OBJECTS IDENTIFIED"
# ----------------------------
# Enhanced Stealth UI CSS
# ----------------------------
custom_css = """
/* Deep Black Background */
.gradio-container { background-color: #05070a !important; }
body { background-color: #05070a !important; color: #e0e0e0 !important; }
/* Cyber Blue Slider */
input[type='range'] { accent-color: #00E5FF !important; }
/* Large "Clicky" Button */
.gr-button-primary {
background: #00E5FF !important;
color: #05070a !important;
font-weight: 800 !important;
font-size: 1.1em !important;
padding: 15px !important;
border: none !important;
transition: transform 0.1s ease, background 0.3s ease !important;
box-shadow: 0 4px 15px rgba(0, 229, 255, 0.2) !important;
}
/* Button Click Animation */
.gr-button-primary:active {
transform: scale(0.96) !important;
background: #00B0FF !important;
}
/* Text and Labels */
.block .label-wrap { color: #00E5FF !important; font-family: monospace; }
"""
with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo:
gr.Markdown("# 🛰️ **SAR** TARGET ACQUISITION SYSTEM")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="numpy", label="RAW SAR FEED")
threshold_slider = gr.Slider(0.01, 1.0, value=0.5, label="DETECTION THRESHOLD")
run_button = gr.Button("RUN SYSTEM ANALYSIS", variant="primary")
status_text = gr.Text(label="ANALYSIS LOG", interactive=False)
with gr.Column(scale=1):
image_output = gr.Image(label="PROCESSED OUTPUT (SATELLITE VIEW)", interactive=False)
run_button.click(predict, [image_input, threshold_slider], [image_output, status_text])
demo.launch()