Caplin43's picture
Update modeling.py
d613335 verified
import torch
import torch.nn as nn
class HanHumanoidCNN(nn.Module):
def __init__(self, config):
super(HanHumanoidCNN, self).__init__()
self.conv = nn.Conv1d(
in_channels=config["input_channels"],
out_channels=config["num_filters"],
kernel_size=config["kernel_size"]
)
self.relu = nn.ReLU()
self.fc = nn.Linear(config["num_filters"], config["output_dim"])
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = x.mean(dim=2)
x = self.fc(x)
return x