output_fn
Browse files- code/inference.py +8 -0
code/inference.py
CHANGED
|
@@ -12,3 +12,11 @@ def predict_fn(data: Union[List[str], str], model):
|
|
| 12 |
outputs = model(data, padding=False, truncation=True)
|
| 13 |
embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
|
| 14 |
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
outputs = model(data, padding=False, truncation=True)
|
| 13 |
embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
|
| 14 |
return embeddings
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def output_fn(prediction, accept):
|
| 18 |
+
return json.dumps(
|
| 19 |
+
obj={
|
| 20 |
+
"outputs": prediction
|
| 21 |
+
}
|
| 22 |
+
)
|