PeteBleackley commited on
Commit
798488e
·
1 Parent(s): 1d1a876

Using torch.nn.CosineSimilarity to simplify code

Browse files
qarac/models/QaracTrainerModel.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  import qarac.models.QaracEncoderModel
11
  import qarac.models.QaracDecoderModel
12
 
13
- EPSILON=torch.tensor(1.0e-12)
14
 
15
  class QaracTrainerModel(torch.nn.Module):
16
 
@@ -39,6 +39,7 @@ class QaracTrainerModel(torch.nn.Module):
39
  self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_model_path,
40
  config,
41
  tokenizer)
 
42
 
43
  def forward(self,
44
  all_text,
@@ -92,15 +93,9 @@ class QaracTrainerModel(torch.nn.Module):
92
  reasoning = self.decoder((self.answer_encoder(proposition0)
93
  +self.answer_encoder(proposition1),
94
  conclusion_offset))
95
- s0vec = self.answer_encoder(statement0)
96
- s0norm = torch.maximum(torch.linalg.vector_norm(s0vec,
97
- dim=1,
98
- keepdim=True),EPSILON)
99
- s0 = s0vec/s0norm
100
- s1vec = self.answer_encoder(statement1)
101
- s1norm = torch.maximum(torch.linalg.vector_norm(s1vec,
102
- dim=1,
103
- keepdim=True),EPSILON)
104
- s1 = s1vec/s1norm
105
- consistency = torch.einsum('ij,ij->i',s0,s1)
106
  return (encode_decode,question_answering,reasoning,consistency)
 
10
  import qarac.models.QaracEncoderModel
11
  import qarac.models.QaracDecoderModel
12
 
13
+
14
 
15
  class QaracTrainerModel(torch.nn.Module):
16
 
 
39
  self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_model_path,
40
  config,
41
  tokenizer)
42
+ self.cosine = torch.nn.CosineSimilarity(dim=2,eps=1.0e-12)
43
 
44
  def forward(self,
45
  all_text,
 
93
  reasoning = self.decoder((self.answer_encoder(proposition0)
94
  +self.answer_encoder(proposition1),
95
  conclusion_offset))
96
+ s0 = self.answer_encoder(statement0)
97
+
98
+ s1 = self.answer_encoder(statement1)
99
+
100
+ consistency = self.cosine(s0,s1)
 
 
 
 
 
 
101
  return (encode_decode,question_answering,reasoning,consistency)
qarac/models/layers/GlobalAttentionPoolingHead.py CHANGED
@@ -8,7 +8,6 @@ Created on Tue Sep 5 07:32:55 2023
8
 
9
  import torch
10
 
11
- EPSILON = torch.tensor(1.0e-12)
12
 
13
  class GlobalAttentionPoolingHead(torch.nn.Module):
14
 
@@ -29,6 +28,7 @@ class GlobalAttentionPoolingHead(torch.nn.Module):
29
  super(GlobalAttentionPoolingHead,self).__init__()
30
  self.global_projection = torch.nn.Linear(size,size,bias=False)
31
  self.local_projection = torch.nn.Linear(size,size,bias=False)
 
32
 
33
 
34
 
@@ -55,16 +55,9 @@ class GlobalAttentionPoolingHead(torch.nn.Module):
55
  else:
56
  attention_mask = attention_mask.unsqueeze(2)
57
  Xa = X*attention_mask
58
- sigma = torch.sum(Xa,dim=1)
59
- psigma = self.global_projection(sigma)
60
- nsigma = torch.maximum(torch.linalg.vector_norm(psigma,
61
- dim=1,
62
- keepdim=True),EPSILON)
63
- gp = psigma/nsigma
64
- loc = self.local_projection(Xa)
65
- nloc = torch.maximum(torch.linalg.vector_norm(loc,
66
- dim=2,
67
- keepdim=True),EPSILON)
68
- lp = loc/nloc
69
- attention = torch.einsum('ijk,ik->ij',lp,gp)
70
  return torch.einsum('ij,ijk->ik',attention,Xa)
 
8
 
9
  import torch
10
 
 
11
 
12
  class GlobalAttentionPoolingHead(torch.nn.Module):
13
 
 
28
  super(GlobalAttentionPoolingHead,self).__init__()
29
  self.global_projection = torch.nn.Linear(size,size,bias=False)
30
  self.local_projection = torch.nn.Linear(size,size,bias=False)
31
+ self.cosine = torch.nn.CosineSimilarity(dim=2,eps=1.0e-12)
32
 
33
 
34
 
 
55
  else:
56
  attention_mask = attention_mask.unsqueeze(2)
57
  Xa = X*attention_mask
58
+ sigma = torch.sum(Xa,dim=1,keepdim=True)
59
+ gp = self.global_projection(sigma)
60
+ lp = self.local_projection(Xa)
61
+
62
+ attention = self.cosine(lp,gp)
 
 
 
 
 
 
 
63
  return torch.einsum('ij,ijk->ik',attention,Xa)
scripts.py CHANGED
@@ -149,6 +149,7 @@ def train_models(path):
149
  X['conclusion_offset'],
150
  X['statement0'],
151
  X['statement1'])
 
152
  loss = loss_fn(prediction,Y)
153
  loss.backward()
154
  optimizer.step()
@@ -411,16 +412,9 @@ def test_consistency(path):
411
  s1_attn = torch.not_equal(s1_in,
412
  pad_token)
413
  s0_vec = encoder(s0_in,attention_mask=s0_attn)
414
- s0_norm = torch.maximum(torch.linalg.vector_norm(s0_vec,
415
- dim=1,
416
- keepdim=True),EPSILON)
417
- s0 = s0_vec/s0_norm
418
  s1_vec = encoder(s1_in,attention_mask=s1_attn)
419
- s1_norm = torch.maximum(torch.linalg.vector_norm(s1_vec,
420
- dim=1,
421
- keepdim=True),EPSILON)
422
- s1 = s1_vec/s1_norm
423
- consistency = torch.einsum('ij,ij->i',s0,s1).numpy()
424
  results = pandas.DataFrame({'label':data['gold_label'],
425
  'score':consistency})
426
  third = 1.0/3.0
 
149
  X['conclusion_offset'],
150
  X['statement0'],
151
  X['statement1'])
152
+ print([y.shape for y in prediction])
153
  loss = loss_fn(prediction,Y)
154
  loss.backward()
155
  optimizer.step()
 
412
  s1_attn = torch.not_equal(s1_in,
413
  pad_token)
414
  s0_vec = encoder(s0_in,attention_mask=s0_attn)
 
 
 
 
415
  s1_vec = encoder(s1_in,attention_mask=s1_attn)
416
+ cosine = torch.nn.CosineSimilarity(dim=2,eps=1.0e-12)
417
+ consistency = cosine(s0_vec,s1_vec).numpy()
 
 
 
418
  results = pandas.DataFrame({'label':data['gold_label'],
419
  'score':consistency})
420
  third = 1.0/3.0