ToxicityPredictionApp / fastapi_app.py
BasselAhmed's picture
Update fastapi_app.py
619c025 verified
from fastapi import FastAPI, HTTPException
import os
from starlette.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import torch
from simpletransformers.classification import ClassificationModel
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Query(BaseModel):
query :str
rob_chem_model = ClassificationModel('roberta', 'seyonec/SMILES_tokenized_PubChem_shard00_160k',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})
@app.post("/ToxicityPrediction")
async def c(query:Query):
try:
predictions, raw_outputs = rob_chem_model.predict([str(query.query)])
print(predictions)
return {"prediction":predictions[0]}
except Exception as e:
raise HTTPException(detail = str(e) , status_code = 500)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=5566)