rrivera1849 commited on
Commit
09be69e
·
1 Parent(s): 353504a

Upload LUAR

Browse files
Files changed (1) hide show
  1. model.py +9 -5
model.py CHANGED
@@ -44,12 +44,12 @@ class LUAR(PreTrainedModel):
44
  def mean_pooling(self, token_embeddings, attention_mask):
45
  """Mean Pooling as described in the SBERT paper.
46
  """
47
- input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
48
  sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
49
  sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
50
  return sum_embeddings / sum_mask
51
 
52
- def get_episode_embeddings(self, input_ids, attention_mask):
53
  """Computes the Author Embedding.
54
  """
55
  B, E, _ = attention_mask.shape
@@ -61,7 +61,8 @@ class LUAR(PreTrainedModel):
61
  input_ids=input_ids,
62
  attention_mask=attention_mask,
63
  return_dict=True,
64
- output_hidden_states=True
 
65
  )
66
 
67
  # at this point, we're embedding individual "comments"
@@ -74,11 +75,14 @@ class LUAR(PreTrainedModel):
74
 
75
  episode_embeddings = self.linear(episode_embeddings)
76
 
 
 
 
77
  return episode_embeddings
78
 
79
- def forward(self, input_ids, attention_mask):
80
  """Calculates a fixed-length feature vector for a batch of episode samples.
81
  """
82
- output = self.get_episode_embeddings(input_ids, attention_mask)
83
 
84
  return output
 
44
  def mean_pooling(self, token_embeddings, attention_mask):
45
  """Mean Pooling as described in the SBERT paper.
46
  """
47
+ input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type())
48
  sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
49
  sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
50
  return sum_embeddings / sum_mask
51
 
52
+ def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False):
53
  """Computes the Author Embedding.
54
  """
55
  B, E, _ = attention_mask.shape
 
61
  input_ids=input_ids,
62
  attention_mask=attention_mask,
63
  return_dict=True,
64
+ output_hidden_states=True,
65
+ output_attentions=output_attentions,
66
  )
67
 
68
  # at this point, we're embedding individual "comments"
 
75
 
76
  episode_embeddings = self.linear(episode_embeddings)
77
 
78
+ if output_attentions:
79
+ return episode_embeddings, outputs["attentions"]
80
+
81
  return episode_embeddings
82
 
83
+ def forward(self, input_ids, attention_mask, output_attentions=False):
84
  """Calculates a fixed-length feature vector for a batch of episode samples.
85
  """
86
+ output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
87
 
88
  return output