varox34's picture
Upload 64 files
366b225 verified
raw
history blame contribute delete
651 Bytes
# -*- coding: utf-8 -*-
from parser.modules.dropout import SharedDropout
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, n_in, n_hidden, dropout=0):
super(MLP, self).__init__()
self.linear = nn.Linear(n_in, n_hidden)
self.activation = nn.LeakyReLU(negative_slope=0.1)
self.dropout = SharedDropout(p=dropout)
self.reset_parameters()
def reset_parameters(self):
nn.init.orthogonal_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x):
x = self.linear(x)
x = self.activation(x)
x = self.dropout(x)
return x