erfaneshrati's picture
truesat is up
7914bb9
#!/usr/bin/env python3
"""
Gradio app for TrueSat Detection using ultralytics YOLO
"""
import gradio as gr
import numpy as np
import cv2
import yaml
import logging
import os
from typing import List, Tuple
from pathlib import Path
from ultralytics import YOLO
from huggingface_hub import hf_hub_download, login
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TrueSatDetector:
def __init__(self, repo_id: str = "truthdotphd/truesat-detection"):
"""Initialize the TrueSat detector with ultralytics YOLO from Hugging Face Hub."""
self.repo_id = repo_id
self.model = None
self.class_names = []
self.onnx_model_path = None
# Setup HF authentication
self._setup_hf_auth()
# Load class information and download model
self._load_class_info()
self._download_model()
def _setup_hf_auth(self):
"""Setup Hugging Face authentication using HF_TOKEN environment variable."""
hf_token = os.getenv('HF_TOKEN')
if hf_token:
try:
login(token=hf_token)
logger.info("Successfully authenticated with Hugging Face Hub")
except Exception as e:
logger.warning(f"Failed to authenticate with HF Hub: {e}")
else:
logger.warning("HF_TOKEN not found in environment variables. Access to private repos may fail.")
def _download_model(self):
"""Download the ONNX model from Hugging Face Hub."""
try:
logger.info(f"Downloading model from {self.repo_id}...")
self.onnx_model_path = hf_hub_download(
repo_id=self.repo_id,
filename="model.onnx",
subfolder="1",
cache_dir="./hf_cache"
)
logger.info(f"Model downloaded to: {self.onnx_model_path}")
except Exception as e:
logger.error(f"Failed to download model from HF Hub: {e}")
raise RuntimeError(f"Could not download model from {self.repo_id}: {e}")
def _load_class_info(self):
"""Load class information from Hugging Face repository or use fallback."""
try:
# Try to download class information from HF repository
try:
logger.info("Attempting to download class information from HF repository...")
class_file_path = hf_hub_download(
repo_id=self.repo_id,
filename="class_names.yaml",
cache_dir="./hf_cache"
)
# Read class names from downloaded file
with open(class_file_path, 'r') as f:
class_config = yaml.safe_load(f)
self.class_names = class_config.get('names', [])
logger.info(f"Loaded {len(self.class_names)} classes from HF repository")
logger.info(f"Sample classes: {self.class_names[:5]}...")
return
except Exception as e:
logger.warning(f"Could not load class names from HF repository: {e}")
logger.info("Falling back to hardcoded class names...")
# Fallback to hardcoded class names
self._load_fallback_classes()
except Exception as e:
logger.error(f"Failed to load class info: {e}")
self._load_fallback_classes()
def _load_fallback_classes(self):
"""Load fallback class names if configuration files are not available."""
self.class_names = [
'Aircraft Hangar', 'Airplane', 'Airport', 'Barge', 'Baseball Diamond',
'Basketball Court', 'Bridge', 'Building', 'Bus', 'Cargo Truck',
'Cargo/Container Railcar', 'Cargo/Passenger Plane', 'Cement Mixer',
'Construction Site', 'Container Crane', 'Container Ship', 'Crane Truck',
'Damaged Building', 'Dump Truck', 'Engineering Vehicle', 'Excavator',
'Facility', 'Ferry', 'Fishing Vessel', 'Flat Railcar', 'Front Loader/Bulldozer',
'Ground Grader', 'Ground Track Field', 'Harbor', 'Haul Truck', 'Helicopter',
'Helipad', 'Hut/Tent', 'Large Vehicle', 'Locomotive', 'Mobile Crane',
'Motorboat', 'Oil Tanker', 'Passenger Railcar', 'Pylon', 'Railway Vehicle',
'Reach Stacker', 'Roundabout', 'Sailboat', 'Scraper/Tractor', 'Shed', 'Ship',
'Shipping Container', 'Shipping Container Lot', 'Small Vehicle', 'Soccer Field',
'Storage Tank', 'Straddle Carrier', 'Swimming Pool', 'Tank Railcar',
'Tennis Court', 'Tower', 'Tower Crane', 'Trailer', 'Truck', 'Truck Tractor',
'Truck Tractor with Box Trailer', 'Truck Tractor with Flatbed Trailer',
'Truck Tractor with Liquid Tank', 'Tugboat', 'Utility Truck', 'Vehicle',
'Vehicle Lot', 'Yacht'
]
logger.info(f"Using fallback class names: {len(self.class_names)} classes")
def load_model(self):
"""Load the YOLO ONNX model using ultralytics."""
try:
if not self.onnx_model_path or not Path(self.onnx_model_path).exists():
raise FileNotFoundError(f"ONNX model not found: {self.onnx_model_path}")
# Load YOLO model from ONNX file
self.model = YOLO(self.onnx_model_path)
logger.info(f"Successfully loaded YOLO model from: {self.onnx_model_path}")
# Override the model's class names with our custom ones
if hasattr(self.model.model, 'names'):
self.model.model.names = {i: name for i, name in enumerate(self.class_names)}
return True
except Exception as e:
logger.error(f"Failed to load YOLO model: {e}")
return False
def detect(self, image: np.ndarray, conf_threshold: float = 0.25) -> Tuple[np.ndarray, np.ndarray, List[str]]:
"""Run detection on an image using ultralytics YOLO."""
if self.model is None:
if not self.load_model():
raise RuntimeError("Failed to load YOLO model")
try:
# Run YOLO inference - ultralytics handles all preprocessing/postprocessing
results = self.model.predict(
source=image,
conf=conf_threshold,
verbose=False,
save=False,
show=False
)
# Extract results from the first (and only) image
result = results[0]
if result.boxes is None or len(result.boxes) == 0:
# No detections
return np.array([]).reshape(0, 4), np.array([]), []
# Extract bounding boxes, confidence scores, and class IDs
boxes = result.boxes.xyxy.cpu().numpy() # [x1, y1, x2, y2] format
scores = result.boxes.conf.cpu().numpy() # confidence scores
class_ids = result.boxes.cls.cpu().numpy().astype(int) # class IDs
# Convert class IDs to class names
class_names = [self.class_names[class_id] if class_id < len(self.class_names)
else f"Unknown_{class_id}" for class_id in class_ids]
logger.info(f"Found {len(boxes)} detections")
if len(boxes) > 0:
logger.info(f"Score range: {scores.min():.3f} - {scores.max():.3f}")
logger.info(f"Classes detected: {set(class_names)}")
return boxes, scores, class_names
except Exception as e:
logger.error(f"Detection failed: {e}")
raise
def draw_detections(image: np.ndarray, boxes: np.ndarray, scores: np.ndarray,
classes: List[str]) -> np.ndarray:
"""Draw bounding boxes and labels on image."""
if len(boxes) == 0:
return image
# Create a copy of the image
annotated = image.copy()
# Generate colors for different classes
unique_classes = list(set(classes))
colors = np.random.randint(0, 255, size=(len(unique_classes), 3), dtype=np.uint8)
class_colors = {cls: colors[i] for i, cls in enumerate(unique_classes)}
for box, score, cls in zip(boxes, scores, classes):
x1, y1, x2, y2 = box.astype(int)
# Get color for this class
color = class_colors[cls]
color_bgr = (int(color[2]), int(color[1]), int(color[0])) # RGB to BGR for cv2
# Draw bounding box
cv2.rectangle(annotated, (x1, y1), (x2, y2), color_bgr, 2)
# Draw label
label = f"{cls}: {score:.2f}"
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
# Draw label background
cv2.rectangle(annotated, (x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1), color_bgr, -1)
# Draw label text
cv2.putText(annotated, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
return annotated
# Initialize detector
detector = TrueSatDetector()
def detect_objects(image, conf_threshold):
"""Main detection function for Gradio interface."""
try:
# Run detection using ultralytics
boxes, scores, classes = detector.detect(image, conf_threshold)
# Draw results
annotated_image = draw_detections(image, boxes, scores, classes)
# Log results
logger.info(f"Found {len(boxes)} detections")
if len(boxes) > 0:
logger.info(f"Classes detected: {set(classes)}")
return annotated_image
except Exception as e:
logger.error(f"Detection failed: {e}")
# Return original image with error message
error_image = image.copy()
cv2.putText(error_image, f"Error: {str(e)}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
return error_image
# Create Gradio interface
demo = gr.Interface(
fn=detect_objects,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.25, step=0.05,
label="Confidence Threshold",
info="Minimum confidence score for detections")
],
outputs=gr.Image(label="Detection Results"),
title="🛰️ TrueSat Satellite Object Detection",
description="""
Upload a satellite image to detect various objects including:
- Vessels (ships, boats, barges)
- Aircraft (planes, helicopters)
- Vehicles (trucks, cars)
- Infrastructure (buildings, bridges, airports)
- And 60+ other object classes
**Note:** Uses ultralytics YOLO for accurate detection results.
""",
article="""
### How to use:
1. Upload a satellite or aerial image
2. Adjust the confidence threshold to filter detections
3. Click Submit to run detection
### Technical Details:
- Model: YOLO11x trained on satellite imagery
- Classes: 69 object categories optimized for satellite/aerial imagery
- Backend: Ultralytics YOLO with ONNX inference
- Features: Automatic NMS, proper preprocessing, accurate confidence scores
""",
theme=gr.themes.Soft(),
examples=None # Add examples if you have sample images
)
if __name__ == "__main__":
logger.info("Starting TrueSat Detection App...")
logger.info("Loading YOLO model...")
if detector.load_model():
logger.info("✅ Successfully loaded YOLO model")
logger.info(f"✅ Loaded {len(detector.class_names)} object classes")
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
else:
logger.error("❌ Failed to load YOLO model")
logger.error("Please make sure:")
logger.error("1. HF_TOKEN environment variable is set for private repo access")
logger.error("2. The 'truthdotphd/truesat-detection' repository is accessible")
logger.error("3. Ultralytics and huggingface_hub are properly installed")
logger.error("4. You have sufficient memory/GPU resources")
# Launch anyway but with warning
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)