File size: 630 Bytes
912372b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
@dataclass
class ReLUState:
zeros: torch.Tensor
class StatefulReLU(nn.Module):
can_torch_compile = True
has_backward = True
hidden_size: int
@staticmethod
def create_state(device: torch.device, layer: nn.Module) -> ReLUState:
zeros = torch.zeros(layer.hidden_size, device=device)
return ReLUState(zeros=zeros)
def forward_with_state(self, state: ReLUState, input: torch.Tensor) -> torch.Tensor:
return torch.maximum(input, state.zeros)
__all__ = ["StatefulReLU"]
|