Nipun's picture
Complete SIREN super-resolution demo with improvements
691ba3c
import torch
import torch.nn as nn
import numpy as np
class SineLayer(nn.Module):
"""Sine activation layer for SIREN network.
Args:
in_features: Number of input features
out_features: Number of output features
bias: Whether to use bias
is_first: Whether this is the first layer (uses different initialization)
omega_0: Frequency parameter for sine activation
"""
def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.init_weights()
def init_weights(self):
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1 / self.in_features,
1 / self.in_features)
else:
self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0)
def forward(self, x):
return torch.sin(self.omega_0 * self.linear(x))
class SIREN(nn.Module):
"""SIREN network for implicit neural representations.
Args:
in_features: Number of input features (2 for image coordinates)
hidden_features: Number of hidden units in each layer
hidden_layers: Number of hidden layers
out_features: Number of output features (3 for RGB)
outermost_linear: Whether to use linear activation in the last layer
first_omega_0: Frequency parameter for first layer
hidden_omega_0: Frequency parameter for hidden layers
"""
def __init__(self, in_features=2, hidden_features=256, hidden_layers=3,
out_features=3, outermost_linear=True,
first_omega_0=30, hidden_omega_0=30):
super().__init__()
self.net = []
self.net.append(SineLayer(in_features, hidden_features,
is_first=True, omega_0=first_omega_0))
for i in range(hidden_layers):
self.net.append(SineLayer(hidden_features, hidden_features,
is_first=False, omega_0=hidden_omega_0))
if outermost_linear:
final_linear = nn.Linear(hidden_features, out_features)
with torch.no_grad():
final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
np.sqrt(6 / hidden_features) / hidden_omega_0)
self.net.append(final_linear)
else:
self.net.append(SineLayer(hidden_features, out_features,
is_first=False, omega_0=hidden_omega_0))
self.net = nn.Sequential(*self.net)
def forward(self, coords):
output = self.net(coords)
return output