enginevibration / train_model.py
GitHub Actions
Deploy from GitHub Actions
a1edd1c
# --- train_model.py (Regression Version) ---
import pandas as pd
import numpy as np
import os
import xgboost as xgb
from sklearn.model_selection import train_test_split
def main():
"""Trains an XGBoost model to predict normal vibration."""
print("--- Starting Regression Model Training ---")
data_path = os.path.join('data', 'raw_data.csv')
if not os.path.exists(data_path):
print(f"Error: {data_path} not found. Please run generate_data.py")
return
df = pd.read_csv(data_path)
df.dropna(inplace=True)
# --- Feature & Target Definition ---
# The model will learn to predict the 'actual' vibration based on the conditions
features = ['rpm', 'ambient_temp_c', 'fuel_level_percent', 'sea_state']
target = 'az_vibration_actual'
X = df[features]
y = df[target]
# --- Model Training ---
# We will train on the 'healthy' portion of the data (first 60%)
# This teaches the model what "normal" looks like before the fault gets bad
healthy_cutoff = int(len(df) * 0.6)
X_train = X.iloc[:healthy_cutoff]
y_train = y.iloc[:healthy_cutoff]
# Initialize and train the XGBoost Regressor
model = xgb.XGBRegressor(
objective='reg:squarederror',
n_estimators=1000,
learning_rate=0.05,
max_depth=5,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
n_jobs=-1 # Use all available CPU cores
)
print("\nStarting model training...")
model.fit(X_train, y_train, verbose=False)
# --- Save Model and Calculate Anomaly Threshold ---
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
model.save_model(os.path.join(model_dir, 'vibration_regressor.json'))
print(f"\nModel saved to: {model_dir}/vibration_regressor.json")
# The key step: Use the trained model to predict what the vibration *should* have been
# for the healthy data, then find the error.
y_pred_healthy = model.predict(X_train)
errors = np.abs(y_train - y_pred_healthy)
# The anomaly threshold is a level of error that is unusual for a healthy machine
anomaly_threshold = np.percentile(errors, 99) * 1.5 # 99th percentile, with a buffer
np.save(os.path.join(model_dir, 'regression_anomaly_threshold.npy'), anomaly_threshold)
print(f"Anomaly threshold ({anomaly_threshold:.4f}) saved to: {model_dir}/regression_anomaly_threshold.npy")
if __name__ == "__main__":
main()