Fonty02 commited on
Commit
3a0873f
·
verified ·
1 Parent(s): 7bdb4b9

Update nygaardcodecommentclassification/api/API.py

Browse files
nygaardcodecommentclassification/api/API.py CHANGED
@@ -27,8 +27,8 @@ from nygaardcodecommentclassification.api.schemas import PredictionRequest
27
  # ---------------------------------------------------------------------------
28
  # Global Resources
29
  # ---------------------------------------------------------------------------
30
- # Initialize the prediction controller with the models directory
31
- controller = PredictionController(config.MODELS_DIR)
32
 
33
  # Thread pool for CPU-bound ML inference tasks
34
  # This prevents blocking the async event loop during model predictions
@@ -56,6 +56,7 @@ async def lifespan(app: FastAPI):
56
  controller.shutdown()
57
  _executor.shutdown(wait=True)
58
 
 
59
  # ---------------------------------------------------------------------------
60
  # FastAPI Application Definition
61
  # ---------------------------------------------------------------------------
@@ -108,6 +109,7 @@ def construct_response(f: Callable) -> Callable:
108
  "data": "Response payload (if successful)"
109
  }
110
  """
 
111
  @wraps(f)
112
  async def wrap(request: Request, *args, **kwargs) -> Dict[str, Any]:
113
  # Initialize response with request metadata
@@ -120,7 +122,7 @@ def construct_response(f: Callable) -> Callable:
120
  try:
121
  # Execute the wrapped endpoint function
122
  results = await f(request, *args, **kwargs)
123
-
124
  # If function returns a dict with status/message/data, use it directly
125
  if isinstance(results, dict) and "status-code" in results:
126
  response_struct.update(results)
@@ -144,12 +146,11 @@ def construct_response(f: Callable) -> Callable:
144
  return wrap
145
 
146
 
147
-
148
-
149
  # ---------------------------------------------------------------------------
150
  # API Endpoints
151
  # ---------------------------------------------------------------------------
152
 
 
153
  @app.get("/models", tags=["Info"])
154
  @construct_response
155
  async def _get_models(request: Request) -> Dict[str, Any]:
@@ -169,11 +170,7 @@ async def _get_models(request: Request) -> Dict[str, Any]:
169
  }
170
  """
171
  data = controller.get_models_info()
172
- return {
173
- "status-code": HTTPStatus.OK,
174
- "message": "Available models retrieved",
175
- "data": data
176
- }
177
 
178
 
179
  @app.get("/languages", tags=["Info"])
@@ -196,11 +193,7 @@ async def _get_languages(request: Request) -> Dict[str, Any]:
196
  }
197
  """
198
  data = {"languages": config.LANGUAGES}
199
- return {
200
- "status-code": HTTPStatus.OK,
201
- "message": "Supported languages retrieved",
202
- "data": data
203
- }
204
 
205
 
206
  @app.post("/predict", tags=["Prediction"])
@@ -214,7 +207,8 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
214
  Args:
215
  request: The FastAPI request object
216
  payload: PredictionRequest containing:
217
- - texts: List of code comments to classify
 
218
  - language: Programming language ("java", "python", "pharo")
219
  - model_type: Model to use (default: "catboost")
220
 
@@ -228,6 +222,7 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
228
  POST /predict
229
  {
230
  "texts": ["This method calculates fibonacci", "this is a deprecated function"],
 
231
  "language": "java",
232
  "model_type": "catboost"
233
  }
@@ -235,8 +230,8 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
235
  Example Response:
236
  {
237
  "results": [
238
- {"text": "This method calculates fibonacci", "labels": ["summary"]},
239
- {"text": "this is a deprecated function", "labels": ["deprecation"]}
240
  ]
241
  }
242
  """
@@ -245,8 +240,9 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
245
  _executor,
246
  controller.predict,
247
  payload.texts,
 
248
  payload.language,
249
- payload.model_type
250
  )
251
 
252
  return {
@@ -255,8 +251,8 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
255
  "data": {
256
  "model_used": payload.model_type,
257
  "language": payload.language,
258
- "results": results
259
- }
260
  }
261
 
262
 
@@ -265,4 +261,5 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
265
  # ---------------------------------------------------------------------------
266
  if __name__ == "__main__":
267
  import uvicorn
268
- uvicorn.run(app, host="0.0.0.0", port=7680)
 
 
27
  # ---------------------------------------------------------------------------
28
  # Global Resources
29
  # ---------------------------------------------------------------------------
30
+ # Initialize the prediction controller (models loaded from MLflow on startup)
31
+ controller = PredictionController()
32
 
33
  # Thread pool for CPU-bound ML inference tasks
34
  # This prevents blocking the async event loop during model predictions
 
56
  controller.shutdown()
57
  _executor.shutdown(wait=True)
58
 
59
+
60
  # ---------------------------------------------------------------------------
61
  # FastAPI Application Definition
62
  # ---------------------------------------------------------------------------
 
109
  "data": "Response payload (if successful)"
110
  }
111
  """
112
+
113
  @wraps(f)
114
  async def wrap(request: Request, *args, **kwargs) -> Dict[str, Any]:
115
  # Initialize response with request metadata
 
122
  try:
123
  # Execute the wrapped endpoint function
124
  results = await f(request, *args, **kwargs)
125
+
126
  # If function returns a dict with status/message/data, use it directly
127
  if isinstance(results, dict) and "status-code" in results:
128
  response_struct.update(results)
 
146
  return wrap
147
 
148
 
 
 
149
  # ---------------------------------------------------------------------------
150
  # API Endpoints
151
  # ---------------------------------------------------------------------------
152
 
153
+
154
  @app.get("/models", tags=["Info"])
155
  @construct_response
156
  async def _get_models(request: Request) -> Dict[str, Any]:
 
170
  }
171
  """
172
  data = controller.get_models_info()
173
+ return {"status-code": HTTPStatus.OK, "message": "Available models retrieved", "data": data}
 
 
 
 
174
 
175
 
176
  @app.get("/languages", tags=["Info"])
 
193
  }
194
  """
195
  data = {"languages": config.LANGUAGES}
196
+ return {"status-code": HTTPStatus.OK, "message": "Supported languages retrieved", "data": data}
 
 
 
 
197
 
198
 
199
  @app.post("/predict", tags=["Prediction"])
 
207
  Args:
208
  request: The FastAPI request object
209
  payload: PredictionRequest containing:
210
+ - texts: List of code comments
211
+ - class_names: List of class names corresponding to each comment
212
  - language: Programming language ("java", "python", "pharo")
213
  - model_type: Model to use (default: "catboost")
214
 
 
222
  POST /predict
223
  {
224
  "texts": ["This method calculates fibonacci", "this is a deprecated function"],
225
+ "class_names": ["MathUtils", "Utils"],
226
  "language": "java",
227
  "model_type": "catboost"
228
  }
 
230
  Example Response:
231
  {
232
  "results": [
233
+ {"text": "This method calculates fibonacci", "class_name": "MathUtils", "labels": ["summary"]},
234
+ {"text": "this is a deprecated function", "class_name": "Utils", "labels": ["deprecation"]}
235
  ]
236
  }
237
  """
 
240
  _executor,
241
  controller.predict,
242
  payload.texts,
243
+ payload.class_names,
244
  payload.language,
245
+ payload.model_type,
246
  )
247
 
248
  return {
 
251
  "data": {
252
  "model_used": payload.model_type,
253
  "language": payload.language,
254
+ "results": results,
255
+ },
256
  }
257
 
258
 
 
261
  # ---------------------------------------------------------------------------
262
  if __name__ == "__main__":
263
  import uvicorn
264
+
265
+ uvicorn.run(app, host="0.0.0.0", port=7860)