Spaces:
Running
Running
| """ | |
| StableResNet Model for Biomass Prediction | |
| A numerically stable ResNet architecture for regression tasks | |
| Author: najahpokkiri | |
| Date: 2025-05-17 | |
| """ | |
| """ | |
| StableResNet Model Architecture | |
| This module defines the StableResNet architecture used for biomass prediction. | |
| The model is designed for numerical stability with batch normalization and residual connections. | |
| Author: najahpokkiri | |
| Date: 2025-05-17 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class StableResNet(nn.Module): | |
| """Numerically stable ResNet for biomass regression""" | |
| def __init__(self, n_features, dropout=0.2): | |
| super().__init__() | |
| self.input_proj = nn.Sequential( | |
| nn.Linear(n_features, 256), | |
| nn.LayerNorm(256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout) | |
| ) | |
| self.layer1 = self._make_simple_resblock(256, 256) | |
| self.layer2 = self._make_simple_resblock(256, 128) | |
| self.layer3 = self._make_simple_resblock(128, 64) | |
| self.regressor = nn.Sequential( | |
| nn.Linear(64, 32), | |
| nn.ReLU(), | |
| nn.Linear(32, 1) | |
| ) | |
| self._init_weights() | |
| def _make_simple_resblock(self, in_dim, out_dim): | |
| """Create a simple residual block or downsampling block""" | |
| if in_dim == out_dim: | |
| # Residual block | |
| return nn.Sequential( | |
| nn.Linear(in_dim, out_dim), | |
| nn.BatchNorm1d(out_dim), | |
| nn.ReLU(), | |
| nn.Linear(out_dim, out_dim), | |
| nn.BatchNorm1d(out_dim), | |
| nn.ReLU() | |
| ) | |
| else: | |
| # Downsampling block | |
| return nn.Sequential( | |
| nn.Linear(in_dim, out_dim), | |
| nn.BatchNorm1d(out_dim), | |
| nn.ReLU(), | |
| ) | |
| def _init_weights(self): | |
| """Initialize weights for better convergence""" | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x): | |
| """Forward pass through the network""" | |
| x = self.input_proj(x) | |
| # First residual block | |
| identity = x | |
| out = self.layer1(x) | |
| x = out + identity | |
| # Remaining blocks | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| # Regression output | |
| x = self.regressor(x) | |
| return x.squeeze() | |
| def get_model_info(): | |
| """Return information about the model architecture""" | |
| return { | |
| 'name': 'StableResNet', | |
| 'description': 'Numerically stable ResNet for biomass regression', | |
| 'parameters': { | |
| 'n_features': 'Number of input features', | |
| 'dropout': 'Dropout rate (default: 0.2)' | |
| }, | |
| 'architecture': [ | |
| 'Input projection with layer normalization', | |
| 'Residual blocks with batch normalization', | |
| 'Downsampling blocks', | |
| 'Regression head' | |
| ] | |
| } |