File size: 712 Bytes
a1e9f64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar  8 08:08:03 2024

@author: peter
"""

import torch

class FactorizedMatrixMultiplication(torch.nn.Module):
    
    def __init__(self,size):
        super(FactorizedMatrixMultiplication,self).__init__()
        self.left = torch.nn.parameter.Parameter(torch.empty((size,8)))
        self.right = torch.nn.parameter.Parameter(torch.empty((8,size)))
        sigma = (3.0/(4.0*size))**0.25
        torch.nn.init.normal_(self.left,0.0,sigma)
        torch.nn.init.normal_(self.right,0.0,sigma)
        self.matrix = torch.tensordot(self.left,self.right,1)
        
    def forward(self,X):
        return torch.einsum('ij,klj->kli',self.matrix,X)