sirus / backend /ml_module /tools /model_training_tools.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
783a952
# ml_module/tools/model_training_tools.py
import io
from datetime import datetime
from typing import Optional
import joblib
import pandas as pd
from agno.tools import Toolkit, tool
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from ml_module.services.storage_service import MLStorageService
from ml_module.services.project_service import ProjectService
from ml_module.core.exceptions import FileOperationException
from ml_module.core.constants import ArtifactTypes
from ml_module.core.response_formatter import (
FormattedResponse,
ProgressBlock,
ProgressStatus,
Severity,
make_text_response,
metric_block,
simple_table,
simple_table_with_types,
visualization_block,
text_block,
)
class ModelTrainingToolkit(Toolkit):
"""A toolkit for safely training and evaluating pre-approved scikit-learn models with code generation."""
def __init__(self, storage_service: MLStorageService, user_id: str, project_id: str, project_service: Optional[ProjectService] = None):
super().__init__(name="model_training_tools")
self.storage = storage_service
self.project_service = project_service
self.user_id = user_id
self.project_id = project_id
def _get_base_path(self, subfolder: str = "") -> str:
return f"{self.user_id}/{self.project_id}/{subfolder}"
def generate_training_code(
self,
input_path: str,
target_column: str,
model_type: str,
version: int
) -> str:
"""
Generates executable Python code that reproduces the training process.
Args:
input_path (str): The path to the processed dataset
target_column (str): The name of the target column
model_type (str): The type of model to train
version (int): The version number for this training code
Returns:
str: The generated Python code
"""
# Model configuration mapping
model_configs = {
'RandomForest': {
'import': 'from sklearn.ensemble import RandomForestClassifier',
'init': 'RandomForestClassifier(random_state=42)',
'params': 'random_state=42'
},
'LogisticRegression': {
'import': 'from sklearn.linear_model import LogisticRegression',
'init': 'LogisticRegression(random_state=42)',
'params': 'random_state=42'
}
}
if model_type not in model_configs:
raise ValueError(f"Unsupported model type: {model_type}")
config = model_configs[model_type]
timestamp = datetime.now().isoformat()
# Generate the training code
code = f'''#!/usr/bin/env python3
"""
Generated ML Training Code - Version {version}
Generated on: {timestamp}
Model Type: {model_type}
Target Column: {target_column}
Input Data: {input_path}
This code reproduces the exact training process used by the ML system.
"""
import pandas as pd
import joblib
from sklearn.model_selection import train_test_split
{config['import']}
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import json
import os
from datetime import datetime
def train_model():
"""Main training function that reproduces the ML workflow."""
try:
print(f"Starting training process for {{model_type}} model...")
print(f"Timestamp: {{datetime.now().isoformat()}}")
# 1. Load Data
print("\\n1. Loading dataset...")
input_file = "{input_path}"
if not os.path.exists(input_file):
raise FileNotFoundError(f"Input file not found: {{input_file}}")
df = pd.read_csv(input_file)
print(f" Loaded dataset with {{len(df)}} rows and {{len(df.columns)}} columns")
# 2. Prepare Data (Train-Test Split)
print("\\n2. Preparing data...")
target_column = "{target_column}"
if target_column not in df.columns:
raise ValueError(f"Target column '{{target_column}}' not found in dataset")
X = df.drop(columns=[target_column])
y = df[target_column]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f" Training set: {{len(X_train)}} samples")
print(f" Testing set: {{len(X_test)}} samples")
print(f" Features: {{list(X.columns)}}")
# 3. Initialize and Train Model
print("\\n3. Training model...")
model = {config['init']}
print(f" Model configuration: {config['params']}")
model.fit(X_train, y_train)
print(" Model training completed successfully!")
# 4. Evaluate Model
print("\\n4. Evaluating model performance...")
y_pred = model.predict(X_test)
metrics = {{
"model_type": "{model_type}",
"version": {version},
"timestamp": datetime.now().isoformat(),
"data_info": {{
"total_samples": len(df),
"training_samples": len(X_train),
"testing_samples": len(X_test),
"features": list(X.columns)
}},
"performance": {{
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred, average='weighted'),
"recall": recall_score(y_test, y_pred, average='weighted'),
"f1_score": f1_score(y_test, y_pred, average='weighted')
}}
}}
print(" Model Performance:")
for metric, value in metrics["performance"].items():
print(f" - {{metric.title()}}: {{value:.4f}}")
# 5. Save Artifacts
print("\\n5. Saving model artifacts...")
model_filename = f"{model_type}_model_v{version}.joblib"
metrics_filename = f"{model_type}_metrics_v{version}.json"
# Save model
joblib.dump(model, model_filename)
print(f" Model saved: {{model_filename}}")
# Save metrics
with open(metrics_filename, 'w') as f:
json.dump(metrics, f, indent=2)
print(f" Metrics saved: {{metrics_filename}}")
print("\\n🎉 Training completed successfully!")
return metrics
except Exception as e:
print(f"\\n❌ Training failed: {{str(e)}}")
raise e
if __name__ == "__main__":
# Execute training
results = train_model()
print("\\n" + "="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"Model Type: {{results['model_type']}}")
print(f"Version: {{results['version']}}")
print(f"Accuracy: {{results['performance']['accuracy']:.4f}}")
print(f"F1 Score: {{results['performance']['f1_score']:.4f}}")
print("="*50)
'''
return code
@tool
def train_sklearn_classifier(
self,
input_path: str,
target_column: str,
model_type: str
) -> FormattedResponse:
"""
Trains a specified classification model, evaluates its performance, saves
both the model artifact and metrics, and generates reproducible training code.
Args:
input_path (str): The path to the processed dataset (e.g., 'processed/cleaned_data.csv').
target_column (str): The name of the column to be predicted.
model_type (str): The type of model to train. Must be one of: 'RandomForest', 'LogisticRegression'.
Returns:
FormattedResponse: Structured confirmation with metrics and artifact references.
"""
supported_models = {
'RandomForest': RandomForestClassifier(random_state=42),
'LogisticRegression': LogisticRegression(random_state=42)
}
if model_type not in supported_models:
response = make_text_response(
f"Model type '{model_type}' is not supported. Choose from {list(supported_models.keys())}.",
severity=Severity.ERROR,
)
response.summary = "Unsupported model type"
response.done = True
return response
try:
# Get current model version
current_version = 1
if self.project_service:
current_version = self.project_service.get_latest_version(
self.user_id, self.project_id, "model"
) + 1
# 1. Load Data
source_path = self._get_base_path() + "/" + input_path
df = self.storage.load_dataframe(source_path)
# 2. Prepare Data (Train-Test Split)
X = df.drop(columns=[target_column])
y = df[target_column]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. Train Model
model = supported_models[model_type]
model.fit(X_train, y_train)
# 4. Evaluate Model
y_pred = model.predict(X_test)
metrics = {
"model_type": model_type,
"version": current_version,
"timestamp": datetime.now().isoformat(),
"data_info": {
"total_samples": len(df),
"training_samples": len(X_train),
"testing_samples": len(X_test),
"features": list(X.columns),
"target_column": target_column
},
"performance": {
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred, average='weighted'),
"recall": recall_score(y_test, y_pred, average='weighted'),
"f1_score": f1_score(y_test, y_pred, average='weighted')
}
}
# 5. Generate Training Code
training_code = self.generate_training_code(
input_path=input_path,
target_column=target_column,
model_type=model_type,
version=current_version
)
# 6. Save Artifacts with versioning
model_path = f"{self._get_base_path('models')}/{model_type}_model_v{current_version}.joblib"
metrics_path = f"{self._get_base_path('models')}/{model_type}_metrics_v{current_version}.json"
code_path = f"{self._get_base_path('models')}/training_code_v{current_version}.py"
model_buffer = io.BytesIO()
joblib.dump(model, model_buffer)
model_info = self.storage.save_bytes(
model_buffer.getvalue(),
model_path,
content_type="application/octet-stream",
)
metrics_info = self.storage.save_json(metrics, metrics_path)
metrics_info.metadata.update({"model_type": model_type, "version": current_version})
code_info = self.storage.save_text(training_code, code_path)
code_info.metadata.update({"model_type": model_type, "version": current_version})
# Record artifacts in project registry for lifecycle tracking
if self.project_service:
data_info = metrics.get("data_info", {})
common_metadata = {
"model_type": model_type,
"target_column": target_column,
"features": data_info.get("features", []),
}
performance = metrics.get("performance", {})
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.MODEL_ARTIFACT,
current_version,
model_info,
version_scope="model",
extra_metadata={**common_metadata, "performance": performance},
)
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.MODEL_METRICS,
current_version,
metrics_info,
version_scope="model",
extra_metadata={**common_metadata, "performance": performance},
)
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.TRAINING_CODE,
current_version,
code_info,
version_scope="model",
extra_metadata={
**common_metadata,
"lines_of_code": training_code.count("\n") + 1,
},
)
blocks = [
text_block(
f"Trained `{model_type}` model version {current_version}",
severity=Severity.SUCCESS,
),
metric_block("Accuracy", metrics["performance"]["accuracy"]),
metric_block("Precision", metrics["performance"]["precision"]),
metric_block("Recall", metrics["performance"]["recall"]),
metric_block("F1 Score", metrics["performance"]["f1_score"]),
text_block(
"**Artifacts saved**\n" + "\n".join(
[
f"- Model artifact: `{model_path}`",
f"- Metrics JSON: `{metrics_path}`",
f"- Training code: `{code_path}`",
]
),
severity=Severity.INFO,
block_id="training_artifacts",
),
simple_table(
[
{
"features": len(metrics["data_info"].get("features", [])),
"train_rows": metrics["data_info"].get("training_samples"),
"test_rows": metrics["data_info"].get("testing_samples"),
}
],
caption="Dataset split",
block_id="training_dataset_split",
),
text_block(
"**Next steps**\n- Review metrics JSON\n- Validate artifacts before deployment",
severity=Severity.INFO,
block_id="training_next_steps",
),
]
return FormattedResponse(
blocks=blocks,
summary=f"Trained {model_type} model v{current_version}",
correlation_id=model_info.path,
done=True,
)
except Exception as e:
raise FileOperationException(f"train model '{model_type}'", source_path, e)