|
|
"""Main API for Code Comment Classification using FastAPI.""" |
|
|
from contextlib import asynccontextmanager |
|
|
from datetime import datetime |
|
|
from functools import lru_cache, wraps |
|
|
from http import HTTPStatus |
|
|
import inspect |
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
from api.schemas import PredictRequest |
|
|
from api.sync_models import sync_best_models_to_disk |
|
|
from fastapi import FastAPI, Request, Response |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
from codecommentclassification import ModelPredictor |
|
|
|
|
|
MODELS_DIR = Path(os.getenv("MODELS_DIR", "models/api")) |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=3) |
|
|
def get_predictor(lang: str, model_type: str) -> ModelPredictor: |
|
|
"""Lazily loads the heavy model only when requested.""" |
|
|
logger.info(f"Loading model for {lang} - {model_type}...") |
|
|
return ModelPredictor(lang=lang, model_type=model_type, model_root=str(MODELS_DIR)) |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Lifespan context manager to sync models at startup.""" |
|
|
try: |
|
|
logger.info(f"Syncing champion models from MLflow to {MODELS_DIR}...") |
|
|
sync_best_models_to_disk( |
|
|
models_root=MODELS_DIR.parent, |
|
|
api_subdir=MODELS_DIR.name, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to sync models at startup: {e}") |
|
|
|
|
|
if not MODELS_DIR.exists(): |
|
|
logger.warning(f"Models directory not found at: {MODELS_DIR.resolve()}") |
|
|
else: |
|
|
logger.info(f"Using models from: {MODELS_DIR.resolve()}") |
|
|
yield |
|
|
get_predictor.cache_clear() |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Code Comment Classification API", |
|
|
description="API for classifying code comments using SetFit models.", |
|
|
version="0.1", |
|
|
lifespan=lifespan, |
|
|
) |
|
|
|
|
|
frontend_origins = os.getenv("FRONTEND_ORIGINS") |
|
|
|
|
|
if frontend_origins: |
|
|
origins = [o.strip() for o in frontend_origins.split(",") if o.strip()] |
|
|
else: |
|
|
|
|
|
origins = [ |
|
|
"http://localhost:5173", |
|
|
"http://127.0.0.1:5173", |
|
|
"http://localhost", |
|
|
] |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
def _build_response(results: dict, request: Request): |
|
|
if isinstance(results, (Response, JSONResponse)): |
|
|
return results |
|
|
|
|
|
response = { |
|
|
"message": results["message"], |
|
|
"method": request.method, |
|
|
"status-code": results["status-code"], |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"url": request.url._url, |
|
|
} |
|
|
|
|
|
if "data" in results: |
|
|
response["data"] = results["data"] |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
def construct_response(f): |
|
|
"""Construct a JSON response for an endpoint's results (sync and async).""" |
|
|
if inspect.iscoroutinefunction(f): |
|
|
|
|
|
@wraps(f) |
|
|
async def wrap(request: Request, *args, **kwargs): |
|
|
results = await f(request, *args, **kwargs) |
|
|
return _build_response(results, request) |
|
|
else: |
|
|
|
|
|
@wraps(f) |
|
|
def wrap(request: Request, *args, **kwargs): |
|
|
results = f(request, *args, **kwargs) |
|
|
return _build_response(results, request) |
|
|
|
|
|
return wrap |
|
|
|
|
|
|
|
|
@app.get("/", tags=["General"]) |
|
|
@construct_response |
|
|
def _index(request: Request): |
|
|
"""Root endpoint.""" |
|
|
return { |
|
|
"message": HTTPStatus.OK.phrase, |
|
|
"status-code": HTTPStatus.OK, |
|
|
"data": { |
|
|
"message": "Welcome to the Code Comment Classification API! Please use /docs for API documentation." |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/privacy", tags=["General"]) |
|
|
@construct_response |
|
|
async def get_privacy_notice(request: Request): |
|
|
"""Return the Privacy Notice for the API.""" |
|
|
return { |
|
|
"message": "Privacy Notice", |
|
|
"status-code": HTTPStatus.OK, |
|
|
"data": { |
|
|
"policy": "This API processes text data for classification purposes only. No data is permanently stored.", |
|
|
"compliance_link": "https://behavizapi.peopleware.ai/api/docs#section/Getting-Started/Privacy-Notice", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/status") |
|
|
def get_status(): |
|
|
"""Endpoint to check if the API is running.""" |
|
|
return {"status": "API is running"} |
|
|
|
|
|
|
|
|
@app.get("/models", tags=["Prediction"]) |
|
|
@construct_response |
|
|
def _get_models_list(request: Request): |
|
|
"""Return the list of available languages based on directories found in models/ .""" |
|
|
|
|
|
if MODELS_DIR.exists(): |
|
|
available_languages = [ |
|
|
{"language": d.name, "model_types": mt.name} |
|
|
for d in MODELS_DIR.iterdir() |
|
|
if d.is_dir() |
|
|
for mt in d.iterdir() |
|
|
if mt.is_dir() |
|
|
] |
|
|
else: |
|
|
available_languages = [] |
|
|
|
|
|
return { |
|
|
"message": HTTPStatus.OK.phrase, |
|
|
"status-code": HTTPStatus.OK, |
|
|
"data": available_languages, |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/predict", tags=["Prediction"]) |
|
|
@construct_response |
|
|
def predict( |
|
|
request: Request, |
|
|
payload: PredictRequest, |
|
|
): |
|
|
"""Inference endpoint.""" |
|
|
if payload.model_type is None: |
|
|
return { |
|
|
"message": "Model type must be specified.", |
|
|
"status-code": HTTPStatus.BAD_REQUEST, |
|
|
} |
|
|
|
|
|
try: |
|
|
predictor = get_predictor(payload.language.value, payload.model_type.value) |
|
|
result = predictor.predict(payload.text) |
|
|
predictions_list = result.tolist() if hasattr(result, "tolist") else result |
|
|
|
|
|
return { |
|
|
"message": HTTPStatus.OK.phrase, |
|
|
"status-code": HTTPStatus.OK, |
|
|
"data": { |
|
|
"language": payload.language, |
|
|
"model_type": payload.model_type, |
|
|
"predictions": predictions_list, |
|
|
}, |
|
|
} |
|
|
|
|
|
except FileNotFoundError: |
|
|
return { |
|
|
"message": f"Model for language '{payload.language}' not found.", |
|
|
"status-code": HTTPStatus.NOT_FOUND, |
|
|
} |
|
|
except ValueError as e: |
|
|
return { |
|
|
"message": str(e), |
|
|
"status-code": HTTPStatus.BAD_REQUEST, |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"message": f"Internal Error: {str(e)}", |
|
|
"status-code": HTTPStatus.INTERNAL_SERVER_ERROR, |
|
|
} |
|
|
|