File size: 2,313 Bytes
a16f583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from sklearn.ensemble import RandomForestRegressor
from config import bands_list_order, window_size  # For feature size estimation
import numpy as np

def count_rf_nodes(model):
    """Count the total number of nodes across all trees in the RandomForestRegressor."""
    total_nodes = 0
    for estimator in model.estimators_:
        total_nodes += estimator.tree_.node_count
    return total_nodes


def main():
    # Model configuration matching your training script
    model = RandomForestRegressor(
        n_estimators=1000,   # Number of trees
        max_depth=10,        # Maximum depth of each tree
        n_jobs=-1,           # Use all available cores
        random_state=42      # For reproducibility
    )

    # Since RF requires training data to build trees, we'll simulate a small fit
    # to get a realistic node count. In practice, this depends on the data.
    # Here, we use a placeholder dataset based on your config.
    n_samples = 100  # Small sample size for demonstration
    n_features = len(bands_list_order) * window_size * window_size  # Flattened feature size
    X_dummy = np.random.rand(n_samples, n_features)
    y_dummy = np.random.rand(n_samples)
    model.fit(X_dummy, y_dummy)

    # Count total nodes across all trees
    total_nodes = count_rf_nodes(model)
    
    # Hyperparameters from your training script
    hyperparameters = {
        "n_estimators": model.n_estimators,
        "max_depth": model.max_depth,
        "min_samples_split": model.min_samples_split,
        "min_samples_leaf": model.min_samples_leaf,
        "random_state": model.random_state
    }

    # Print hyperparameter summary
    print("RandomForestRegressor Hyperparameters:")
    for key, value in hyperparameters.items():
        print(f"{key}: {value}")
    
    # Print feature size (input dimensionality)
    print(f"\nInput Features (Flattened): {n_features} (bands={len(bands_list_order)}, window_size={window_size})")
    
    # Print total number of nodes
    print(f"\nTotal Number of Nodes Across All Trees: {total_nodes:,}")
    
    # Note on "parameters"
    print("\nNote: RandomForest doesn't have 'trainable parameters' like neural networks. "
          "Complexity is reflected in the number of trees, nodes, and feature size.")


if __name__ == "__main__":
    main()