File size: 630 Bytes
912372b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

@dataclass
class ReLUState:
    zeros: torch.Tensor


class StatefulReLU(nn.Module):
    can_torch_compile = True
    has_backward = True

    hidden_size: int

    @staticmethod
    def create_state(device: torch.device, layer: nn.Module) -> ReLUState:
        zeros = torch.zeros(layer.hidden_size, device=device)
        return ReLUState(zeros=zeros)

    def forward_with_state(self, state: ReLUState, input: torch.Tensor) -> torch.Tensor:
        return torch.maximum(input, state.zeros)


__all__ = ["StatefulReLU"]