Spaces:
Running
Running
| # ml_module/tools/model_training_tools.py | |
| import io | |
| from datetime import datetime | |
| from typing import Optional | |
| import joblib | |
| import pandas as pd | |
| from agno.tools import Toolkit, tool | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | |
| from sklearn.model_selection import train_test_split | |
| from ml_module.services.storage_service import MLStorageService | |
| from ml_module.services.project_service import ProjectService | |
| from ml_module.core.exceptions import FileOperationException | |
| from ml_module.core.constants import ArtifactTypes | |
| from ml_module.core.response_formatter import ( | |
| FormattedResponse, | |
| ProgressBlock, | |
| ProgressStatus, | |
| Severity, | |
| make_text_response, | |
| metric_block, | |
| simple_table, | |
| simple_table_with_types, | |
| visualization_block, | |
| text_block, | |
| ) | |
| class ModelTrainingToolkit(Toolkit): | |
| """A toolkit for safely training and evaluating pre-approved scikit-learn models with code generation.""" | |
| def __init__(self, storage_service: MLStorageService, user_id: str, project_id: str, project_service: Optional[ProjectService] = None): | |
| super().__init__(name="model_training_tools") | |
| self.storage = storage_service | |
| self.project_service = project_service | |
| self.user_id = user_id | |
| self.project_id = project_id | |
| def _get_base_path(self, subfolder: str = "") -> str: | |
| return f"{self.user_id}/{self.project_id}/{subfolder}" | |
| def generate_training_code( | |
| self, | |
| input_path: str, | |
| target_column: str, | |
| model_type: str, | |
| version: int | |
| ) -> str: | |
| """ | |
| Generates executable Python code that reproduces the training process. | |
| Args: | |
| input_path (str): The path to the processed dataset | |
| target_column (str): The name of the target column | |
| model_type (str): The type of model to train | |
| version (int): The version number for this training code | |
| Returns: | |
| str: The generated Python code | |
| """ | |
| # Model configuration mapping | |
| model_configs = { | |
| 'RandomForest': { | |
| 'import': 'from sklearn.ensemble import RandomForestClassifier', | |
| 'init': 'RandomForestClassifier(random_state=42)', | |
| 'params': 'random_state=42' | |
| }, | |
| 'LogisticRegression': { | |
| 'import': 'from sklearn.linear_model import LogisticRegression', | |
| 'init': 'LogisticRegression(random_state=42)', | |
| 'params': 'random_state=42' | |
| } | |
| } | |
| if model_type not in model_configs: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| config = model_configs[model_type] | |
| timestamp = datetime.now().isoformat() | |
| # Generate the training code | |
| code = f'''#!/usr/bin/env python3 | |
| """ | |
| Generated ML Training Code - Version {version} | |
| Generated on: {timestamp} | |
| Model Type: {model_type} | |
| Target Column: {target_column} | |
| Input Data: {input_path} | |
| This code reproduces the exact training process used by the ML system. | |
| """ | |
| import pandas as pd | |
| import joblib | |
| from sklearn.model_selection import train_test_split | |
| {config['import']} | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score | |
| import json | |
| import os | |
| from datetime import datetime | |
| def train_model(): | |
| """Main training function that reproduces the ML workflow.""" | |
| try: | |
| print(f"Starting training process for {{model_type}} model...") | |
| print(f"Timestamp: {{datetime.now().isoformat()}}") | |
| # 1. Load Data | |
| print("\\n1. Loading dataset...") | |
| input_file = "{input_path}" | |
| if not os.path.exists(input_file): | |
| raise FileNotFoundError(f"Input file not found: {{input_file}}") | |
| df = pd.read_csv(input_file) | |
| print(f" Loaded dataset with {{len(df)}} rows and {{len(df.columns)}} columns") | |
| # 2. Prepare Data (Train-Test Split) | |
| print("\\n2. Preparing data...") | |
| target_column = "{target_column}" | |
| if target_column not in df.columns: | |
| raise ValueError(f"Target column '{{target_column}}' not found in dataset") | |
| X = df.drop(columns=[target_column]) | |
| y = df[target_column] | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| print(f" Training set: {{len(X_train)}} samples") | |
| print(f" Testing set: {{len(X_test)}} samples") | |
| print(f" Features: {{list(X.columns)}}") | |
| # 3. Initialize and Train Model | |
| print("\\n3. Training model...") | |
| model = {config['init']} | |
| print(f" Model configuration: {config['params']}") | |
| model.fit(X_train, y_train) | |
| print(" Model training completed successfully!") | |
| # 4. Evaluate Model | |
| print("\\n4. Evaluating model performance...") | |
| y_pred = model.predict(X_test) | |
| metrics = {{ | |
| "model_type": "{model_type}", | |
| "version": {version}, | |
| "timestamp": datetime.now().isoformat(), | |
| "data_info": {{ | |
| "total_samples": len(df), | |
| "training_samples": len(X_train), | |
| "testing_samples": len(X_test), | |
| "features": list(X.columns) | |
| }}, | |
| "performance": {{ | |
| "accuracy": accuracy_score(y_test, y_pred), | |
| "precision": precision_score(y_test, y_pred, average='weighted'), | |
| "recall": recall_score(y_test, y_pred, average='weighted'), | |
| "f1_score": f1_score(y_test, y_pred, average='weighted') | |
| }} | |
| }} | |
| print(" Model Performance:") | |
| for metric, value in metrics["performance"].items(): | |
| print(f" - {{metric.title()}}: {{value:.4f}}") | |
| # 5. Save Artifacts | |
| print("\\n5. Saving model artifacts...") | |
| model_filename = f"{model_type}_model_v{version}.joblib" | |
| metrics_filename = f"{model_type}_metrics_v{version}.json" | |
| # Save model | |
| joblib.dump(model, model_filename) | |
| print(f" Model saved: {{model_filename}}") | |
| # Save metrics | |
| with open(metrics_filename, 'w') as f: | |
| json.dump(metrics, f, indent=2) | |
| print(f" Metrics saved: {{metrics_filename}}") | |
| print("\\n🎉 Training completed successfully!") | |
| return metrics | |
| except Exception as e: | |
| print(f"\\n❌ Training failed: {{str(e)}}") | |
| raise e | |
| if __name__ == "__main__": | |
| # Execute training | |
| results = train_model() | |
| print("\\n" + "="*50) | |
| print("TRAINING SUMMARY") | |
| print("="*50) | |
| print(f"Model Type: {{results['model_type']}}") | |
| print(f"Version: {{results['version']}}") | |
| print(f"Accuracy: {{results['performance']['accuracy']:.4f}}") | |
| print(f"F1 Score: {{results['performance']['f1_score']:.4f}}") | |
| print("="*50) | |
| ''' | |
| return code | |
| def train_sklearn_classifier( | |
| self, | |
| input_path: str, | |
| target_column: str, | |
| model_type: str | |
| ) -> FormattedResponse: | |
| """ | |
| Trains a specified classification model, evaluates its performance, saves | |
| both the model artifact and metrics, and generates reproducible training code. | |
| Args: | |
| input_path (str): The path to the processed dataset (e.g., 'processed/cleaned_data.csv'). | |
| target_column (str): The name of the column to be predicted. | |
| model_type (str): The type of model to train. Must be one of: 'RandomForest', 'LogisticRegression'. | |
| Returns: | |
| FormattedResponse: Structured confirmation with metrics and artifact references. | |
| """ | |
| supported_models = { | |
| 'RandomForest': RandomForestClassifier(random_state=42), | |
| 'LogisticRegression': LogisticRegression(random_state=42) | |
| } | |
| if model_type not in supported_models: | |
| response = make_text_response( | |
| f"Model type '{model_type}' is not supported. Choose from {list(supported_models.keys())}.", | |
| severity=Severity.ERROR, | |
| ) | |
| response.summary = "Unsupported model type" | |
| response.done = True | |
| return response | |
| try: | |
| # Get current model version | |
| current_version = 1 | |
| if self.project_service: | |
| current_version = self.project_service.get_latest_version( | |
| self.user_id, self.project_id, "model" | |
| ) + 1 | |
| # 1. Load Data | |
| source_path = self._get_base_path() + "/" + input_path | |
| df = self.storage.load_dataframe(source_path) | |
| # 2. Prepare Data (Train-Test Split) | |
| X = df.drop(columns=[target_column]) | |
| y = df[target_column] | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| # 3. Train Model | |
| model = supported_models[model_type] | |
| model.fit(X_train, y_train) | |
| # 4. Evaluate Model | |
| y_pred = model.predict(X_test) | |
| metrics = { | |
| "model_type": model_type, | |
| "version": current_version, | |
| "timestamp": datetime.now().isoformat(), | |
| "data_info": { | |
| "total_samples": len(df), | |
| "training_samples": len(X_train), | |
| "testing_samples": len(X_test), | |
| "features": list(X.columns), | |
| "target_column": target_column | |
| }, | |
| "performance": { | |
| "accuracy": accuracy_score(y_test, y_pred), | |
| "precision": precision_score(y_test, y_pred, average='weighted'), | |
| "recall": recall_score(y_test, y_pred, average='weighted'), | |
| "f1_score": f1_score(y_test, y_pred, average='weighted') | |
| } | |
| } | |
| # 5. Generate Training Code | |
| training_code = self.generate_training_code( | |
| input_path=input_path, | |
| target_column=target_column, | |
| model_type=model_type, | |
| version=current_version | |
| ) | |
| # 6. Save Artifacts with versioning | |
| model_path = f"{self._get_base_path('models')}/{model_type}_model_v{current_version}.joblib" | |
| metrics_path = f"{self._get_base_path('models')}/{model_type}_metrics_v{current_version}.json" | |
| code_path = f"{self._get_base_path('models')}/training_code_v{current_version}.py" | |
| model_buffer = io.BytesIO() | |
| joblib.dump(model, model_buffer) | |
| model_info = self.storage.save_bytes( | |
| model_buffer.getvalue(), | |
| model_path, | |
| content_type="application/octet-stream", | |
| ) | |
| metrics_info = self.storage.save_json(metrics, metrics_path) | |
| metrics_info.metadata.update({"model_type": model_type, "version": current_version}) | |
| code_info = self.storage.save_text(training_code, code_path) | |
| code_info.metadata.update({"model_type": model_type, "version": current_version}) | |
| # Record artifacts in project registry for lifecycle tracking | |
| if self.project_service: | |
| data_info = metrics.get("data_info", {}) | |
| common_metadata = { | |
| "model_type": model_type, | |
| "target_column": target_column, | |
| "features": data_info.get("features", []), | |
| } | |
| performance = metrics.get("performance", {}) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.MODEL_ARTIFACT, | |
| current_version, | |
| model_info, | |
| version_scope="model", | |
| extra_metadata={**common_metadata, "performance": performance}, | |
| ) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.MODEL_METRICS, | |
| current_version, | |
| metrics_info, | |
| version_scope="model", | |
| extra_metadata={**common_metadata, "performance": performance}, | |
| ) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.TRAINING_CODE, | |
| current_version, | |
| code_info, | |
| version_scope="model", | |
| extra_metadata={ | |
| **common_metadata, | |
| "lines_of_code": training_code.count("\n") + 1, | |
| }, | |
| ) | |
| blocks = [ | |
| text_block( | |
| f"Trained `{model_type}` model version {current_version}", | |
| severity=Severity.SUCCESS, | |
| ), | |
| metric_block("Accuracy", metrics["performance"]["accuracy"]), | |
| metric_block("Precision", metrics["performance"]["precision"]), | |
| metric_block("Recall", metrics["performance"]["recall"]), | |
| metric_block("F1 Score", metrics["performance"]["f1_score"]), | |
| text_block( | |
| "**Artifacts saved**\n" + "\n".join( | |
| [ | |
| f"- Model artifact: `{model_path}`", | |
| f"- Metrics JSON: `{metrics_path}`", | |
| f"- Training code: `{code_path}`", | |
| ] | |
| ), | |
| severity=Severity.INFO, | |
| block_id="training_artifacts", | |
| ), | |
| simple_table( | |
| [ | |
| { | |
| "features": len(metrics["data_info"].get("features", [])), | |
| "train_rows": metrics["data_info"].get("training_samples"), | |
| "test_rows": metrics["data_info"].get("testing_samples"), | |
| } | |
| ], | |
| caption="Dataset split", | |
| block_id="training_dataset_split", | |
| ), | |
| text_block( | |
| "**Next steps**\n- Review metrics JSON\n- Validate artifacts before deployment", | |
| severity=Severity.INFO, | |
| block_id="training_next_steps", | |
| ), | |
| ] | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary=f"Trained {model_type} model v{current_version}", | |
| correlation_id=model_info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| raise FileOperationException(f"train model '{model_type}'", source_path, e) |