model(data, padding=False, truncation=True)
Browse files- code/inference.py +1 -1
code/inference.py
CHANGED
|
@@ -3,6 +3,6 @@ from typing import List, Union
|
|
| 3 |
|
| 4 |
|
| 5 |
def predict_fn(data: Union[List[str], str], model):
|
| 6 |
-
outputs = model(
|
| 7 |
embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
|
| 8 |
return embeddings
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
def predict_fn(data: Union[List[str], str], model):
|
| 6 |
+
outputs = model(data, padding=False, truncation=True)
|
| 7 |
embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
|
| 8 |
return embeddings
|