|
|
from fastapi import FastAPI, HTTPException |
|
|
import os |
|
|
from starlette.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
import uvicorn |
|
|
from typing import List |
|
|
import torch |
|
|
from fastapi.encoders import jsonable_encoder |
|
|
from fastapi.responses import JSONResponse |
|
|
from simpletransformers.classification import ClassificationModel |
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
class Query(BaseModel): |
|
|
smiles :List[str] |
|
|
class PredictionResponse(BaseModel): |
|
|
predictions: List[int] |
|
|
|
|
|
rob_chem_model = ClassificationModel('roberta', 'BasselAhmed/RobertaChemClinToxTuned',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4}) |
|
|
|
|
|
@app.post("/ToxicityPrediction/") |
|
|
def c(query:Query)-> PredictionResponse: |
|
|
print(query) |
|
|
print(type(query)) |
|
|
query_dict = query.dict() |
|
|
try: |
|
|
predictions, raw_outputs = rob_chem_model.predict(query.smiles) |
|
|
print(predictions) |
|
|
print(type({"predictions": predictions})) |
|
|
|
|
|
|
|
|
query_dict.update({"predictions":predictions}) |
|
|
print(query_dict) |
|
|
print(type(query_dict)) |
|
|
|
|
|
|
|
|
|
|
|
return {"predictions" :list(predictions)} |
|
|
|
|
|
except Exception as e: |
|
|
print("Excepted") |
|
|
raise HTTPException(detail = str(e) , status_code = 500) |
|
|
|