PeteBleackley commited on
Commit
a1e9f64
·
1 Parent(s): 776f717

Factorized the weight matrix in the GlobalAttentionPoolingHead, thus reducing the number of parameters in this layer by a factor of 48

Browse files
qarac/models/layers/FactorizedMatrixMultiplication.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Fri Mar 8 08:08:03 2024
5
+
6
+ @author: peter
7
+ """
8
+
9
+ import torch
10
+
11
+ class FactorizedMatrixMultiplication(torch.nn.Module):
12
+
13
+ def __init__(self,size):
14
+ super(FactorizedMatrixMultiplication,self).__init__()
15
+ self.left = torch.nn.parameter.Parameter(torch.empty((size,8)))
16
+ self.right = torch.nn.parameter.Parameter(torch.empty((8,size)))
17
+ sigma = (3.0/(4.0*size))**0.25
18
+ torch.nn.init.normal_(self.left,0.0,sigma)
19
+ torch.nn.init.normal_(self.right,0.0,sigma)
20
+ self.matrix = torch.tensordot(self.left,self.right,1)
21
+
22
+ def forward(self,X):
23
+ return torch.einsum('ij,klj->kli',self.matrix,X)
qarac/models/layers/GlobalAttentionPoolingHead.py CHANGED
@@ -7,6 +7,7 @@ Created on Tue Sep 5 07:32:55 2023
7
  """
8
 
9
  import torch
 
10
 
11
 
12
  class GlobalAttentionPoolingHead(torch.nn.Module):
@@ -26,8 +27,8 @@ class GlobalAttentionPoolingHead(torch.nn.Module):
26
  """
27
  size = config.hidden_size
28
  super(GlobalAttentionPoolingHead,self).__init__()
29
- self.global_projection = torch.nn.Linear(size,size,bias=False)
30
- self.local_projection = torch.nn.Linear(size,size,bias=False)
31
  self.cosine = torch.nn.CosineSimilarity(dim=2,eps=1.0e-12)
32
 
33
 
 
7
  """
8
 
9
  import torch
10
+ import FactorizedMatrixMultiplication
11
 
12
 
13
  class GlobalAttentionPoolingHead(torch.nn.Module):
 
27
  """
28
  size = config.hidden_size
29
  super(GlobalAttentionPoolingHead,self).__init__()
30
+ self.global_projection = FactorizedMatrixMultiplication.FactorizedMatrixMultiplication(size)
31
+ self.local_projection = FactorizedMatrixMultiplication.FactorizedMatrixMultiplication(size)
32
  self.cosine = torch.nn.CosineSimilarity(dim=2,eps=1.0e-12)
33
 
34