biomass-model / inference.py
pokkiri's picture
Upload folder using huggingface_hub
519a27e verified
"""
Inference code for StableResNet Biomass Prediction Model
Provides utility functions for making predictions with the model
Author: najahpokkiri
Date: 2025-05-17
"""
import os
import torch
import numpy as np
import joblib
from model import StableResNet
from huggingface_hub import hf_hub_download
def load_model_from_hub(repo_id="najahpokkiri/biomass-model"):
"""Load model from HuggingFace Hub"""
# Download files from HuggingFace
model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
package_path = hf_hub_download(repo_id=repo_id, filename="model_package.pkl")
# Load package with metadata
package = joblib.load(package_path)
n_features = package['n_features']
# Initialize model
model = StableResNet(n_features=n_features)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
return model, package
def load_model_local(model_path, package_path):
"""Load model from local files"""
# Load package with metadata
package = joblib.load(package_path)
n_features = package['n_features']
# Initialize model
model = StableResNet(n_features=n_features)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
return model, package
def predict_biomass(model, features, package):
"""Predict biomass from feature array"""
# Get metadata
scaler = package['scaler']
use_log_transform = package['use_log_transform']
epsilon = package.get('epsilon', 1.0)
# Scale features
features_scaled = scaler.transform(features)
# Convert to tensor
tensor = torch.tensor(features_scaled, dtype=torch.float32)
# Make prediction
with torch.no_grad():
output = model(tensor).numpy()
# Convert from log scale if needed
if use_log_transform:
output = np.exp(output) - epsilon
output = np.maximum(output, 0) # Ensure non-negative
return output
def predict_from_geotiff(tiff_path, output_path=None, model=None, package=None, repo_id="najahpokkiri/biomass-model"):
"""Predict biomass from a GeoTIFF file"""
try:
import rasterio
except ImportError:
raise ImportError("rasterio is required for GeoTIFF processing. Install with 'pip install rasterio'.")
# Load model if not provided
if model is None or package is None:
model, package = load_model_from_hub(repo_id)
with rasterio.open(tiff_path) as src:
# Read image data
data = src.read()
height, width = data.shape[1], data.shape[2]
transform = src.transform
crs = src.crs
# Predict in chunks
chunk_size = 1000
predictions = np.zeros((height, width), dtype=np.float32)
# Create mask for valid pixels
valid_mask = np.all(np.isfinite(data), axis=0)
# Process image in chunks
for y_start in range(0, height, chunk_size):
y_end = min(y_start + chunk_size, height)
for x_start in range(0, width, chunk_size):
x_end = min(x_start + chunk_size, width)
# Get chunk mask
chunk_mask = valid_mask[y_start:y_end, x_start:x_end]
if not np.any(chunk_mask):
continue
# Extract valid pixels
valid_y, valid_x = np.where(chunk_mask)
# Extract features
pixel_features = []
for i, j in zip(valid_y, valid_x):
pixel_values = data[:, y_start+i, x_start+j]
pixel_features.append(pixel_values)
# Make predictions
pixel_features = np.array(pixel_features)
batch_predictions = predict_biomass(model, pixel_features, package)
# Insert predictions back into the image
for idx, (i, j) in enumerate(zip(valid_y, valid_x)):
predictions[y_start+i, x_start+j] = batch_predictions[idx]
# Save predictions if output path is provided
if output_path:
meta = src.meta.copy()
meta.update(
dtype='float32',
count=1,
nodata=0
)
with rasterio.open(output_path, 'w', **meta) as dst:
dst.write(predictions, 1)
print(f"Saved biomass predictions to: {output_path}")
return predictions
def example():
"""Example usage"""
print("StableResNet Biomass Prediction Example")
print("-" * 40)
# Option 1: Load from HuggingFace Hub
print("Loading model from HuggingFace Hub...")
model, package = load_model_from_hub("najahpokkiri/biomass-model")
# Option 2: Load from local files
# model, package = load_model_local("model.pt", "model_package.pkl")
print(f"Model loaded. Expecting {package['n_features']} features")
# Example: Create synthetic features for demonstration
n_features = package['n_features']
example_features = np.random.rand(5, n_features)
print("\nPredicting biomass for 5 sample points...")
predictions = predict_biomass(model, example_features, package)
for i, pred in enumerate(predictions):
print(f"Sample {i+1}: {pred:.2f} Mg/ha")
print("\nTo process a GeoTIFF file:")
print("predictions = predict_from_geotiff('your_image.tif', 'output_biomass.tif')")
if __name__ == "__main__":
example()