Sky-Blue-da-ba-dee's picture
added files
ac9ddbb
raw
history blame
6.38 kB
"""Main API for Code Comment Classification using FastAPI."""
from contextlib import asynccontextmanager
from datetime import datetime
from functools import lru_cache, wraps
from http import HTTPStatus
import inspect
import logging
import os
from pathlib import Path
from api.schemas import PredictRequest
from api.sync_models import sync_best_models_to_disk
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from codecommentclassification import ModelPredictor
MODELS_DIR = Path(os.getenv("MODELS_DIR", "models/api"))
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@lru_cache(maxsize=3)
def get_predictor(lang: str, model_type: str) -> ModelPredictor:
"""Lazily loads the heavy model only when requested."""
logger.info(f"Loading model for {lang} - {model_type}...")
return ModelPredictor(lang=lang, model_type=model_type, model_root=str(MODELS_DIR))
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager to sync models at startup."""
try:
logger.info(f"Syncing champion models from MLflow to {MODELS_DIR}...")
sync_best_models_to_disk(
models_root=MODELS_DIR.parent,
api_subdir=MODELS_DIR.name,
)
except Exception as e:
logger.error(f"Failed to sync models at startup: {e}")
if not MODELS_DIR.exists():
logger.warning(f"Models directory not found at: {MODELS_DIR.resolve()}")
else:
logger.info(f"Using models from: {MODELS_DIR.resolve()}")
yield
get_predictor.cache_clear()
app = FastAPI(
title="Code Comment Classification API",
description="API for classifying code comments using SetFit models.",
version="0.1",
lifespan=lifespan,
)
frontend_origins = os.getenv("FRONTEND_ORIGINS")
if frontend_origins:
origins = [o.strip() for o in frontend_origins.split(",") if o.strip()]
else:
# default di sviluppo
origins = [
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def _build_response(results: dict, request: Request):
if isinstance(results, (Response, JSONResponse)):
return results
response = {
"message": results["message"],
"method": request.method,
"status-code": results["status-code"],
"timestamp": datetime.now().isoformat(),
"url": request.url._url,
}
if "data" in results:
response["data"] = results["data"]
return response
def construct_response(f):
"""Construct a JSON response for an endpoint's results (sync and async)."""
if inspect.iscoroutinefunction(f):
@wraps(f)
async def wrap(request: Request, *args, **kwargs):
results = await f(request, *args, **kwargs)
return _build_response(results, request)
else:
@wraps(f)
def wrap(request: Request, *args, **kwargs):
results = f(request, *args, **kwargs)
return _build_response(results, request)
return wrap
@app.get("/", tags=["General"])
@construct_response
def _index(request: Request):
"""Root endpoint."""
return {
"message": HTTPStatus.OK.phrase,
"status-code": HTTPStatus.OK,
"data": {
"message": "Welcome to the Code Comment Classification API! Please use /docs for API documentation."
},
}
@app.get("/privacy", tags=["General"])
@construct_response
async def get_privacy_notice(request: Request):
"""Return the Privacy Notice for the API."""
return {
"message": "Privacy Notice",
"status-code": HTTPStatus.OK,
"data": {
"policy": "This API processes text data for classification purposes only. No data is permanently stored.",
"compliance_link": "https://behavizapi.peopleware.ai/api/docs#section/Getting-Started/Privacy-Notice",
},
}
@app.get("/status")
def get_status():
"""Endpoint to check if the API is running."""
return {"status": "API is running"}
@app.get("/models", tags=["Prediction"])
@construct_response
def _get_models_list(request: Request):
"""Return the list of available languages based on directories found in models/ ."""
# Since we aren't pre-loading, we scan the directory to see what IS available
if MODELS_DIR.exists():
available_languages = [
{"language": d.name, "model_types": mt.name}
for d in MODELS_DIR.iterdir()
if d.is_dir()
for mt in d.iterdir()
if mt.is_dir()
]
else:
available_languages = []
return {
"message": HTTPStatus.OK.phrase,
"status-code": HTTPStatus.OK,
"data": available_languages,
}
@app.post("/predict", tags=["Prediction"])
@construct_response
def predict(
request: Request,
payload: PredictRequest,
):
"""Inference endpoint."""
if payload.model_type is None:
return {
"message": "Model type must be specified.",
"status-code": HTTPStatus.BAD_REQUEST,
}
try:
predictor = get_predictor(payload.language.value, payload.model_type.value)
result = predictor.predict(payload.text)
predictions_list = result.tolist() if hasattr(result, "tolist") else result
return {
"message": HTTPStatus.OK.phrase,
"status-code": HTTPStatus.OK,
"data": {
"language": payload.language,
"model_type": payload.model_type,
"predictions": predictions_list,
},
}
except FileNotFoundError:
return {
"message": f"Model for language '{payload.language}' not found.",
"status-code": HTTPStatus.NOT_FOUND,
}
except ValueError as e:
return {
"message": str(e),
"status-code": HTTPStatus.BAD_REQUEST,
}
except Exception as e:
return {
"message": f"Internal Error: {str(e)}",
"status-code": HTTPStatus.INTERNAL_SERVER_ERROR,
}