"""Main API for Code Comment Classification using FastAPI.""" from contextlib import asynccontextmanager from datetime import datetime from functools import lru_cache, wraps from http import HTTPStatus import inspect import logging import os from pathlib import Path from api.schemas import PredictRequest from api.sync_models import sync_best_models_to_disk from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from codecommentclassification import ModelPredictor MODELS_DIR = Path(os.getenv("MODELS_DIR", "models/api")) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @lru_cache(maxsize=3) def get_predictor(lang: str, model_type: str) -> ModelPredictor: """Lazily loads the heavy model only when requested.""" logger.info(f"Loading model for {lang} - {model_type}...") return ModelPredictor(lang=lang, model_type=model_type, model_root=str(MODELS_DIR)) @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager to sync models at startup.""" try: logger.info(f"Syncing champion models from MLflow to {MODELS_DIR}...") sync_best_models_to_disk( models_root=MODELS_DIR.parent, api_subdir=MODELS_DIR.name, ) except Exception as e: logger.error(f"Failed to sync models at startup: {e}") if not MODELS_DIR.exists(): logger.warning(f"Models directory not found at: {MODELS_DIR.resolve()}") else: logger.info(f"Using models from: {MODELS_DIR.resolve()}") yield get_predictor.cache_clear() app = FastAPI( title="Code Comment Classification API", description="API for classifying code comments using SetFit models.", version="0.1", lifespan=lifespan, ) frontend_origins = os.getenv("FRONTEND_ORIGINS") if frontend_origins: origins = [o.strip() for o in frontend_origins.split(",") if o.strip()] else: # default di sviluppo origins = [ "http://localhost:5173", "http://127.0.0.1:5173", "http://localhost", ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def _build_response(results: dict, request: Request): if isinstance(results, (Response, JSONResponse)): return results response = { "message": results["message"], "method": request.method, "status-code": results["status-code"], "timestamp": datetime.now().isoformat(), "url": request.url._url, } if "data" in results: response["data"] = results["data"] return response def construct_response(f): """Construct a JSON response for an endpoint's results (sync and async).""" if inspect.iscoroutinefunction(f): @wraps(f) async def wrap(request: Request, *args, **kwargs): results = await f(request, *args, **kwargs) return _build_response(results, request) else: @wraps(f) def wrap(request: Request, *args, **kwargs): results = f(request, *args, **kwargs) return _build_response(results, request) return wrap @app.get("/", tags=["General"]) @construct_response def _index(request: Request): """Root endpoint.""" return { "message": HTTPStatus.OK.phrase, "status-code": HTTPStatus.OK, "data": { "message": "Welcome to the Code Comment Classification API! Please use /docs for API documentation." }, } @app.get("/privacy", tags=["General"]) @construct_response async def get_privacy_notice(request: Request): """Return the Privacy Notice for the API.""" return { "message": "Privacy Notice", "status-code": HTTPStatus.OK, "data": { "policy": "This API processes text data for classification purposes only. No data is permanently stored.", "compliance_link": "https://behavizapi.peopleware.ai/api/docs#section/Getting-Started/Privacy-Notice", }, } @app.get("/status") def get_status(): """Endpoint to check if the API is running.""" return {"status": "API is running"} @app.get("/models", tags=["Prediction"]) @construct_response def _get_models_list(request: Request): """Return the list of available languages based on directories found in models/ .""" # Since we aren't pre-loading, we scan the directory to see what IS available if MODELS_DIR.exists(): available_languages = [ {"language": d.name, "model_types": mt.name} for d in MODELS_DIR.iterdir() if d.is_dir() for mt in d.iterdir() if mt.is_dir() ] else: available_languages = [] return { "message": HTTPStatus.OK.phrase, "status-code": HTTPStatus.OK, "data": available_languages, } @app.post("/predict", tags=["Prediction"]) @construct_response def predict( request: Request, payload: PredictRequest, ): """Inference endpoint.""" if payload.model_type is None: return { "message": "Model type must be specified.", "status-code": HTTPStatus.BAD_REQUEST, } try: predictor = get_predictor(payload.language.value, payload.model_type.value) result = predictor.predict(payload.text) predictions_list = result.tolist() if hasattr(result, "tolist") else result return { "message": HTTPStatus.OK.phrase, "status-code": HTTPStatus.OK, "data": { "language": payload.language, "model_type": payload.model_type, "predictions": predictions_list, }, } except FileNotFoundError: return { "message": f"Model for language '{payload.language}' not found.", "status-code": HTTPStatus.NOT_FOUND, } except ValueError as e: return { "message": str(e), "status-code": HTTPStatus.BAD_REQUEST, } except Exception as e: return { "message": f"Internal Error: {str(e)}", "status-code": HTTPStatus.INTERNAL_SERVER_ERROR, }