Spaces:
Build error
Build error
| """ | |
| Model training script with MLflow tracking | |
| Trains strategy recommendation models | |
| """ | |
| import os | |
| import sys | |
| import pandas as pd | |
| import numpy as np | |
| import pickle | |
| import json | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score | |
| import mlflow | |
| import mlflow.sklearn | |
| import yaml | |
| def load_params(): | |
| """Load parameters from params.yaml""" | |
| with open("params.yaml", "r") as f: | |
| return yaml.safe_load(f) | |
| def prepare_features(df: pd.DataFrame): | |
| """Prepare features for model training""" | |
| features = ["sma_10", "sma_20", "rsi", "volatility", "price_position"] | |
| X = df[features].fillna(0) | |
| return X | |
| def create_labels(df: pd.DataFrame, strategy_type: str = "TOP"): | |
| """Create labels based on strategy rules""" | |
| if strategy_type == "TOP": | |
| # TOP strategy: buy when price position > 70, RSI 50-70 | |
| y = ((df["price_position"] > 70) & | |
| (df["rsi"] > 50) & (df["rsi"] < 70)).astype(int) | |
| else: # BOTTOM | |
| # BOTTOM strategy: buy when price position < 30, RSI < 30 | |
| y = ((df["price_position"] < 30) & (df["rsi"] < 30)).astype(int) | |
| return y | |
| def main(): | |
| """Main training function""" | |
| params = load_params() | |
| model_params = params["model"]["params"] | |
| # Load data | |
| df = pd.read_parquet("data/processed/indicators.parquet") | |
| df = df.dropna(subset=["rsi", "sma_10", "sma_20"]) | |
| # Prepare features | |
| X = prepare_features(df) | |
| # Create output directory | |
| os.makedirs("models", exist_ok=True) | |
| os.makedirs("metrics", exist_ok=True) | |
| # MLflow setup | |
| mlflow.set_tracking_uri(params["mlops"]["mlflow"]["tracking_uri"]) | |
| mlflow.set_experiment(params["mlops"]["mlflow"]["experiment_name"]) | |
| results = {} | |
| # Train models for both strategies | |
| for strategy_type in ["TOP", "BOTTOM"]: | |
| with mlflow.start_run(run_name=f"{strategy_type}_strategy"): | |
| # Create labels | |
| y = create_labels(df, strategy_type) | |
| if y.sum() < 10: # Need minimum samples | |
| print(f"Not enough samples for {strategy_type} strategy") | |
| continue | |
| # Split data | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=params["model"]["random_state"] | |
| ) | |
| # Train model | |
| model = RandomForestClassifier( | |
| n_estimators=model_params["n_estimators"], | |
| max_depth=model_params["max_depth"], | |
| random_state=params["model"]["random_state"] | |
| ) | |
| model.fit(X_train, y_train) | |
| # Evaluate | |
| y_pred = model.predict(X_test) | |
| accuracy = accuracy_score(y_test, y_pred) | |
| precision = precision_score(y_test, y_pred, zero_division=0) | |
| recall = recall_score(y_test, y_pred, zero_division=0) | |
| f1 = f1_score(y_test, y_pred, zero_division=0) | |
| # Log to MLflow | |
| mlflow.log_params(model_params) | |
| mlflow.log_param("strategy_type", strategy_type) | |
| mlflow.log_metrics({ | |
| "accuracy": accuracy, | |
| "precision": precision, | |
| "recall": recall, | |
| "f1_score": f1 | |
| }) | |
| mlflow.sklearn.log_model(model, f"{strategy_type.lower()}_model") | |
| # Save model | |
| model_path = f"models/{strategy_type.lower()}_strategy_model.pkl" | |
| with open(model_path, "wb") as f: | |
| pickle.dump(model, f) | |
| results[strategy_type] = { | |
| "accuracy": float(accuracy), | |
| "precision": float(precision), | |
| "recall": float(recall), | |
| "f1_score": float(f1) | |
| } | |
| print(f"{strategy_type} Strategy - Accuracy: {accuracy:.3f}, F1: {f1:.3f}") | |
| # Save metadata | |
| metadata = { | |
| "models": list(results.keys()), | |
| "metrics": results, | |
| "training_date": pd.Timestamp.now().isoformat() | |
| } | |
| with open("models/model_metadata.json", "w") as f: | |
| json.dump(metadata, f, indent=2) | |
| # Save metrics for DVC | |
| with open("metrics/model_metrics.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print("Training complete!") | |
| if __name__ == "__main__": | |
| main() | |