|
|
import os |
|
|
import pandas as pd |
|
|
import joblib |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.ensemble import RandomForestRegressor |
|
|
from sklearn.preprocessing import OneHotEncoder |
|
|
from sklearn.compose import ColumnTransformer |
|
|
from sklearn.pipeline import Pipeline |
|
|
from sklearn.metrics import mean_squared_error |
|
|
|
|
|
|
|
|
DATA_FILE = os.path.join(os.path.dirname(__file__), '..', 'data', 'thunderbird_market_trends.csv') |
|
|
MODEL_OUTPUT_FILE = os.path.join(os.path.dirname(__file__), '..', 'models', 'thunderbird_market_predictor_v1.joblib') |
|
|
|
|
|
def train_model(): |
|
|
print("--- Starting Thunderbird Market Predictor Training ---") |
|
|
|
|
|
|
|
|
try: |
|
|
df = pd.read_csv(DATA_FILE) |
|
|
print(f"β
Data loaded successfully. Shape: {df.shape}") |
|
|
except FileNotFoundError: |
|
|
print(f"β ERROR: Training data not found at {DATA_FILE}. Run the export script first.") |
|
|
return |
|
|
|
|
|
|
|
|
df['month'] = pd.to_datetime(df['month']) |
|
|
df['month_of_year'] = df['month'].dt.month |
|
|
|
|
|
X = df[['niche', 'trend_score', 'month_of_year']] |
|
|
y = df['successful_campaigns'] |
|
|
|
|
|
|
|
|
categorical_features = ['niche'] |
|
|
preprocessor = ColumnTransformer( |
|
|
transformers=[ |
|
|
('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features) |
|
|
], |
|
|
remainder='passthrough' |
|
|
) |
|
|
|
|
|
|
|
|
model = RandomForestRegressor(n_estimators=100, random_state=42, min_samples_leaf=2) |
|
|
|
|
|
|
|
|
pipeline = Pipeline(steps=[('preprocessor', preprocessor), |
|
|
('regressor', model)]) |
|
|
|
|
|
|
|
|
print("π Training the model...") |
|
|
pipeline.fit(X, y) |
|
|
print("β
Model training complete.") |
|
|
|
|
|
|
|
|
predictions = pipeline.predict(X) |
|
|
mse = mean_squared_error(y, predictions) |
|
|
print(f" - Model Evaluation (MSE on training data): {mse:.2f}") |
|
|
|
|
|
|
|
|
joblib.dump(pipeline, MODEL_OUTPUT_FILE) |
|
|
print(f"\nβ
Success! Trained model saved to: {MODEL_OUTPUT_FILE}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train_model() |