| | """Main script: REST API initialization and endpoints. |
| | |
| | This module defines the FastAPI application for: |
| | - Retrieving available ML models and supported programming languages |
| | - Classifying code comments using multi-label classification |
| | |
| | The API is designed to be client-agnostic and supports concurrent requests |
| | through asynchronous endpoint handlers and background thread execution |
| | for CPU-bound ML inference tasks. |
| | """ |
| |
|
| | import asyncio |
| | from concurrent.futures import ThreadPoolExecutor |
| | from contextlib import asynccontextmanager |
| | from datetime import datetime |
| | from functools import wraps |
| | from http import HTTPStatus |
| | from typing import Any, Callable, Dict |
| |
|
| | from fastapi import FastAPI, Request |
| | from fastapi.middleware.cors import CORSMiddleware |
| |
|
| | from nygaardcodecommentclassification import config |
| | from nygaardcodecommentclassification.api.controllers import PredictionController |
| | from nygaardcodecommentclassification.api.schemas import PredictionRequest |
| |
|
| | |
| | |
| | |
| | |
| | controller = PredictionController() |
| |
|
| | |
| | |
| | _executor = ThreadPoolExecutor(max_workers=4) |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """Async context manager for application lifecycle events. |
| | |
| | This handles: |
| | - Startup: Load all ML models into memory for fast inference |
| | - Shutdown: Release model resources and clear GPU memory if applicable |
| | |
| | Args: |
| | app: The FastAPI application instance |
| | |
| | Yields: |
| | Control back to the application after startup is complete |
| | """ |
| | |
| | controller.startup() |
| | yield |
| | |
| | controller.shutdown() |
| | _executor.shutdown(wait=True) |
| |
|
| |
|
| | |
| | |
| | |
| | app = FastAPI( |
| | title="Nygaard Code Comment Classification API", |
| | description=""" |
| | Multi-label classification API for code comments. |
| | """, |
| | version="1.0", |
| | lifespan=lifespan, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | def construct_response(f: Callable) -> Callable: |
| | """Decorator to enforce a consistent JSON response structure across all endpoints. |
| | |
| | This decorator wraps endpoint functions to provide: |
| | - Uniform response format with timestamp, method, URL, status, and data |
| | - Centralized error handling for ValueError (client errors) and Exception (server errors) |
| | - Automatic HTTP status code mapping |
| | |
| | Args: |
| | f: The endpoint function to wrap |
| | |
| | Returns: |
| | Wrapped function that returns a standardized response dict |
| | |
| | Response Structure: |
| | { |
| | "timestamp": "ISO 8601 timestamp", |
| | "method": "HTTP method (GET, POST, etc.)", |
| | "url": "Full request URL", |
| | "status-code": "HTTP status code", |
| | "message": "Status message or error description", |
| | "data": "Response payload (if successful)" |
| | } |
| | """ |
| |
|
| | @wraps(f) |
| | async def wrap(request: Request, *args, **kwargs) -> Dict[str, Any]: |
| | |
| | response_struct: Dict[str, Any] = { |
| | "timestamp": datetime.now().isoformat(), |
| | "method": request.method, |
| | "url": str(request.url), |
| | } |
| |
|
| | try: |
| | |
| | results = await f(request, *args, **kwargs) |
| |
|
| | |
| | if isinstance(results, dict) and "status-code" in results: |
| | response_struct.update(results) |
| | else: |
| | |
| | response_struct["status-code"] = HTTPStatus.OK |
| | response_struct["message"] = HTTPStatus.OK.phrase |
| | response_struct["data"] = results |
| |
|
| | except ValueError as e: |
| | |
| | response_struct["status-code"] = HTTPStatus.BAD_REQUEST |
| | response_struct["message"] = str(e) |
| | except Exception as e: |
| | |
| | response_struct["status-code"] = HTTPStatus.INTERNAL_SERVER_ERROR |
| | response_struct["message"] = f"Internal Server Error: {str(e)}" |
| |
|
| | return response_struct |
| |
|
| | return wrap |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @app.get("/models", tags=["Info"]) |
| | @construct_response |
| | async def _get_models(request: Request) -> Dict[str, Any]: |
| | """Retrieve the list of available ML models grouped by language. |
| | |
| | Returns: |
| | Dict containing: |
| | - status-code: HTTP 200 on success |
| | - message: Status description |
| | - data: Dict mapping languages to available model types |
| | |
| | Example Response: |
| | { |
| | "java": ["catboost"], |
| | "python": ["catboost"], |
| | "pharo": ["catboost"] |
| | } |
| | """ |
| | data = controller.get_models_info() |
| | return {"status-code": HTTPStatus.OK, "message": "Available models retrieved", "data": data} |
| |
|
| |
|
| | @app.get("/languages", tags=["Info"]) |
| | @construct_response |
| | async def _get_languages(request: Request) -> Dict[str, Any]: |
| | """Retrieve the list of supported programming languages. |
| | |
| | Returns the programming languages for which code comment classification |
| | is available. Each language has its own trained model. |
| | |
| | Returns: |
| | Dict containing: |
| | - status-code: HTTP 200 on success |
| | - message: Status description |
| | - data: Dict with "languages" key containing list of supported languages |
| | |
| | Example Response: |
| | { |
| | "languages": ["java", "python", "pharo"] |
| | } |
| | """ |
| | data = {"languages": config.LANGUAGES} |
| | return {"status-code": HTTPStatus.OK, "message": "Supported languages retrieved", "data": data} |
| |
|
| |
|
| | @app.post("/predict", tags=["Prediction"]) |
| | @construct_response |
| | async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, Any]: |
| | """Classify code comments using multi-label classification. |
| | |
| | This endpoint performs ML inference to classify code comments into |
| | multiple categories. |
| | |
| | Args: |
| | request: The FastAPI request object |
| | payload: PredictionRequest containing: |
| | - texts: List of code comments to classify |
| | - language: Programming language ("java", "python", "pharo") |
| | - model_type: Model to use (default: "catboost") |
| | |
| | Returns: |
| | Dict containing: |
| | - status-code: HTTP 200 on success, 400 on invalid input, 500 on error |
| | - message: Status description |
| | - data: Dict with model_used, language, and results list |
| | |
| | Example Request: |
| | POST /predict |
| | { |
| | "texts": ["This method calculates fibonacci", "this is a deprecated function"], |
| | "language": "java", |
| | "model_type": "catboost" |
| | } |
| | |
| | Example Response: |
| | { |
| | "results": [ |
| | {"text": "This method calculates fibonacci", "labels": ["summary"]}, |
| | {"text": "this is a deprecated function", "labels": ["deprecation"]} |
| | ] |
| | } |
| | """ |
| | loop = asyncio.get_event_loop() |
| | results = await loop.run_in_executor( |
| | _executor, controller.predict, payload.texts, payload.language, payload.model_type |
| | ) |
| |
|
| | return { |
| | "status-code": HTTPStatus.OK, |
| | "message": "Prediction successful", |
| | "data": { |
| | "model_used": payload.model_type, |
| | "language": payload.language, |
| | "results": results, |
| | }, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|