updated handler.py
Browse files- handler.py +2 -0
handler.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
|
|
| 4 |
from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig
|
| 5 |
|
| 6 |
class EndpointHandler():
|
|
@@ -28,6 +29,7 @@ class EndpointHandler():
|
|
| 28 |
|
| 29 |
# run normal prediction
|
| 30 |
prediction = self.model.classify(inputs)
|
|
|
|
| 31 |
return prediction
|
| 32 |
|
| 33 |
class CustomModel(nn.Module):
|
|
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
+
import json
|
| 5 |
from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig
|
| 6 |
|
| 7 |
class EndpointHandler():
|
|
|
|
| 29 |
|
| 30 |
# run normal prediction
|
| 31 |
prediction = self.model.classify(inputs)
|
| 32 |
+
prediction = json.dumps(prediction)
|
| 33 |
return prediction
|
| 34 |
|
| 35 |
class CustomModel(nn.Module):
|