|
|
""" |
|
|
Inference code for StableResNet Biomass Prediction Model |
|
|
Provides utility functions for making predictions with the model |
|
|
|
|
|
Author: najahpokkiri |
|
|
Date: 2025-05-17 |
|
|
""" |
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import joblib |
|
|
from model import StableResNet |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
def load_model_from_hub(repo_id="najahpokkiri/biomass-model"): |
|
|
"""Load model from HuggingFace Hub""" |
|
|
|
|
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.pt") |
|
|
package_path = hf_hub_download(repo_id=repo_id, filename="model_package.pkl") |
|
|
|
|
|
|
|
|
package = joblib.load(package_path) |
|
|
n_features = package['n_features'] |
|
|
|
|
|
|
|
|
model = StableResNet(n_features=n_features) |
|
|
model.load_state_dict(torch.load(model_path, map_location='cpu')) |
|
|
model.eval() |
|
|
|
|
|
return model, package |
|
|
|
|
|
def load_model_local(model_path, package_path): |
|
|
"""Load model from local files""" |
|
|
|
|
|
package = joblib.load(package_path) |
|
|
n_features = package['n_features'] |
|
|
|
|
|
|
|
|
model = StableResNet(n_features=n_features) |
|
|
model.load_state_dict(torch.load(model_path, map_location='cpu')) |
|
|
model.eval() |
|
|
|
|
|
return model, package |
|
|
|
|
|
def predict_biomass(model, features, package): |
|
|
"""Predict biomass from feature array""" |
|
|
|
|
|
scaler = package['scaler'] |
|
|
use_log_transform = package['use_log_transform'] |
|
|
epsilon = package.get('epsilon', 1.0) |
|
|
|
|
|
|
|
|
features_scaled = scaler.transform(features) |
|
|
|
|
|
|
|
|
tensor = torch.tensor(features_scaled, dtype=torch.float32) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(tensor).numpy() |
|
|
|
|
|
|
|
|
if use_log_transform: |
|
|
output = np.exp(output) - epsilon |
|
|
output = np.maximum(output, 0) |
|
|
|
|
|
return output |
|
|
|
|
|
def predict_from_geotiff(tiff_path, output_path=None, model=None, package=None, repo_id="najahpokkiri/biomass-model"): |
|
|
"""Predict biomass from a GeoTIFF file""" |
|
|
try: |
|
|
import rasterio |
|
|
except ImportError: |
|
|
raise ImportError("rasterio is required for GeoTIFF processing. Install with 'pip install rasterio'.") |
|
|
|
|
|
|
|
|
if model is None or package is None: |
|
|
model, package = load_model_from_hub(repo_id) |
|
|
|
|
|
with rasterio.open(tiff_path) as src: |
|
|
|
|
|
data = src.read() |
|
|
height, width = data.shape[1], data.shape[2] |
|
|
transform = src.transform |
|
|
crs = src.crs |
|
|
|
|
|
|
|
|
chunk_size = 1000 |
|
|
predictions = np.zeros((height, width), dtype=np.float32) |
|
|
|
|
|
|
|
|
valid_mask = np.all(np.isfinite(data), axis=0) |
|
|
|
|
|
|
|
|
for y_start in range(0, height, chunk_size): |
|
|
y_end = min(y_start + chunk_size, height) |
|
|
|
|
|
for x_start in range(0, width, chunk_size): |
|
|
x_end = min(x_start + chunk_size, width) |
|
|
|
|
|
|
|
|
chunk_mask = valid_mask[y_start:y_end, x_start:x_end] |
|
|
if not np.any(chunk_mask): |
|
|
continue |
|
|
|
|
|
|
|
|
valid_y, valid_x = np.where(chunk_mask) |
|
|
|
|
|
|
|
|
pixel_features = [] |
|
|
for i, j in zip(valid_y, valid_x): |
|
|
pixel_values = data[:, y_start+i, x_start+j] |
|
|
pixel_features.append(pixel_values) |
|
|
|
|
|
|
|
|
pixel_features = np.array(pixel_features) |
|
|
batch_predictions = predict_biomass(model, pixel_features, package) |
|
|
|
|
|
|
|
|
for idx, (i, j) in enumerate(zip(valid_y, valid_x)): |
|
|
predictions[y_start+i, x_start+j] = batch_predictions[idx] |
|
|
|
|
|
|
|
|
if output_path: |
|
|
meta = src.meta.copy() |
|
|
meta.update( |
|
|
dtype='float32', |
|
|
count=1, |
|
|
nodata=0 |
|
|
) |
|
|
|
|
|
with rasterio.open(output_path, 'w', **meta) as dst: |
|
|
dst.write(predictions, 1) |
|
|
|
|
|
print(f"Saved biomass predictions to: {output_path}") |
|
|
|
|
|
return predictions |
|
|
|
|
|
def example(): |
|
|
"""Example usage""" |
|
|
print("StableResNet Biomass Prediction Example") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
print("Loading model from HuggingFace Hub...") |
|
|
model, package = load_model_from_hub("najahpokkiri/biomass-model") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Model loaded. Expecting {package['n_features']} features") |
|
|
|
|
|
|
|
|
n_features = package['n_features'] |
|
|
example_features = np.random.rand(5, n_features) |
|
|
|
|
|
print("\nPredicting biomass for 5 sample points...") |
|
|
predictions = predict_biomass(model, example_features, package) |
|
|
|
|
|
for i, pred in enumerate(predictions): |
|
|
print(f"Sample {i+1}: {pred:.2f} Mg/ha") |
|
|
|
|
|
print("\nTo process a GeoTIFF file:") |
|
|
print("predictions = predict_from_geotiff('your_image.tif', 'output_biomass.tif')") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
example() |