Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Fri Mar 8 08:08:03 2024 | |
| @author: peter | |
| """ | |
| import torch | |
| class FactorizedMatrixMultiplication(torch.nn.Module): | |
| def __init__(self,size): | |
| super(FactorizedMatrixMultiplication,self).__init__() | |
| self.left = torch.nn.parameter.Parameter(torch.empty((size,8))) | |
| self.right = torch.nn.parameter.Parameter(torch.empty((8,size))) | |
| sigma = (3.0/(4.0*size))**0.25 | |
| torch.nn.init.normal_(self.left,0.0,sigma) | |
| torch.nn.init.normal_(self.right,0.0,sigma) | |
| self.matrix = torch.tensordot(self.left,self.right,1) | |
| def forward(self,X): | |
| return torch.einsum('ij,klj->kli',self.matrix,X) |