Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Any, List | |
| import uvicorn | |
| import torch | |
| import logging | |
| import os | |
| import asyncio | |
| import pandas as pd | |
| from datetime import datetime | |
| import shutil | |
| from pathlib import Path | |
| import numpy as np | |
| import sys | |
| # Add parent directory to Python path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| from voting import perform_voting_ensemble, save_predictions | |
| from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR | |
| from dataset_utils import load_label_encoders | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Ensemble Voting API") | |
| # Create necessary directories | |
| UPLOAD_DIR = Path("uploads") | |
| PREDICTIONS_DIR = Path(PREDICTIONS_SAVE_DIR) | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| PREDICTIONS_DIR.mkdir(parents=True, exist_ok=True) | |
| class EnsembleConfig(BaseModel): | |
| model_names: List[str] | |
| weights: Optional[Dict[str, float]] = None | |
| class EnsembleResponse(BaseModel): | |
| message: str | |
| metrics: Dict[str, Any] | |
| predictions: List[Dict[str, Any]] | |
| class PredictionData(BaseModel): | |
| model_name: str | |
| probabilities: List[List[float]] | |
| true_labels: Optional[List[int]] = None | |
| async def root(): | |
| return {"message": "Ensemble Voting API"} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def perform_ensemble( | |
| config: EnsembleConfig | |
| ): | |
| """Perform ensemble voting using specified models""" | |
| try: | |
| # Perform ensemble voting | |
| ensemble_reports, true_labels, ensemble_predictions = perform_voting_ensemble(config.model_names) | |
| # Load label encoders for decoding predictions | |
| label_encoders = load_label_encoders() | |
| # Format predictions with original labels | |
| formatted_predictions = [] | |
| for i, (col, preds) in enumerate(zip(LABEL_COLUMNS, ensemble_predictions)): | |
| if true_labels[i] is not None: | |
| label_encoder = label_encoders[col] | |
| true_labels_orig = label_encoder.inverse_transform(true_labels[i]) | |
| pred_labels_orig = label_encoder.inverse_transform(preds) | |
| for true, pred in zip(true_labels_orig, pred_labels_orig): | |
| formatted_predictions.append({ | |
| "field": col, | |
| "true_label": true, | |
| "predicted_label": pred | |
| }) | |
| return EnsembleResponse( | |
| message="Ensemble voting completed successfully", | |
| metrics=ensemble_reports, | |
| predictions=formatted_predictions | |
| ) | |
| except Exception as e: | |
| logger.error(f"Ensemble voting failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Ensemble voting failed: {str(e)}") | |
| async def save_model_predictions( | |
| prediction_data: PredictionData | |
| ): | |
| """Save predictions from a model for later ensemble voting""" | |
| try: | |
| # Convert probabilities to numpy arrays | |
| all_probs = [np.array(probs) for probs in prediction_data.probabilities] | |
| true_labels = [np.array(prediction_data.true_labels) if prediction_data.true_labels else None] | |
| # Save predictions | |
| save_predictions( | |
| prediction_data.model_name, | |
| all_probs, | |
| true_labels | |
| ) | |
| return { | |
| "message": f"Predictions saved successfully for model {prediction_data.model_name}", | |
| "model_name": prediction_data.model_name | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to save predictions: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to save predictions: {str(e)}") | |
| async def get_available_models(): | |
| """Get list of models with saved predictions""" | |
| try: | |
| model_dirs = [d for d in os.listdir(PREDICTIONS_DIR) | |
| if os.path.isdir(os.path.join(PREDICTIONS_DIR, d))] | |
| available_models = [] | |
| for model_name in model_dirs: | |
| model_dir = os.path.join(PREDICTIONS_DIR, model_name) | |
| has_all_files = all( | |
| os.path.exists(os.path.join(model_dir, f"{col}_probs.pkl")) | |
| for col in LABEL_COLUMNS | |
| ) | |
| if has_all_files: | |
| available_models.append(model_name) | |
| return { | |
| "available_models": available_models, | |
| "total_models": len(available_models) | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to get available models: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get available models: {str(e)}") | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7861)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |