Update inference.py
Browse files- inference.py +2 -2
inference.py
CHANGED
|
@@ -40,7 +40,7 @@ class DualBertModel: # use this class to encode your document and queries.
|
|
| 40 |
attention_mask = torch.tensor(encoded.get('attention_mask')).to(device)
|
| 41 |
return input_ids, attention_mask
|
| 42 |
|
| 43 |
-
def encode_queries(self, queries, batch_size: int, **kwargs)
|
| 44 |
"""
|
| 45 |
Encodes a list of strings (queries) and returns a list of encodings.
|
| 46 |
"""
|
|
@@ -54,7 +54,7 @@ class DualBertModel: # use this class to encode your document and queries.
|
|
| 54 |
return torch.concat(to_return)
|
| 55 |
|
| 56 |
|
| 57 |
-
def encode_corpus(self, corpus, batch_size: int, **kwargs)
|
| 58 |
"""
|
| 59 |
Encodes a list of strings (documents) and returns a list of encodings.
|
| 60 |
"""
|
|
|
|
| 40 |
attention_mask = torch.tensor(encoded.get('attention_mask')).to(device)
|
| 41 |
return input_ids, attention_mask
|
| 42 |
|
| 43 |
+
def encode_queries(self, queries, batch_size: int, **kwargs):
|
| 44 |
"""
|
| 45 |
Encodes a list of strings (queries) and returns a list of encodings.
|
| 46 |
"""
|
|
|
|
| 54 |
return torch.concat(to_return)
|
| 55 |
|
| 56 |
|
| 57 |
+
def encode_corpus(self, corpus, batch_size: int, **kwargs):
|
| 58 |
"""
|
| 59 |
Encodes a list of strings (documents) and returns a list of encodings.
|
| 60 |
"""
|