File size: 1,906 Bytes
519a27e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""
StableResNet Model for Biomass Prediction
A numerically stable ResNet architecture for regression tasks
Author: najahpokkiri
Date: 2025-05-17
"""
import torch
import torch.nn as nn
class StableResNet(nn.Module):
"""Numerically stable ResNet for biomass regression"""
def __init__(self, n_features, dropout=0.2):
super().__init__()
self.input_proj = nn.Sequential(
nn.Linear(n_features, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Dropout(dropout)
)
self.layer1 = self._make_simple_resblock(256, 256)
self.layer2 = self._make_simple_resblock(256, 128)
self.layer3 = self._make_simple_resblock(128, 64)
self.regressor = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
self._init_weights()
def _make_simple_resblock(self, in_dim, out_dim):
return nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.BatchNorm1d(out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim),
nn.BatchNorm1d(out_dim),
nn.ReLU()
) if in_dim == out_dim else nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.BatchNorm1d(out_dim),
nn.ReLU(),
)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.input_proj(x)
identity = x
out = self.layer1(x)
x = out + identity
x = self.layer2(x)
x = self.layer3(x)
x = self.regressor(x)
return x.squeeze() |