File size: 1,658 Bytes
c5fad91
c4e0631
 
 
 
d813078
c4e0631
760c66c
 
c4e0631
 
 
 
 
 
 
 
 
 
 
 
d813078
2cea6cc
9b5892c
 
ccd32c7
c4e0631
2cea6cc
9b5892c
f367c81
 
d813078
c4e0631
b3a49b9
c4e0631
2cea6cc
e15c96d
 
d813078
2cea6cc
 
d813078
 
2cea6cc
9b5892c
 
c4e0631
46b0d3e
c4e0631
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from fastapi import FastAPI, HTTPException
import os
from starlette.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from typing import List
import torch
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from simpletransformers.classification import ClassificationModel
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class Query(BaseModel):
    smiles :List[str]
class PredictionResponse(BaseModel):
    predictions: List[int]

rob_chem_model = ClassificationModel('roberta', 'BasselAhmed/RobertaChemClinToxTuned',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})

@app.post("/ToxicityPrediction/")
def c(query:Query)-> PredictionResponse: 
    print(query)
    print(type(query))
    query_dict = query.dict()
    try:
        predictions, raw_outputs = rob_chem_model.predict(query.smiles)
        print(predictions)
        print(type({"predictions": predictions}))
        #json_compatible_item_data = jsonable_encoder(predictions[0])
        #return JSONResponse(content=json_compatible_item_data)
        query_dict.update({"predictions":predictions})
        print(query_dict)
        print(type(query_dict))
        #answer = {"prediction":predictions[0]}
        #return {'response':answer}
        
        return {"predictions" :list(predictions)}
        #return list(predictions)
    except Exception as e:
        print("Excepted")
        raise HTTPException(detail = str(e) , status_code = 500)