#!/usr/bin/env python3 """ Gradio app for TrueSat Detection using ultralytics YOLO """ import gradio as gr import numpy as np import cv2 import yaml import logging import os from typing import List, Tuple from pathlib import Path from ultralytics import YOLO from huggingface_hub import hf_hub_download, login # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TrueSatDetector: def __init__(self, repo_id: str = "truthdotphd/truesat-detection"): """Initialize the TrueSat detector with ultralytics YOLO from Hugging Face Hub.""" self.repo_id = repo_id self.model = None self.class_names = [] self.onnx_model_path = None # Setup HF authentication self._setup_hf_auth() # Load class information and download model self._load_class_info() self._download_model() def _setup_hf_auth(self): """Setup Hugging Face authentication using HF_TOKEN environment variable.""" hf_token = os.getenv('HF_TOKEN') if hf_token: try: login(token=hf_token) logger.info("Successfully authenticated with Hugging Face Hub") except Exception as e: logger.warning(f"Failed to authenticate with HF Hub: {e}") else: logger.warning("HF_TOKEN not found in environment variables. Access to private repos may fail.") def _download_model(self): """Download the ONNX model from Hugging Face Hub.""" try: logger.info(f"Downloading model from {self.repo_id}...") self.onnx_model_path = hf_hub_download( repo_id=self.repo_id, filename="model.onnx", subfolder="1", cache_dir="./hf_cache" ) logger.info(f"Model downloaded to: {self.onnx_model_path}") except Exception as e: logger.error(f"Failed to download model from HF Hub: {e}") raise RuntimeError(f"Could not download model from {self.repo_id}: {e}") def _load_class_info(self): """Load class information from Hugging Face repository or use fallback.""" try: # Try to download class information from HF repository try: logger.info("Attempting to download class information from HF repository...") class_file_path = hf_hub_download( repo_id=self.repo_id, filename="class_names.yaml", cache_dir="./hf_cache" ) # Read class names from downloaded file with open(class_file_path, 'r') as f: class_config = yaml.safe_load(f) self.class_names = class_config.get('names', []) logger.info(f"Loaded {len(self.class_names)} classes from HF repository") logger.info(f"Sample classes: {self.class_names[:5]}...") return except Exception as e: logger.warning(f"Could not load class names from HF repository: {e}") logger.info("Falling back to hardcoded class names...") # Fallback to hardcoded class names self._load_fallback_classes() except Exception as e: logger.error(f"Failed to load class info: {e}") self._load_fallback_classes() def _load_fallback_classes(self): """Load fallback class names if configuration files are not available.""" self.class_names = [ 'Aircraft Hangar', 'Airplane', 'Airport', 'Barge', 'Baseball Diamond', 'Basketball Court', 'Bridge', 'Building', 'Bus', 'Cargo Truck', 'Cargo/Container Railcar', 'Cargo/Passenger Plane', 'Cement Mixer', 'Construction Site', 'Container Crane', 'Container Ship', 'Crane Truck', 'Damaged Building', 'Dump Truck', 'Engineering Vehicle', 'Excavator', 'Facility', 'Ferry', 'Fishing Vessel', 'Flat Railcar', 'Front Loader/Bulldozer', 'Ground Grader', 'Ground Track Field', 'Harbor', 'Haul Truck', 'Helicopter', 'Helipad', 'Hut/Tent', 'Large Vehicle', 'Locomotive', 'Mobile Crane', 'Motorboat', 'Oil Tanker', 'Passenger Railcar', 'Pylon', 'Railway Vehicle', 'Reach Stacker', 'Roundabout', 'Sailboat', 'Scraper/Tractor', 'Shed', 'Ship', 'Shipping Container', 'Shipping Container Lot', 'Small Vehicle', 'Soccer Field', 'Storage Tank', 'Straddle Carrier', 'Swimming Pool', 'Tank Railcar', 'Tennis Court', 'Tower', 'Tower Crane', 'Trailer', 'Truck', 'Truck Tractor', 'Truck Tractor with Box Trailer', 'Truck Tractor with Flatbed Trailer', 'Truck Tractor with Liquid Tank', 'Tugboat', 'Utility Truck', 'Vehicle', 'Vehicle Lot', 'Yacht' ] logger.info(f"Using fallback class names: {len(self.class_names)} classes") def load_model(self): """Load the YOLO ONNX model using ultralytics.""" try: if not self.onnx_model_path or not Path(self.onnx_model_path).exists(): raise FileNotFoundError(f"ONNX model not found: {self.onnx_model_path}") # Load YOLO model from ONNX file self.model = YOLO(self.onnx_model_path) logger.info(f"Successfully loaded YOLO model from: {self.onnx_model_path}") # Override the model's class names with our custom ones if hasattr(self.model.model, 'names'): self.model.model.names = {i: name for i, name in enumerate(self.class_names)} return True except Exception as e: logger.error(f"Failed to load YOLO model: {e}") return False def detect(self, image: np.ndarray, conf_threshold: float = 0.25) -> Tuple[np.ndarray, np.ndarray, List[str]]: """Run detection on an image using ultralytics YOLO.""" if self.model is None: if not self.load_model(): raise RuntimeError("Failed to load YOLO model") try: # Run YOLO inference - ultralytics handles all preprocessing/postprocessing results = self.model.predict( source=image, conf=conf_threshold, verbose=False, save=False, show=False ) # Extract results from the first (and only) image result = results[0] if result.boxes is None or len(result.boxes) == 0: # No detections return np.array([]).reshape(0, 4), np.array([]), [] # Extract bounding boxes, confidence scores, and class IDs boxes = result.boxes.xyxy.cpu().numpy() # [x1, y1, x2, y2] format scores = result.boxes.conf.cpu().numpy() # confidence scores class_ids = result.boxes.cls.cpu().numpy().astype(int) # class IDs # Convert class IDs to class names class_names = [self.class_names[class_id] if class_id < len(self.class_names) else f"Unknown_{class_id}" for class_id in class_ids] logger.info(f"Found {len(boxes)} detections") if len(boxes) > 0: logger.info(f"Score range: {scores.min():.3f} - {scores.max():.3f}") logger.info(f"Classes detected: {set(class_names)}") return boxes, scores, class_names except Exception as e: logger.error(f"Detection failed: {e}") raise def draw_detections(image: np.ndarray, boxes: np.ndarray, scores: np.ndarray, classes: List[str]) -> np.ndarray: """Draw bounding boxes and labels on image.""" if len(boxes) == 0: return image # Create a copy of the image annotated = image.copy() # Generate colors for different classes unique_classes = list(set(classes)) colors = np.random.randint(0, 255, size=(len(unique_classes), 3), dtype=np.uint8) class_colors = {cls: colors[i] for i, cls in enumerate(unique_classes)} for box, score, cls in zip(boxes, scores, classes): x1, y1, x2, y2 = box.astype(int) # Get color for this class color = class_colors[cls] color_bgr = (int(color[2]), int(color[1]), int(color[0])) # RGB to BGR for cv2 # Draw bounding box cv2.rectangle(annotated, (x1, y1), (x2, y2), color_bgr, 2) # Draw label label = f"{cls}: {score:.2f}" label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] # Draw label background cv2.rectangle(annotated, (x1, y1 - label_size[1] - 10), (x1 + label_size[0], y1), color_bgr, -1) # Draw label text cv2.putText(annotated, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) return annotated # Initialize detector detector = TrueSatDetector() def detect_objects(image, conf_threshold): """Main detection function for Gradio interface.""" try: # Run detection using ultralytics boxes, scores, classes = detector.detect(image, conf_threshold) # Draw results annotated_image = draw_detections(image, boxes, scores, classes) # Log results logger.info(f"Found {len(boxes)} detections") if len(boxes) > 0: logger.info(f"Classes detected: {set(classes)}") return annotated_image except Exception as e: logger.error(f"Detection failed: {e}") # Return original image with error message error_image = image.copy() cv2.putText(error_image, f"Error: {str(e)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) return error_image # Create Gradio interface demo = gr.Interface( fn=detect_objects, inputs=[ gr.Image(type="numpy", label="Upload Image"), gr.Slider(minimum=0.1, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold", info="Minimum confidence score for detections") ], outputs=gr.Image(label="Detection Results"), title="🛰️ TrueSat Satellite Object Detection", description=""" Upload a satellite image to detect various objects including: - Vessels (ships, boats, barges) - Aircraft (planes, helicopters) - Vehicles (trucks, cars) - Infrastructure (buildings, bridges, airports) - And 60+ other object classes **Note:** Uses ultralytics YOLO for accurate detection results. """, article=""" ### How to use: 1. Upload a satellite or aerial image 2. Adjust the confidence threshold to filter detections 3. Click Submit to run detection ### Technical Details: - Model: YOLO11x trained on satellite imagery - Classes: 69 object categories optimized for satellite/aerial imagery - Backend: Ultralytics YOLO with ONNX inference - Features: Automatic NMS, proper preprocessing, accurate confidence scores """, theme=gr.themes.Soft(), examples=None # Add examples if you have sample images ) if __name__ == "__main__": logger.info("Starting TrueSat Detection App...") logger.info("Loading YOLO model...") if detector.load_model(): logger.info("✅ Successfully loaded YOLO model") logger.info(f"✅ Loaded {len(detector.class_names)} object classes") demo.launch(server_name="0.0.0.0", server_port=7860, share=False) else: logger.error("❌ Failed to load YOLO model") logger.error("Please make sure:") logger.error("1. HF_TOKEN environment variable is set for private repo access") logger.error("2. The 'truthdotphd/truesat-detection' repository is accessible") logger.error("3. Ultralytics and huggingface_hub are properly installed") logger.error("4. You have sufficient memory/GPU resources") # Launch anyway but with warning demo.launch(server_name="0.0.0.0", server_port=7860, share=False)