BasselAhmed's picture
Update app.py
9b5892c verified
from fastapi import FastAPI, HTTPException
import os
from starlette.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from typing import List
import torch
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from simpletransformers.classification import ClassificationModel
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Query(BaseModel):
smiles :List[str]
class PredictionResponse(BaseModel):
predictions: List[int]
rob_chem_model = ClassificationModel('roberta', 'BasselAhmed/RobertaChemClinToxTuned',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})
@app.post("/ToxicityPrediction/")
def c(query:Query)-> PredictionResponse:
print(query)
print(type(query))
query_dict = query.dict()
try:
predictions, raw_outputs = rob_chem_model.predict(query.smiles)
print(predictions)
print(type({"predictions": predictions}))
#json_compatible_item_data = jsonable_encoder(predictions[0])
#return JSONResponse(content=json_compatible_item_data)
query_dict.update({"predictions":predictions})
print(query_dict)
print(type(query_dict))
#answer = {"prediction":predictions[0]}
#return {'response':answer}
return {"predictions" :list(predictions)}
#return list(predictions)
except Exception as e:
print("Excepted")
raise HTTPException(detail = str(e) , status_code = 500)