danieldk's picture
danieldk HF Staff
Add layer
912372b
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
@dataclass
class ReLUState:
zeros: torch.Tensor
class StatefulReLU(nn.Module):
can_torch_compile = True
has_backward = True
hidden_size: int
@staticmethod
def create_state(device: torch.device, layer: nn.Module) -> ReLUState:
zeros = torch.zeros(layer.hidden_size, device=device)
return ReLUState(zeros=zeros)
def forward_with_state(self, state: ReLUState, input: torch.Tensor) -> torch.Tensor:
return torch.maximum(input, state.zeros)
__all__ = ["StatefulReLU"]