File size: 5,671 Bytes
519a27e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""
Inference code for StableResNet Biomass Prediction Model
Provides utility functions for making predictions with the model
Author: najahpokkiri
Date: 2025-05-17
"""
import os
import torch
import numpy as np
import joblib
from model import StableResNet
from huggingface_hub import hf_hub_download
def load_model_from_hub(repo_id="najahpokkiri/biomass-model"):
"""Load model from HuggingFace Hub"""
# Download files from HuggingFace
model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
package_path = hf_hub_download(repo_id=repo_id, filename="model_package.pkl")
# Load package with metadata
package = joblib.load(package_path)
n_features = package['n_features']
# Initialize model
model = StableResNet(n_features=n_features)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
return model, package
def load_model_local(model_path, package_path):
"""Load model from local files"""
# Load package with metadata
package = joblib.load(package_path)
n_features = package['n_features']
# Initialize model
model = StableResNet(n_features=n_features)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
return model, package
def predict_biomass(model, features, package):
"""Predict biomass from feature array"""
# Get metadata
scaler = package['scaler']
use_log_transform = package['use_log_transform']
epsilon = package.get('epsilon', 1.0)
# Scale features
features_scaled = scaler.transform(features)
# Convert to tensor
tensor = torch.tensor(features_scaled, dtype=torch.float32)
# Make prediction
with torch.no_grad():
output = model(tensor).numpy()
# Convert from log scale if needed
if use_log_transform:
output = np.exp(output) - epsilon
output = np.maximum(output, 0) # Ensure non-negative
return output
def predict_from_geotiff(tiff_path, output_path=None, model=None, package=None, repo_id="najahpokkiri/biomass-model"):
"""Predict biomass from a GeoTIFF file"""
try:
import rasterio
except ImportError:
raise ImportError("rasterio is required for GeoTIFF processing. Install with 'pip install rasterio'.")
# Load model if not provided
if model is None or package is None:
model, package = load_model_from_hub(repo_id)
with rasterio.open(tiff_path) as src:
# Read image data
data = src.read()
height, width = data.shape[1], data.shape[2]
transform = src.transform
crs = src.crs
# Predict in chunks
chunk_size = 1000
predictions = np.zeros((height, width), dtype=np.float32)
# Create mask for valid pixels
valid_mask = np.all(np.isfinite(data), axis=0)
# Process image in chunks
for y_start in range(0, height, chunk_size):
y_end = min(y_start + chunk_size, height)
for x_start in range(0, width, chunk_size):
x_end = min(x_start + chunk_size, width)
# Get chunk mask
chunk_mask = valid_mask[y_start:y_end, x_start:x_end]
if not np.any(chunk_mask):
continue
# Extract valid pixels
valid_y, valid_x = np.where(chunk_mask)
# Extract features
pixel_features = []
for i, j in zip(valid_y, valid_x):
pixel_values = data[:, y_start+i, x_start+j]
pixel_features.append(pixel_values)
# Make predictions
pixel_features = np.array(pixel_features)
batch_predictions = predict_biomass(model, pixel_features, package)
# Insert predictions back into the image
for idx, (i, j) in enumerate(zip(valid_y, valid_x)):
predictions[y_start+i, x_start+j] = batch_predictions[idx]
# Save predictions if output path is provided
if output_path:
meta = src.meta.copy()
meta.update(
dtype='float32',
count=1,
nodata=0
)
with rasterio.open(output_path, 'w', **meta) as dst:
dst.write(predictions, 1)
print(f"Saved biomass predictions to: {output_path}")
return predictions
def example():
"""Example usage"""
print("StableResNet Biomass Prediction Example")
print("-" * 40)
# Option 1: Load from HuggingFace Hub
print("Loading model from HuggingFace Hub...")
model, package = load_model_from_hub("najahpokkiri/biomass-model")
# Option 2: Load from local files
# model, package = load_model_local("model.pt", "model_package.pkl")
print(f"Model loaded. Expecting {package['n_features']} features")
# Example: Create synthetic features for demonstration
n_features = package['n_features']
example_features = np.random.rand(5, n_features)
print("\nPredicting biomass for 5 sample points...")
predictions = predict_biomass(model, example_features, package)
for i, pred in enumerate(predictions):
print(f"Sample {i+1}: {pred:.2f} Mg/ha")
print("\nTo process a GeoTIFF file:")
print("predictions = predict_from_geotiff('your_image.tif', 'output_biomass.tif')")
if __name__ == "__main__":
example() |