Spaces:
Running
Running
| """ | |
| Biomass Prediction Gradio App with Exact 99 Features | |
| Author: najahpokkiri | |
| Date: 2025-05-19 | |
| Updated with side-by-side RGB comparison, fixed sample image loading, and corrected biomass calculation. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import joblib | |
| import tempfile | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as colors | |
| from PIL import Image | |
| import io | |
| import logging | |
| from huggingface_hub import hf_hub_download | |
| # Configure logger | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Import model architecture | |
| from model import StableResNet | |
| # Import feature engineering | |
| from feature_engineering import extract_all_features | |
| # Import config - this must happen before loading model_package.pkl | |
| try: | |
| from config import BiomassPipelineConfig | |
| logger.info("Successfully imported config.BiomassPipelineConfig") | |
| except ImportError as e: | |
| logger.error(f"Failed to import config.BiomassPipelineConfig: {e}") | |
| logger.error("This will likely cause errors when loading the model package") | |
| class BiomassPredictorApp: | |
| """Gradio app for biomass prediction from satellite imagery""" | |
| def __init__(self, model_repo="pokkiri/biomass-model"): | |
| """Initialize the app with model repository information""" | |
| self.model = None | |
| self.package = None | |
| self.feature_names = [] | |
| self.model_repo = model_repo | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Cache for storing temporary files | |
| self.temp_files = [] | |
| # Load the model | |
| self.load_model() | |
| def load_model(self): | |
| """Load the model and preprocessing pipeline from HuggingFace Hub""" | |
| try: | |
| logger.info(f"Loading model from {self.model_repo}") | |
| # Download model files from HuggingFace | |
| model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt") | |
| package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl") | |
| try: | |
| # Try to load package with metadata | |
| logger.info(f"Loading package from {package_path}") | |
| self.package = joblib.load(package_path) | |
| logger.info("Successfully loaded model package") | |
| # Extract information from package | |
| n_features = self.package['n_features'] | |
| self.feature_names = self.package.get('feature_names', [f"feature_{i}" for i in range(n_features)]) | |
| logger.info(f"Package keys: {list(self.package.keys())}") | |
| logger.info(f"Model expects {n_features} features") | |
| # Verify feature count is 99 | |
| if n_features != 99: | |
| logger.warning(f"Warning: Model expects {n_features} features, not the expected 99. This may cause issues.") | |
| except Exception as e: | |
| logger.error(f"Error loading package file: {e}") | |
| # Fallback to default values | |
| n_features = 99 # We know there are 99 features | |
| self.feature_names = [f"feature_{i}" for i in range(n_features)] | |
| # Create a minimal package with essential components | |
| self.package = { | |
| 'n_features': n_features, | |
| 'use_log_transform': True, | |
| 'epsilon': 1.0, | |
| 'scaler': None # Will handle the None case in prediction | |
| } | |
| # Initialize model | |
| self.model = StableResNet(n_features=n_features) | |
| self.model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info(f"Model loaded successfully from {self.model_repo}") | |
| logger.info(f"Number of features: {n_features}") | |
| logger.info(f"Using device: {self.device}") | |
| logger.info(f"Log transform: {self.package.get('use_log_transform', True)}") | |
| logger.info(f"Epsilon: {self.package.get('epsilon', 1.0)}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return False | |
| def cleanup(self): | |
| """Clean up temporary files""" | |
| for tmp_path in self.temp_files: | |
| try: | |
| if os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to remove temporary file {tmp_path}: {e}") | |
| self.temp_files = [] | |
| def load_sample_image(self): | |
| """Load the sample image and return a file-like object""" | |
| try: | |
| sample_path = "input_chip_1.tif" | |
| if os.path.exists(sample_path): | |
| logger.info(f"Loading sample image from {sample_path}") | |
| return sample_path | |
| else: | |
| logger.warning(f"Sample image not found at {sample_path}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error loading sample image: {e}") | |
| return None | |
| def predict_biomass(self, image_file, display_type="heatmap"): | |
| """Predict biomass from a satellite image""" | |
| if self.model is None: | |
| return None, "Error: Model not loaded. Please check logs for details." | |
| if image_file is None: | |
| return None, "Error: No file uploaded. Please upload a GeoTIFF file or use the sample image." | |
| try: | |
| # Check if we're using the sample image (string path) or an uploaded file | |
| if isinstance(image_file, str): | |
| logger.info(f"Using sample image: {image_file}") | |
| tmp_path = image_file # Use the sample path directly | |
| cleanup_tmp = False # Don't delete the sample file | |
| else: | |
| # Create a temporary file to save the uploaded file | |
| with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file: | |
| tmp_path = tmp_file.name | |
| with open(image_file.name, 'rb') as f: | |
| tmp_file.write(f.read()) | |
| # Add to list for cleanup later | |
| self.temp_files.append(tmp_path) | |
| cleanup_tmp = True | |
| # Ensure rasterio is available | |
| try: | |
| import rasterio | |
| except ImportError: | |
| return None, "Error: rasterio is required but not installed. Please install with: pip install rasterio" | |
| # Open the image file | |
| with rasterio.open(tmp_path) as src: | |
| image = src.read() | |
| height, width = image.shape[1], image.shape[2] | |
| transform = src.transform | |
| crs = src.crs | |
| # Check if we need to limit to 59 bands | |
| if image.shape[0] > 59: | |
| logger.info(f"Image has {image.shape[0]} bands, selecting first 59 for model compatibility") | |
| image = image[:59, :, :] | |
| logger.info(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands") | |
| # Validate minimum band count | |
| if image.shape[0] < 1: | |
| return None, f"Error: Image has no bands. Please use multi-band satellite imagery." | |
| # Generate all features using feature engineering | |
| logger.info("Generating all 99 features from bands...") | |
| feature_matrix, valid_mask, generated_features = extract_all_features(image) | |
| # Print basic feature statistics for debugging | |
| logger.info(f"Feature statistics - Min: {np.min(feature_matrix, axis=0)[:5]}, " + | |
| f"Max: {np.max(feature_matrix, axis=0)[:5]}, " + | |
| f"Mean: {np.mean(feature_matrix, axis=0)[:5]}") | |
| # Verify we have exactly 99 features | |
| if feature_matrix.shape[1] != 99: | |
| logger.error(f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99.") | |
| return None, f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99." | |
| # Apply feature scaling if available | |
| try: | |
| if 'scaler' in self.package and self.package['scaler'] is not None: | |
| logger.info("Applying feature scaling...") | |
| feature_matrix = self.package['scaler'].transform(feature_matrix) | |
| logger.info("Scaling complete") | |
| logger.info(f"After scaling - Min: {np.min(feature_matrix, axis=0)[:5]}, " + | |
| f"Max: {np.max(feature_matrix, axis=0)[:5]}") | |
| except Exception as e: | |
| logger.warning(f"Error applying scaler: {e}. Using original features.") | |
| # Initialize predictions array | |
| predictions = np.zeros((height, width), dtype=np.float32) | |
| # Get valid pixel coordinates | |
| valid_y, valid_x = np.where(valid_mask) | |
| # Make predictions | |
| logger.info(f"Running model inference on {len(valid_y)} valid pixels...") | |
| with torch.no_grad(): | |
| # Process in batches to avoid memory issues | |
| batch_size = 10000 | |
| for i in range(0, len(valid_y), batch_size): | |
| end_idx = min(i + batch_size, len(valid_y)) | |
| batch = feature_matrix[i:end_idx] | |
| # Convert to tensor | |
| batch_tensor = torch.tensor(batch, dtype=torch.float32).to(self.device) | |
| # Get predictions | |
| batch_predictions = self.model(batch_tensor).cpu().numpy() | |
| # Handle scalar case for single-item batches | |
| if batch_predictions.ndim == 0: | |
| batch_predictions = np.array([batch_predictions]) | |
| # Log raw predictions | |
| if i == 0: | |
| logger.info(f"Raw prediction sample: {batch_predictions[:5]}") | |
| # Fix: Correct log transform reversal | |
| if self.package.get('use_log_transform', True): | |
| # Get epsilon value, default to 1.0 | |
| epsilon = self.package.get('epsilon', 1.0) | |
| # Log transform should be exp(x) - epsilon | |
| batch_predictions = np.exp(batch_predictions) | |
| # Only subtract epsilon if it's not zero or close to zero | |
| if abs(epsilon) > 1e-10: | |
| batch_predictions = batch_predictions - epsilon | |
| # Ensure non-negative | |
| batch_predictions = np.maximum(batch_predictions, 0) | |
| # Log transformed predictions | |
| if i == 0: | |
| logger.info(f"Transformed prediction sample: {batch_predictions[:5]}") | |
| logger.info(f"Using log transform: {self.package.get('use_log_transform', True)}, " + | |
| f"epsilon: {self.package.get('epsilon', 1.0)}") | |
| # Map predictions back to image | |
| for j, pred in enumerate(batch_predictions): | |
| y_idx = valid_y[i + j] | |
| x_idx = valid_x[i + j] | |
| predictions[y_idx, x_idx] = pred | |
| # Log progress | |
| if (i // batch_size) % 5 == 0 or end_idx == len(valid_y): | |
| logger.info(f"Processed {end_idx}/{len(valid_y)} pixels") | |
| # Calculate and log prediction statistics | |
| valid_predictions = predictions[valid_mask] | |
| logger.info(f"Prediction statistics - Min: {np.min(valid_predictions):.2f}, " + | |
| f"Max: {np.max(valid_predictions):.2f}, " + | |
| f"Mean: {np.mean(valid_predictions):.2f}, " + | |
| f"Median: {np.median(valid_predictions):.2f}") | |
| # Create visualization | |
| logger.info("Creating visualization...") | |
| if display_type == "heatmap": | |
| # Create heatmap visualization | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| # Use masked array for better visualization | |
| masked_predictions = np.ma.masked_where(~valid_mask, predictions) | |
| # Set min/max values based on percentiles for better contrast | |
| vmin = np.percentile(predictions[valid_mask], 1) | |
| vmax = np.percentile(predictions[valid_mask], 99) | |
| im = ax.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax) | |
| fig.colorbar(im, ax=ax, label='Biomass (Mg/ha)') | |
| ax.set_title('Predicted Above-Ground Biomass') | |
| ax.axis('off') # Hide axes for cleaner visualization | |
| elif display_type == "rgb_overlay": | |
| # Create side-by-side comparison (RGB and Biomass) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) | |
| # Prepare RGB image using bands 4,3,2 (0-indexed: 3,2,1) | |
| rgb_bands = [3, 2, 1] # Using 4,3,2 for RGB (0-indexed) | |
| if image.shape[0] >= 5: # Ensure we have enough bands (need at least 5 for 0-indexed band 4) | |
| # Create RGB image | |
| rgb = np.zeros((height, width, 3), dtype=np.float32) | |
| for i, band_idx in enumerate(rgb_bands): | |
| if band_idx < image.shape[0]: | |
| rgb[:, :, i] = image[band_idx] | |
| # Handle potential NaN values | |
| rgb = np.nan_to_num(rgb) | |
| # Enhance contrast with percentile-based normalization | |
| for i in range(3): | |
| p2 = np.percentile(rgb[:,:,i], 2) | |
| p98 = np.percentile(rgb[:,:,i], 98) | |
| if p98 > p2: | |
| rgb[:,:,i] = np.clip((rgb[:,:,i] - p2) / (p98 - p2), 0, 1) | |
| # Display RGB image | |
| ax1.imshow(rgb) | |
| ax1.set_title('RGB Image (Bands 4,3,2)') | |
| ax1.axis('off') | |
| # Display biomass prediction | |
| masked_predictions = np.ma.masked_where(~valid_mask, predictions) | |
| vmin = np.percentile(predictions[valid_mask], 1) | |
| vmax = np.percentile(predictions[valid_mask], 99) | |
| im = ax2.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax) | |
| fig.colorbar(im, ax=ax2, label='Biomass (Mg/ha)') | |
| ax2.set_title('Predicted Biomass') | |
| ax2.axis('off') | |
| # Add super title | |
| plt.suptitle('RGB Image and Biomass Prediction', fontsize=16) | |
| plt.tight_layout() | |
| else: | |
| # Fallback to heatmap if not enough bands | |
| logger.warning(f"Not enough bands for RGB display (need 5, got {image.shape[0]}). Showing biomass only.") | |
| masked_predictions = np.ma.masked_where(~valid_mask, predictions) | |
| im = ax1.imshow(masked_predictions, cmap='viridis') | |
| fig.colorbar(im, ax=ax1, label='Biomass (Mg/ha)') | |
| ax1.set_title('Predicted Above-Ground Biomass') | |
| ax1.axis('off') | |
| # Save figure to bytes buffer | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close(fig) | |
| # Calculate summary statistics | |
| valid_predictions = predictions[valid_mask] | |
| stats = { | |
| 'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha", | |
| 'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha", | |
| 'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha", | |
| 'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha" | |
| } | |
| # Add area and total biomass if transform is available | |
| if transform is not None: | |
| pixel_area_m2 = abs(transform[0] * transform[4]) # Assuming square pixels | |
| total_biomass = np.sum(valid_predictions) * (pixel_area_m2 / 10000) # Convert to hectares | |
| area_hectares = np.sum(valid_mask) * (pixel_area_m2 / 10000) | |
| stats['Total Biomass'] = f"{total_biomass:.2f} Mg" | |
| stats['Area'] = f"{area_hectares:.2f} hectares" | |
| # Format statistics as markdown | |
| stats_md = "### Biomass Statistics\n\n" | |
| stats_md += "| Metric | Value |\n|--------|-------|\n" | |
| for k, v in stats.items(): | |
| stats_md += f"| {k} | {v} |\n" | |
| # Add processing info | |
| stats_md += f"\n\n*Processed {np.sum(valid_mask):,} valid pixels with {feature_matrix.shape[1]} features*" | |
| # Cleanup temporary files if needed | |
| if cleanup_tmp: | |
| self.cleanup() | |
| # Return visualization and statistics | |
| return Image.open(buf), stats_md | |
| except Exception as e: | |
| # Ensure cleanup even on error | |
| self.cleanup() | |
| import traceback | |
| logger.error(f"Error predicting biomass: {e}") | |
| logger.error(traceback.format_exc()) | |
| return None, f"Error predicting biomass: {str(e)}\n\nPlease check logs for details." | |
| def create_interface(self): | |
| """Create Gradio interface""" | |
| with gr.Blocks(title="Biomass Prediction Model") as interface: | |
| gr.Markdown("# Above-Ground Biomass Prediction") | |
| gr.Markdown(""" | |
| Upload a multi-band satellite image to predict above-ground biomass (AGB) across the landscape. | |
| **Requirements:** | |
| - Image must be a GeoTIFF with spectral bands | |
| - For best results, use imagery with at least 59 bands or similar to training data | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.File( | |
| label="Upload Satellite Image (GeoTIFF)", | |
| file_types=[".tif", ".tiff"] | |
| ) | |
| display_type = gr.Radio( | |
| choices=["heatmap", "rgb_overlay"], | |
| value="heatmap", | |
| label="Display Type" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Generate Biomass Prediction", variant="primary") | |
| sample_btn = gr.Button("Use Sample Image") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image( | |
| label="Biomass Prediction Map", | |
| type="pil" | |
| ) | |
| output_stats = gr.Markdown( | |
| label="Statistics" | |
| ) | |
| with gr.Accordion("About", open=False): | |
| gr.Markdown(""" | |
| ## About This Model | |
| This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery. | |
| ### Model Details | |
| - Architecture: StableResNet | |
| - Input: Multi-spectral satellite imagery | |
| - Output: Above-ground biomass (Mg/ha) | |
| - Creator: vertify.earth for GIZ Forest Forward | |
| - Date: 2025-05-19 | |
| ### How It Works | |
| 1. The model extracts features from each pixel in the satellite image | |
| 2. These features include spectral bands, vegetation indices, texture metrics, and more | |
| 3. The model outputs a biomass prediction for each pixel | |
| 4. Results are visualized as a heatmap or RGB overlay | |
| ### Updates in This Version | |
| - Fixed biomass value calculation issue (improved log transform handling) | |
| - Added detailed diagnostics for troubleshooting | |
| - Enhanced RGB visualization with band verification | |
| """) | |
| # Add a warning if model failed to load | |
| if self.model is None: | |
| gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.") | |
| # Connect the submit button | |
| submit_btn.click( | |
| fn=self.predict_biomass, | |
| inputs=[input_image, display_type], | |
| outputs=[output_image, output_stats] | |
| ) | |
| # Handle sample image button | |
| def use_sample_image(display_type): | |
| sample_path = self.load_sample_image() | |
| if sample_path is None: | |
| return None, "Error: Sample image not found. Please make sure 'input_chip_1.tif' exists in the app directory." | |
| return self.predict_biomass(sample_path, display_type) | |
| sample_btn.click( | |
| fn=use_sample_image, | |
| inputs=[display_type], | |
| outputs=[output_image, output_stats] | |
| ) | |
| return interface | |
| def launch_app(): | |
| """Launch the Gradio app""" | |
| try: | |
| # Create app instance | |
| app = BiomassPredictorApp() | |
| # Create interface | |
| interface = app.create_interface() | |
| # Launch interface - Important: no share=True in Hugging Face Spaces | |
| interface.launch() | |
| except Exception as e: | |
| logger.error(f"Error launching app: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| if __name__ == "__main__": | |
| launch_app() |