Spaces:
Sleeping
Sleeping
| # yolo_predictor.py | |
| import os | |
| import logging | |
| import tempfile | |
| import numpy as np | |
| import tifffile | |
| from rasterio.transform import from_bounds | |
| from ultralytics import YOLO | |
| from ndvi_predictor import normalize_rgb, predict_ndvi | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def load_yolo_model(model_path): | |
| """Load YOLO model from .pt file""" | |
| logger.info(f"Loading YOLO model from: {model_path}") | |
| return YOLO(model_path) | |
| def predict_yolo(yolo_model, image_path, conf=0.01): | |
| """ | |
| Predict using YOLO model on 4-channel TIFF image | |
| Args: | |
| yolo_model: Loaded YOLO model | |
| image_path: Path to 4-channel TIFF image | |
| conf: Confidence threshold | |
| Returns: | |
| results: YOLO results object | |
| """ | |
| logger.info(f"Starting YOLO prediction on: {image_path} with confidence: {conf}") | |
| # Verify file exists and has correct format | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image file not found: {image_path}") | |
| try: | |
| # Quick validation of the TIFF file | |
| test_array = tifffile.imread(image_path) | |
| logger.info(f"TIFF file shape: {test_array.shape}, dtype: {test_array.dtype}") | |
| # Validate channels | |
| if len(test_array.shape) == 3: | |
| channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2] | |
| else: | |
| channels = 1 | |
| if channels != 4: | |
| raise ValueError(f"Expected 4-channel image, got {channels} channels") | |
| except Exception as e: | |
| logger.error(f"Error validating TIFF file: {e}") | |
| raise | |
| logger.info("Running YOLO model inference...") | |
| # Run YOLO prediction directly on the input file | |
| results = yolo_model([image_path], conf=conf) | |
| logger.info(f"YOLO prediction completed. Results type: {type(results[0])}") | |
| return results[0] # Return first result | |
| def create_4channel_tiff(rgb_array, ndvi_array, output_path): | |
| """ | |
| Create a 4-channel TIFF file with RGB channels + NDVI channel | |
| Args: | |
| rgb_array: RGB image array (H, W, 3) | |
| ndvi_array: NDVI array (H, W) with values in [-1, 1] | |
| output_path: Path to save the 4-channel TIFF | |
| """ | |
| logger.info(f"Creating 4-channel TIFF file at: {output_path}") | |
| logger.info(f"RGB shape: {rgb_array.shape}, NDVI shape: {ndvi_array.shape}") | |
| # Ensure RGB is in uint8 format | |
| if rgb_array.dtype != np.uint8: | |
| if rgb_array.max() <= 1.0: | |
| rgb_uint8 = (rgb_array * 255).astype(np.uint8) | |
| else: | |
| rgb_uint8 = rgb_array.astype(np.uint8) | |
| else: | |
| rgb_uint8 = rgb_array | |
| # Convert NDVI from [-1, 1] to [0, 255] uint8 format (same as reference code) | |
| ndvi_scaled = (((ndvi_array + 1) / 2) * 255).astype(np.uint8) | |
| logger.info(f"RGB range: [{rgb_uint8.min()}, {rgb_uint8.max()}]") | |
| logger.info(f"NDVI scaled range: [{ndvi_scaled.min()}, {ndvi_scaled.max()}]") | |
| # Stack RGB + NDVI to create 4-channel image | |
| # Format: (channels, height, width) - channel-first format | |
| four_channel = np.stack([ | |
| rgb_uint8[:, :, 0], # R channel | |
| rgb_uint8[:, :, 1], # G channel | |
| rgb_uint8[:, :, 2], # B channel | |
| ndvi_scaled # NDVI channel | |
| ], axis=0) | |
| logger.info(f"4-channel array shape: {four_channel.shape}, dtype: {four_channel.dtype}") | |
| logger.info(f"4-channel range: [{four_channel.min()}, {four_channel.max()}]") | |
| # Save as TIFF using tifffile | |
| tifffile.imwrite(output_path, four_channel) | |
| logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}") | |
| def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.01): | |
| """ | |
| Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction | |
| Args: | |
| ndvi_model: Loaded NDVI prediction model | |
| yolo_model: Loaded YOLO model | |
| rgb_array: RGB image as numpy array (H, W, 3) | |
| conf: Confidence threshold for YOLO | |
| Returns: | |
| results: YOLO results object | |
| """ | |
| logger.info("Starting full prediction pipeline") | |
| logger.info(f"Input RGB array shape: {rgb_array.shape}, dtype: {rgb_array.dtype}") | |
| # Step 1: Normalize RGB image | |
| logger.info("Step 1: Normalizing RGB image") | |
| normalized_rgb = normalize_rgb(rgb_array) | |
| logger.info(f"Normalized RGB shape: {normalized_rgb.shape}, range: [{normalized_rgb.min():.3f}, {normalized_rgb.max():.3f}]") | |
| # Step 2: Predict NDVI | |
| logger.info("Step 2: Predicting NDVI from RGB") | |
| ndvi_prediction = predict_ndvi(ndvi_model, normalized_rgb) | |
| logger.info(f"NDVI prediction shape: {ndvi_prediction.shape}, range: [{ndvi_prediction.min():.3f}, {ndvi_prediction.max():.3f}]") | |
| # Step 3: Create 4-channel TIFF file | |
| logger.info("Step 3: Creating 4-channel TIFF file (BGR+NDVI)") | |
| # Create temporary file for the 4-channel TIFF | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file: | |
| tiff_path = tmp_file.name | |
| try: | |
| # Create the 4-channel TIFF | |
| create_4channel_tiff(rgb_array, ndvi_prediction, tiff_path) | |
| # Verify the created file | |
| if not os.path.exists(tiff_path): | |
| raise FileNotFoundError(f"Failed to create 4-channel TIFF at: {tiff_path}") | |
| file_size = os.path.getsize(tiff_path) | |
| logger.info(f"Created 4-channel TIFF file size: {file_size} bytes") | |
| # Step 4: Run YOLO prediction on the 4-channel TIFF | |
| logger.info("Step 4: Running YOLO prediction on 4-channel TIFF") | |
| results = predict_yolo(yolo_model, tiff_path, conf=conf) | |
| logger.info("Full pipeline completed successfully") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error in pipeline: {e}") | |
| raise | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(tiff_path): | |
| try: | |
| os.unlink(tiff_path) | |
| logger.info(f"Cleaned up temporary file: {tiff_path}") | |
| except Exception as cleanup_error: | |
| logger.warning(f"Failed to clean up temporary file: {cleanup_error}") | |
| def validate_4channel_tiff(tiff_path): | |
| """ | |
| Validate that a TIFF file has exactly 4 channels | |
| Args: | |
| tiff_path: Path to TIFF file | |
| Returns: | |
| bool: True if valid 4-channel TIFF, False otherwise | |
| """ | |
| try: | |
| array = tifffile.imread(tiff_path) | |
| if len(array.shape) == 3: | |
| channels = array.shape[0] if array.shape[0] <= 4 else array.shape[2] | |
| else: | |
| channels = 1 | |
| logger.info(f"TIFF validation - Shape: {array.shape}, Channels: {channels}") | |
| return channels == 4 | |
| except Exception as e: | |
| logger.error(f"Error validating TIFF file: {e}") | |
| return False |