import torch import torch.nn as nn import numpy as np class SineLayer(nn.Module): """Sine activation layer for SIREN network. Args: in_features: Number of input features out_features: Number of output features bias: Whether to use bias is_first: Whether this is the first layer (uses different initialization) omega_0: Frequency parameter for sine activation """ def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30): super().__init__() self.omega_0 = omega_0 self.is_first = is_first self.in_features = in_features self.linear = nn.Linear(in_features, out_features, bias=bias) self.init_weights() def init_weights(self): with torch.no_grad(): if self.is_first: self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) else: self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, np.sqrt(6 / self.in_features) / self.omega_0) def forward(self, x): return torch.sin(self.omega_0 * self.linear(x)) class SIREN(nn.Module): """SIREN network for implicit neural representations. Args: in_features: Number of input features (2 for image coordinates) hidden_features: Number of hidden units in each layer hidden_layers: Number of hidden layers out_features: Number of output features (3 for RGB) outermost_linear: Whether to use linear activation in the last layer first_omega_0: Frequency parameter for first layer hidden_omega_0: Frequency parameter for hidden layers """ def __init__(self, in_features=2, hidden_features=256, hidden_layers=3, out_features=3, outermost_linear=True, first_omega_0=30, hidden_omega_0=30): super().__init__() self.net = [] self.net.append(SineLayer(in_features, hidden_features, is_first=True, omega_0=first_omega_0)) for i in range(hidden_layers): self.net.append(SineLayer(hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0)) if outermost_linear: final_linear = nn.Linear(hidden_features, out_features) with torch.no_grad(): final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, np.sqrt(6 / hidden_features) / hidden_omega_0) self.net.append(final_linear) else: self.net.append(SineLayer(hidden_features, out_features, is_first=False, omega_0=hidden_omega_0)) self.net = nn.Sequential(*self.net) def forward(self, coords): output = self.net(coords) return output