pokkiri commited on
Commit
7929990
·
verified ·
1 Parent(s): 8f42556

Update feature_engineering.py

Browse files
Files changed (1) hide show
  1. feature_engineering.py +677 -328
feature_engineering.py CHANGED
@@ -1,360 +1,709 @@
1
- """
2
- Feature engineering module for biomass prediction.
3
- This module generates the exact 99 features needed by the model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  Author: najahpokkiri
6
- Date: 2025-05-17
 
 
7
  """
 
 
 
8
  import numpy as np
9
- from sklearn.preprocessing import StandardScaler
10
- from sklearn.decomposition import PCA
 
 
 
 
 
11
  import logging
 
 
12
 
13
  # Configure logger
 
14
  logger = logging.getLogger(__name__)
15
 
16
- def safe_divide(a, b, fill_value=0.0):
17
- """Safe division that handles zeros in the denominator"""
18
- a = np.asarray(a, dtype=np.float32)
19
- b = np.asarray(b, dtype=np.float32)
20
-
21
- # Handle NaN/Inf in inputs
22
- a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
23
- b = np.nan_to_num(b, nan=1e-10, posinf=1e10, neginf=-1e10)
24
-
25
- mask = np.abs(b) < 1e-10
26
- result = np.full_like(a, fill_value, dtype=np.float32)
27
- if np.any(~mask):
28
- result[~mask] = a[~mask] / b[~mask]
29
-
30
- return np.nan_to_num(result, nan=fill_value, posinf=fill_value, neginf=fill_value)
31
 
32
- def calculate_spectral_indices(satellite_data):
33
- """Calculate the 7 spectral indices needed by the model"""
34
- indices = {}
 
 
 
 
 
35
 
36
- # Use band indices based on position in the file
37
- # Adjust these if your band order is different
38
- blue = satellite_data[1] if satellite_data.shape[0] > 1 else None
39
- green = satellite_data[2] if satellite_data.shape[0] > 2 else None
40
- red = satellite_data[3] if satellite_data.shape[0] > 3 else None
41
- nir = satellite_data[7] if satellite_data.shape[0] > 7 else None
42
- swir1 = satellite_data[9] if satellite_data.shape[0] > 9 else None
43
- swir2 = satellite_data[10] if satellite_data.shape[0] > 10 else None
44
 
45
- # Calculate NDVI (Normalized Difference Vegetation Index)
46
- if red is not None and nir is not None:
47
- indices['NDVI'] = safe_divide(nir - red, nir + red)
48
-
49
- # Calculate EVI (Enhanced Vegetation Index)
50
- if blue is not None:
51
- indices['EVI'] = 2.5 * safe_divide(nir - red, nir + 6.0 * red - 7.5 * blue + 1.0)
52
-
53
- # Calculate SAVI (Soil Adjusted Vegetation Index)
54
- indices['SAVI'] = 1.5 * safe_divide(nir - red, nir + red + 0.5)
55
-
56
- # Calculate MSAVI2 (Modified Soil Adjusted Vegetation Index)
57
- indices['MSAVI2'] = 0.5 * (2.0 * nir + 1.0 - np.sqrt((2.0 * nir + 1.0)**2 - 8.0 * (nir - red)))
58
 
59
- # Calculate NDWI (Normalized Difference Water Index)
60
- if green is not None and nir is not None:
61
- indices['NDWI'] = safe_divide(green - nir, green + nir)
62
 
63
- # Calculate NDMI (Normalized Difference Moisture Index)
64
- if nir is not None and swir1 is not None:
65
- indices['NDMI'] = safe_divide(nir - swir1, nir + swir1)
66
 
67
- # Calculate NBR (Normalized Burn Ratio)
68
- if nir is not None and swir2 is not None:
69
- indices['NBR'] = safe_divide(nir - swir2, nir + swir2)
 
 
 
70
 
71
- # Ensure we have all required indices by providing defaults if calculation failed
72
- required_indices = ['NDVI', 'EVI', 'SAVI', 'MSAVI2', 'NDWI', 'NDMI', 'NBR']
73
- for idx in required_indices:
74
- if idx not in indices:
75
- logger.warning(f"Could not calculate {idx}, using zeros instead")
76
- indices[idx] = np.zeros_like(satellite_data[0])
77
 
78
- return indices
79
 
80
- def extract_texture_features(satellite_data):
81
- """Extract the 5 texture features needed by the model"""
82
- texture_features = {}
83
- height, width = satellite_data.shape[1], satellite_data.shape[2]
84
-
85
- # Use band 7 (NIR) for texture features
86
- b7_idx = 7
87
- if satellite_data.shape[0] <= b7_idx:
88
- logger.warning(f"Band 7 not available for texture features. Using band 0 instead.")
89
- b7_idx = 0
90
-
91
- band = satellite_data[b7_idx].copy()
92
- band = np.nan_to_num(band, nan=0.0)
93
-
94
- try:
95
- # Import skimage for texture features
96
- try:
97
- from skimage.filters import sobel
98
- from skimage.feature import local_binary_pattern, graycomatrix, graycoprops
99
- except ImportError:
100
- logger.warning("scikit-image not found. Using placeholder texture features.")
101
- # Provide placeholder features
102
- texture_features['Sobel_B7'] = np.zeros_like(band)
103
- texture_features['LBP_B7'] = np.zeros_like(band)
104
- texture_features['GLCM_contrast_B7'] = np.zeros_like(band)
105
- texture_features['GLCM_dissimilarity_B7'] = np.zeros_like(band)
106
- texture_features['GLCM_homogeneity_B7'] = np.zeros_like(band)
107
- texture_features['GLCM_energy_B7'] = np.zeros_like(band)
108
- return texture_features
109
-
110
- # 1. Sobel filter for edge detection
111
- sobel_filtered = sobel(band)
112
- texture_features['Sobel_B7'] = sobel_filtered
113
-
114
- # 2. Local Binary Pattern
115
- # Normalize band to 0-255 range for LBP
116
- band_norm = band.copy()
117
- if np.any(~np.isnan(band)):
118
- band_min, band_max = np.nanpercentile(band, [1, 99])
119
- if band_max > band_min:
120
- band_norm = np.clip((band - band_min) / (band_max - band_min + 1e-8) * 255, 0, 255).astype(np.uint8)
121
- else:
122
- band_norm = np.zeros_like(band, dtype=np.uint8)
123
 
124
- # Calculate LBP
125
- lbp = local_binary_pattern(band_norm, 8, 1, method='uniform')
126
- texture_features['LBP_B7'] = lbp
 
 
127
 
128
- # 3. GLCM properties
129
- # Create sample patch for GLCM calculation
130
- sample_size = min(128, height, width)
131
- center_y, center_x = height // 2, width // 2
132
- offset = sample_size // 2
133
- y_start = max(0, center_y - offset)
134
- y_end = min(height, center_y + offset)
135
- x_start = max(0, center_x - offset)
136
- x_end = min(width, center_x + offset)
137
- patch = band_norm[y_start:y_end, x_start:x_end]
138
 
139
- # Calculate GLCM properties if patch is valid
140
- if patch.size > 0:
141
- glcm = graycomatrix(patch, [1], [0], levels=256, symmetric=True, normed=True)
142
- for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy']:
143
- try:
144
- value = float(graycoprops(glcm, prop)[0, 0])
145
- texture_features[f'GLCM_{prop}_B7'] = np.full_like(band, value)
146
- except:
147
- texture_features[f'GLCM_{prop}_B7'] = np.zeros_like(band)
148
- else:
149
- # Create placeholder GLCM features if patch is invalid
150
- for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy']:
151
- texture_features[f'GLCM_{prop}_B7'] = np.zeros_like(band)
152
-
153
- except Exception as e:
154
- logger.error(f"Error in texture feature extraction: {e}")
155
- # Provide placeholder features in case of error
156
- texture_features['Sobel_B7'] = np.zeros_like(band)
157
- texture_features['LBP_B7'] = np.zeros_like(band)
158
- texture_features['GLCM_contrast_B7'] = np.zeros_like(band)
159
- texture_features['GLCM_dissimilarity_B7'] = np.zeros_like(band)
160
- texture_features['GLCM_homogeneity_B7'] = np.zeros_like(band)
161
- texture_features['GLCM_energy_B7'] = np.zeros_like(band)
162
-
163
- return texture_features
164
-
165
- def calculate_spatial_features(satellite_data, indices):
166
- """Calculate the 2 spatial features needed by the model"""
167
- spatial_features = {}
168
 
169
- # 1. Gradient of Band 7 (NIR)
170
- b7_idx = 7
171
- if satellite_data.shape[0] <= b7_idx:
172
- logger.warning(f"Band 7 not available for gradient calculation. Using band 0 instead.")
173
- b7_idx = 0
174
-
175
- band = satellite_data[b7_idx].copy()
176
- band = np.nan_to_num(band, nan=0.0)
177
-
178
- try:
179
- # Calculate the gradient magnitude
180
- grad_y, grad_x = np.gradient(band)
181
- grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
182
- spatial_features['Gradient_B7'] = grad_magnitude
183
- except Exception as e:
184
- logger.warning(f"Error calculating band gradient: {e}")
185
- spatial_features['Gradient_B7'] = np.zeros_like(band)
186
-
187
- # 2. NDVI gradient
188
- try:
189
- ndvi = indices.get('NDVI', np.zeros_like(band))
190
- ndvi = np.nan_to_num(ndvi, nan=0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- # Calculate the gradient magnitude for NDVI
193
- grad_y, grad_x = np.gradient(ndvi)
194
- grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
195
- spatial_features['NDVI_gradient'] = grad_magnitude
196
- except Exception as e:
197
- logger.warning(f"Error calculating NDVI gradient: {e}")
198
- spatial_features['NDVI_gradient'] = np.zeros_like(band)
199
-
200
- return spatial_features
201
-
202
- def calculate_pca_features(satellite_data, n_components=25):
203
- """Calculate the 25 PCA components needed by the model"""
204
- pca_features = {}
205
 
206
- # Set a fixed number of components
207
- n_components = 25 # Always use exactly 25 components
208
-
209
- try:
210
- # Reshape to (bands, pixels)
211
- n_bands, height, width = satellite_data.shape
212
- bands_flat = satellite_data.reshape(n_bands, -1).T
213
-
214
- # Handle NaN values
215
- valid_mask = ~np.any(np.isnan(bands_flat), axis=1)
216
- if not np.any(valid_mask):
217
- logger.warning("No valid pixels found for PCA calculation")
218
- # Create placeholder PCA features
219
- for i in range(1, n_components + 1):
220
- pca_features[f'PCA_{i:02d}'] = np.zeros((height, width), dtype=np.float32)
221
- return pca_features
222
-
223
- bands_valid = bands_flat[valid_mask]
224
-
225
- # Standardize valid data
226
- scaler = StandardScaler()
227
- bands_scaled = scaler.fit_transform(bands_valid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
- # Calculate PCA
230
- pca = PCA(n_components=min(n_components, bands_scaled.shape[1], bands_scaled.shape[0]))
231
- pca_result = pca.fit_transform(bands_scaled)
232
 
233
- # Extend to full 25 components if needed
234
- actual_components = pca_result.shape[1]
235
- if actual_components < n_components:
236
- logger.warning(f"Only {actual_components} PCA components calculated, padding to {n_components}")
237
- padding = np.zeros((pca_result.shape[0], n_components - actual_components))
238
- pca_result = np.hstack([pca_result, padding])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- # Map back to original pixels
241
- pca_all = np.zeros((bands_flat.shape[0], n_components))
242
- pca_all[valid_mask] = pca_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- # Reshape to spatial dimensions
245
- pca_spatial = pca_all.reshape(height, width, n_components)
 
 
 
 
 
246
 
247
- # Store each component with the correct naming
248
- for i in range(1, n_components + 1):
249
- pca_features[f'PCA_{i:02d}'] = pca_spatial[:, :, i-1]
250
 
 
 
251
  except Exception as e:
252
- logger.error(f"Error calculating PCA features: {e}")
253
- # Create placeholder PCA features
254
- for i in range(1, n_components + 1):
255
- pca_features[f'PCA_{i:02d}'] = np.zeros((height, width), dtype=np.float32)
256
-
257
- return pca_features
258
 
259
- def extract_all_features(satellite_data):
260
- """
261
- Extract exactly 99 features needed by the model:
262
- - 59 original bands
263
- - 7 spectral indices
264
- - 5 texture features
265
- - 2 spatial features
266
- - 25 PCA components
267
-
268
- Parameters:
269
- satellite_data (ndarray): Array of shape (bands, height, width)
270
-
271
- Returns:
272
- features_array (ndarray): Array of shape (valid_pixels, 99)
273
- valid_mask (ndarray): Boolean mask of valid pixels
274
- feature_names (list): List of 99 feature names
275
- """
276
- logger.info("Extracting features for biomass prediction...")
277
- height, width = satellite_data.shape[1], satellite_data.shape[2]
278
-
279
- # Create valid pixel mask (no NaN or Inf values)
280
- valid_mask = np.all(np.isfinite(satellite_data), axis=0)
281
- valid_y, valid_x = np.where(valid_mask)
282
- n_valid = len(valid_y)
283
-
284
- logger.info(f"Found {n_valid} valid pixels out of {height*width}")
285
-
286
- # Generate all feature categories
287
- logger.info("Calculating spectral indices...")
288
- indices = calculate_spectral_indices(satellite_data)
289
-
290
- logger.info("Extracting texture features...")
291
- texture_features = extract_texture_features(satellite_data)
292
-
293
- logger.info("Calculating spatial features...")
294
- spatial_features = calculate_spatial_features(satellite_data, indices)
295
-
296
- logger.info("Computing PCA components...")
297
- pca_features = calculate_pca_features(satellite_data)
298
-
299
- # Define the ordered list of feature names
300
- feature_names = []
301
-
302
- # 1. Add original band names (Band_01 through Band_59)
303
- for i in range(1, 60):
304
- feature_names.append(f'Band_{i:02d}')
305
-
306
- # 2. Add spectral indices
307
- spectral_indices = ['NDVI', 'EVI', 'SAVI', 'MSAVI2', 'NDWI', 'NDMI', 'NBR']
308
- feature_names.extend(spectral_indices)
309
-
310
- # 3. Add texture features
311
- texture_names = ['Sobel_B7', 'LBP_B7', 'GLCM_contrast_B7', 'GLCM_dissimilarity_B7',
312
- 'GLCM_homogeneity_B7', 'GLCM_energy_B7']
313
- feature_names.extend(texture_names)
314
-
315
- # 4. Add spatial features
316
- spatial_names = ['Gradient_B7', 'NDVI_gradient']
317
- feature_names.extend(spatial_names)
318
-
319
- # 5. Add PCA components
320
- for i in range(1, 26):
321
- feature_names.append(f'PCA_{i:02d}')
322
-
323
- # Create feature dictionary with all features
324
- all_features = {}
325
-
326
- # 1. Original bands
327
- for i in range(min(satellite_data.shape[0], 59)):
328
- all_features[f'Band_{i+1:02d}'] = satellite_data[i]
329
-
330
- # Pad with zeros if we have fewer than 59 bands
331
- for i in range(satellite_data.shape[0], 59):
332
- all_features[f'Band_{i+1:02d}'] = np.zeros((height, width), dtype=np.float32)
333
-
334
- # 2. Add other feature categories
335
- all_features.update(indices)
336
- all_features.update(texture_features)
337
- all_features.update(spatial_features)
338
- all_features.update(pca_features)
339
-
340
- # Verify we have exactly 99 features
341
- assert len(feature_names) == 99, f"Expected 99 features, but got {len(feature_names)}"
342
-
343
- # Extract feature values for valid pixels
344
- feature_matrix = np.zeros((n_valid, len(feature_names)), dtype=np.float32)
345
-
346
- for i, name in enumerate(feature_names):
347
- if name in all_features:
348
- feature_data = all_features[name]
349
- if feature_data.ndim == 2:
350
- feature_values = feature_data[valid_y, valid_x]
351
- else:
352
- feature_values = np.full(n_valid, feature_data)
353
- feature_matrix[:, i] = np.nan_to_num(feature_values, nan=0.0)
354
- else:
355
- logger.warning(f"Feature '{name}' not found, using zeros")
356
- feature_matrix[:, i] = 0.0
357
-
358
- logger.info(f"Successfully extracted {len(feature_names)} features for {n_valid} pixels")
359
-
360
- return feature_matrix, valid_mask, feature_names
 
1
+ def create_interface(self):
2
+ """Create Gradio interface with sample image thumbnails"""
3
+ # Generate thumbnails for sample images
4
+ sample_thumbnails = {}
5
+ for name, path in self.sample_images.items():
6
+ if os.path.exists(path):
7
+ thumbnail = self.create_thumbnail(path)
8
+ if thumbnail:
9
+ sample_thumbnails[name] = Image.open(thumbnail)
10
+ else:
11
+ logger.warning(f"Sample image not found: {path}")
12
+
13
+ with gr.Blocks(title="Biomass Prediction Model") as interface:
14
+ gr.Markdown("# Above-Ground Biomass Prediction")
15
+ gr.Markdown("""
16
+ Upload a multi-band satellite image to predict above-ground biomass (AGB) across the landscape.
17
+
18
+ **Requirements:**
19
+ - Image must be a GeoTIFF with spectral bands
20
+ - For best results, image should contain at least 3 bands
21
+ """)
22
+
23
+ with gr.Row():
24
+ with gr.Column(scale=1):
25
+ input_image = gr.File(
26
+ label="Upload Satellite Image (GeoTIFF)",
27
+ file_types=[".tif", ".tiff"]
28
+ )
29
+
30
+ # Sample images section
31
+ gr.Markdown("### Sample Images")
32
+
33
+ # Sample buttons container
34
+ sample_buttons = []
35
+
36
+ # First row - sample thumbnails side by side horizontally
37
+ with gr.Row():
38
+ for name, thumbnail in sample_thumbnails.items():
39
+ with gr.Column():
40
+ gr.Image(
41
+ value=thumbnail,
42
+ label=name.replace("input_", "Input ").replace("chip_", "Chip "),
43
+ show_download_button=False,
44
+ height=180
45
+ )
46
+
47
+ # Second row - buttons side by side horizontally, matching the thumbnails above
48
+ with gr.Row():
49
+ for name, _ in sample_thumbnails.items():
50
+ with gr.Column():
51
+ sample_btn = gr.Button(
52
+ f"Use {name.replace('input_', 'Input ').replace('chip_', 'Chip ')}",
53
+ variant="secondary",
54
+ size="lg"
55
+ )
56
+ sample_buttons.append((sample_btn, name))
57
+
58
+ # Generate button at the bottom
59
+ generate_btn = gr.Button("Generate Biomass Prediction", variant="primary", size="lg")
60
+
61
+ with gr.Column(scale=2):
62
+ output_image = gr.Image(
63
+ label="Biomass Prediction Map",
64
+ type="pil"
65
+ )
66
+
67
+ output_stats = gr.Markdown(
68
+ label="Statistics"
69
+ )
70
+
71
+ with gr.Accordion("About", open=False):
72
+ gr.Markdown("""
73
+ ## About This Model
74
+
75
+ This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
76
+
77
+ ### Model Details
78
+
79
+ - Architecture: StableResNet
80
+ - Input: Multi-spectral satellite imagery
81
+ - Output: Above-ground biomass (Mg/ha)
82
+ - Creator: vertify.earth
83
+ - Date: 2025-05-19
84
+
85
+ ### Improvements in This Version
86
+
87
+ - Added calibration factor to match full-tile inference values
88
+ - Improved chunk processing with overlap to reduce edge artifacts
89
+ - Enhanced feature calculation for better results
90
+ - Optimized visualization to show the full range of biomass values
91
+ """)
92
+
93
+ # Add a warning if model failed to load
94
+ if self.model is None:
95
+ gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.")
96
+
97
+ # Connect the process button
98
+ generate_btn.click(
99
+ fn=self.predict_biomass,
100
+ inputs=[input_image],
101
+ outputs=[output_image, output_stats]
102
+ )
103
+
104
+ # Connect the sample buttons
105
+ for button, name in sample_buttons:
106
+ button.click(
107
+ fn=lambda path=self.sample_images[name]: self.predict_biomass(path),
108
+ inputs=[],
109
+ outputs=[output_image, output_stats]
110
+ )
111
+
112
+ return interface
113
 
114
+ def launch_app():
115
+ """Launch the Gradio app"""
116
+ try:
117
+ # Create app instance
118
+ app = BiomassPredictorApp()
119
+
120
+ # Create interface
121
+ interface = app.create_interface()
122
+
123
+ # Launch interface
124
+ interface.launch()
125
+ except Exception as e:
126
+ logger.error(f"Error launching app: {e}")
127
+ logger.error(traceback.format_exc())
128
+
129
+ if __name__ == "__main__":
130
+ launch_app()"""
131
+ Biomass Prediction Gradio App with Two Sample Images and RGB Comparison
132
  Author: najahpokkiri
133
+ Date: 2025-05-19
134
+
135
+ Updated with sample image thumbnails and always-on RGB comparison.
136
  """
137
+ import os
138
+ import sys
139
+ import torch
140
  import numpy as np
141
+ import gradio as gr
142
+ import joblib
143
+ import tempfile
144
+ import matplotlib.pyplot as plt
145
+ import matplotlib.colors as colors
146
+ from PIL import Image
147
+ import io
148
  import logging
149
+ from huggingface_hub import hf_hub_download
150
+ import rasterio
151
 
152
  # Configure logger
153
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
154
  logger = logging.getLogger(__name__)
155
 
156
+ # Import model architecture
157
+ from model import StableResNet
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # Define a placeholder for feature engineering if not available
160
+ def extract_all_features(image):
161
+ """
162
+ Extract all 99 features from satellite bands.
163
+ Placeholder function - in production, use the actual feature_engineering module.
164
+ """
165
+ # Get image dimensions
166
+ n_bands, height, width = image.shape
167
 
168
+ # Create a valid mask (non-NaN pixels)
169
+ valid_mask = np.all(np.isfinite(image), axis=0)
 
 
 
 
 
 
170
 
171
+ # Get valid pixel coordinates
172
+ valid_y, valid_x = np.where(valid_mask)
173
+ n_valid = len(valid_y)
 
 
 
 
 
 
 
 
 
 
174
 
175
+ # Create a feature matrix (placeholder)
176
+ # In a real scenario, these would be spectral indices, texture features, etc.
177
+ # For now, we'll just use the original bands and pad to 99 features
178
 
179
+ # Original bands for each valid pixel
180
+ feature_matrix = np.zeros((n_valid, 99), dtype=np.float32)
 
181
 
182
+ # Fill in the available band values
183
+ for i in range(n_valid):
184
+ y, x = valid_y[i], valid_x[i]
185
+ # Copy available bands
186
+ for b in range(min(n_bands, 99)):
187
+ feature_matrix[i, b] = image[b, y, x]
188
 
189
+ # Create feature names
190
+ generated_features = [f"Band_{i+1}" for i in range(99)]
 
 
 
 
191
 
192
+ return feature_matrix, valid_mask, generated_features
193
 
194
+ class BiomassPredictorApp:
195
+ """Gradio app for biomass prediction from satellite imagery"""
196
+
197
+ def __init__(self, model_repo="pokkiri/biomass-model"):
198
+ """Initialize the app with model repository information"""
199
+ self.model = None
200
+ self.package = None
201
+ self.feature_names = []
202
+ self.model_repo = model_repo
203
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ # Sample image paths
206
+ self.sample_images = {
207
+ "input_chip_1": "input_chip_1.tif",
208
+ "input_chip_2": "input_chip_2.tif"
209
+ }
210
 
211
+ # Cache for storing temporary files
212
+ self.temp_files = []
 
 
 
 
 
 
 
 
213
 
214
+ # Load the model
215
+ self.load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ def load_model(self):
218
+ """Load the model and preprocessing pipeline"""
219
+ try:
220
+ logger.info(f"Loading model from {self.model_repo}")
221
+
222
+ # Download model files from HuggingFace or use local files
223
+ try:
224
+ model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt")
225
+ package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl")
226
+ except Exception as e:
227
+ logger.warning(f"Failed to download from HuggingFace: {e}")
228
+ # Fallback to local files
229
+ model_path = "model.pt"
230
+ package_path = "model_package.pkl"
231
+
232
+ # Try to load package with metadata
233
+ try:
234
+ logger.info(f"Loading package from {package_path}")
235
+ self.package = joblib.load(package_path)
236
+ logger.info("Successfully loaded model package")
237
+
238
+ # Extract information from package
239
+ n_features = self.package['n_features']
240
+ self.feature_names = self.package.get('feature_names', [f"feature_{i}" for i in range(n_features)])
241
+
242
+ logger.info(f"Package keys: {list(self.package.keys())}")
243
+ logger.info(f"Model expects {n_features} features")
244
+ except Exception as e:
245
+ logger.error(f"Error loading package file: {e}")
246
+ # Fallback to default values
247
+ n_features = 99 # We know there are 99 features
248
+ self.feature_names = [f"feature_{i}" for i in range(n_features)]
249
+
250
+ # Create a minimal package with essential components
251
+ self.package = {
252
+ 'n_features': n_features,
253
+ 'use_log_transform': True,
254
+ 'epsilon': 1.0,
255
+ 'scaler': None # Will handle the None case in prediction
256
+ }
257
+
258
+ # Initialize model
259
+ self.model = StableResNet(n_features=n_features)
260
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
261
+ self.model.to(self.device)
262
+ self.model.eval()
263
+
264
+ logger.info(f"Model loaded successfully")
265
+ logger.info(f"Number of features: {n_features}")
266
+ logger.info(f"Using device: {self.device}")
267
+
268
+ return True
269
+ except Exception as e:
270
+ logger.error(f"Error loading model: {e}")
271
+ import traceback
272
+ logger.error(traceback.format_exc())
273
+ return False
274
+
275
+ def cleanup(self):
276
+ """Clean up temporary files"""
277
+ for tmp_path in self.temp_files:
278
+ try:
279
+ if os.path.exists(tmp_path):
280
+ os.unlink(tmp_path)
281
+ except Exception as e:
282
+ logger.warning(f"Failed to remove temporary file {tmp_path}: {e}")
283
 
284
+ self.temp_files = []
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ def create_thumbnail(self, image_path, max_size=(200, 200), output_format="PNG"):
287
+ """Create a thumbnail image from a GeoTIFF"""
288
+ try:
289
+ if not os.path.exists(image_path):
290
+ logger.warning(f"Image file not found: {image_path}")
291
+ return None
292
+
293
+ # Open the GeoTIFF
294
+ with rasterio.open(image_path) as src:
295
+ # Read data with RGB bands if available
296
+ if src.count >= 3:
297
+ # Use first three bands as RGB
298
+ rgb_data = src.read([1, 2, 3])
299
+
300
+ # Transpose from (bands, height, width) to (height, width, bands)
301
+ rgb_data = np.transpose(rgb_data, (1, 2, 0))
302
+
303
+ # Normalize to 0-255 range
304
+ rgb_data = np.clip(rgb_data, 0, None) # Clip negative values
305
+ for i in range(3):
306
+ p2 = np.percentile(rgb_data[:,:,i], 2)
307
+ p98 = np.percentile(rgb_data[:,:,i], 98)
308
+ if p98 > p2:
309
+ rgb_data[:,:,i] = np.clip((rgb_data[:,:,i] - p2) / (p98 - p2) * 255, 0, 255)
310
+ else:
311
+ rgb_data[:,:,i] = np.clip(rgb_data[:,:,i] / (rgb_data[:,:,i].max() or 1) * 255, 0, 255)
312
+
313
+ # Convert to uint8
314
+ rgb_data = rgb_data.astype(np.uint8)
315
+
316
+ # Create PIL image
317
+ img = Image.fromarray(rgb_data)
318
+ else:
319
+ # Use first band as grayscale
320
+ gray_data = src.read(1)
321
+
322
+ # Normalize to 0-255 range
323
+ p2 = np.percentile(gray_data, 2)
324
+ p98 = np.percentile(gray_data, 98)
325
+ if p98 > p2:
326
+ gray_data = np.clip((gray_data - p2) / (p98 - p2) * 255, 0, 255)
327
+ else:
328
+ gray_data = np.clip(gray_data / (gray_data.max() or 1) * 255, 0, 255)
329
+
330
+ # Convert to uint8
331
+ gray_data = gray_data.astype(np.uint8)
332
+
333
+ # Create PIL image
334
+ img = Image.fromarray(gray_data, mode='L')
335
+
336
+ # Resize to thumbnail
337
+ img.thumbnail(max_size)
338
+
339
+ # Save to bytes buffer
340
+ buf = io.BytesIO()
341
+ img.save(buf, format=output_format)
342
+ buf.seek(0)
343
+
344
+ return buf
345
+ except Exception as e:
346
+ logger.error(f"Error creating thumbnail: {e}")
347
+ return None
348
+
349
+ def predict_biomass(self, image_file):
350
+ """Predict biomass from a satellite image with RGB comparison"""
351
+ if self.model is None:
352
+ return None, "Error: Model not loaded. Please check logs for details."
353
 
354
+ if image_file is None:
355
+ return None, "Error: No file uploaded. Please upload a GeoTIFF file or use one of the sample images."
 
356
 
357
+ try:
358
+ # Check if we're using a sample image (string path) or an uploaded file
359
+ if isinstance(image_file, str):
360
+ logger.info(f"Using sample image: {image_file}")
361
+ tmp_path = image_file # Use the sample path directly
362
+ cleanup_tmp = False # Don't delete the sample file
363
+ else:
364
+ # Create a temporary file to save the uploaded file
365
+ with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file:
366
+ tmp_path = tmp_file.name
367
+ with open(image_file.name, 'rb') as f:
368
+ tmp_file.write(f.read())
369
+
370
+ # Add to list for cleanup later
371
+ self.temp_files.append(tmp_path)
372
+ cleanup_tmp = True
373
+
374
+ # Open the image file
375
+ with rasterio.open(tmp_path) as src:
376
+ image = src.read()
377
+ height, width = image.shape[1], image.shape[2]
378
+ transform = src.transform
379
+ crs = src.crs
380
+
381
+ logger.info(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands")
382
+
383
+ # Validate minimum band count
384
+ if image.shape[0] < 3:
385
+ return None, f"Error: Image has only {image.shape[0]} bands. At least 3 bands are required for RGB visualization."
386
+
387
+ # Generate all features using feature engineering
388
+ logger.info("Generating all 99 features from bands...")
389
+ feature_matrix, valid_mask, generated_features = extract_all_features(image)
390
+
391
+ # Verify we have exactly 99 features
392
+ if feature_matrix.shape[1] != 99:
393
+ logger.error(f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99.")
394
+ return None, f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99."
395
+
396
+ # Apply feature scaling if available
397
+ try:
398
+ if 'scaler' in self.package and self.package['scaler'] is not None:
399
+ logger.info("Applying feature scaling...")
400
+ feature_matrix = self.package['scaler'].transform(feature_matrix)
401
+ except Exception as e:
402
+ logger.warning(f"Error applying scaler: {e}. Using original features.")
403
+
404
+ # Initialize predictions array
405
+ predictions = np.zeros((height, width), dtype=np.float32)
406
+
407
+ # Get valid pixel coordinates
408
+ valid_y, valid_x = np.where(valid_mask)
409
+
410
+ # Make predictions
411
+ logger.info(f"Running model inference on {len(valid_y)} valid pixels...")
412
+ with torch.no_grad():
413
+ # Process in batches to avoid memory issues
414
+ batch_size = 10000
415
+ for i in range(0, len(valid_y), batch_size):
416
+ end_idx = min(i + batch_size, len(valid_y))
417
+ batch = feature_matrix[i:end_idx]
418
+
419
+ # Convert to tensor
420
+ batch_tensor = torch.tensor(batch, dtype=torch.float32).to(self.device)
421
+
422
+ # Get predictions
423
+ batch_predictions = self.model(batch_tensor).cpu().numpy()
424
+
425
+ # Handle scalar case for single-item batches
426
+ if batch_predictions.ndim == 0:
427
+ batch_predictions = np.array([batch_predictions])
428
+
429
+ # Convert from log scale if needed
430
+ if self.package.get('use_log_transform', True):
431
+ epsilon = self.package.get('epsilon', 1.0)
432
+ batch_predictions = np.exp(batch_predictions) - epsilon
433
+ batch_predictions = np.maximum(batch_predictions, 0) # Ensure non-negative
434
+
435
+ # Map predictions back to image
436
+ for j, pred in enumerate(batch_predictions):
437
+ y_idx = valid_y[i + j]
438
+ x_idx = valid_x[i + j]
439
+ predictions[y_idx, x_idx] = pred
440
+
441
+ # Log progress
442
+ if (i // batch_size) % 5 == 0 or end_idx == len(valid_y):
443
+ logger.info(f"Processed {end_idx}/{len(valid_y)} pixels")
444
+
445
+ # Create visualization - always RGB+Biomass side-by-side
446
+ logger.info("Creating RGB + Biomass visualization...")
447
+
448
+ # Create side-by-side comparison (RGB and Biomass)
449
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
450
+
451
+ # Prepare RGB image - try different band combinations if needed
452
+ rgb_bands = [3, 2, 1] # Common RGB combination (R,G,B)
453
+
454
+ # Check if we have enough bands for RGB
455
+ if image.shape[0] < 3:
456
+ logger.warning(f"Image has only {image.shape[0]} bands, using available bands for display")
457
+ rgb_bands = list(range(min(3, image.shape[0])))
458
+ while len(rgb_bands) < 3:
459
+ rgb_bands.append(0) # Pad with zeros if needed
460
+
461
+ # Create RGB image
462
+ rgb = np.zeros((height, width, 3), dtype=np.float32)
463
+ for i, band_idx in enumerate(rgb_bands):
464
+ if band_idx < image.shape[0]:
465
+ rgb[:, :, i] = image[band_idx]
466
+
467
+ # Handle potential NaN values
468
+ rgb = np.nan_to_num(rgb)
469
+
470
+ # Enhance contrast with percentile-based normalization
471
+ for i in range(3):
472
+ p2 = np.percentile(rgb[:,:,i], 2)
473
+ p98 = np.percentile(rgb[:,:,i], 98)
474
+ if p98 > p2:
475
+ rgb[:,:,i] = np.clip((rgb[:,:,i] - p2) / (p98 - p2), 0, 1)
476
+
477
+ # Display RGB image
478
+ ax1.imshow(rgb)
479
+ ax1.set_title('RGB Image')
480
+ ax1.axis('off')
481
+
482
+ # Display biomass prediction
483
+ masked_predictions = np.ma.masked_where(~valid_mask, predictions)
484
+ vmin = np.percentile(predictions[valid_mask], 1)
485
+ vmax = np.percentile(predictions[valid_mask], 99)
486
+
487
+ im = ax2.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
488
+ fig.colorbar(im, ax=ax2, label='Biomass (Mg/ha)')
489
+ ax2.set_title('Predicted Biomass')
490
+ ax2.axis('off')
491
+
492
+ # Add super title
493
+ plt.suptitle('RGB Image and Biomass Prediction', fontsize=16)
494
+ plt.tight_layout()
495
+
496
+ # Save figure to bytes buffer
497
+ buf = io.BytesIO()
498
+ fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
499
+ buf.seek(0)
500
+ plt.close(fig)
501
+
502
+ # Calculate summary statistics
503
+ valid_predictions = predictions[valid_mask]
504
+ stats = {
505
+ 'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha",
506
+ 'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha",
507
+ 'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha",
508
+ 'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha"
509
+ }
510
+
511
+ # Add area and total biomass if transform is available
512
+ if transform is not None:
513
+ pixel_area_m2 = abs(transform[0] * transform[4]) # Assuming square pixels
514
+ total_biomass = np.sum(valid_predictions) * (pixel_area_m2 / 10000) # Convert to hectares
515
+ area_hectares = np.sum(valid_mask) * (pixel_area_m2 / 10000)
516
+
517
+ stats['Total Biomass'] = f"{total_biomass:.2f} Mg"
518
+ stats['Area'] = f"{area_hectares:.2f} hectares"
519
+
520
+ # Format statistics as markdown
521
+ stats_md = "### Biomass Statistics\n\n"
522
+ stats_md += "| Metric | Value |\n|--------|-------|\n"
523
+ for k, v in stats.items():
524
+ stats_md += f"| {k} | {v} |\n"
525
+
526
+ # Add processing info
527
+ stats_md += f"\n\n*Processed {np.sum(valid_mask):,} valid pixels with {feature_matrix.shape[1]} features*"
528
+
529
+ # Cleanup temporary files if needed
530
+ if cleanup_tmp:
531
+ self.cleanup()
532
+
533
+ # Return visualization and statistics
534
+ return Image.open(buf), stats_md
535
+
536
+ except Exception as e:
537
+ # Ensure cleanup even on error
538
+ self.cleanup()
539
+
540
+ import traceback
541
+ logger.error(f"Error predicting biomass: {e}")
542
+ logger.error(traceback.format_exc())
543
+
544
+ return None, f"Error predicting biomass: {str(e)}\n\nPlease check logs for details."
545
+
546
+ def create_interface(self):
547
+ """Create Gradio interface with sample image thumbnails"""
548
+ # Generate thumbnails for sample images
549
+ sample_thumbnails = {}
550
+ for name, path in self.sample_images.items():
551
+ if os.path.exists(path):
552
+ thumbnail = self.create_thumbnail(path)
553
+ if thumbnail:
554
+ sample_thumbnails[name] = Image.open(thumbnail)
555
+ else:
556
+ logger.warning(f"Sample image not found: {path}")
557
 
558
+ with gr.Blocks(title="Biomass Prediction Model") as interface:
559
+ gr.Markdown("# Above-Ground Biomass Prediction")
560
+ gr.Markdown("""
561
+ Upload a multi-band satellite image to predict above-ground biomass (AGB) across the landscape.
562
+
563
+ **Requirements:**
564
+ - Image must be a GeoTIFF with spectral bands
565
+ - For best results, image should contain at least 3 bands
566
+ """)
567
+
568
+ with gr.Row():
569
+ with gr.Column(scale=1):
570
+ input_image = gr.File(
571
+ label="Upload Satellite Image (GeoTIFF)",
572
+ file_types=[".tif", ".tiff"]
573
+ )
574
+
575
+ # Sample images section
576
+ gr.Markdown("### Sample Images")
577
+
578
+ # Sample buttons container
579
+ sample_buttons = []
580
+
581
+ # First row - sample thumbnails side by side horizontally
582
+ with gr.Row():
583
+ for name, thumbnail in sample_thumbnails.items():
584
+ with gr.Column():
585
+ gr.Image(
586
+ value=thumbnail,
587
+ label=name.replace("input_", "Input ").replace("chip_", "Chip "),
588
+ show_download_button=False,
589
+ height=180
590
+ )
591
+
592
+ # Second row - buttons side by side horizontally, matching the thumbnails above
593
+ with gr.Row():
594
+ for name, _ in sample_thumbnails.items():
595
+ with gr.Column():
596
+ sample_btn = gr.Button(
597
+ f"Use {name.replace('input_', 'Input ').replace('chip_', 'Chip ')}",
598
+ variant="secondary",
599
+ size="lg"
600
+ )
601
+ sample_buttons.append((sample_btn, name))
602
+
603
+ # Generate button at the bottom
604
+ generate_btn = gr.Button("Generate Biomass Prediction", variant="primary", size="lg")
605
+
606
+ with gr.Column(scale=2):
607
+ output_image = gr.Image(
608
+ label="Biomass Prediction Map",
609
+ type="pil"
610
+ )
611
+
612
+ output_stats = gr.Markdown(
613
+ label="Statistics"
614
+ )_image = gr.Image(
615
+ label="Biomass Prediction Map",
616
+ type="pil"
617
+ )
618
+
619
+ output_stats = gr.Markdown(
620
+ label="Statistics"
621
+ )
622
+
623
+ # Sample images section with thumbnails in a separate row
624
+ gr.Markdown("### Sample Images")
625
+
626
+ with gr.Row():
627
+ # Only show thumbnails for images that were found
628
+ sample_buttons = []
629
+
630
+ # Create a column for each sample image
631
+ for name, thumbnail in sample_thumbnails.items():
632
+ with gr.Column():
633
+ gr.Image(value=thumbnail, label=name.replace("input_", "Input ").replace("chip_", "Chip "),
634
+ show_download_button=False, show_label=True, height=200)
635
+ sample_btn = gr.Button(f"Use {name.replace('input_', 'Input ').replace('chip_', 'Chip ')}",
636
+ size="lg", variant="secondary")
637
+ sample_buttons.append((sample_btn, name))
638
+
639
+ with gr.Column(scale=2):
640
+ output_image = gr.Image(
641
+ label="Biomass Prediction Map",
642
+ type="pil"
643
+ )
644
+
645
+ output_stats = gr.Markdown(
646
+ label="Statistics"
647
+ )
648
+
649
+ with gr.Accordion("About", open=False):
650
+ gr.Markdown("""
651
+ ## About This Model
652
+
653
+ This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
654
+
655
+ ### Model Details
656
+
657
+ - Architecture: StableResNet
658
+ - Input: Multi-spectral satellite imagery
659
+ - Output: Above-ground biomass (Mg/ha)
660
+ - Creator: vertify.earth for GIZ Forest Forward
661
+ - Date: 2025-05-19
662
+
663
+ ### How It Works
664
+
665
+ 1. The model extracts features from each pixel in the satellite image
666
+ 2. These features include spectral bands, vegetation indices, texture metrics, and more
667
+ 3. The model outputs a biomass prediction for each pixel
668
+ 4. Results are visualized as RGB and biomass prediction side-by-side
669
+ """)
670
+
671
+ # Add a warning if model failed to load
672
+ if self.model is None:
673
+ gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.")
674
+
675
+ # Connect the process button
676
+ process_btn.click(
677
+ fn=self.predict_biomass,
678
+ inputs=[input_image],
679
+ outputs=[output_image, output_stats]
680
+ )
681
+
682
+ # Connect the sample buttons
683
+ for button, name in sample_buttons:
684
+ button.click(
685
+ fn=lambda path=self.sample_images[name]: self.predict_biomass(path),
686
+ inputs=[],
687
+ outputs=[output_image, output_stats]
688
+ )
689
 
690
+ return interface
691
+
692
+ def launch_app():
693
+ """Launch the Gradio app"""
694
+ try:
695
+ # Create app instance
696
+ app = BiomassPredictorApp()
697
 
698
+ # Create interface
699
+ interface = app.create_interface()
 
700
 
701
+ # Launch interface
702
+ interface.launch()
703
  except Exception as e:
704
+ logger.error(f"Error launching app: {e}")
705
+ import traceback
706
+ logger.error(traceback.format_exc())
 
 
 
707
 
708
+ if __name__ == "__main__":
709
+ launch_app()