Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Gradio app for TrueSat Detection using ultralytics YOLO | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import yaml | |
| import logging | |
| import os | |
| from typing import List, Tuple | |
| from pathlib import Path | |
| from ultralytics import YOLO | |
| from huggingface_hub import hf_hub_download, login | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class TrueSatDetector: | |
| def __init__(self, repo_id: str = "truthdotphd/truesat-detection"): | |
| """Initialize the TrueSat detector with ultralytics YOLO from Hugging Face Hub.""" | |
| self.repo_id = repo_id | |
| self.model = None | |
| self.class_names = [] | |
| self.onnx_model_path = None | |
| # Setup HF authentication | |
| self._setup_hf_auth() | |
| # Load class information and download model | |
| self._load_class_info() | |
| self._download_model() | |
| def _setup_hf_auth(self): | |
| """Setup Hugging Face authentication using HF_TOKEN environment variable.""" | |
| hf_token = os.getenv('HF_TOKEN') | |
| if hf_token: | |
| try: | |
| login(token=hf_token) | |
| logger.info("Successfully authenticated with Hugging Face Hub") | |
| except Exception as e: | |
| logger.warning(f"Failed to authenticate with HF Hub: {e}") | |
| else: | |
| logger.warning("HF_TOKEN not found in environment variables. Access to private repos may fail.") | |
| def _download_model(self): | |
| """Download the ONNX model from Hugging Face Hub.""" | |
| try: | |
| logger.info(f"Downloading model from {self.repo_id}...") | |
| self.onnx_model_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename="model.onnx", | |
| subfolder="1", | |
| cache_dir="./hf_cache" | |
| ) | |
| logger.info(f"Model downloaded to: {self.onnx_model_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to download model from HF Hub: {e}") | |
| raise RuntimeError(f"Could not download model from {self.repo_id}: {e}") | |
| def _load_class_info(self): | |
| """Load class information from Hugging Face repository or use fallback.""" | |
| try: | |
| # Try to download class information from HF repository | |
| try: | |
| logger.info("Attempting to download class information from HF repository...") | |
| class_file_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename="class_names.yaml", | |
| cache_dir="./hf_cache" | |
| ) | |
| # Read class names from downloaded file | |
| with open(class_file_path, 'r') as f: | |
| class_config = yaml.safe_load(f) | |
| self.class_names = class_config.get('names', []) | |
| logger.info(f"Loaded {len(self.class_names)} classes from HF repository") | |
| logger.info(f"Sample classes: {self.class_names[:5]}...") | |
| return | |
| except Exception as e: | |
| logger.warning(f"Could not load class names from HF repository: {e}") | |
| logger.info("Falling back to hardcoded class names...") | |
| # Fallback to hardcoded class names | |
| self._load_fallback_classes() | |
| except Exception as e: | |
| logger.error(f"Failed to load class info: {e}") | |
| self._load_fallback_classes() | |
| def _load_fallback_classes(self): | |
| """Load fallback class names if configuration files are not available.""" | |
| self.class_names = [ | |
| 'Aircraft Hangar', 'Airplane', 'Airport', 'Barge', 'Baseball Diamond', | |
| 'Basketball Court', 'Bridge', 'Building', 'Bus', 'Cargo Truck', | |
| 'Cargo/Container Railcar', 'Cargo/Passenger Plane', 'Cement Mixer', | |
| 'Construction Site', 'Container Crane', 'Container Ship', 'Crane Truck', | |
| 'Damaged Building', 'Dump Truck', 'Engineering Vehicle', 'Excavator', | |
| 'Facility', 'Ferry', 'Fishing Vessel', 'Flat Railcar', 'Front Loader/Bulldozer', | |
| 'Ground Grader', 'Ground Track Field', 'Harbor', 'Haul Truck', 'Helicopter', | |
| 'Helipad', 'Hut/Tent', 'Large Vehicle', 'Locomotive', 'Mobile Crane', | |
| 'Motorboat', 'Oil Tanker', 'Passenger Railcar', 'Pylon', 'Railway Vehicle', | |
| 'Reach Stacker', 'Roundabout', 'Sailboat', 'Scraper/Tractor', 'Shed', 'Ship', | |
| 'Shipping Container', 'Shipping Container Lot', 'Small Vehicle', 'Soccer Field', | |
| 'Storage Tank', 'Straddle Carrier', 'Swimming Pool', 'Tank Railcar', | |
| 'Tennis Court', 'Tower', 'Tower Crane', 'Trailer', 'Truck', 'Truck Tractor', | |
| 'Truck Tractor with Box Trailer', 'Truck Tractor with Flatbed Trailer', | |
| 'Truck Tractor with Liquid Tank', 'Tugboat', 'Utility Truck', 'Vehicle', | |
| 'Vehicle Lot', 'Yacht' | |
| ] | |
| logger.info(f"Using fallback class names: {len(self.class_names)} classes") | |
| def load_model(self): | |
| """Load the YOLO ONNX model using ultralytics.""" | |
| try: | |
| if not self.onnx_model_path or not Path(self.onnx_model_path).exists(): | |
| raise FileNotFoundError(f"ONNX model not found: {self.onnx_model_path}") | |
| # Load YOLO model from ONNX file | |
| self.model = YOLO(self.onnx_model_path) | |
| logger.info(f"Successfully loaded YOLO model from: {self.onnx_model_path}") | |
| # Override the model's class names with our custom ones | |
| if hasattr(self.model.model, 'names'): | |
| self.model.model.names = {i: name for i, name in enumerate(self.class_names)} | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load YOLO model: {e}") | |
| return False | |
| def detect(self, image: np.ndarray, conf_threshold: float = 0.25) -> Tuple[np.ndarray, np.ndarray, List[str]]: | |
| """Run detection on an image using ultralytics YOLO.""" | |
| if self.model is None: | |
| if not self.load_model(): | |
| raise RuntimeError("Failed to load YOLO model") | |
| try: | |
| # Run YOLO inference - ultralytics handles all preprocessing/postprocessing | |
| results = self.model.predict( | |
| source=image, | |
| conf=conf_threshold, | |
| verbose=False, | |
| save=False, | |
| show=False | |
| ) | |
| # Extract results from the first (and only) image | |
| result = results[0] | |
| if result.boxes is None or len(result.boxes) == 0: | |
| # No detections | |
| return np.array([]).reshape(0, 4), np.array([]), [] | |
| # Extract bounding boxes, confidence scores, and class IDs | |
| boxes = result.boxes.xyxy.cpu().numpy() # [x1, y1, x2, y2] format | |
| scores = result.boxes.conf.cpu().numpy() # confidence scores | |
| class_ids = result.boxes.cls.cpu().numpy().astype(int) # class IDs | |
| # Convert class IDs to class names | |
| class_names = [self.class_names[class_id] if class_id < len(self.class_names) | |
| else f"Unknown_{class_id}" for class_id in class_ids] | |
| logger.info(f"Found {len(boxes)} detections") | |
| if len(boxes) > 0: | |
| logger.info(f"Score range: {scores.min():.3f} - {scores.max():.3f}") | |
| logger.info(f"Classes detected: {set(class_names)}") | |
| return boxes, scores, class_names | |
| except Exception as e: | |
| logger.error(f"Detection failed: {e}") | |
| raise | |
| def draw_detections(image: np.ndarray, boxes: np.ndarray, scores: np.ndarray, | |
| classes: List[str]) -> np.ndarray: | |
| """Draw bounding boxes and labels on image.""" | |
| if len(boxes) == 0: | |
| return image | |
| # Create a copy of the image | |
| annotated = image.copy() | |
| # Generate colors for different classes | |
| unique_classes = list(set(classes)) | |
| colors = np.random.randint(0, 255, size=(len(unique_classes), 3), dtype=np.uint8) | |
| class_colors = {cls: colors[i] for i, cls in enumerate(unique_classes)} | |
| for box, score, cls in zip(boxes, scores, classes): | |
| x1, y1, x2, y2 = box.astype(int) | |
| # Get color for this class | |
| color = class_colors[cls] | |
| color_bgr = (int(color[2]), int(color[1]), int(color[0])) # RGB to BGR for cv2 | |
| # Draw bounding box | |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), color_bgr, 2) | |
| # Draw label | |
| label = f"{cls}: {score:.2f}" | |
| label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] | |
| # Draw label background | |
| cv2.rectangle(annotated, (x1, y1 - label_size[1] - 10), | |
| (x1 + label_size[0], y1), color_bgr, -1) | |
| # Draw label text | |
| cv2.putText(annotated, label, (x1, y1 - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) | |
| return annotated | |
| # Initialize detector | |
| detector = TrueSatDetector() | |
| def detect_objects(image, conf_threshold): | |
| """Main detection function for Gradio interface.""" | |
| try: | |
| # Run detection using ultralytics | |
| boxes, scores, classes = detector.detect(image, conf_threshold) | |
| # Draw results | |
| annotated_image = draw_detections(image, boxes, scores, classes) | |
| # Log results | |
| logger.info(f"Found {len(boxes)} detections") | |
| if len(boxes) > 0: | |
| logger.info(f"Classes detected: {set(classes)}") | |
| return annotated_image | |
| except Exception as e: | |
| logger.error(f"Detection failed: {e}") | |
| # Return original image with error message | |
| error_image = image.copy() | |
| cv2.putText(error_image, f"Error: {str(e)}", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
| return error_image | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=detect_objects, | |
| inputs=[ | |
| gr.Image(type="numpy", label="Upload Image"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.25, step=0.05, | |
| label="Confidence Threshold", | |
| info="Minimum confidence score for detections") | |
| ], | |
| outputs=gr.Image(label="Detection Results"), | |
| title="🛰️ TrueSat Satellite Object Detection", | |
| description=""" | |
| Upload a satellite image to detect various objects including: | |
| - Vessels (ships, boats, barges) | |
| - Aircraft (planes, helicopters) | |
| - Vehicles (trucks, cars) | |
| - Infrastructure (buildings, bridges, airports) | |
| - And 60+ other object classes | |
| **Note:** Uses ultralytics YOLO for accurate detection results. | |
| """, | |
| article=""" | |
| ### How to use: | |
| 1. Upload a satellite or aerial image | |
| 2. Adjust the confidence threshold to filter detections | |
| 3. Click Submit to run detection | |
| ### Technical Details: | |
| - Model: YOLO11x trained on satellite imagery | |
| - Classes: 69 object categories optimized for satellite/aerial imagery | |
| - Backend: Ultralytics YOLO with ONNX inference | |
| - Features: Automatic NMS, proper preprocessing, accurate confidence scores | |
| """, | |
| theme=gr.themes.Soft(), | |
| examples=None # Add examples if you have sample images | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Starting TrueSat Detection App...") | |
| logger.info("Loading YOLO model...") | |
| if detector.load_model(): | |
| logger.info("✅ Successfully loaded YOLO model") | |
| logger.info(f"✅ Loaded {len(detector.class_names)} object classes") | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |
| else: | |
| logger.error("❌ Failed to load YOLO model") | |
| logger.error("Please make sure:") | |
| logger.error("1. HF_TOKEN environment variable is set for private repo access") | |
| logger.error("2. The 'truthdotphd/truesat-detection' repository is accessible") | |
| logger.error("3. Ultralytics and huggingface_hub are properly installed") | |
| logger.error("4. You have sufficient memory/GPU resources") | |
| # Launch anyway but with warning | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |