|
|
""" |
|
|
Monkey patch GELU to fix compatibility issues |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def patched_gelu_forward(self, input): |
|
|
return F.gelu(input) |
|
|
|
|
|
|
|
|
_original_gelu_forward = nn.GELU.forward |
|
|
|
|
|
|
|
|
nn.GELU.forward = patched_gelu_forward |
|
|
|
|
|
|
|
|
class PatchedGELU(nn.Module): |
|
|
def __init__(self, approximate='none'): |
|
|
super().__init__() |
|
|
|
|
|
def forward(self, input): |
|
|
return F.gelu(input) |
|
|
|
|
|
def __getattr__(self, name): |
|
|
if name == 'approximate': |
|
|
return 'none' |
|
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
|
|
|
|
|
|
|
|
nn.GELU = PatchedGELU |
|
|
|
|
|
print("GELU patched for compatibility") |