KaiquanMah commited on
Commit
5eff4ab
ยท
verified ยท
1 Parent(s): b8bf9dd

added wandb

Browse files
Files changed (1) hide show
  1. main.py +34 -4
main.py CHANGED
@@ -1,10 +1,14 @@
1
  import argparse
2
  import os
3
- from data_loader import load_and_process_data
4
  from model_trainer import train_models
5
  from model_manager import save_models, load_models
6
  from model_predictor import predict
7
- from config import MODEL_DIR
 
 
 
 
8
  ## ===========================
9
  # MAIN FUNCTION
10
  # ===========================
@@ -19,16 +23,42 @@ def main(train=True, retrain=False):
19
 
20
  if train or retrain:
21
  print("\n๐Ÿš€ Training models...")
22
- models = train_models(X_train, y_train)
23
  save_models(models)
24
 
25
  else:
26
  print("\n๐Ÿš€ Loading existing models...")
27
  models = load_models()
28
 
29
- print("\n๐Ÿ” Making predictions...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  predictions = predict(models, test_df)
31
 
 
 
 
 
 
 
 
 
 
32
  # Save final predictions
33
  predictions.to_csv("final_predictions.csv", index=False)
34
  print("\nโœ… Predictions saved successfully as 'final_predictions.csv'!")
 
1
  import argparse
2
  import os
3
+ from data_loader import load_and_process_data, CATEGORICAL_COLUMNS
4
  from model_trainer import train_models
5
  from model_manager import save_models, load_models
6
  from model_predictor import predict
7
+ from config import MODEL_DIR, CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS
8
+ import wandb
9
+ from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report
10
+ import pandas as pd
11
+
12
  ## ===========================
13
  # MAIN FUNCTION
14
  # ===========================
 
23
 
24
  if train or retrain:
25
  print("\n๐Ÿš€ Training models...")
26
+ models = train_models(X_train, y_train, CATEGORICAL_COLUMNS)
27
  save_models(models)
28
 
29
  else:
30
  print("\n๐Ÿš€ Loading existing models...")
31
  models = load_models()
32
 
33
+
34
+ # add wandb, validation set scoring
35
+ param_grid = {"CATBOOST_PARAMS": CATBOOST_PARAMS,
36
+ "XGB_PARAMS": XGB_PARAMS,
37
+ "RF_PARAMS": RF_PARAMS}
38
+ os.getenv("WANDB_API_KEY")
39
+ run = wandb.init(project="is_click_predictor", config=param_grid)
40
+
41
+ print("\n๐Ÿ” Makings predictions for validation set...")
42
+ predictions_val = predict(models, X_val)
43
+ accuracy_val = accuracy_score(y_val, predictions_val["is_click_predicted"])
44
+ balanced_accuracy_val = balanced_accuracy_score(y_val, predictions_val["is_click_predicted"])
45
+ classification_report_val = classification_report(y_val, predictions_val["is_click_predicted"], output_dict=True)
46
+ classification_report_val = pd.DataFrame(classification_report_val).transpose()
47
+ predictions_val_table = wandb.Table(dataframe=predictions_val)
48
+ classification_report_val_table = wandb.Table(dataframe=classification_report_val)
49
+
50
+ print("\n๐Ÿ” Making predictions for test set...")
51
  predictions = predict(models, test_df)
52
 
53
+ # wandb logging
54
+ run.log({"param_grid": param_grid,
55
+ "accuracy_val": accuracy_val,
56
+ "balanced_accuracy_val": balanced_accuracy_val,
57
+ "classification_report_val_table": classification_report_val_table,
58
+ "predictions_val_table": predictions_val_table,
59
+ "y_val": y_val})
60
+ run.finish()
61
+
62
  # Save final predictions
63
  predictions.to_csv("final_predictions.csv", index=False)
64
  print("\nโœ… Predictions saved successfully as 'final_predictions.csv'!")