codingcoolfun9ed commited on
Commit
528529c
·
verified ·
1 Parent(s): d49df57

lazy loading to try and bypass the 30 min startup limit

Browse files
Files changed (1) hide show
  1. api/predict.py +8 -5
api/predict.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
- import re
3
  import numpy as np
4
-
5
  from transformers import (
6
  DistilBertTokenizer, DistilBertForSequenceClassification,
7
  RobertaTokenizer, RobertaForSequenceClassification,
@@ -21,6 +20,8 @@ uncertaintyThreshold = 0.67
21
 
22
  CLASS_NAMES = ['genuine', 'fake']
23
 
 
 
24
  def validateText(text):
25
  if not isinstance(text, str):
26
  return False
@@ -37,9 +38,9 @@ def cleanReview(text):
37
  return text.strip()
38
 
39
  def loadResources():
40
- global models, tokenizers, maxLengths
41
 
42
- if len(models) > 0:
43
  return
44
 
45
  print("loading ensemble models...", flush=True)
@@ -120,10 +121,12 @@ def loadResources():
120
  print(f"error loading model {i}: {str(e)}", flush=True)
121
  raise
122
 
 
123
  print("all ensemble models loaded", flush=True)
124
 
125
  def ensemblePredict(text):
126
- loadResources()
 
127
 
128
  if not isinstance(text, str):
129
  text = str(text)
 
1
  import torch
 
2
  import numpy as np
3
+ import re
4
  from transformers import (
5
  DistilBertTokenizer, DistilBertForSequenceClassification,
6
  RobertaTokenizer, RobertaForSequenceClassification,
 
20
 
21
  CLASS_NAMES = ['genuine', 'fake']
22
 
23
+ models_loaded = False
24
+
25
  def validateText(text):
26
  if not isinstance(text, str):
27
  return False
 
38
  return text.strip()
39
 
40
  def loadResources():
41
+ global models, tokenizers, maxLengths, models_loaded
42
 
43
+ if models_loaded:
44
  return
45
 
46
  print("loading ensemble models...", flush=True)
 
121
  print(f"error loading model {i}: {str(e)}", flush=True)
122
  raise
123
 
124
+ models_loaded = True
125
  print("all ensemble models loaded", flush=True)
126
 
127
  def ensemblePredict(text):
128
+ if not models_loaded:
129
+ loadResources()
130
 
131
  if not isinstance(text, str):
132
  text = str(text)