File size: 7,398 Bytes
fc5324e
 
c421799
3b61715
 
65137d3
3b61715
 
 
3143e36
fc5324e
c421799
 
 
 
fc5324e
 
3b61715
fc5324e
 
6e63a4a
fc5324e
503cb09
fc5324e
 
 
503cb09
 
fc5324e
 
 
 
c421799
 
3b61715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c421799
503cb09
 
65137d3
c421799
503cb09
 
3b61715
503cb09
3b61715
fc5324e
503cb09
3b61715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e63a4a
3b61715
 
 
 
 
503cb09
3b61715
503cb09
fc5324e
503cb09
 
 
3b61715
 
 
3143e36
 
 
 
 
 
 
 
 
3b61715
 
3143e36
 
3b61715
 
 
3143e36
 
3b61715
 
 
 
 
 
3143e36
 
3b61715
 
 
 
 
 
 
 
3143e36
 
3b61715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# yolo_predictor.py
import os
import logging
import tempfile
import numpy as np
import tifffile
from rasterio.transform import from_bounds
from ultralytics import YOLO
from ndvi_predictor import normalize_rgb, predict_ndvi
from resize_image import resize_image_optimized

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def load_yolo_model(model_path):
    """Load YOLO model from .pt file"""
    logger.info(f"Loading YOLO model from: {model_path}")
    return YOLO(model_path)

def predict_yolo(yolo_model, image_path, conf=0.001):
    """
    Predict using YOLO model on 4-channel TIFF image
    
    Args:
        yolo_model: Loaded YOLO model
        image_path: Path to 4-channel TIFF image
        conf: Confidence threshold
    
    Returns:
        results: YOLO results object
    """
    logger.info(f"Starting YOLO prediction on: {image_path} with confidence: {conf}")
    
    # Verify file exists and has correct format
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image file not found: {image_path}")
    
    try:
        # Quick validation of the TIFF file
        test_array = tifffile.imread(image_path)
        logger.info(f"TIFF file shape: {test_array.shape}, dtype: {test_array.dtype}")
        
        # Validate channels
        if len(test_array.shape) == 3:
            channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
        else:
            channels = 1
        
        if channels != 4:
            raise ValueError(f"Expected 4-channel image, got {channels} channels")
            
    except Exception as e:
        logger.error(f"Error validating TIFF file: {e}")
        raise
    
    logger.info("Running YOLO model inference...")
    # Run YOLO prediction directly on the input file
    results = yolo_model([image_path], conf=conf)
    
    logger.info(f"YOLO prediction completed. Results type: {type(results[0])}")
    return results[0]  # Return first result

def create_4channel_tiff(rgb_array, ndvi_array, output_path):
    """
    Create a 4-channel TIFF file with RGB channels + NDVI channel
    
    Args:
        rgb_array: RGB image array (H, W, 3)
        ndvi_array: NDVI array (H, W) with values in [-1, 1]
        output_path: Path to save the 4-channel TIFF
    """
    logger.info(f"Creating 4-channel TIFF file at: {output_path}")
    logger.info(f"RGB shape: {rgb_array.shape}, NDVI shape: {ndvi_array.shape}")
    
    # Ensure RGB is in uint8 format
    if rgb_array.dtype != np.uint8:
        if rgb_array.max() <= 1.0:
            rgb_uint8 = (rgb_array * 255).astype(np.uint8)
        else:
            rgb_uint8 = rgb_array.astype(np.uint8)
    else:
        rgb_uint8 = rgb_array
    
    # Convert NDVI from [-1, 1] to [0, 255] uint8 format (same as reference code)
    ndvi_scaled = (((ndvi_array + 1) / 2) * 255).astype(np.uint8)
    
    logger.info(f"RGB range: [{rgb_uint8.min()}, {rgb_uint8.max()}]")
    logger.info(f"NDVI scaled range: [{ndvi_scaled.min()}, {ndvi_scaled.max()}]")
    
    # Stack RGB + NDVI to create 4-channel image
    # Format: (channels, height, width) - channel-first format
    four_channel = np.stack([
        rgb_uint8[:, :, 0],  # R channel
        rgb_uint8[:, :, 1],  # G channel  
        rgb_uint8[:, :, 2],  # B channel
        ndvi_scaled          # NDVI channel
    ], axis=0)
    
    logger.info(f"4-channel array shape: {four_channel.shape}, dtype: {four_channel.dtype}")
    logger.info(f"4-channel range: [{four_channel.min()}, {four_channel.max()}]")
    
    # Save as TIFF using tifffile
    tifffile.imwrite(output_path, four_channel)
    logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}")

def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.001):
    """
    Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction
    
    Args:
        ndvi_model: Loaded NDVI prediction model
        yolo_model: Loaded YOLO model
        rgb_array: RGB image as numpy array (H, W, 3)
        conf: Confidence threshold for YOLO
    
    Returns:
        results: YOLO results object
    """
    logger.info("Starting full prediction pipeline")
    logger.info(f"Input RGB array shape: {rgb_array.shape}, dtype: {rgb_array.dtype}")
    
    # Step 1: Resize RGB image to target size
    logger.info("Step 1: Resizing RGB image to target size")
    target_size = (640, 640)  # (height, width)
    rgb_resized = resize_image_optimized(rgb_array, target_size)
    logger.info(f"Resized RGB shape: {rgb_resized.shape}")
    
    # Step 2: Normalize RGB image
    logger.info("Step 2: Normalizing RGB image")
    normalized_rgb = normalize_rgb(rgb_resized)
    logger.info(f"Normalized RGB shape: {normalized_rgb.shape}, range: [{normalized_rgb.min():.3f}, {normalized_rgb.max():.3f}]")
    
    # Step 3: Predict NDVI
    logger.info("Step 3: Predicting NDVI from RGB")
    ndvi_prediction = predict_ndvi(ndvi_model, normalized_rgb)
    logger.info(f"NDVI prediction shape: {ndvi_prediction.shape}, range: [{ndvi_prediction.min():.3f}, {ndvi_prediction.max():.3f}]")
    
    # Step 4: Create 4-channel TIFF file
    logger.info("Step 4: Creating 4-channel TIFF file (RGB+NDVI)")
    
    # Create temporary file for the 4-channel TIFF
    with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
        tiff_path = tmp_file.name
    
    try:
        # Create the 4-channel TIFF using resized RGB and predicted NDVI
        create_4channel_tiff(rgb_resized, ndvi_prediction, tiff_path)
        
        # Verify the created file
        if not os.path.exists(tiff_path):
            raise FileNotFoundError(f"Failed to create 4-channel TIFF at: {tiff_path}")
        
        file_size = os.path.getsize(tiff_path)
        logger.info(f"Created 4-channel TIFF file size: {file_size} bytes")
        
        # Step 5: Run YOLO prediction on the 4-channel TIFF
        logger.info("Step 5: Running YOLO prediction on 4-channel TIFF")
        results = predict_yolo(yolo_model, tiff_path, conf=conf)
        
        logger.info("Full pipeline completed successfully")
        return results
        
    except Exception as e:
        logger.error(f"Error in pipeline: {e}")
        raise
    finally:
        # Clean up temporary file
        if os.path.exists(tiff_path):
            try:
                os.unlink(tiff_path)
                logger.info(f"Cleaned up temporary file: {tiff_path}")
            except Exception as cleanup_error:
                logger.warning(f"Failed to clean up temporary file: {cleanup_error}")

def validate_4channel_tiff(tiff_path):
    """
    Validate that a TIFF file has exactly 4 channels
    
    Args:
        tiff_path: Path to TIFF file
    
    Returns:
        bool: True if valid 4-channel TIFF, False otherwise
    """
    try:
        array = tifffile.imread(tiff_path)
        
        if len(array.shape) == 3:
            channels = array.shape[0] if array.shape[0] <= 4 else array.shape[2]
        else:
            channels = 1
            
        logger.info(f"TIFF validation - Shape: {array.shape}, Channels: {channels}")
        return channels == 4
        
    except Exception as e:
        logger.error(f"Error validating TIFF file: {e}")
        return False