""" Inference code for StableResNet Biomass Prediction Model Provides utility functions for making predictions with the model Author: najahpokkiri Date: 2025-05-17 """ import os import torch import numpy as np import joblib from model import StableResNet from huggingface_hub import hf_hub_download def load_model_from_hub(repo_id="najahpokkiri/biomass-model"): """Load model from HuggingFace Hub""" # Download files from HuggingFace model_path = hf_hub_download(repo_id=repo_id, filename="model.pt") package_path = hf_hub_download(repo_id=repo_id, filename="model_package.pkl") # Load package with metadata package = joblib.load(package_path) n_features = package['n_features'] # Initialize model model = StableResNet(n_features=n_features) model.load_state_dict(torch.load(model_path, map_location='cpu')) model.eval() return model, package def load_model_local(model_path, package_path): """Load model from local files""" # Load package with metadata package = joblib.load(package_path) n_features = package['n_features'] # Initialize model model = StableResNet(n_features=n_features) model.load_state_dict(torch.load(model_path, map_location='cpu')) model.eval() return model, package def predict_biomass(model, features, package): """Predict biomass from feature array""" # Get metadata scaler = package['scaler'] use_log_transform = package['use_log_transform'] epsilon = package.get('epsilon', 1.0) # Scale features features_scaled = scaler.transform(features) # Convert to tensor tensor = torch.tensor(features_scaled, dtype=torch.float32) # Make prediction with torch.no_grad(): output = model(tensor).numpy() # Convert from log scale if needed if use_log_transform: output = np.exp(output) - epsilon output = np.maximum(output, 0) # Ensure non-negative return output def predict_from_geotiff(tiff_path, output_path=None, model=None, package=None, repo_id="najahpokkiri/biomass-model"): """Predict biomass from a GeoTIFF file""" try: import rasterio except ImportError: raise ImportError("rasterio is required for GeoTIFF processing. Install with 'pip install rasterio'.") # Load model if not provided if model is None or package is None: model, package = load_model_from_hub(repo_id) with rasterio.open(tiff_path) as src: # Read image data data = src.read() height, width = data.shape[1], data.shape[2] transform = src.transform crs = src.crs # Predict in chunks chunk_size = 1000 predictions = np.zeros((height, width), dtype=np.float32) # Create mask for valid pixels valid_mask = np.all(np.isfinite(data), axis=0) # Process image in chunks for y_start in range(0, height, chunk_size): y_end = min(y_start + chunk_size, height) for x_start in range(0, width, chunk_size): x_end = min(x_start + chunk_size, width) # Get chunk mask chunk_mask = valid_mask[y_start:y_end, x_start:x_end] if not np.any(chunk_mask): continue # Extract valid pixels valid_y, valid_x = np.where(chunk_mask) # Extract features pixel_features = [] for i, j in zip(valid_y, valid_x): pixel_values = data[:, y_start+i, x_start+j] pixel_features.append(pixel_values) # Make predictions pixel_features = np.array(pixel_features) batch_predictions = predict_biomass(model, pixel_features, package) # Insert predictions back into the image for idx, (i, j) in enumerate(zip(valid_y, valid_x)): predictions[y_start+i, x_start+j] = batch_predictions[idx] # Save predictions if output path is provided if output_path: meta = src.meta.copy() meta.update( dtype='float32', count=1, nodata=0 ) with rasterio.open(output_path, 'w', **meta) as dst: dst.write(predictions, 1) print(f"Saved biomass predictions to: {output_path}") return predictions def example(): """Example usage""" print("StableResNet Biomass Prediction Example") print("-" * 40) # Option 1: Load from HuggingFace Hub print("Loading model from HuggingFace Hub...") model, package = load_model_from_hub("najahpokkiri/biomass-model") # Option 2: Load from local files # model, package = load_model_local("model.pt", "model_package.pkl") print(f"Model loaded. Expecting {package['n_features']} features") # Example: Create synthetic features for demonstration n_features = package['n_features'] example_features = np.random.rand(5, n_features) print("\nPredicting biomass for 5 sample points...") predictions = predict_biomass(model, example_features, package) for i, pred in enumerate(predictions): print(f"Sample {i+1}: {pred:.2f} Mg/ha") print("\nTo process a GeoTIFF file:") print("predictions = predict_from_geotiff('your_image.tif', 'output_biomass.tif')") if __name__ == "__main__": example()