|
|
"""Main script: REST API initialization and endpoints. |
|
|
|
|
|
This module defines the FastAPI application for: |
|
|
- Retrieving available ML models and supported programming languages |
|
|
- Classifying code comments using multi-label classification |
|
|
|
|
|
The API is designed to be client-agnostic and supports concurrent requests |
|
|
through asynchronous endpoint handlers and background thread execution |
|
|
for CPU-bound ML inference tasks. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from contextlib import asynccontextmanager |
|
|
from datetime import datetime |
|
|
from functools import wraps |
|
|
from http import HTTPStatus |
|
|
import json |
|
|
from typing import Any, Callable, Dict |
|
|
|
|
|
from fastapi import FastAPI, Request, Response |
|
|
from fastapi.responses import RedirectResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
from nygaardcodecommentclassification import config |
|
|
from nygaardcodecommentclassification.api.controllers import PredictionController |
|
|
from nygaardcodecommentclassification.api.schemas import PredictionRequest |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
controller = PredictionController() |
|
|
|
|
|
|
|
|
|
|
|
_executor = ThreadPoolExecutor(max_workers=4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI) -> Any: |
|
|
"""Async context manager for application lifecycle events. |
|
|
|
|
|
This handles: |
|
|
- Startup: Load all ML models into memory for fast inference |
|
|
- Shutdown: Release model resources and clear GPU memory if applicable |
|
|
|
|
|
Args: |
|
|
app: The FastAPI application instance |
|
|
|
|
|
Yields: |
|
|
None: Control back to the application after startup is complete |
|
|
""" |
|
|
|
|
|
controller.startup() |
|
|
yield |
|
|
|
|
|
controller.shutdown() |
|
|
_executor.shutdown(wait=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Nygaard Code Comment Classification API", |
|
|
description=""" |
|
|
Multi-label classification API for code comments. |
|
|
""", |
|
|
version="1.0", |
|
|
lifespan=lifespan, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct_response(f: Callable) -> Callable: |
|
|
"""Decorator to enforce a consistent JSON response structure across all endpoints. |
|
|
|
|
|
This decorator wraps endpoint functions to provide: |
|
|
- Uniform response format with timestamp, method, URL, status, and data |
|
|
- Centralized error handling for ValueError (client errors) and Exception (server errors) |
|
|
- Automatic HTTP status code mapping |
|
|
|
|
|
Args: |
|
|
f: The endpoint function to wrap |
|
|
|
|
|
Returns: |
|
|
Wrapped function that returns a standardized response dict |
|
|
|
|
|
Response Structure: |
|
|
{ |
|
|
"timestamp": "ISO 8601 timestamp", |
|
|
"method": "HTTP method (GET, POST, etc.)", |
|
|
"url": "Full request URL", |
|
|
"status-code": "HTTP status code", |
|
|
"message": "Status message or error description", |
|
|
"data": "Response payload (if successful)" |
|
|
} |
|
|
""" |
|
|
|
|
|
@wraps(f) |
|
|
async def wrap(request: Request, *args, **kwargs) -> Dict[str, Any]: |
|
|
|
|
|
response_struct: Dict[str, Any] = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"method": request.method, |
|
|
"url": str(request.url), |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
results = await f(request, *args, **kwargs) |
|
|
|
|
|
|
|
|
if isinstance(results, dict) and "status-code" in results: |
|
|
response_struct.update(results) |
|
|
else: |
|
|
|
|
|
response_struct["status-code"] = HTTPStatus.OK |
|
|
response_struct["message"] = HTTPStatus.OK.phrase |
|
|
response_struct["data"] = results |
|
|
|
|
|
except ValueError as e: |
|
|
|
|
|
response_struct["status-code"] = HTTPStatus.BAD_REQUEST |
|
|
response_struct["message"] = str(e) |
|
|
except Exception as e: |
|
|
|
|
|
response_struct["status-code"] = HTTPStatus.INTERNAL_SERVER_ERROR |
|
|
response_struct["message"] = f"Internal Server Error: {str(e)}" |
|
|
|
|
|
return response_struct |
|
|
|
|
|
return wrap |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/models", tags=["Info"]) |
|
|
@construct_response |
|
|
async def _get_models(request: Request) -> Dict[str, Any]: |
|
|
"""Retrieve the list of available ML models grouped by language. |
|
|
|
|
|
Returns: |
|
|
Dict containing: |
|
|
- status-code: HTTP 200 on success |
|
|
- message: Status description |
|
|
- data: Dict mapping languages to available model types |
|
|
|
|
|
Example Response: |
|
|
{ |
|
|
"java": ["catboost"], |
|
|
"python": ["catboost"], |
|
|
"pharo": ["catboost"] |
|
|
} |
|
|
""" |
|
|
data = controller.get_models_info() |
|
|
return {"status-code": HTTPStatus.OK, "message": "Available models retrieved", "data": data} |
|
|
|
|
|
|
|
|
@app.get("/languages", tags=["Info"]) |
|
|
@construct_response |
|
|
async def _get_languages(request: Request) -> Dict[str, Any]: |
|
|
"""Retrieve the list of supported programming languages. |
|
|
|
|
|
Returns the programming languages for which code comment classification |
|
|
is available. Each language has its own trained model. |
|
|
|
|
|
Returns: |
|
|
Dict containing: |
|
|
- status-code: HTTP 200 on success |
|
|
- message: Status description |
|
|
- data: Dict with "languages" key containing list of supported languages |
|
|
|
|
|
Example Response: |
|
|
{ |
|
|
"languages": ["java", "python", "pharo"] |
|
|
} |
|
|
""" |
|
|
data = {"languages": config.LANGUAGES} |
|
|
return {"status-code": HTTPStatus.OK, "message": "Supported languages retrieved", "data": data} |
|
|
|
|
|
|
|
|
@app.post("/predict", tags=["Prediction"]) |
|
|
@construct_response |
|
|
async def _predict( |
|
|
request: Request, response: Response, payload: PredictionRequest |
|
|
) -> Dict[str, Any]: |
|
|
"""Classify code comments using multi-label classification. |
|
|
|
|
|
This endpoint performs ML inference to classify code comments into |
|
|
multiple categories. |
|
|
|
|
|
Args: |
|
|
request: The FastAPI request object |
|
|
response: The FastAPI response object |
|
|
payload: PredictionRequest containing: |
|
|
- texts: List of code comments |
|
|
- class_names: List of class names corresponding to each comment |
|
|
- language: Programming language ("java", "python", "pharo") |
|
|
- model_type: Model to use (default: "catboost") |
|
|
|
|
|
Returns: |
|
|
Dict containing: |
|
|
- status-code: HTTP 200 on success, 400 on invalid input, 500 on error |
|
|
- message: Status description |
|
|
- data: Dict with model_used, language, and results list |
|
|
|
|
|
Example Request: |
|
|
POST /predict |
|
|
{ |
|
|
"texts": ["This method calculates fibonacci", "this is a deprecated function"], |
|
|
"class_names": ["MathUtils", "Utils"], |
|
|
"language": "java", |
|
|
"model_type": "catboost" |
|
|
} |
|
|
|
|
|
Example Response: |
|
|
{ |
|
|
"results": [ |
|
|
{"text": "This method calculates fibonacci", "class_name": "MathUtils", "labels": ["summary"]}, |
|
|
{"text": "this is a deprecated function", "class_name": "Utils", "labels": ["deprecation"]} |
|
|
] |
|
|
} |
|
|
""" |
|
|
loop = asyncio.get_event_loop() |
|
|
results = await loop.run_in_executor( |
|
|
_executor, |
|
|
controller.predict, |
|
|
payload.texts, |
|
|
payload.class_names, |
|
|
payload.language, |
|
|
payload.model_type, |
|
|
) |
|
|
|
|
|
response.headers["X-model"] = payload.model_type |
|
|
response.headers["X-language"] = payload.language |
|
|
|
|
|
|
|
|
all_labels = [label for result in results for label in result["labels"]] |
|
|
response.headers["X-predicted-labels"] = json.dumps(all_labels) |
|
|
|
|
|
return { |
|
|
"status-code": HTTPStatus.OK, |
|
|
"message": "Prediction successful", |
|
|
"data": { |
|
|
"model_used": payload.model_type, |
|
|
"language": payload.language, |
|
|
"results": results, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/", tags=["Info"]) |
|
|
async def _root(request: Request) -> RedirectResponse: |
|
|
"""Root endpoint redirecting to API documentation. |
|
|
Returns: |
|
|
Redirect response to the auto-generated API docs at /docs |
|
|
""" |
|
|
return RedirectResponse(url="/docs") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|