| | from typing import Optional |
| |
|
| | from fastapi import APIRouter |
| | from fastapi import FastAPI |
| | from schemas import ClassificationResult |
| | from utils import load_image |
| | from utils import load_model |
| |
|
| |
|
| | |
| |
|
| | model = load_model() |
| |
|
| | app = FastAPI( |
| | title="MosAl", |
| | openapi_url="/openapi.json", |
| | description="""Obtain classification predictions for mosquito image""", |
| | version="0.1.0", |
| | ) |
| |
|
| | api_router = APIRouter() |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | @api_router.get("/classify", status_code=200, response_model=ClassificationResult) |
| | async def predict_image(image_name, model=model): |
| | img = load_image(image_name) |
| | prediction, pred_idx, probs = model.predict(img) |
| | if prediction: |
| | return {"prediction": prediction, |
| | "score": round(probs.numpy()[pred_idx], 3), |
| | } |
| | else: |
| | return {"message": [0]} |
| |
|
| |
|
| |
|
| | app.include_router(api_router) |
| |
|
| | if __name__ == "__main__": |
| | |
| | import uvicorn |
| |
|
| | uvicorn.run(app, host="0.0.0.0", port=7860, log_level="debug") |
| |
|