import torch import torch.nn as nn class MultiTheologyModel(nn.Module): def __init__( self, vocab_size, embed_dim=64, hidden_dim=128 ): super().__init__() self.vocab_size = vocab_size self.embed = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM( embed_dim, hidden_dim, batch_first=True ) self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, x): x = self.embed(x) out, _ = self.lstm(x) return self.fc(out)