Delete latent_Recurrent.py
Browse files- latent_Recurrent.py +0 -22
latent_Recurrent.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from typing import Optional, Tuple
|
| 5 |
-
from prelude_Block import PreludeBlock
|
| 6 |
-
from recurrent_Block import RecurrentBlock
|
| 7 |
-
from codaBlock import CodaBlock
|
| 8 |
-
|
| 9 |
-
# Full Latent Recurrent Depth Model
|
| 10 |
-
class LatentRecurrentDepthLM(nn.Module):
|
| 11 |
-
def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.prelude = PreludeBlock(vocab_size, d_model, num_heads, dropout)
|
| 14 |
-
self.recurrent = RecurrentBlock(d_model, num_heads, dropout)
|
| 15 |
-
self.coda = CodaBlock(d_model, vocab_size)
|
| 16 |
-
|
| 17 |
-
def forward(self, x: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 18 |
-
hidden = self.prelude(x, mask)
|
| 19 |
-
recurrent_state = torch.zeros_like(hidden)
|
| 20 |
-
for _ in range(num_iterations):
|
| 21 |
-
hidden, recurrent_state = self.recurrent(hidden, recurrent_state, mask)
|
| 22 |
-
return self.coda(hidden)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|