edwjin commited on
Commit
28aad06
·
verified ·
1 Parent(s): 11d2c2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -35
app.py CHANGED
@@ -1,15 +1,12 @@
1
- from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
- from load_texts import load_texts
4
  from tokenizer import SimpleTokenizer
5
  from transformer import Classifier
6
  from constants import block_size
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
9
  import uvicorn
10
-
11
  import torch
12
-
13
  import pickle
14
 
15
  app = FastAPI()
@@ -23,35 +20,26 @@ app.add_middleware(
23
  )
24
 
25
  model = None
26
- tokenizer = None
27
  pres_dict = {}
28
 
29
- with open('pres_dict.pkl', 'rb') as file:
30
- reversed_dict = pickle.load(file)
 
31
  pres_dict = {value: key for key, value in reversed_dict.items()}
32
 
33
- def initialize():
34
- global model, tokenizer
35
-
36
- if not tokenizer:
37
- tokenizer = SimpleTokenizer()
38
-
39
- print('start tokenizer')
40
-
41
- for text in load_texts('train.tsv'):
42
- tokenizer.update_vocab(text.split('\t', 1)[1])
43
-
44
- print('finish tokenizer, vocab size is: ', tokenizer.vocab_size)
45
-
46
- if not model:
47
- model = Classifier(tokenizer.vocab_size)
48
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
- print('loading model')
50
- model.load_state_dict(torch.load('all_pres_classifier_model_dict.pth', map_location=device))
51
- print('finished loading model')
52
- model.to(device)
53
- model.eval()
54
 
 
 
 
 
 
 
 
 
 
55
  class TextInput(BaseModel):
56
  text: str
57
 
@@ -61,11 +49,6 @@ def home():
61
 
62
  @app.post("/predict")
63
  def predict(request: TextInput):
64
- global model, tokenizer
65
-
66
- if model is None or tokenizer is None:
67
- initialize()
68
-
69
  text = request.text
70
 
71
  # Get the text from the POST request body
@@ -82,5 +65,4 @@ def predict(request: TextInput):
82
 
83
  _, predicted = torch.max(output.data, 1)
84
 
85
- return {"predicted": pres_dict[predicted.tolist()[0]]}
86
-
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  from tokenizer import SimpleTokenizer
4
  from transformer import Classifier
5
  from constants import block_size
6
  from fastapi.middleware.cors import CORSMiddleware
7
 
8
  import uvicorn
 
9
  import torch
 
10
  import pickle
11
 
12
  app = FastAPI()
 
20
  )
21
 
22
  model = None
23
+ tokenizer = SimpleTokenizer()
24
  pres_dict = {}
25
 
26
+ # load in pres dicts
27
+ with open('speechesdataset/pres_dict.pkl', 'rb') as file1:
28
+ reversed_dict = pickle.load(file1)
29
  pres_dict = {value: key for key, value in reversed_dict.items()}
30
 
31
+ with open('speechesdataset/tokenizer.pkl', 'rb') as file:
32
+ tokenizer = pickle.load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # load in model
35
+ model = Classifier(tokenizer.vocab_size)
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ print('loading model')
38
+ model.load_state_dict(torch.load('speechesdataset/classifier_model_dict.pth', map_location=device))
39
+ print('finished loading model')
40
+ model.to(device)
41
+ model.eval()
42
+
43
  class TextInput(BaseModel):
44
  text: str
45
 
 
49
 
50
  @app.post("/predict")
51
  def predict(request: TextInput):
 
 
 
 
 
52
  text = request.text
53
 
54
  # Get the text from the POST request body
 
65
 
66
  _, predicted = torch.max(output.data, 1)
67
 
68
+ return {"predicted": pres_dict[predicted.tolist()[0]]}