Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, Query | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional | |
| import pandas as pd | |
| import io | |
| import pickle | |
| import joblib | |
| import numpy as np | |
| from enum import Enum | |
| import traceback | |
| app = FastAPI( | |
| title="Exoplanet Prediction API", | |
| description="API for predicting exoplanet candidates using KOI features", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with specific origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define available models | |
| class ModelType(str, Enum): | |
| random_forest = "random_forest" | |
| xgboost = "xgboost" | |
| ensemble = "ensemble" | |
| # logistic_regression = "logistic_regression" | |
| # svm = "svm" | |
| # neural_network = "neural_network" | |
| # gradient_boosting = "gradient_boosting" | |
| # Dictionary to store loaded models | |
| models = {} | |
| def load_models(): | |
| """Load all available models on startup""" | |
| models['random_forest'] = joblib.load('models/rf.pkl') | |
| models['xgboost'] = joblib.load('models/xgboost.pkl') | |
| models['ensemble'] = joblib.load('models/xgb_rf.pkl') | |
| # models['logistic_regression'] = pickle.load(open('models/logistic_regression.pkl', 'rb')) | |
| # models['svm'] = pickle.load(open('models/svm.pkl', 'rb')) | |
| # models['neural_network'] = pickle.load(open('models/neural_network.pkl', 'rb')) | |
| # models['gradient_boosting'] = pickle.load(open('models/gradient_boosting.pkl', 'rb')) | |
| # Load models when app starts | |
| async def startup_event(): | |
| load_models() | |
| print("Models loaded successfully") | |
| class PredictionInput(BaseModel): | |
| koi_model_snr: float = Field(..., description="Transit signal-to-noise ratio") | |
| koi_prad: float = Field(..., description="Planetary radius in Earth radii") | |
| koi_fpflag_ss: int = Field(..., ge=0, le=1, description="Stellar eclipse flag") | |
| koi_fpflag_co: int = Field(..., ge=0, le=1, description="Centroid offset flag") | |
| koi_period: float = Field(..., description="Orbital period in days") | |
| koi_depth: float = Field(..., description="Transit depth in parts per million") | |
| koi_fpflag_nt: int = Field(..., ge=0, le=1, description="Not transit-like flag") | |
| koi_insol: float = Field(..., description="Insolation flux in Earth units") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "koi_model_snr": 15.5, | |
| "koi_prad": 2.3, | |
| "koi_fpflag_ss": 0, | |
| "koi_fpflag_co": 0, | |
| "koi_period": 10.5, | |
| "koi_depth": 500.0, | |
| "koi_fpflag_nt": 0, | |
| "koi_insol": 1.2 | |
| } | |
| } | |
| class PredictionOutput(BaseModel): | |
| prediction: int | |
| probability: float | |
| classification: str | |
| class BatchPredictionOutput(BaseModel): | |
| predictions: List[dict] | |
| total_processed: int | |
| def prepare_features(data: dict) -> np.ndarray: | |
| """Convert input dictionary to feature array in correct order""" | |
| feature_order = [ | |
| 'koi_model_snr', 'koi_prad', 'koi_fpflag_ss', 'koi_fpflag_co', | |
| 'koi_period', 'koi_depth', 'koi_fpflag_nt', 'koi_insol' | |
| ] | |
| return np.array([[data[f] for f in feature_order]]) | |
| def make_prediction(features: np.ndarray, model_name: str): | |
| """Make prediction using the selected model""" | |
| if model_name in models: | |
| model = models[model_name] | |
| prediction = model.predict(features)[0] | |
| probability = model.predict_proba(features)[0][1] if hasattr(model, 'predict_proba') else 0.5 | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Model {model_name} not found") | |
| # Placeholder for demonstration - replace with actual model prediction | |
| # prediction = np.random.choice([0, 1, 2]) # 0=false positive, 1=confirmed, 2=candidate | |
| # probability = np.random.random() | |
| return prediction, probability | |
| def read_root(): | |
| return { | |
| "message": "Exoplanet Prediction API", | |
| "available_models": [model.value for model in ModelType], | |
| "endpoints": { | |
| "/predict": "Single prediction (POST)", | |
| "/predict/batch": "Batch prediction from CSV (POST)", | |
| "/models": "List available models (GET)", | |
| "/health": "Health check (GET)", | |
| "/docs": "API documentation" | |
| } | |
| } | |
| def list_models(): | |
| """List all available models""" | |
| return { | |
| "available_models": [model.value for model in ModelType], | |
| "loaded_models": list(models.keys()) if models else [] | |
| } | |
| def health_check(): | |
| return { | |
| "status": "healthy", | |
| "models_loaded": len(models), | |
| "available_models": [model.value for model in ModelType] | |
| } | |
| def predict_single( | |
| input_data: PredictionInput, | |
| model: ModelType = Query(ModelType.ensemble, description="Model to use for prediction") | |
| ): | |
| """ | |
| Make a single prediction for exoplanet classification using the specified model | |
| """ | |
| try: | |
| features = prepare_features(input_data.dict()) | |
| prediction, probability = make_prediction(features, model.value) | |
| if prediction == 1: | |
| classification = "Confirmed Exoplanet" | |
| elif prediction == 2: | |
| classification = "Exoplanet Candidate" | |
| else: | |
| classification = "False Positive" | |
| return PredictionOutput( | |
| prediction=int(prediction), | |
| probability=float(probability), | |
| classification=classification | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| async def predict_batch( | |
| file: UploadFile = File(...), | |
| model: ModelType = Query(ModelType.ensemble, description="Model to use for predictions") | |
| ): | |
| """ | |
| Make batch predictions from CSV file using the specified model | |
| Returns a CSV file with predictions | |
| """ | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="File must be a CSV") | |
| try: | |
| # Read CSV file | |
| contents = await file.read() | |
| df = pd.read_csv(io.BytesIO(contents)) | |
| # Validate required columns | |
| required_cols = [ | |
| 'koi_model_snr', 'koi_prad', 'koi_fpflag_ss', 'koi_fpflag_co', | |
| 'koi_period', 'koi_depth', 'koi_fpflag_nt', 'koi_insol' | |
| ] | |
| missing_cols = set(required_cols) - set(df.columns) | |
| if missing_cols: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Missing required columns: {missing_cols}" | |
| ) | |
| X = df[required_cols].copy() | |
| X = X.fillna(X.median()) | |
| scaler = joblib.load('models/scaler.pkl') | |
| if scaler is not None: | |
| X_scaled = scaler.transform(X.values) | |
| else: | |
| X_scaled = X.values | |
| # Make predictions | |
| predictions = [] | |
| probabilities = [] | |
| for idx in range(len(X_scaled)): | |
| try: | |
| features = X_scaled[idx:idx+1] | |
| pred, prob = make_prediction(features, model.value) | |
| predictions.append(pred) | |
| probabilities.append(prob) | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=400, detail="Error predicting row") | |
| # Add predictions to dataframe | |
| df['prediction'] = predictions | |
| df['probability'] = probabilities | |
| df['classification'] = df['prediction'].map({ | |
| 1: 'Confirmed Exoplanet', | |
| 2: 'Exoplanet Candidate', | |
| 0: 'False Positive' | |
| }) | |
| # Convert to CSV for download | |
| output = io.StringIO() | |
| df.to_csv(output, index=False) | |
| output.seek(0) | |
| return StreamingResponse( | |
| iter([output.getvalue()]), | |
| media_type="text/csv", | |
| headers={"Content-Disposition": f"attachment; filename=predictions_{model.value}_{file.filename}"} | |
| ) | |
| except pd.errors.EmptyDataError: | |
| raise HTTPException(status_code=400, detail="CSV file is empty") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print('Batch processing error') | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
| async def predict_batch_json( | |
| file: UploadFile = File(...), | |
| model: ModelType = Query(ModelType.ensemble, description="Model to use for predictions") | |
| ): | |
| """ | |
| Make batch predictions from CSV file using the specified model | |
| Returns JSON response with predictions | |
| """ | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="File must be a CSV") | |
| try: | |
| contents = await file.read() | |
| df = pd.read_csv(io.BytesIO(contents)) | |
| print("file received and read") | |
| required_cols = [ | |
| 'koi_model_snr', 'koi_prad', 'koi_fpflag_ss', 'koi_fpflag_co', | |
| 'koi_period', 'koi_depth', 'koi_fpflag_nt', 'koi_insol' | |
| ] | |
| missing_cols = set(required_cols) - set(df.columns) | |
| if missing_cols: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Missing required columns: {missing_cols}" | |
| ) | |
| X = df[required_cols].copy() | |
| X = X.fillna(X.median()) | |
| scaler = joblib.load('models/scaler.pkl') | |
| if scaler is not None: | |
| X_scaled = scaler.transform(X.values) | |
| else: | |
| X_scaled = X.values | |
| results = [] | |
| for idx in range(len(X_scaled)): | |
| try: | |
| features = X_scaled[idx:idx+1] | |
| pred, prob = make_prediction(features, model.value) | |
| if pred == 1: | |
| classification = "Confirmed Exoplanet" | |
| elif pred == 2: | |
| classification = "Exoplanet Candidate" | |
| else: | |
| classification = "False Positive" | |
| results.append({ | |
| "row_index": int(idx), | |
| "prediction": int(pred), | |
| "probability": float(prob), | |
| "classification": classification | |
| }) | |
| except Exception as e: | |
| print(f"Error predicting row {idx}: {e}") | |
| results.append({ | |
| "row_index": int(idx), | |
| "prediction": 0, | |
| "probability": 0.0, | |
| "classification": "Error", | |
| "model_used": model.value | |
| }) | |
| return BatchPredictionOutput( | |
| predictions=results, | |
| total_processed=len(results) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |