Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -50,7 +50,7 @@ def get_splade_representation(text):
|
|
| 50 |
# output.logits is typically [batch_size, sequence_length, vocab_size]
|
| 51 |
# We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
|
| 52 |
# inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
|
| 53 |
-
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs
|
| 54 |
else:
|
| 55 |
# Fallback/error message if the output structure is unexpected
|
| 56 |
return "Model output structure not as expected for SPLADE. 'logits' not found."
|
|
|
|
| 50 |
# output.logits is typically [batch_size, sequence_length, vocab_size]
|
| 51 |
# We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
|
| 52 |
# inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
|
| 53 |
+
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
|
| 54 |
else:
|
| 55 |
# Fallback/error message if the output structure is unexpected
|
| 56 |
return "Model output structure not as expected for SPLADE. 'logits' not found."
|