| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple
|
| from Model.prelude_Block import PreludeBlock
|
| from Model.recurrent_Block import RecurrentBlock
|
| from Model.codaBlock import CodaBlock
|
|
|
|
|
| class LatentRecurrentDepthLM(nn.Module):
|
| def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1):
|
| super().__init__()
|
| self.prelude = PreludeBlock(vocab_size, d_model, num_heads, dropout)
|
| self.recurrent = RecurrentBlock(d_model, num_heads, dropout)
|
| self.coda = CodaBlock(d_model, vocab_size)
|
|
|
| def forward(self, x: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| hidden = self.prelude(x, mask)
|
| recurrent_state = torch.zeros_like(hidden)
|
| for _ in range(num_iterations):
|
| hidden, recurrent_state = self.recurrent(hidden, recurrent_state, mask)
|
| return self.coda(hidden) |