Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test the inference pipeline with sample data | |
| """ | |
| import sys | |
| import os | |
| import json | |
| import pandas as pd | |
| import joblib | |
| # Add src to path | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| def load_model_and_artifacts(): | |
| """Load the trained model and required artifacts""" | |
| # Load model | |
| model_path = "artifacts/model.pkl" | |
| model = joblib.load(model_path) | |
| print(f"Model loaded from {model_path}") | |
| # Load feature columns | |
| feature_columns_path = "artifacts/feature_columns.json" | |
| with open(feature_columns_path, 'r') as f: | |
| feature_columns = json.load(f) | |
| print(f"Feature columns loaded: {len(feature_columns)} features") | |
| # Load threshold | |
| threshold_path = "artifacts/threshold.json" | |
| with open(threshold_path, 'r') as f: | |
| threshold_config = json.load(f) | |
| threshold = threshold_config["threshold"] | |
| print(f"Classification threshold: {threshold}") | |
| return model, feature_columns, threshold | |
| def transform_input_data(data, feature_columns): | |
| """ | |
| Transform input data to match training format | |
| This replicates the feature engineering from training | |
| """ | |
| df = pd.DataFrame([data]) | |
| # Binary mappings (must match training) | |
| BINARY_MAP = { | |
| "No": 0, "Yes": 1, | |
| "Female": 0, "Male": 1, | |
| "No phone service": 0, "Yes": 1, | |
| "No internet service": 0, "Yes": 1 | |
| } | |
| # Apply binary encoding | |
| binary_features = ["gender", "Partner", "Dependents", "PhoneService", "PaperlessBilling"] | |
| for feature in binary_features: | |
| if feature in df.columns: | |
| df[feature] = df[feature].map(BINARY_MAP).fillna(0).astype(int) | |
| # One-hot encoding for multi-category features | |
| multi_features = ["MultipleLines", "InternetService", "OnlineSecurity", | |
| "OnlineBackup", "DeviceProtection", "TechSupport", | |
| "StreamingTV", "StreamingMovies", "Contract", "PaymentMethod"] | |
| # Apply one-hot encoding | |
| df_encoded = pd.get_dummies(df, columns=multi_features, drop_first=True) | |
| # Ensure all expected features exist (fill missing with 0) | |
| for col in feature_columns: | |
| if col not in df_encoded.columns: | |
| df_encoded[col] = 0 | |
| # Reorder columns to match training | |
| df_final = df_encoded[feature_columns] | |
| return df_final | |
| def predict_churn(customer_data): | |
| """Make prediction for a single customer""" | |
| # Load model and artifacts | |
| model, feature_columns, threshold = load_model_and_artifacts() | |
| # Transform input data | |
| X = transform_input_data(customer_data, feature_columns) | |
| # Make prediction | |
| prediction_proba = model.predict_proba(X)[0, 1] | |
| prediction_binary = (prediction_proba >= threshold).astype(int) | |
| result = { | |
| "churn_probability": float(prediction_proba), | |
| "churn_prediction": "Yes" if prediction_binary == 1 else "No", | |
| "threshold_used": threshold, | |
| "confidence": "High" if prediction_proba > 0.7 or prediction_proba < 0.3 else "Medium" | |
| } | |
| return result | |
| def main(): | |
| """Test inference with sample customers""" | |
| print("=== Testing Churn Prediction Inference ===\n") | |
| # Sample customer 1: High churn risk | |
| customer_high_risk = { | |
| "gender": "Female", | |
| "SeniorCitizen": 0, | |
| "Partner": "No", | |
| "Dependents": "No", | |
| "tenure": 1, | |
| "PhoneService": "Yes", | |
| "MultipleLines": "No", | |
| "InternetService": "Fiber optic", | |
| "OnlineSecurity": "No", | |
| "OnlineBackup": "No", | |
| "DeviceProtection": "No", | |
| "TechSupport": "No", | |
| "StreamingTV": "No", | |
| "StreamingMovies": "No", | |
| "Contract": "Month-to-month", | |
| "PaperlessBilling": "Yes", | |
| "PaymentMethod": "Electronic check", | |
| "MonthlyCharges": 75.50, | |
| "TotalCharges": 75.50 | |
| } | |
| # Sample customer 2: Low churn risk | |
| customer_low_risk = { | |
| "gender": "Male", | |
| "SeniorCitizen": 0, | |
| "Partner": "Yes", | |
| "Dependents": "Yes", | |
| "tenure": 60, | |
| "PhoneService": "Yes", | |
| "MultipleLines": "Yes", | |
| "InternetService": "DSL", | |
| "OnlineSecurity": "Yes", | |
| "OnlineBackup": "Yes", | |
| "DeviceProtection": "Yes", | |
| "TechSupport": "Yes", | |
| "StreamingTV": "Yes", | |
| "StreamingMovies": "Yes", | |
| "Contract": "Two year", | |
| "PaperlessBilling": "No", | |
| "PaymentMethod": "Bank transfer (automatic)", | |
| "MonthlyCharges": 95.00, | |
| "TotalCharges": 5700.00 | |
| } | |
| # Test both customers | |
| print("Customer 1 (High Risk Profile):") | |
| print(f"Input: {customer_high_risk}") | |
| result1 = predict_churn(customer_high_risk) | |
| print(f"Prediction: {result1}") | |
| print() | |
| print("Customer 2 (Low Risk Profile):") | |
| print(f"Input: {customer_low_risk}") | |
| result2 = predict_churn(customer_low_risk) | |
| print(f"Prediction: {result2}") | |
| print() | |
| print("=== Inference Testing Completed Successfully! ===") | |
| if __name__ == "__main__": | |
| main() | |