Update helpers/required_classes.py
Browse files
helpers/required_classes.py
CHANGED
|
@@ -24,14 +24,15 @@ class BertEmbedder:
|
|
| 24 |
self.embedder.to(self.device)
|
| 25 |
|
| 26 |
def __call__(self, text: str):
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
def batch_predict(self, texts: List[str]):
|
| 37 |
encoded_input = self.tokenizer(texts,
|
|
|
|
| 24 |
self.embedder.to(self.device)
|
| 25 |
|
| 26 |
def __call__(self, text: str):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
encoded_input = self.tokenizer(text,
|
| 29 |
+
return_tensors='pt',
|
| 30 |
+
max_length=self.max_length,
|
| 31 |
+
padding=True,
|
| 32 |
+
truncation=True).to(self.device)
|
| 33 |
+
model_output = self.embedder(**encoded_input)
|
| 34 |
+
text_embed = model_output.pooler_output[0].cpu()
|
| 35 |
+
return text_embed
|
| 36 |
|
| 37 |
def batch_predict(self, texts: List[str]):
|
| 38 |
encoded_input = self.tokenizer(texts,
|