lazy loading to try and bypass the 30 min startup limit
Browse files- api/predict.py +8 -5
api/predict.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
-
import re
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from transformers import (
|
| 6 |
DistilBertTokenizer, DistilBertForSequenceClassification,
|
| 7 |
RobertaTokenizer, RobertaForSequenceClassification,
|
|
@@ -21,6 +20,8 @@ uncertaintyThreshold = 0.67
|
|
| 21 |
|
| 22 |
CLASS_NAMES = ['genuine', 'fake']
|
| 23 |
|
|
|
|
|
|
|
| 24 |
def validateText(text):
|
| 25 |
if not isinstance(text, str):
|
| 26 |
return False
|
|
@@ -37,9 +38,9 @@ def cleanReview(text):
|
|
| 37 |
return text.strip()
|
| 38 |
|
| 39 |
def loadResources():
|
| 40 |
-
global models, tokenizers, maxLengths
|
| 41 |
|
| 42 |
-
if
|
| 43 |
return
|
| 44 |
|
| 45 |
print("loading ensemble models...", flush=True)
|
|
@@ -120,10 +121,12 @@ def loadResources():
|
|
| 120 |
print(f"error loading model {i}: {str(e)}", flush=True)
|
| 121 |
raise
|
| 122 |
|
|
|
|
| 123 |
print("all ensemble models loaded", flush=True)
|
| 124 |
|
| 125 |
def ensemblePredict(text):
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
if not isinstance(text, str):
|
| 129 |
text = str(text)
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
+
import re
|
| 4 |
from transformers import (
|
| 5 |
DistilBertTokenizer, DistilBertForSequenceClassification,
|
| 6 |
RobertaTokenizer, RobertaForSequenceClassification,
|
|
|
|
| 20 |
|
| 21 |
CLASS_NAMES = ['genuine', 'fake']
|
| 22 |
|
| 23 |
+
models_loaded = False
|
| 24 |
+
|
| 25 |
def validateText(text):
|
| 26 |
if not isinstance(text, str):
|
| 27 |
return False
|
|
|
|
| 38 |
return text.strip()
|
| 39 |
|
| 40 |
def loadResources():
|
| 41 |
+
global models, tokenizers, maxLengths, models_loaded
|
| 42 |
|
| 43 |
+
if models_loaded:
|
| 44 |
return
|
| 45 |
|
| 46 |
print("loading ensemble models...", flush=True)
|
|
|
|
| 121 |
print(f"error loading model {i}: {str(e)}", flush=True)
|
| 122 |
raise
|
| 123 |
|
| 124 |
+
models_loaded = True
|
| 125 |
print("all ensemble models loaded", flush=True)
|
| 126 |
|
| 127 |
def ensemblePredict(text):
|
| 128 |
+
if not models_loaded:
|
| 129 |
+
loadResources()
|
| 130 |
|
| 131 |
if not isinstance(text, str):
|
| 132 |
text = str(text)
|