|
|
""" |
|
|
FastAPI application for skill classification service. |
|
|
|
|
|
Provides REST API endpoints for classifying GitHub issues and pull requests |
|
|
into skill categories using machine learning models. |
|
|
|
|
|
Usage: |
|
|
Development: fastapi dev hopcroft_skill_classification_tool_competition/main.py |
|
|
Production: fastapi run hopcroft_skill_classification_tool_competition/main.py |
|
|
|
|
|
Endpoints: |
|
|
GET / - API information |
|
|
GET /health - Health check |
|
|
POST /predict - Single issue classification |
|
|
POST /predict/batch - Batch classification |
|
|
""" |
|
|
|
|
|
from contextlib import asynccontextmanager |
|
|
from datetime import datetime |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from typing import List |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Request, Response, status |
|
|
from fastapi.responses import JSONResponse, RedirectResponse |
|
|
import mlflow |
|
|
from prometheus_client import ( |
|
|
CONTENT_TYPE_LATEST, |
|
|
Counter, |
|
|
Gauge, |
|
|
Histogram, |
|
|
Summary, |
|
|
generate_latest, |
|
|
) |
|
|
from pydantic import ValidationError |
|
|
|
|
|
from hopcroft_skill_classification_tool_competition.api_models import ( |
|
|
BatchIssueInput, |
|
|
BatchPredictionResponse, |
|
|
ErrorResponse, |
|
|
HealthCheckResponse, |
|
|
IssueInput, |
|
|
PredictionRecord, |
|
|
PredictionResponse, |
|
|
SkillPrediction, |
|
|
) |
|
|
from hopcroft_skill_classification_tool_competition.config import MLFLOW_CONFIG |
|
|
from hopcroft_skill_classification_tool_competition.modeling.predict import SkillPredictor |
|
|
|
|
|
|
|
|
|
|
|
REQUESTS_TOTAL = Counter( |
|
|
"hopcroft_requests_total", |
|
|
"Total number of requests", |
|
|
["method", "endpoint", "http_status"], |
|
|
) |
|
|
|
|
|
|
|
|
REQUEST_DURATION_SECONDS = Histogram( |
|
|
"hopcroft_request_duration_seconds", |
|
|
"Request duration in seconds", |
|
|
["method", "endpoint"], |
|
|
) |
|
|
|
|
|
|
|
|
IN_PROGRESS_REQUESTS = Gauge( |
|
|
"hopcroft_in_progress_requests", |
|
|
"Number of requests currently in progress", |
|
|
["method", "endpoint"], |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_PREDICTION_SECONDS = Summary( |
|
|
"hopcroft_prediction_processing_seconds", |
|
|
"Time spent processing model predictions", |
|
|
) |
|
|
|
|
|
predictor = None |
|
|
model_version = "1.0.0" |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Manage application startup and shutdown.""" |
|
|
global predictor, model_version |
|
|
|
|
|
print("=" * 80) |
|
|
print("Starting Skill Classification API") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
mlflow.set_tracking_uri(MLFLOW_CONFIG["uri"]) |
|
|
print(f"MLflow tracking URI set to: {MLFLOW_CONFIG['uri']}") |
|
|
|
|
|
try: |
|
|
model_name = os.getenv("MODEL_NAME", "random_forest_tfidf_gridsearch.pkl") |
|
|
print(f"Loading model: {model_name}") |
|
|
predictor = SkillPredictor(model_name=model_name) |
|
|
print("Model and artifacts loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Failed to load model: {e}") |
|
|
print("WARNING: API starting in degraded mode (prediction will fail)") |
|
|
|
|
|
print(f"Model version {model_version} initialized") |
|
|
print("API ready") |
|
|
print("=" * 80) |
|
|
|
|
|
yield |
|
|
|
|
|
print("Shutting down API") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Skill Classification API", |
|
|
description="API for classifying GitHub issues and pull requests into skill categories", |
|
|
version="1.0.0", |
|
|
docs_url="/docs", |
|
|
redoc_url="/redoc", |
|
|
lifespan=lifespan, |
|
|
) |
|
|
|
|
|
|
|
|
@app.middleware("http") |
|
|
async def monitor_requests(request: Request, call_next): |
|
|
"""Middleware to collect Prometheus metrics for each request.""" |
|
|
method = request.method |
|
|
|
|
|
|
|
|
endpoint = request.url.path |
|
|
|
|
|
IN_PROGRESS_REQUESTS.labels(method=method, endpoint=endpoint).inc() |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
response = await call_next(request) |
|
|
status_code = response.status_code |
|
|
REQUESTS_TOTAL.labels(method=method, endpoint=endpoint, http_status=status_code).inc() |
|
|
return response |
|
|
except Exception as e: |
|
|
REQUESTS_TOTAL.labels(method=method, endpoint=endpoint, http_status=500).inc() |
|
|
raise e |
|
|
finally: |
|
|
duration = time.time() - start_time |
|
|
REQUEST_DURATION_SECONDS.labels(method=method, endpoint=endpoint).observe(duration) |
|
|
IN_PROGRESS_REQUESTS.labels(method=method, endpoint=endpoint).dec() |
|
|
|
|
|
|
|
|
@app.get("/metrics", tags=["Observability"]) |
|
|
async def metrics(): |
|
|
"""Expose Prometheus metrics.""" |
|
|
return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) |
|
|
|
|
|
|
|
|
@app.get("/", tags=["Root"]) |
|
|
async def root(): |
|
|
"""Return basic API information.""" |
|
|
return { |
|
|
"message": "Skill Classification API", |
|
|
"version": "1.0.0", |
|
|
"documentation": "/docs", |
|
|
"demo": "/demo", |
|
|
"health": "/health", |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthCheckResponse, tags=["Health"]) |
|
|
async def health_check(): |
|
|
"""Check API and model status.""" |
|
|
return HealthCheckResponse( |
|
|
status="healthy", |
|
|
model_loaded=predictor is not None, |
|
|
version="1.0.0", |
|
|
timestamp=datetime.now(), |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/demo") |
|
|
async def redirect_to_demo(): |
|
|
"""Redirect to Streamlit demo.""" |
|
|
return RedirectResponse(url="http://localhost:8501") |
|
|
|
|
|
|
|
|
@app.post( |
|
|
"/predict", |
|
|
response_model=PredictionRecord, |
|
|
status_code=status.HTTP_201_CREATED, |
|
|
tags=["Prediction"], |
|
|
summary="Classify a single issue", |
|
|
response_description="Skill predictions with confidence scores", |
|
|
) |
|
|
async def predict_skills(issue: IssueInput) -> PredictionRecord: |
|
|
""" |
|
|
Classify a single GitHub issue or pull request into skill categories. |
|
|
|
|
|
Args: |
|
|
issue: IssueInput containing issue text and optional metadata |
|
|
|
|
|
Returns: |
|
|
PredictionRecord with list of predicted skills, confidence scores, and run_id |
|
|
|
|
|
Raises: |
|
|
HTTPException: If prediction fails |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
if predictor is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_text = f"{issue.issue_text} {issue.issue_description or ''} {issue.repo_name or ''}" |
|
|
|
|
|
with MODEL_PREDICTION_SECONDS.time(): |
|
|
predictions_data = predictor.predict(full_text) |
|
|
|
|
|
|
|
|
predictions = [ |
|
|
SkillPrediction(skill_name=p["skill_name"], confidence=p["confidence"]) |
|
|
for p in predictions_data |
|
|
] |
|
|
|
|
|
processing_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
|
|
|
run_id = "local" |
|
|
timestamp = datetime.now() |
|
|
|
|
|
try: |
|
|
experiment_name = MLFLOW_CONFIG["experiments"]["baseline"] |
|
|
mlflow.set_experiment(experiment_name) |
|
|
|
|
|
with mlflow.start_run() as run: |
|
|
run_id = run.info.run_id |
|
|
|
|
|
mlflow.log_param("issue_text", issue.issue_text) |
|
|
if issue.repo_name: |
|
|
mlflow.log_param("repo_name", issue.repo_name) |
|
|
|
|
|
|
|
|
|
|
|
if predictions: |
|
|
mlflow.log_param("top_skill", predictions[0].skill_name) |
|
|
mlflow.log_metric("top_confidence", predictions[0].confidence) |
|
|
|
|
|
|
|
|
predictions_json = json.dumps([p.model_dump() for p in predictions]) |
|
|
mlflow.set_tag("predictions_json", predictions_json) |
|
|
mlflow.set_tag("model_version", model_version) |
|
|
except Exception as e: |
|
|
print(f"MLflow logging failed: {e}") |
|
|
|
|
|
return PredictionRecord( |
|
|
predictions=predictions, |
|
|
num_predictions=len(predictions), |
|
|
model_version=model_version, |
|
|
processing_time_ms=round(processing_time, 2), |
|
|
run_id=run_id, |
|
|
timestamp=timestamp, |
|
|
input_text=issue.issue_text, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail=f"Prediction failed: {str(e)}", |
|
|
) |
|
|
|
|
|
|
|
|
@app.post( |
|
|
"/predict/batch", |
|
|
response_model=BatchPredictionResponse, |
|
|
status_code=status.HTTP_200_OK, |
|
|
tags=["Prediction"], |
|
|
summary="Classify multiple issues", |
|
|
response_description="Batch skill predictions", |
|
|
) |
|
|
async def predict_skills_batch(batch: BatchIssueInput) -> BatchPredictionResponse: |
|
|
""" |
|
|
Classify multiple GitHub issues or pull requests in batch. |
|
|
|
|
|
Args: |
|
|
batch: BatchIssueInput containing list of issues (max 100) |
|
|
|
|
|
Returns: |
|
|
BatchPredictionResponse with prediction results for each issue |
|
|
|
|
|
Raises: |
|
|
HTTPException: If batch prediction fails |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
results = [] |
|
|
|
|
|
if predictor is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
for issue in batch.issues: |
|
|
full_text = ( |
|
|
f"{issue.issue_text} {issue.issue_description or ''} {issue.repo_name or ''}" |
|
|
) |
|
|
predictions_data = predictor.predict(full_text) |
|
|
|
|
|
predictions = [ |
|
|
SkillPrediction(skill_name=p["skill_name"], confidence=p["confidence"]) |
|
|
for p in predictions_data |
|
|
] |
|
|
|
|
|
results.append( |
|
|
PredictionResponse( |
|
|
predictions=predictions, |
|
|
num_predictions=len(predictions), |
|
|
model_version=model_version, |
|
|
) |
|
|
) |
|
|
|
|
|
total_processing_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
return BatchPredictionResponse( |
|
|
results=results, |
|
|
total_issues=len(batch.issues), |
|
|
total_processing_time_ms=round(total_processing_time, 2), |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail=f"Batch prediction failed: {str(e)}", |
|
|
) |
|
|
|
|
|
|
|
|
@app.get( |
|
|
"/predictions/{run_id}", |
|
|
response_model=PredictionRecord, |
|
|
status_code=status.HTTP_200_OK, |
|
|
tags=["Prediction"], |
|
|
summary="Get a prediction by ID", |
|
|
response_description="Prediction details", |
|
|
) |
|
|
async def get_prediction(run_id: str) -> PredictionRecord: |
|
|
""" |
|
|
Retrieve a specific prediction by its MLflow Run ID. |
|
|
|
|
|
Args: |
|
|
run_id: The MLflow Run ID |
|
|
|
|
|
Returns: |
|
|
PredictionRecord containing the prediction details |
|
|
|
|
|
Raises: |
|
|
HTTPException: If run not found or error occurs |
|
|
""" |
|
|
try: |
|
|
run = mlflow.get_run(run_id) |
|
|
data = run.data |
|
|
|
|
|
|
|
|
predictions_json = data.tags.get("predictions_json", "[]") |
|
|
predictions_data = json.loads(predictions_json) |
|
|
predictions = [SkillPrediction(**p) for p in predictions_data] |
|
|
|
|
|
|
|
|
timestamp = datetime.fromtimestamp(run.info.start_time / 1000.0) |
|
|
|
|
|
return PredictionRecord( |
|
|
predictions=predictions, |
|
|
num_predictions=len(predictions), |
|
|
model_version=data.tags.get("model_version", "unknown"), |
|
|
processing_time_ms=None, |
|
|
run_id=run.info.run_id, |
|
|
timestamp=timestamp, |
|
|
input_text=data.params.get("issue_text", ""), |
|
|
) |
|
|
|
|
|
except mlflow.exceptions.MlflowException: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_404_NOT_FOUND, detail=f"Prediction with ID {run_id} not found" |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail=f"Failed to retrieve prediction: {str(e)}", |
|
|
) |
|
|
|
|
|
|
|
|
@app.get( |
|
|
"/predictions", |
|
|
response_model=List[PredictionRecord], |
|
|
status_code=status.HTTP_200_OK, |
|
|
tags=["Prediction"], |
|
|
summary="List predictions", |
|
|
response_description="List of recent predictions", |
|
|
) |
|
|
async def list_predictions(skip: int = 0, limit: int = 10) -> List[PredictionRecord]: |
|
|
""" |
|
|
Retrieve a list of recent predictions. |
|
|
|
|
|
Args: |
|
|
skip: Number of records to skip (not fully supported by MLflow search, handled client-side) |
|
|
limit: Maximum number of records to return |
|
|
|
|
|
Returns: |
|
|
List of PredictionRecord |
|
|
""" |
|
|
try: |
|
|
experiment_name = MLFLOW_CONFIG["experiments"]["baseline"] |
|
|
experiment = mlflow.get_experiment_by_name(experiment_name) |
|
|
|
|
|
if not experiment: |
|
|
return [] |
|
|
|
|
|
|
|
|
runs = mlflow.search_runs( |
|
|
experiment_ids=[experiment.experiment_id], |
|
|
max_results=limit + skip, |
|
|
order_by=["start_time DESC"], |
|
|
) |
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if runs.empty: |
|
|
return [] |
|
|
|
|
|
|
|
|
runs = runs.iloc[skip:] |
|
|
|
|
|
for _, row in runs.iterrows(): |
|
|
run_id = row.run_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_val(row, prefix, key, default=None): |
|
|
col = f"{prefix}.{key}" |
|
|
return row[col] if col in row else default |
|
|
|
|
|
predictions_json = get_val(row, "tags", "predictions_json", "[]") |
|
|
try: |
|
|
predictions_data = json.loads(predictions_json) |
|
|
predictions = [SkillPrediction(**p) for p in predictions_data] |
|
|
except Exception: |
|
|
predictions = [] |
|
|
|
|
|
timestamp = row.start_time |
|
|
|
|
|
|
|
|
model_version = get_val(row, "tags", "model_version") |
|
|
if model_version is None or model_version == "": |
|
|
model_version = "unknown" |
|
|
|
|
|
|
|
|
input_text = get_val(row, "params", "issue_text") |
|
|
if input_text is None: |
|
|
input_text = "" |
|
|
|
|
|
results.append( |
|
|
PredictionRecord( |
|
|
predictions=predictions, |
|
|
num_predictions=len(predictions), |
|
|
model_version=model_version, |
|
|
processing_time_ms=None, |
|
|
run_id=run_id, |
|
|
timestamp=timestamp, |
|
|
input_text=input_text, |
|
|
) |
|
|
) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail=f"Failed to list predictions: {str(e)}", |
|
|
) |
|
|
|
|
|
|
|
|
@app.exception_handler(ValidationError) |
|
|
async def validation_exception_handler(request, exc: ValidationError): |
|
|
"""Handle Pydantic validation errors.""" |
|
|
return JSONResponse( |
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, |
|
|
content=ErrorResponse( |
|
|
error="Validation Error", detail=str(exc), timestamp=datetime.now() |
|
|
).model_dump(), |
|
|
) |
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException) |
|
|
async def http_exception_handler(request, exc: HTTPException): |
|
|
"""Handle HTTP exceptions.""" |
|
|
return JSONResponse( |
|
|
status_code=exc.status_code, |
|
|
content=ErrorResponse( |
|
|
error=exc.detail, detail=None, timestamp=datetime.now() |
|
|
).model_dump(), |
|
|
) |
|
|
|