File size: 1,658 Bytes
c5fad91 c4e0631 d813078 c4e0631 760c66c c4e0631 d813078 2cea6cc 9b5892c ccd32c7 c4e0631 2cea6cc 9b5892c f367c81 d813078 c4e0631 b3a49b9 c4e0631 2cea6cc e15c96d d813078 2cea6cc d813078 2cea6cc 9b5892c c4e0631 46b0d3e c4e0631 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from fastapi import FastAPI, HTTPException
import os
from starlette.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from typing import List
import torch
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from simpletransformers.classification import ClassificationModel
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Query(BaseModel):
smiles :List[str]
class PredictionResponse(BaseModel):
predictions: List[int]
rob_chem_model = ClassificationModel('roberta', 'BasselAhmed/RobertaChemClinToxTuned',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})
@app.post("/ToxicityPrediction/")
def c(query:Query)-> PredictionResponse:
print(query)
print(type(query))
query_dict = query.dict()
try:
predictions, raw_outputs = rob_chem_model.predict(query.smiles)
print(predictions)
print(type({"predictions": predictions}))
#json_compatible_item_data = jsonable_encoder(predictions[0])
#return JSONResponse(content=json_compatible_item_data)
query_dict.update({"predictions":predictions})
print(query_dict)
print(type(query_dict))
#answer = {"prediction":predictions[0]}
#return {'response':answer}
return {"predictions" :list(predictions)}
#return list(predictions)
except Exception as e:
print("Excepted")
raise HTTPException(detail = str(e) , status_code = 500)
|