|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
@dataclass |
|
|
class ReLUState: |
|
|
zeros: torch.Tensor |
|
|
|
|
|
|
|
|
class StatefulReLU(nn.Module): |
|
|
can_torch_compile = True |
|
|
has_backward = True |
|
|
|
|
|
hidden_size: int |
|
|
|
|
|
@staticmethod |
|
|
def create_state(device: torch.device, layer: nn.Module) -> ReLUState: |
|
|
zeros = torch.zeros(layer.hidden_size, device=device) |
|
|
return ReLUState(zeros=zeros) |
|
|
|
|
|
def forward_with_state(self, state: ReLUState, input: torch.Tensor) -> torch.Tensor: |
|
|
return torch.maximum(input, state.zeros) |
|
|
|
|
|
|
|
|
__all__ = ["StatefulReLU"] |
|
|
|