Spaces:
Runtime error
Runtime error
Update nygaardcodecommentclassification/api/API.py
Browse files
nygaardcodecommentclassification/api/API.py
CHANGED
|
@@ -27,8 +27,8 @@ from nygaardcodecommentclassification.api.schemas import PredictionRequest
|
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
# Global Resources
|
| 29 |
# ---------------------------------------------------------------------------
|
| 30 |
-
# Initialize the prediction controller
|
| 31 |
-
controller = PredictionController(
|
| 32 |
|
| 33 |
# Thread pool for CPU-bound ML inference tasks
|
| 34 |
# This prevents blocking the async event loop during model predictions
|
|
@@ -56,6 +56,7 @@ async def lifespan(app: FastAPI):
|
|
| 56 |
controller.shutdown()
|
| 57 |
_executor.shutdown(wait=True)
|
| 58 |
|
|
|
|
| 59 |
# ---------------------------------------------------------------------------
|
| 60 |
# FastAPI Application Definition
|
| 61 |
# ---------------------------------------------------------------------------
|
|
@@ -108,6 +109,7 @@ def construct_response(f: Callable) -> Callable:
|
|
| 108 |
"data": "Response payload (if successful)"
|
| 109 |
}
|
| 110 |
"""
|
|
|
|
| 111 |
@wraps(f)
|
| 112 |
async def wrap(request: Request, *args, **kwargs) -> Dict[str, Any]:
|
| 113 |
# Initialize response with request metadata
|
|
@@ -120,7 +122,7 @@ def construct_response(f: Callable) -> Callable:
|
|
| 120 |
try:
|
| 121 |
# Execute the wrapped endpoint function
|
| 122 |
results = await f(request, *args, **kwargs)
|
| 123 |
-
|
| 124 |
# If function returns a dict with status/message/data, use it directly
|
| 125 |
if isinstance(results, dict) and "status-code" in results:
|
| 126 |
response_struct.update(results)
|
|
@@ -144,12 +146,11 @@ def construct_response(f: Callable) -> Callable:
|
|
| 144 |
return wrap
|
| 145 |
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
# ---------------------------------------------------------------------------
|
| 150 |
# API Endpoints
|
| 151 |
# ---------------------------------------------------------------------------
|
| 152 |
|
|
|
|
| 153 |
@app.get("/models", tags=["Info"])
|
| 154 |
@construct_response
|
| 155 |
async def _get_models(request: Request) -> Dict[str, Any]:
|
|
@@ -169,11 +170,7 @@ async def _get_models(request: Request) -> Dict[str, Any]:
|
|
| 169 |
}
|
| 170 |
"""
|
| 171 |
data = controller.get_models_info()
|
| 172 |
-
return {
|
| 173 |
-
"status-code": HTTPStatus.OK,
|
| 174 |
-
"message": "Available models retrieved",
|
| 175 |
-
"data": data
|
| 176 |
-
}
|
| 177 |
|
| 178 |
|
| 179 |
@app.get("/languages", tags=["Info"])
|
|
@@ -196,11 +193,7 @@ async def _get_languages(request: Request) -> Dict[str, Any]:
|
|
| 196 |
}
|
| 197 |
"""
|
| 198 |
data = {"languages": config.LANGUAGES}
|
| 199 |
-
return {
|
| 200 |
-
"status-code": HTTPStatus.OK,
|
| 201 |
-
"message": "Supported languages retrieved",
|
| 202 |
-
"data": data
|
| 203 |
-
}
|
| 204 |
|
| 205 |
|
| 206 |
@app.post("/predict", tags=["Prediction"])
|
|
@@ -214,7 +207,8 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
|
|
| 214 |
Args:
|
| 215 |
request: The FastAPI request object
|
| 216 |
payload: PredictionRequest containing:
|
| 217 |
-
- texts: List of code comments
|
|
|
|
| 218 |
- language: Programming language ("java", "python", "pharo")
|
| 219 |
- model_type: Model to use (default: "catboost")
|
| 220 |
|
|
@@ -228,6 +222,7 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
|
|
| 228 |
POST /predict
|
| 229 |
{
|
| 230 |
"texts": ["This method calculates fibonacci", "this is a deprecated function"],
|
|
|
|
| 231 |
"language": "java",
|
| 232 |
"model_type": "catboost"
|
| 233 |
}
|
|
@@ -235,8 +230,8 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
|
|
| 235 |
Example Response:
|
| 236 |
{
|
| 237 |
"results": [
|
| 238 |
-
{"text": "This method calculates fibonacci", "labels": ["summary"]},
|
| 239 |
-
{"text": "this is a deprecated function", "labels": ["deprecation"]}
|
| 240 |
]
|
| 241 |
}
|
| 242 |
"""
|
|
@@ -245,8 +240,9 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
|
|
| 245 |
_executor,
|
| 246 |
controller.predict,
|
| 247 |
payload.texts,
|
|
|
|
| 248 |
payload.language,
|
| 249 |
-
payload.model_type
|
| 250 |
)
|
| 251 |
|
| 252 |
return {
|
|
@@ -255,8 +251,8 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
|
|
| 255 |
"data": {
|
| 256 |
"model_used": payload.model_type,
|
| 257 |
"language": payload.language,
|
| 258 |
-
"results": results
|
| 259 |
-
}
|
| 260 |
}
|
| 261 |
|
| 262 |
|
|
@@ -265,4 +261,5 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
|
|
| 265 |
# ---------------------------------------------------------------------------
|
| 266 |
if __name__ == "__main__":
|
| 267 |
import uvicorn
|
| 268 |
-
|
|
|
|
|
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
# Global Resources
|
| 29 |
# ---------------------------------------------------------------------------
|
| 30 |
+
# Initialize the prediction controller (models loaded from MLflow on startup)
|
| 31 |
+
controller = PredictionController()
|
| 32 |
|
| 33 |
# Thread pool for CPU-bound ML inference tasks
|
| 34 |
# This prevents blocking the async event loop during model predictions
|
|
|
|
| 56 |
controller.shutdown()
|
| 57 |
_executor.shutdown(wait=True)
|
| 58 |
|
| 59 |
+
|
| 60 |
# ---------------------------------------------------------------------------
|
| 61 |
# FastAPI Application Definition
|
| 62 |
# ---------------------------------------------------------------------------
|
|
|
|
| 109 |
"data": "Response payload (if successful)"
|
| 110 |
}
|
| 111 |
"""
|
| 112 |
+
|
| 113 |
@wraps(f)
|
| 114 |
async def wrap(request: Request, *args, **kwargs) -> Dict[str, Any]:
|
| 115 |
# Initialize response with request metadata
|
|
|
|
| 122 |
try:
|
| 123 |
# Execute the wrapped endpoint function
|
| 124 |
results = await f(request, *args, **kwargs)
|
| 125 |
+
|
| 126 |
# If function returns a dict with status/message/data, use it directly
|
| 127 |
if isinstance(results, dict) and "status-code" in results:
|
| 128 |
response_struct.update(results)
|
|
|
|
| 146 |
return wrap
|
| 147 |
|
| 148 |
|
|
|
|
|
|
|
| 149 |
# ---------------------------------------------------------------------------
|
| 150 |
# API Endpoints
|
| 151 |
# ---------------------------------------------------------------------------
|
| 152 |
|
| 153 |
+
|
| 154 |
@app.get("/models", tags=["Info"])
|
| 155 |
@construct_response
|
| 156 |
async def _get_models(request: Request) -> Dict[str, Any]:
|
|
|
|
| 170 |
}
|
| 171 |
"""
|
| 172 |
data = controller.get_models_info()
|
| 173 |
+
return {"status-code": HTTPStatus.OK, "message": "Available models retrieved", "data": data}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
@app.get("/languages", tags=["Info"])
|
|
|
|
| 193 |
}
|
| 194 |
"""
|
| 195 |
data = {"languages": config.LANGUAGES}
|
| 196 |
+
return {"status-code": HTTPStatus.OK, "message": "Supported languages retrieved", "data": data}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
@app.post("/predict", tags=["Prediction"])
|
|
|
|
| 207 |
Args:
|
| 208 |
request: The FastAPI request object
|
| 209 |
payload: PredictionRequest containing:
|
| 210 |
+
- texts: List of code comments
|
| 211 |
+
- class_names: List of class names corresponding to each comment
|
| 212 |
- language: Programming language ("java", "python", "pharo")
|
| 213 |
- model_type: Model to use (default: "catboost")
|
| 214 |
|
|
|
|
| 222 |
POST /predict
|
| 223 |
{
|
| 224 |
"texts": ["This method calculates fibonacci", "this is a deprecated function"],
|
| 225 |
+
"class_names": ["MathUtils", "Utils"],
|
| 226 |
"language": "java",
|
| 227 |
"model_type": "catboost"
|
| 228 |
}
|
|
|
|
| 230 |
Example Response:
|
| 231 |
{
|
| 232 |
"results": [
|
| 233 |
+
{"text": "This method calculates fibonacci", "class_name": "MathUtils", "labels": ["summary"]},
|
| 234 |
+
{"text": "this is a deprecated function", "class_name": "Utils", "labels": ["deprecation"]}
|
| 235 |
]
|
| 236 |
}
|
| 237 |
"""
|
|
|
|
| 240 |
_executor,
|
| 241 |
controller.predict,
|
| 242 |
payload.texts,
|
| 243 |
+
payload.class_names,
|
| 244 |
payload.language,
|
| 245 |
+
payload.model_type,
|
| 246 |
)
|
| 247 |
|
| 248 |
return {
|
|
|
|
| 251 |
"data": {
|
| 252 |
"model_used": payload.model_type,
|
| 253 |
"language": payload.language,
|
| 254 |
+
"results": results,
|
| 255 |
+
},
|
| 256 |
}
|
| 257 |
|
| 258 |
|
|
|
|
| 261 |
# ---------------------------------------------------------------------------
|
| 262 |
if __name__ == "__main__":
|
| 263 |
import uvicorn
|
| 264 |
+
|
| 265 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|