gpt-3d / gpt3d /model.py
Peeble's picture
Create gpt3d/model.py
000b667 verified
raw
history blame contribute delete
586 Bytes
import torch
import torch.nn as nn
class GPT3D(nn.Module):
def __init__(self, hidden=256, layers=6, heads=8):
super().__init__()
self.embed = nn.Linear(3, hidden)
enc = nn.TransformerEncoderLayer(
d_model=hidden,
nhead=heads,
dim_feedforward=hidden * 4,
batch_first=True
)
self.transformer = nn.TransformerEncoder(enc, num_layers=layers)
self.out = nn.Linear(hidden, 3)
def forward(self, x):
x = self.embed(x)
x = self.transformer(x)
return self.out(x)