BasselAhmed commited on
Commit
e79e84e
·
verified ·
1 Parent(s): 41dfcd3

Adding FastapiEndpoint

Browse files
Files changed (1) hide show
  1. app.py +37 -4
app.py CHANGED
@@ -5,11 +5,43 @@ import torch
5
  import random
6
  import os
7
  import numpy as np
 
 
8
  random.seed(4)
9
  np.random.seed(4)
10
  torch.manual_seed(4)
11
  np.random.seed(4)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Set the Streamlit app title
14
  st.title("Molecule Toxicity Predictions")
15
 
@@ -18,7 +50,7 @@ path = 'ToxicityPrediction/Models/transformers/checkpoint-149-epoch-1'
18
 
19
  # Load the model from the stage
20
  #loaded_model = ClassificationModel('roberta', path, use_cuda = False)
21
- rob_chem_model = ClassificationModel('roberta', 'seyonec/SMILES_tokenized_PubChem_shard00_160k',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})
22
  # Predict based on the input
23
  rob_chem_model.model.eval()
24
  #target_name= st.text_input('Enter a SMILES string:')
@@ -28,6 +60,7 @@ target_name_list = target_name.splitlines()
28
  target_name_list = [x.strip() for x in target_name_list]
29
  predict_toxicity = st.button('Predict Toxicity')
30
  if predict_toxicity:
31
- predictions, raw_outputs = rob_chem_model.predict(target_name_list)
32
- df_pred = pd.DataFrame({'Smiles':target_name_list,'predictions': predictions})
33
- st.dataframe(df_pred)
 
 
5
  import random
6
  import os
7
  import numpy as np
8
+
9
+
10
  random.seed(4)
11
  np.random.seed(4)
12
  torch.manual_seed(4)
13
  np.random.seed(4)
14
 
15
+
16
+ from fastapi import FastAPI, HTTPException
17
+ import os
18
+ from starlette.middleware.cors import CORSMiddleware
19
+ from pydantic import BaseModel
20
+ import uvicorn
21
+ app = FastAPI()
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+ rob_chem_model = ClassificationModel('roberta', 'seyonec/SMILES_tokenized_PubChem_shard00_160k',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})
30
+
31
+ class Query(BaseModel):
32
+ query :str
33
+
34
+ @app.post("/ToxicityPrediction")
35
+ async def c(query:Query):
36
+ try:
37
+ predictions, raw_outputs = rob_chem_model.predict([str(query.query)])
38
+ st.write("Received request")
39
+ return {"prediction":predictions[0]}
40
+ except Exception as e:
41
+ raise HTTPException(detail = str(e) , status_code = 500)
42
+ if __name__ == "__main__":
43
+ uvicorn.run(app, host="0.0.0.0", port=5566)
44
+
45
  # Set the Streamlit app title
46
  st.title("Molecule Toxicity Predictions")
47
 
 
50
 
51
  # Load the model from the stage
52
  #loaded_model = ClassificationModel('roberta', path, use_cuda = False)
53
+ #rob_chem_model = ClassificationModel('roberta', 'seyonec/SMILES_tokenized_PubChem_shard00_160k',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})
54
  # Predict based on the input
55
  rob_chem_model.model.eval()
56
  #target_name= st.text_input('Enter a SMILES string:')
 
60
  target_name_list = [x.strip() for x in target_name_list]
61
  predict_toxicity = st.button('Predict Toxicity')
62
  if predict_toxicity:
63
+ #predictions, raw_outputs = rob_chem_model.predict(target_name_list)
64
+ #df_pred = pd.DataFrame({'Smiles':target_name_list,'predictions': predictions})
65
+ #st.dataframe(df_pred)
66
+ pass