|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
|
|
|
from typing import Sequence, Type |
|
|
|
|
|
|
|
|
def mlp( |
|
|
layer_sizes: Sequence[int], |
|
|
activation: Type[nn.Module], |
|
|
output_activation: Type[nn.Module] = nn.Identity, |
|
|
init_layers_orthogonal: bool = False, |
|
|
final_layer_gain: float = np.sqrt(2), |
|
|
) -> nn.Module: |
|
|
layers = [] |
|
|
for i in range(len(layer_sizes) - 2): |
|
|
layers.append( |
|
|
layer_init( |
|
|
nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal |
|
|
) |
|
|
) |
|
|
layers.append(activation()) |
|
|
layers.append( |
|
|
layer_init( |
|
|
nn.Linear(layer_sizes[-2], layer_sizes[-1]), |
|
|
init_layers_orthogonal, |
|
|
std=final_layer_gain, |
|
|
) |
|
|
) |
|
|
layers.append(output_activation()) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
|
|
|
def layer_init( |
|
|
layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2) |
|
|
) -> nn.Module: |
|
|
if not init_layers_orthogonal: |
|
|
return layer |
|
|
nn.init.orthogonal_(layer.weight, std) |
|
|
nn.init.constant_(layer.bias, 0.0) |
|
|
return layer |
|
|
|