Truncate to 8k by default
#5
by
Jackmin108
- opened
- modeling_bert.py +3 -2
modeling_bert.py
CHANGED
|
@@ -1195,7 +1195,9 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
| 1195 |
inverse_permutation = np.argsort(permutation)
|
| 1196 |
sentences = [sentences[idx] for idx in permutation]
|
| 1197 |
|
| 1198 |
-
padding = tokenizer_kwargs.
|
|
|
|
|
|
|
| 1199 |
|
| 1200 |
all_embeddings = []
|
| 1201 |
|
|
@@ -1214,7 +1216,6 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
| 1214 |
encoded_input = self.tokenizer(
|
| 1215 |
sentences[i : i + batch_size],
|
| 1216 |
return_tensors='pt',
|
| 1217 |
-
padding=padding,
|
| 1218 |
**tokenizer_kwargs,
|
| 1219 |
).to(self.device)
|
| 1220 |
token_embs = self.forward(**encoded_input)[0]
|
|
|
|
| 1195 |
inverse_permutation = np.argsort(permutation)
|
| 1196 |
sentences = [sentences[idx] for idx in permutation]
|
| 1197 |
|
| 1198 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
| 1199 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 8192)
|
| 1200 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
| 1201 |
|
| 1202 |
all_embeddings = []
|
| 1203 |
|
|
|
|
| 1216 |
encoded_input = self.tokenizer(
|
| 1217 |
sentences[i : i + batch_size],
|
| 1218 |
return_tensors='pt',
|
|
|
|
| 1219 |
**tokenizer_kwargs,
|
| 1220 |
).to(self.device)
|
| 1221 |
token_embs = self.forward(**encoded_input)[0]
|