Spaces:
Running
Running
| import pandas as pd | |
| import json | |
| import mlflow | |
| import os | |
| from sklearn.metrics import accuracy_score, f1_score, classification_report | |
| from mlpipeline.entity import ModelEvaluationConfig, ModelEvaluationArtifact | |
| from autogluon.tabular import TabularPredictor | |
| from mlpipeline.logging.logger import get_logger | |
| from mlpipeline.exception import ModelEvaluationException | |
| import sys | |
| logger = get_logger(__name__) | |
| class ModelEvaluation: | |
| def __init__(self, config: ModelEvaluationConfig): | |
| self.config = config | |
| def evaluate(self) -> ModelEvaluationArtifact: | |
| try: | |
| logger.info("Starting model evaluation") | |
| tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "https://dagshub.com/abheshith7/AutoML-MLOps-PipeLine.mlflow/") | |
| dagshub_token = os.getenv("DAGSHUB_TOKEN") | |
| mlflow_enabled = False | |
| if dagshub_token and "dagshub.com" in tracking_uri: | |
| try: | |
| os.environ["MLFLOW_TRACKING_USERNAME"] = os.getenv("DAGSHUB_USERNAME", "abheshith7") | |
| os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token | |
| mlflow.set_tracking_uri(tracking_uri) | |
| mlflow.set_experiment("automl_experiment") | |
| mlflow_enabled = True | |
| logger.info(f"MLflow tracking enabled: {tracking_uri}") | |
| except Exception as e: | |
| logger.warning(f"MLflow tracking disabled: {str(e)}") | |
| mlflow_enabled = False | |
| else: | |
| logger.warning("MLflow tracking disabled: DAGSHUB_TOKEN not set") | |
| if mlflow_enabled: | |
| mlflow.start_run() | |
| test_df = pd.read_csv(self.config.test_data_path) | |
| if mlflow_enabled: | |
| mlflow.log_param("test_samples", len(test_df)) | |
| predictor = TabularPredictor.load(str(self.config.model_path)) | |
| predictions = predictor.predict(test_df) | |
| y_test = test_df[self.config.target_column] | |
| predictions_binary = (predictions > 0).astype(int) | |
| y_test_binary = (y_test > 0).astype(int) | |
| # Calculate comprehensive metrics | |
| from sklearn.metrics import precision_score, recall_score, roc_auc_score, confusion_matrix | |
| accuracy = float(accuracy_score(y_test_binary, predictions_binary)) | |
| f1 = float(f1_score(y_test_binary, predictions_binary, average='weighted')) | |
| precision = float(precision_score(y_test_binary, predictions_binary, average='weighted')) | |
| recall = float(recall_score(y_test_binary, predictions_binary, average='weighted')) | |
| # Get prediction probabilities for ROC-AUC | |
| try: | |
| pred_proba = predictor.predict_proba(test_df) | |
| if hasattr(pred_proba, 'iloc'): | |
| pred_proba_positive = pred_proba.iloc[:, 1] | |
| else: | |
| pred_proba_positive = pred_proba[:, 1] | |
| auc = float(roc_auc_score(y_test_binary, pred_proba_positive)) | |
| except: | |
| auc = 0.0 | |
| metrics = { | |
| "test_accuracy": accuracy, | |
| "test_f1_score": f1, | |
| "test_precision": precision, | |
| "test_recall": recall, | |
| "test_roc_auc": auc | |
| } | |
| # Save confusion matrix | |
| cm = confusion_matrix(y_test_binary, predictions_binary) | |
| cm_dict = { | |
| "true_negatives": int(cm[0][0]), | |
| "false_positives": int(cm[0][1]), | |
| "false_negatives": int(cm[1][0]), | |
| "true_positives": int(cm[1][1]) | |
| } | |
| if mlflow_enabled: | |
| mlflow.log_metrics(metrics) | |
| mlflow.log_metrics(cm_dict) | |
| mlflow.set_tag("evaluation_stage", "test") | |
| os.makedirs(self.config.root_dir, exist_ok=True) | |
| with open(self.config.metrics_file, "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| if mlflow_enabled: | |
| mlflow.log_artifact(str(self.config.metrics_file)) | |
| run_id = mlflow.active_run().info.run_id | |
| logger.info(f"MLflow run logged: {run_id}") | |
| mlflow.end_run() | |
| logger.info(f"Evaluation metrics: {metrics}") | |
| return ModelEvaluationArtifact( | |
| is_model_accepted=True, | |
| evaluation_metrics={"accuracy": accuracy, "f1_score": f1}, | |
| message=f"Model evaluation completed with accuracy: {accuracy:.4f}" | |
| ) | |
| except Exception as e: | |
| raise ModelEvaluationException(str(e), sys) |