Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI application for skill classification service. | |
| Provides REST API endpoints for classifying GitHub issues and pull requests | |
| into skill categories using machine learning models. | |
| Usage: | |
| Development: fastapi dev hopcroft_skill_classification_tool_competition/main.py | |
| Production: fastapi run hopcroft_skill_classification_tool_competition/main.py | |
| Endpoints: | |
| GET / - API information | |
| GET /health - Health check | |
| POST /predict - Single issue classification | |
| POST /predict/batch - Batch classification | |
| """ | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| import json | |
| import os | |
| import time | |
| from typing import List | |
| from fastapi import FastAPI, HTTPException, Request, Response, status | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| import mlflow | |
| from prometheus_client import ( | |
| CONTENT_TYPE_LATEST, | |
| Counter, | |
| Gauge, | |
| Histogram, | |
| Summary, | |
| generate_latest, | |
| ) | |
| from pydantic import ValidationError | |
| from hopcroft_skill_classification_tool_competition.api_models import ( | |
| BatchIssueInput, | |
| BatchPredictionResponse, | |
| ErrorResponse, | |
| HealthCheckResponse, | |
| IssueInput, | |
| PredictionRecord, | |
| PredictionResponse, | |
| SkillPrediction, | |
| ) | |
| from hopcroft_skill_classification_tool_competition.config import MLFLOW_CONFIG | |
| from hopcroft_skill_classification_tool_competition.modeling.predict import SkillPredictor | |
| # Define Prometheus Metrics | |
| # Counter: Total number of requests | |
| REQUESTS_TOTAL = Counter( | |
| "hopcroft_requests_total", | |
| "Total number of requests", | |
| ["method", "endpoint", "http_status"], | |
| ) | |
| # Histogram: Request duration | |
| REQUEST_DURATION_SECONDS = Histogram( | |
| "hopcroft_request_duration_seconds", | |
| "Request duration in seconds", | |
| ["method", "endpoint"], | |
| ) | |
| # Gauge: In-progress requests | |
| IN_PROGRESS_REQUESTS = Gauge( | |
| "hopcroft_in_progress_requests", | |
| "Number of requests currently in progress", | |
| ["method", "endpoint"], | |
| ) | |
| # Summary: Model prediction time | |
| MODEL_PREDICTION_SECONDS = Summary( | |
| "hopcroft_prediction_processing_seconds", | |
| "Time spent processing model predictions", | |
| ) | |
| predictor = None | |
| model_version = "1.0.0" | |
| async def lifespan(app: FastAPI): | |
| """Manage application startup and shutdown.""" | |
| global predictor, model_version | |
| print("=" * 80) | |
| print("Starting Skill Classification API") | |
| print("=" * 80) | |
| # Configure MLflow | |
| mlflow.set_tracking_uri(MLFLOW_CONFIG["uri"]) | |
| print(f"MLflow tracking URI set to: {MLFLOW_CONFIG['uri']}") | |
| try: | |
| model_name = os.getenv("MODEL_NAME", "random_forest_tfidf_gridsearch.pkl") | |
| print(f"Loading model: {model_name}") | |
| predictor = SkillPredictor(model_name=model_name) | |
| print("Model and artifacts loaded successfully") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| print("WARNING: API starting in degraded mode (prediction will fail)") | |
| print(f"Model version {model_version} initialized") | |
| print("API ready") | |
| print("=" * 80) | |
| yield | |
| print("Shutting down API") | |
| app = FastAPI( | |
| title="Skill Classification API", | |
| description="API for classifying GitHub issues and pull requests into skill categories", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| lifespan=lifespan, | |
| ) | |
| async def monitor_requests(request: Request, call_next): | |
| """Middleware to collect Prometheus metrics for each request.""" | |
| method = request.method | |
| # Use a simplified path or template if possible to avoid high cardinality | |
| # For now, using request.url.path is acceptable for this scale | |
| endpoint = request.url.path | |
| IN_PROGRESS_REQUESTS.labels(method=method, endpoint=endpoint).inc() | |
| start_time = time.time() | |
| try: | |
| response = await call_next(request) | |
| status_code = response.status_code | |
| REQUESTS_TOTAL.labels(method=method, endpoint=endpoint, http_status=status_code).inc() | |
| return response | |
| except Exception as e: | |
| REQUESTS_TOTAL.labels(method=method, endpoint=endpoint, http_status=500).inc() | |
| raise e | |
| finally: | |
| duration = time.time() - start_time | |
| REQUEST_DURATION_SECONDS.labels(method=method, endpoint=endpoint).observe(duration) | |
| IN_PROGRESS_REQUESTS.labels(method=method, endpoint=endpoint).dec() | |
| async def metrics(): | |
| """Expose Prometheus metrics.""" | |
| return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) | |
| async def root(): | |
| """Return basic API information.""" | |
| return { | |
| "message": "Skill Classification API", | |
| "version": "1.0.0", | |
| "documentation": "/docs", | |
| "demo": "/demo", | |
| "health": "/health", | |
| } | |
| async def health_check(): | |
| """Check API and model status.""" | |
| return HealthCheckResponse( | |
| status="healthy", | |
| model_loaded=predictor is not None, | |
| version="1.0.0", | |
| timestamp=datetime.now(), | |
| ) | |
| async def redirect_to_demo(): | |
| """Redirect to Streamlit demo.""" | |
| return RedirectResponse(url="http://localhost:8501") | |
| async def predict_skills(issue: IssueInput) -> PredictionRecord: | |
| """ | |
| Classify a single GitHub issue or pull request into skill categories. | |
| Args: | |
| issue: IssueInput containing issue text and optional metadata | |
| Returns: | |
| PredictionRecord with list of predicted skills, confidence scores, and run_id | |
| Raises: | |
| HTTPException: If prediction fails | |
| """ | |
| start_time = time.time() | |
| try: | |
| if predictor is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Combine text fields if needed, or just use issue_text | |
| # The predictor expects a single string | |
| # The predictor expects a single string | |
| full_text = f"{issue.issue_text} {issue.issue_description or ''} {issue.repo_name or ''}" | |
| with MODEL_PREDICTION_SECONDS.time(): | |
| predictions_data = predictor.predict(full_text) | |
| # Convert to Pydantic models | |
| predictions = [ | |
| SkillPrediction(skill_name=p["skill_name"], confidence=p["confidence"]) | |
| for p in predictions_data | |
| ] | |
| processing_time = (time.time() - start_time) * 1000 | |
| # Log to MLflow | |
| run_id = "local" | |
| timestamp = datetime.now() | |
| try: | |
| experiment_name = MLFLOW_CONFIG["experiments"]["baseline"] | |
| mlflow.set_experiment(experiment_name) | |
| with mlflow.start_run() as run: | |
| run_id = run.info.run_id | |
| # Log inputs | |
| mlflow.log_param("issue_text", issue.issue_text) | |
| if issue.repo_name: | |
| mlflow.log_param("repo_name", issue.repo_name) | |
| # Log outputs (as metrics or params/tags for retrieval) | |
| # For simple retrieval, we'll store the main prediction as a tag/param | |
| if predictions: | |
| mlflow.log_param("top_skill", predictions[0].skill_name) | |
| mlflow.log_metric("top_confidence", predictions[0].confidence) | |
| # Store full predictions as a JSON artifact or tag | |
| predictions_json = json.dumps([p.model_dump() for p in predictions]) | |
| mlflow.set_tag("predictions_json", predictions_json) | |
| mlflow.set_tag("model_version", model_version) | |
| except Exception as e: | |
| print(f"MLflow logging failed: {e}") | |
| return PredictionRecord( | |
| predictions=predictions, | |
| num_predictions=len(predictions), | |
| model_version=model_version, | |
| processing_time_ms=round(processing_time, 2), | |
| run_id=run_id, | |
| timestamp=timestamp, | |
| input_text=issue.issue_text, | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Prediction failed: {str(e)}", | |
| ) | |
| async def predict_skills_batch(batch: BatchIssueInput) -> BatchPredictionResponse: | |
| """ | |
| Classify multiple GitHub issues or pull requests in batch. | |
| Args: | |
| batch: BatchIssueInput containing list of issues (max 100) | |
| Returns: | |
| BatchPredictionResponse with prediction results for each issue | |
| Raises: | |
| HTTPException: If batch prediction fails | |
| """ | |
| start_time = time.time() | |
| try: | |
| results = [] | |
| if predictor is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| for issue in batch.issues: | |
| full_text = ( | |
| f"{issue.issue_text} {issue.issue_description or ''} {issue.repo_name or ''}" | |
| ) | |
| predictions_data = predictor.predict(full_text) | |
| predictions = [ | |
| SkillPrediction(skill_name=p["skill_name"], confidence=p["confidence"]) | |
| for p in predictions_data | |
| ] | |
| results.append( | |
| PredictionResponse( | |
| predictions=predictions, | |
| num_predictions=len(predictions), | |
| model_version=model_version, | |
| ) | |
| ) | |
| total_processing_time = (time.time() - start_time) * 1000 | |
| return BatchPredictionResponse( | |
| results=results, | |
| total_issues=len(batch.issues), | |
| total_processing_time_ms=round(total_processing_time, 2), | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Batch prediction failed: {str(e)}", | |
| ) | |
| async def get_prediction(run_id: str) -> PredictionRecord: | |
| """ | |
| Retrieve a specific prediction by its MLflow Run ID. | |
| Args: | |
| run_id: The MLflow Run ID | |
| Returns: | |
| PredictionRecord containing the prediction details | |
| Raises: | |
| HTTPException: If run not found or error occurs | |
| """ | |
| try: | |
| run = mlflow.get_run(run_id) | |
| data = run.data | |
| # Reconstruct predictions from tags | |
| predictions_json = data.tags.get("predictions_json", "[]") | |
| predictions_data = json.loads(predictions_json) | |
| predictions = [SkillPrediction(**p) for p in predictions_data] | |
| # Get timestamp (start_time is in ms) | |
| timestamp = datetime.fromtimestamp(run.info.start_time / 1000.0) | |
| return PredictionRecord( | |
| predictions=predictions, | |
| num_predictions=len(predictions), | |
| model_version=data.tags.get("model_version", "unknown"), | |
| processing_time_ms=None, # Not stored in standard tags, could be added | |
| run_id=run.info.run_id, | |
| timestamp=timestamp, | |
| input_text=data.params.get("issue_text", ""), | |
| ) | |
| except mlflow.exceptions.MlflowException: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, detail=f"Prediction with ID {run_id} not found" | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to retrieve prediction: {str(e)}", | |
| ) | |
| async def list_predictions(skip: int = 0, limit: int = 10) -> List[PredictionRecord]: | |
| """ | |
| Retrieve a list of recent predictions. | |
| Args: | |
| skip: Number of records to skip (not fully supported by MLflow search, handled client-side) | |
| limit: Maximum number of records to return | |
| Returns: | |
| List of PredictionRecord | |
| """ | |
| try: | |
| experiment_name = MLFLOW_CONFIG["experiments"]["baseline"] | |
| experiment = mlflow.get_experiment_by_name(experiment_name) | |
| if not experiment: | |
| return [] | |
| # Search runs | |
| runs = mlflow.search_runs( | |
| experiment_ids=[experiment.experiment_id], | |
| max_results=limit + skip, | |
| order_by=["start_time DESC"], | |
| ) | |
| results = [] | |
| # Convert pandas DataFrame to list of dicts if needed, or iterate | |
| # mlflow.search_runs returns a pandas DataFrame | |
| # We need to iterate through the DataFrame | |
| if runs.empty: | |
| return [] | |
| # Apply skip | |
| runs = runs.iloc[skip:] | |
| for _, row in runs.iterrows(): | |
| run_id = row.run_id | |
| # Extract data from columns (flattened) | |
| # Tags are prefixed with 'tags.', Params with 'params.' | |
| # Helper to safely get value | |
| def get_val(row, prefix, key, default=None): | |
| col = f"{prefix}.{key}" | |
| return row[col] if col in row else default | |
| predictions_json = get_val(row, "tags", "predictions_json", "[]") | |
| try: | |
| predictions_data = json.loads(predictions_json) | |
| predictions = [SkillPrediction(**p) for p in predictions_data] | |
| except Exception: | |
| predictions = [] | |
| timestamp = row.start_time # This is usually a datetime object in the DF | |
| # Get model_version with fallback to "unknown" or inherited default | |
| model_version = get_val(row, "tags", "model_version") | |
| if model_version is None or model_version == "": | |
| model_version = "unknown" | |
| # Get input_text with fallback to empty string | |
| input_text = get_val(row, "params", "issue_text") | |
| if input_text is None: | |
| input_text = "" | |
| results.append( | |
| PredictionRecord( | |
| predictions=predictions, | |
| num_predictions=len(predictions), | |
| model_version=model_version, | |
| processing_time_ms=None, | |
| run_id=run_id, | |
| timestamp=timestamp, | |
| input_text=input_text, | |
| ) | |
| ) | |
| return results | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to list predictions: {str(e)}", | |
| ) | |
| async def validation_exception_handler(request, exc: ValidationError): | |
| """Handle Pydantic validation errors.""" | |
| return JSONResponse( | |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
| content=ErrorResponse( | |
| error="Validation Error", detail=str(exc), timestamp=datetime.now() | |
| ).model_dump(), | |
| ) | |
| async def http_exception_handler(request, exc: HTTPException): | |
| """Handle HTTP exceptions.""" | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content=ErrorResponse( | |
| error=exc.detail, detail=None, timestamp=datetime.now() | |
| ).model_dump(), | |
| ) | |