Upload ConstBERT
Browse files- modeling.py +5 -3
modeling.py
CHANGED
|
@@ -6,6 +6,8 @@ from tqdm import tqdm
|
|
| 6 |
from .colbert_configuration import ColBERTConfig
|
| 7 |
from .tokenization_utils import QueryTokenizer, DocTokenizer
|
| 8 |
import os
|
|
|
|
|
|
|
| 9 |
class NullContextManager(object):
|
| 10 |
def __init__(self, dummy_resource=None):
|
| 11 |
self.dummy_resource = dummy_resource
|
|
@@ -54,7 +56,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
| 54 |
"""
|
| 55 |
_keys_to_ignore_on_load_unexpected = [r"cls"]
|
| 56 |
|
| 57 |
-
def __init__(self, config, colbert_config, verbose:int =
|
| 58 |
super().__init__(config)
|
| 59 |
|
| 60 |
self.config = config
|
|
@@ -175,7 +177,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
| 175 |
|
| 176 |
return D
|
| 177 |
|
| 178 |
-
def
|
| 179 |
if bsize:
|
| 180 |
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
|
| 181 |
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
|
@@ -184,7 +186,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
| 184 |
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
|
| 185 |
return self.query(input_ids, attention_mask)
|
| 186 |
|
| 187 |
-
def
|
| 188 |
assert keep_dims in [True, False, 'flatten']
|
| 189 |
|
| 190 |
if bsize:
|
|
|
|
| 6 |
from .colbert_configuration import ColBERTConfig
|
| 7 |
from .tokenization_utils import QueryTokenizer, DocTokenizer
|
| 8 |
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
class NullContextManager(object):
|
| 12 |
def __init__(self, dummy_resource=None):
|
| 13 |
self.dummy_resource = dummy_resource
|
|
|
|
| 56 |
"""
|
| 57 |
_keys_to_ignore_on_load_unexpected = [r"cls"]
|
| 58 |
|
| 59 |
+
def __init__(self, config, colbert_config, verbose:int = 0):
|
| 60 |
super().__init__(config)
|
| 61 |
|
| 62 |
self.config = config
|
|
|
|
| 177 |
|
| 178 |
return D
|
| 179 |
|
| 180 |
+
def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
|
| 181 |
if bsize:
|
| 182 |
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
|
| 183 |
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
|
|
|
| 186 |
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
|
| 187 |
return self.query(input_ids, attention_mask)
|
| 188 |
|
| 189 |
+
def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
|
| 190 |
assert keep_dims in [True, False, 'flatten']
|
| 191 |
|
| 192 |
if bsize:
|