File size: 5,671 Bytes
519a27e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Inference code for StableResNet Biomass Prediction Model
Provides utility functions for making predictions with the model

Author: najahpokkiri
Date: 2025-05-17
"""
import os
import torch
import numpy as np
import joblib
from model import StableResNet
from huggingface_hub import hf_hub_download

def load_model_from_hub(repo_id="najahpokkiri/biomass-model"):
    """Load model from HuggingFace Hub"""
    # Download files from HuggingFace
    model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
    package_path = hf_hub_download(repo_id=repo_id, filename="model_package.pkl")
    
    # Load package with metadata
    package = joblib.load(package_path)
    n_features = package['n_features']
    
    # Initialize model
    model = StableResNet(n_features=n_features)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.eval()
    
    return model, package

def load_model_local(model_path, package_path):
    """Load model from local files"""
    # Load package with metadata
    package = joblib.load(package_path)
    n_features = package['n_features']
    
    # Initialize model
    model = StableResNet(n_features=n_features)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.eval()
    
    return model, package

def predict_biomass(model, features, package):
    """Predict biomass from feature array"""
    # Get metadata
    scaler = package['scaler']
    use_log_transform = package['use_log_transform']
    epsilon = package.get('epsilon', 1.0)
    
    # Scale features
    features_scaled = scaler.transform(features)
    
    # Convert to tensor
    tensor = torch.tensor(features_scaled, dtype=torch.float32)
    
    # Make prediction
    with torch.no_grad():
        output = model(tensor).numpy()
    
    # Convert from log scale if needed
    if use_log_transform:
        output = np.exp(output) - epsilon
        output = np.maximum(output, 0)  # Ensure non-negative
    
    return output

def predict_from_geotiff(tiff_path, output_path=None, model=None, package=None, repo_id="najahpokkiri/biomass-model"):
    """Predict biomass from a GeoTIFF file"""
    try:
        import rasterio
    except ImportError:
        raise ImportError("rasterio is required for GeoTIFF processing. Install with 'pip install rasterio'.")
    
    # Load model if not provided
    if model is None or package is None:
        model, package = load_model_from_hub(repo_id)
    
    with rasterio.open(tiff_path) as src:
        # Read image data
        data = src.read()
        height, width = data.shape[1], data.shape[2]
        transform = src.transform
        crs = src.crs
        
        # Predict in chunks
        chunk_size = 1000
        predictions = np.zeros((height, width), dtype=np.float32)
        
        # Create mask for valid pixels
        valid_mask = np.all(np.isfinite(data), axis=0)
        
        # Process image in chunks
        for y_start in range(0, height, chunk_size):
            y_end = min(y_start + chunk_size, height)
            
            for x_start in range(0, width, chunk_size):
                x_end = min(x_start + chunk_size, width)
                
                # Get chunk mask
                chunk_mask = valid_mask[y_start:y_end, x_start:x_end]
                if not np.any(chunk_mask):
                    continue
                
                # Extract valid pixels
                valid_y, valid_x = np.where(chunk_mask)
                
                # Extract features
                pixel_features = []
                for i, j in zip(valid_y, valid_x):
                    pixel_values = data[:, y_start+i, x_start+j]
                    pixel_features.append(pixel_values)
                
                # Make predictions
                pixel_features = np.array(pixel_features)
                batch_predictions = predict_biomass(model, pixel_features, package)
                
                # Insert predictions back into the image
                for idx, (i, j) in enumerate(zip(valid_y, valid_x)):
                    predictions[y_start+i, x_start+j] = batch_predictions[idx]
        
        # Save predictions if output path is provided
        if output_path:
            meta = src.meta.copy()
            meta.update(
                dtype='float32',
                count=1,
                nodata=0
            )
            
            with rasterio.open(output_path, 'w', **meta) as dst:
                dst.write(predictions, 1)
            
            print(f"Saved biomass predictions to: {output_path}")
        
        return predictions

def example():
    """Example usage"""
    print("StableResNet Biomass Prediction Example")
    print("-" * 40)
    
    # Option 1: Load from HuggingFace Hub
    print("Loading model from HuggingFace Hub...")
    model, package = load_model_from_hub("najahpokkiri/biomass-model")
    
    # Option 2: Load from local files
    # model, package = load_model_local("model.pt", "model_package.pkl")
    
    print(f"Model loaded. Expecting {package['n_features']} features")
    
    # Example: Create synthetic features for demonstration
    n_features = package['n_features']
    example_features = np.random.rand(5, n_features)
    
    print("\nPredicting biomass for 5 sample points...")
    predictions = predict_biomass(model, example_features, package)
    
    for i, pred in enumerate(predictions):
        print(f"Sample {i+1}: {pred:.2f} Mg/ha")
    
    print("\nTo process a GeoTIFF file:")
    print("predictions = predict_from_geotiff('your_image.tif', 'output_biomass.tif')")

if __name__ == "__main__":
    example()