from fastapi import FastAPI, HTTPException import os from starlette.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn from typing import List import torch from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from simpletransformers.classification import ClassificationModel app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Query(BaseModel): smiles :List[str] class PredictionResponse(BaseModel): predictions: List[int] rob_chem_model = ClassificationModel('roberta', 'BasselAhmed/RobertaChemClinToxTuned',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4}) @app.post("/ToxicityPrediction/") def c(query:Query)-> PredictionResponse: print(query) print(type(query)) query_dict = query.dict() try: predictions, raw_outputs = rob_chem_model.predict(query.smiles) print(predictions) print(type({"predictions": predictions})) #json_compatible_item_data = jsonable_encoder(predictions[0]) #return JSONResponse(content=json_compatible_item_data) query_dict.update({"predictions":predictions}) print(query_dict) print(type(query_dict)) #answer = {"prediction":predictions[0]} #return {'response':answer} return {"predictions" :list(predictions)} #return list(predictions) except Exception as e: print("Excepted") raise HTTPException(detail = str(e) , status_code = 500)