Commit
·
82ec677
1
Parent(s):
9031ba4
Upload LUAR
Browse files
model.py
CHANGED
|
@@ -146,7 +146,7 @@ class LUAR(PreTrainedModel):
|
|
| 146 |
config.k_bucket_size,
|
| 147 |
)
|
| 148 |
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
|
| 149 |
-
|
| 150 |
def create_transformer(self):
|
| 151 |
"""Creates the Transformer backbone.
|
| 152 |
"""
|
|
@@ -163,7 +163,7 @@ class LUAR(PreTrainedModel):
|
|
| 163 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
| 164 |
return sum_embeddings / sum_mask
|
| 165 |
|
| 166 |
-
def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False):
|
| 167 |
"""Computes the Author Embedding.
|
| 168 |
"""
|
| 169 |
B, E, _ = attention_mask.shape
|
|
@@ -171,14 +171,31 @@ class LUAR(PreTrainedModel):
|
|
| 171 |
input_ids = rearrange(input_ids, 'b e l -> (b e) l')
|
| 172 |
attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
# at this point, we're embedding individual "comments"
|
| 183 |
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
|
| 184 |
comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
|
|
@@ -194,9 +211,9 @@ class LUAR(PreTrainedModel):
|
|
| 194 |
|
| 195 |
return episode_embeddings
|
| 196 |
|
| 197 |
-
def forward(self, input_ids, attention_mask, output_attentions=False):
|
| 198 |
"""Calculates a fixed-length feature vector for a batch of episode samples.
|
| 199 |
"""
|
| 200 |
-
output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
|
| 201 |
|
| 202 |
return output
|
|
|
|
| 146 |
config.k_bucket_size,
|
| 147 |
)
|
| 148 |
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
|
| 149 |
+
|
| 150 |
def create_transformer(self):
|
| 151 |
"""Creates the Transformer backbone.
|
| 152 |
"""
|
|
|
|
| 163 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
| 164 |
return sum_embeddings / sum_mask
|
| 165 |
|
| 166 |
+
def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
|
| 167 |
"""Computes the Author Embedding.
|
| 168 |
"""
|
| 169 |
B, E, _ = attention_mask.shape
|
|
|
|
| 171 |
input_ids = rearrange(input_ids, 'b e l -> (b e) l')
|
| 172 |
attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
|
| 173 |
|
| 174 |
+
if document_batch_size > 0:
|
| 175 |
+
outputs = {"last_hidden_state": [], "attentions": []}
|
| 176 |
+
for i in range(0, len(input_ids), document_batch_size):
|
| 177 |
+
out = self.transformer(
|
| 178 |
+
input_ids=input_ids[i:i+document_batch_size],
|
| 179 |
+
attention_mask=attention_mask[i:i+document_batch_size],
|
| 180 |
+
return_dict=True,
|
| 181 |
+
output_hidden_states=False,
|
| 182 |
+
output_attentions=output_attentions,
|
| 183 |
+
)
|
| 184 |
+
outputs["last_hidden_state"].append(out["last_hidden_state"])
|
| 185 |
+
if output_attentions:
|
| 186 |
+
outputs["attentions"].append(out["attentions"])
|
| 187 |
+
outputs["last_hidden_state"] = torch.cat(outputs["last_hidden_state"], dim=0)
|
| 188 |
+
if output_attentions:
|
| 189 |
+
outputs["attentions"] = tuple([torch.cat([x[i] for x in outputs["attentions"]], dim=0) for i in range(len(outputs["attentions"][0]))])
|
| 190 |
+
else:
|
| 191 |
+
outputs = self.transformer(
|
| 192 |
+
input_ids=input_ids,
|
| 193 |
+
attention_mask=attention_mask,
|
| 194 |
+
return_dict=True,
|
| 195 |
+
output_hidden_states=False,
|
| 196 |
+
output_attentions=output_attentions,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
# at this point, we're embedding individual "comments"
|
| 200 |
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
|
| 201 |
comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
|
|
|
|
| 211 |
|
| 212 |
return episode_embeddings
|
| 213 |
|
| 214 |
+
def forward(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
|
| 215 |
"""Calculates a fixed-length feature vector for a batch of episode samples.
|
| 216 |
"""
|
| 217 |
+
output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions, document_batch_size)
|
| 218 |
|
| 219 |
return output
|