File size: 12,528 Bytes
7914bb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
#!/usr/bin/env python3
"""
Gradio app for TrueSat Detection using ultralytics YOLO
"""

import gradio as gr
import numpy as np
import cv2
import yaml
import logging
import os
from typing import List, Tuple
from pathlib import Path
from ultralytics import YOLO
from huggingface_hub import hf_hub_download, login

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TrueSatDetector:
    def __init__(self, repo_id: str = "truthdotphd/truesat-detection"):
        """Initialize the TrueSat detector with ultralytics YOLO from Hugging Face Hub."""
        self.repo_id = repo_id
        self.model = None
        self.class_names = []
        self.onnx_model_path = None
        
        # Setup HF authentication
        self._setup_hf_auth()
        
        # Load class information and download model
        self._load_class_info()
        self._download_model()
        
    def _setup_hf_auth(self):
        """Setup Hugging Face authentication using HF_TOKEN environment variable."""
        hf_token = os.getenv('HF_TOKEN')
        if hf_token:
            try:
                login(token=hf_token)
                logger.info("Successfully authenticated with Hugging Face Hub")
            except Exception as e:
                logger.warning(f"Failed to authenticate with HF Hub: {e}")
        else:
            logger.warning("HF_TOKEN not found in environment variables. Access to private repos may fail.")
    
    def _download_model(self):
        """Download the ONNX model from Hugging Face Hub."""
        try:
            logger.info(f"Downloading model from {self.repo_id}...")
            self.onnx_model_path = hf_hub_download(
                repo_id=self.repo_id,
                filename="model.onnx",
                subfolder="1",
                cache_dir="./hf_cache"
            )
            logger.info(f"Model downloaded to: {self.onnx_model_path}")
        except Exception as e:
            logger.error(f"Failed to download model from HF Hub: {e}")
            raise RuntimeError(f"Could not download model from {self.repo_id}: {e}")
        
    def _load_class_info(self):
        """Load class information from Hugging Face repository or use fallback."""
        try:
            # Try to download class information from HF repository
            try:
                logger.info("Attempting to download class information from HF repository...")
                class_file_path = hf_hub_download(
                    repo_id=self.repo_id,
                    filename="class_names.yaml",
                    cache_dir="./hf_cache"
                )
                
                # Read class names from downloaded file
                with open(class_file_path, 'r') as f:
                    class_config = yaml.safe_load(f)
                
                self.class_names = class_config.get('names', [])
                logger.info(f"Loaded {len(self.class_names)} classes from HF repository")
                logger.info(f"Sample classes: {self.class_names[:5]}...")
                return
                
            except Exception as e:
                logger.warning(f"Could not load class names from HF repository: {e}")
                logger.info("Falling back to hardcoded class names...")
                
            # Fallback to hardcoded class names
            self._load_fallback_classes()
            
        except Exception as e:
            logger.error(f"Failed to load class info: {e}")
            self._load_fallback_classes()
    
    def _load_fallback_classes(self):
        """Load fallback class names if configuration files are not available."""
        self.class_names = [
            'Aircraft Hangar', 'Airplane', 'Airport', 'Barge', 'Baseball Diamond', 
            'Basketball Court', 'Bridge', 'Building', 'Bus', 'Cargo Truck', 
            'Cargo/Container Railcar', 'Cargo/Passenger Plane', 'Cement Mixer', 
            'Construction Site', 'Container Crane', 'Container Ship', 'Crane Truck', 
            'Damaged Building', 'Dump Truck', 'Engineering Vehicle', 'Excavator', 
            'Facility', 'Ferry', 'Fishing Vessel', 'Flat Railcar', 'Front Loader/Bulldozer', 
            'Ground Grader', 'Ground Track Field', 'Harbor', 'Haul Truck', 'Helicopter', 
            'Helipad', 'Hut/Tent', 'Large Vehicle', 'Locomotive', 'Mobile Crane', 
            'Motorboat', 'Oil Tanker', 'Passenger Railcar', 'Pylon', 'Railway Vehicle', 
            'Reach Stacker', 'Roundabout', 'Sailboat', 'Scraper/Tractor', 'Shed', 'Ship', 
            'Shipping Container', 'Shipping Container Lot', 'Small Vehicle', 'Soccer Field', 
            'Storage Tank', 'Straddle Carrier', 'Swimming Pool', 'Tank Railcar', 
            'Tennis Court', 'Tower', 'Tower Crane', 'Trailer', 'Truck', 'Truck Tractor', 
            'Truck Tractor with Box Trailer', 'Truck Tractor with Flatbed Trailer', 
            'Truck Tractor with Liquid Tank', 'Tugboat', 'Utility Truck', 'Vehicle', 
            'Vehicle Lot', 'Yacht'
        ]
        logger.info(f"Using fallback class names: {len(self.class_names)} classes")
        
    def load_model(self):
        """Load the YOLO ONNX model using ultralytics."""
        try:
            if not self.onnx_model_path or not Path(self.onnx_model_path).exists():
                raise FileNotFoundError(f"ONNX model not found: {self.onnx_model_path}")
            
            # Load YOLO model from ONNX file
            self.model = YOLO(self.onnx_model_path)
            
            logger.info(f"Successfully loaded YOLO model from: {self.onnx_model_path}")
            
            # Override the model's class names with our custom ones
            if hasattr(self.model.model, 'names'):
                self.model.model.names = {i: name for i, name in enumerate(self.class_names)}
            
            return True
            
        except Exception as e:
            logger.error(f"Failed to load YOLO model: {e}")
            return False
            
    def detect(self, image: np.ndarray, conf_threshold: float = 0.25) -> Tuple[np.ndarray, np.ndarray, List[str]]:
        """Run detection on an image using ultralytics YOLO."""
        if self.model is None:
            if not self.load_model():
                raise RuntimeError("Failed to load YOLO model")
        
        try:
            # Run YOLO inference - ultralytics handles all preprocessing/postprocessing
            results = self.model.predict(
                source=image,
                conf=conf_threshold,
                verbose=False,
                save=False,
                show=False
            )
            
            # Extract results from the first (and only) image
            result = results[0]
            
            if result.boxes is None or len(result.boxes) == 0:
                # No detections
                return np.array([]).reshape(0, 4), np.array([]), []
            
            # Extract bounding boxes, confidence scores, and class IDs
            boxes = result.boxes.xyxy.cpu().numpy()  # [x1, y1, x2, y2] format
            scores = result.boxes.conf.cpu().numpy()  # confidence scores
            class_ids = result.boxes.cls.cpu().numpy().astype(int)  # class IDs
            
            # Convert class IDs to class names
            class_names = [self.class_names[class_id] if class_id < len(self.class_names) 
                          else f"Unknown_{class_id}" for class_id in class_ids]
            
            logger.info(f"Found {len(boxes)} detections")
            if len(boxes) > 0:
                logger.info(f"Score range: {scores.min():.3f} - {scores.max():.3f}")
                logger.info(f"Classes detected: {set(class_names)}")
            
            return boxes, scores, class_names
            
        except Exception as e:
            logger.error(f"Detection failed: {e}")
            raise

def draw_detections(image: np.ndarray, boxes: np.ndarray, scores: np.ndarray, 
                   classes: List[str]) -> np.ndarray:
    """Draw bounding boxes and labels on image."""
    if len(boxes) == 0:
        return image
        
    # Create a copy of the image
    annotated = image.copy()
    
    # Generate colors for different classes
    unique_classes = list(set(classes))
    colors = np.random.randint(0, 255, size=(len(unique_classes), 3), dtype=np.uint8)
    class_colors = {cls: colors[i] for i, cls in enumerate(unique_classes)}
    
    for box, score, cls in zip(boxes, scores, classes):
        x1, y1, x2, y2 = box.astype(int)
        
        # Get color for this class
        color = class_colors[cls]
        color_bgr = (int(color[2]), int(color[1]), int(color[0]))  # RGB to BGR for cv2
        
        # Draw bounding box
        cv2.rectangle(annotated, (x1, y1), (x2, y2), color_bgr, 2)
        
        # Draw label
        label = f"{cls}: {score:.2f}"
        label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
        
        # Draw label background
        cv2.rectangle(annotated, (x1, y1 - label_size[1] - 10), 
                     (x1 + label_size[0], y1), color_bgr, -1)
        
        # Draw label text
        cv2.putText(annotated, label, (x1, y1 - 5), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
    
    return annotated

# Initialize detector
detector = TrueSatDetector()

def detect_objects(image, conf_threshold):
    """Main detection function for Gradio interface."""
    try:
        # Run detection using ultralytics
        boxes, scores, classes = detector.detect(image, conf_threshold)
        
        # Draw results
        annotated_image = draw_detections(image, boxes, scores, classes)
        
        # Log results
        logger.info(f"Found {len(boxes)} detections")
        if len(boxes) > 0:
            logger.info(f"Classes detected: {set(classes)}")
        
        return annotated_image
        
    except Exception as e:
        logger.error(f"Detection failed: {e}")
        # Return original image with error message
        error_image = image.copy()
        cv2.putText(error_image, f"Error: {str(e)}", (10, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        return error_image

# Create Gradio interface
demo = gr.Interface(
    fn=detect_objects,
    inputs=[
        gr.Image(type="numpy", label="Upload Image"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.25, step=0.05, 
                 label="Confidence Threshold", 
                 info="Minimum confidence score for detections")
    ],
    outputs=gr.Image(label="Detection Results"),
    title="🛰️ TrueSat Satellite Object Detection",
    description="""
    Upload a satellite image to detect various objects including:
    - Vessels (ships, boats, barges)
    - Aircraft (planes, helicopters) 
    - Vehicles (trucks, cars)
    - Infrastructure (buildings, bridges, airports)
    - And 60+ other object classes
    
    **Note:** Uses ultralytics YOLO for accurate detection results.
    """,
    article="""
    ### How to use:
    1. Upload a satellite or aerial image
    2. Adjust the confidence threshold to filter detections
    3. Click Submit to run detection
    
    ### Technical Details:
    - Model: YOLO11x trained on satellite imagery
    - Classes: 69 object categories optimized for satellite/aerial imagery
    - Backend: Ultralytics YOLO with ONNX inference
    - Features: Automatic NMS, proper preprocessing, accurate confidence scores
    """,
    theme=gr.themes.Soft(),
    examples=None  # Add examples if you have sample images
)

if __name__ == "__main__":
    logger.info("Starting TrueSat Detection App...")
    logger.info("Loading YOLO model...")
    
    if detector.load_model():
        logger.info("✅ Successfully loaded YOLO model")
        logger.info(f"✅ Loaded {len(detector.class_names)} object classes")
        demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
    else:
        logger.error("❌ Failed to load YOLO model")
        logger.error("Please make sure:")
        logger.error("1. HF_TOKEN environment variable is set for private repo access")
        logger.error("2. The 'truthdotphd/truesat-detection' repository is accessible")
        logger.error("3. Ultralytics and huggingface_hub are properly installed")
        logger.error("4. You have sufficient memory/GPU resources")
        
        # Launch anyway but with warning
        demo.launch(server_name="0.0.0.0", server_port=7860, share=False)