Commit
·
09be69e
1
Parent(s):
353504a
Upload LUAR
Browse files
model.py
CHANGED
|
@@ -44,12 +44,12 @@ class LUAR(PreTrainedModel):
|
|
| 44 |
def mean_pooling(self, token_embeddings, attention_mask):
|
| 45 |
"""Mean Pooling as described in the SBERT paper.
|
| 46 |
"""
|
| 47 |
-
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).
|
| 48 |
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
|
| 49 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
| 50 |
return sum_embeddings / sum_mask
|
| 51 |
|
| 52 |
-
def get_episode_embeddings(self, input_ids, attention_mask):
|
| 53 |
"""Computes the Author Embedding.
|
| 54 |
"""
|
| 55 |
B, E, _ = attention_mask.shape
|
|
@@ -61,7 +61,8 @@ class LUAR(PreTrainedModel):
|
|
| 61 |
input_ids=input_ids,
|
| 62 |
attention_mask=attention_mask,
|
| 63 |
return_dict=True,
|
| 64 |
-
output_hidden_states=True
|
|
|
|
| 65 |
)
|
| 66 |
|
| 67 |
# at this point, we're embedding individual "comments"
|
|
@@ -74,11 +75,14 @@ class LUAR(PreTrainedModel):
|
|
| 74 |
|
| 75 |
episode_embeddings = self.linear(episode_embeddings)
|
| 76 |
|
|
|
|
|
|
|
|
|
|
| 77 |
return episode_embeddings
|
| 78 |
|
| 79 |
-
def forward(self, input_ids, attention_mask):
|
| 80 |
"""Calculates a fixed-length feature vector for a batch of episode samples.
|
| 81 |
"""
|
| 82 |
-
output = self.get_episode_embeddings(input_ids, attention_mask)
|
| 83 |
|
| 84 |
return output
|
|
|
|
| 44 |
def mean_pooling(self, token_embeddings, attention_mask):
|
| 45 |
"""Mean Pooling as described in the SBERT paper.
|
| 46 |
"""
|
| 47 |
+
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type())
|
| 48 |
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
|
| 49 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
| 50 |
return sum_embeddings / sum_mask
|
| 51 |
|
| 52 |
+
def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False):
|
| 53 |
"""Computes the Author Embedding.
|
| 54 |
"""
|
| 55 |
B, E, _ = attention_mask.shape
|
|
|
|
| 61 |
input_ids=input_ids,
|
| 62 |
attention_mask=attention_mask,
|
| 63 |
return_dict=True,
|
| 64 |
+
output_hidden_states=True,
|
| 65 |
+
output_attentions=output_attentions,
|
| 66 |
)
|
| 67 |
|
| 68 |
# at this point, we're embedding individual "comments"
|
|
|
|
| 75 |
|
| 76 |
episode_embeddings = self.linear(episode_embeddings)
|
| 77 |
|
| 78 |
+
if output_attentions:
|
| 79 |
+
return episode_embeddings, outputs["attentions"]
|
| 80 |
+
|
| 81 |
return episode_embeddings
|
| 82 |
|
| 83 |
+
def forward(self, input_ids, attention_mask, output_attentions=False):
|
| 84 |
"""Calculates a fixed-length feature vector for a batch of episode samples.
|
| 85 |
"""
|
| 86 |
+
output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
|
| 87 |
|
| 88 |
return output
|