| from typing import Any, List, Dict |
| from torchmetrics import MetricCollection |
| import wandb |
| import pytorch_lightning as pl |
| import torch |
| from models.metrics import get_cls_pred_metrics, get_cls_prob_metrics, get_reg_metrics |
| import numpy as np |
| import shap |
| import pandas as pd |
|
|
|
|
| def get_model_framework_dict(): |
| model_framework = { |
| "elastic_net": "stand_alone", |
| "logistic_regression": "stand_alone", |
| "svm": "stand_alone", |
| "xgboost": "stand_alone", |
| "catboost": "stand_alone", |
| "lightgbm": "stand_alone", |
| "widedeep_tab_mlp": "pytorch", |
| "widedeep_tab_resnet": "pytorch", |
| "widedeep_tab_net": "pytorch", |
| "widedeep_tab_transformer": "pytorch", |
| "widedeep_ft_transformer": "pytorch", |
| "widedeep_saint": "pytorch", |
| "widedeep_tab_fastformer": "pytorch", |
| "widedeep_tab_perceiver": "pytorch", |
| "pytorch_tabular_autoint": "pytorch", |
| "pytorch_tabular_tabnet": "pytorch", |
| "pytorch_tabular_node": "pytorch", |
| "pytorch_tabular_category_embedding": "pytorch", |
| "pytorch_tabular_ft_transformer": "pytorch", |
| "pytorch_tabular_tab_transformer": "pytorch", |
| "nbm_spam_spam": "pytorch", |
| "nbm_spam_nam": "pytorch", |
| "nbm_spam_nbm": "pytorch", |
| "arm_net_models": "pytorch", |
| "danet": "pytorch", |
| "nam": "pytorch", |
| "stg": "pytorch", |
| "coxnet": "stand_alone" |
| } |
| return model_framework |
|
|
|
|
| class BaseModel(pl.LightningModule): |
|
|
| def __init__(self, **kwargs): |
| super().__init__() |
| self.save_hyperparameters(logger=False) |
|
|
| self.produce_probabilities = False |
| self.produce_importance = False |
|
|
| if self.hparams.task == "classification": |
| self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean') |
| if self.hparams.output_dim < 2: |
| raise ValueError(f"Classification with {self.hparams.output_dim} classes") |
| self.metrics = get_cls_pred_metrics(self.hparams.output_dim) |
| self.metrics = {f'{k}_pl': v for k, v in self.metrics.items()} |
| self.metrics_dict = {k:v[0] for k,v in self.metrics.items()} |
| self.metrics_prob = get_cls_prob_metrics(self.hparams.output_dim) |
| self.metrics_prob = {f'{k}_pl': v for k, v in self.metrics_prob.items()} |
| self.metrics_prob_dict = {k:v[0] for k,v in self.metrics_prob.items()} |
| elif self.hparams.task == "regression": |
| if self.hparams.loss_type == "MSE": |
| self.loss_fn = torch.nn.MSELoss(reduction='mean') |
| elif self.hparams.loss_type == "L1Loss": |
| self.loss_fn = torch.nn.L1Loss(reduction='mean') |
| else: |
| raise ValueError("Unsupported loss_type") |
| self.metrics = get_reg_metrics() |
| self.metrics = {f'{k}_pl': v for k, v in self.metrics.items()} |
| self.metrics_dict = {k: v[0] for k, v in self.metrics.items()} |
| self.metrics_prob_dict = {} |
|
|
| self.metrics_trn = MetricCollection(self.metrics_dict) |
| self.metrics_trn_prob = MetricCollection(self.metrics_prob_dict) |
| self.metrics_val = self.metrics_trn.clone() |
| self.metrics_val_prob = self.metrics_trn_prob.clone() |
| self.metrics_tst = self.metrics_trn.clone() |
| self.metrics_tst_prob = self.metrics_trn_prob.clone() |
|
|
| def on_train_start(self): |
| |
| |
| |
| pass |
|
|
| def on_fit_start(self) -> None: |
| if wandb.run is not None: |
| for stage_type in ['trn', 'val', 'tst']: |
| for m in self.metrics: |
| wandb.define_metric(f"{stage_type}/{m}", summary=self.metrics[m][1]) |
| if self.hparams.task == "classification": |
| for m in self.metrics_prob: |
| wandb.define_metric(f"{stage_type}/{m}", summary=self.metrics_prob[m][1]) |
| wandb.define_metric(f"{stage_type}/loss", summary='min') |
|
|
| def calc_out_and_loss(self, out, y, stage): |
| loss = self.loss_fn(out, y) |
| return out, loss |
|
|
| def forward(self, batch): |
| pass |
|
|
| def forward_train(self, batch): |
| return self.forward(batch) |
|
|
| def forward_eval(self, batch): |
| return self.forward(batch) |
|
|
| def step(self, batch: Dict, stage:str): |
| y = batch["target"] |
| batch_size = y.size(0) |
| if self.hparams.task == "regression": |
| y = y.view(batch_size, -1) |
|
|
| if stage == "trn": |
| out = self.forward_train(batch) |
| else: |
| out = self.forward_eval(batch) |
| out, loss = self.calc_out_and_loss(out, y, stage) |
|
|
| logs = {"loss": loss} |
| non_logs = {} |
| if self.hparams.task == "classification": |
| probs = torch.softmax(out, dim=1) |
| preds = torch.argmax(out, dim=1) |
| non_logs["preds"] = preds |
| non_logs["targets"] = y |
| if stage == "trn": |
| logs.update(self.metrics_trn(preds, y)) |
| try: |
| logs.update(self.metrics_trn_prob(probs, y)) |
| except ValueError: |
| pass |
| elif stage == "val": |
| logs.update(self.metrics_val(preds, y)) |
| try: |
| logs.update(self.metrics_val_prob(probs, y)) |
| except ValueError: |
| pass |
| elif stage == "tst": |
| logs.update(self.metrics_tst(preds, y)) |
| try: |
| logs.update(self.metrics_tst_prob(probs, y)) |
| except ValueError: |
| pass |
| elif self.hparams.task == "regression": |
| if stage == "trn": |
| logs.update(self.metrics_trn(out, y)) |
| elif stage == "val": |
| logs.update(self.metrics_val(out, y)) |
| elif stage == "tst": |
| logs.update(self.metrics_tst(out, y)) |
|
|
| return loss, logs, non_logs |
|
|
| def training_step(self, batch: Dict, batch_idx: int): |
| loss, logs, non_logs = self.step(batch, "trn") |
| d = {f"trn/{k}": v for k, v in logs.items()} |
| self.log_dict(d, on_step=False, on_epoch=True, logger=True) |
| logs.update(non_logs) |
| return logs |
|
|
| def training_epoch_end(self, outputs: List[Any]): |
| pass |
|
|
| def validation_step(self, batch: Dict, batch_idx: int): |
| loss, logs, non_logs = self.step(batch, "val") |
| d = {f"val/{k}": v for k, v in logs.items()} |
| self.log_dict(d, on_step=False, on_epoch=True, logger=True) |
| logs.update(non_logs) |
| return logs |
|
|
| def validation_epoch_end(self, outputs: List[Any]): |
| pass |
|
|
| def test_step(self, batch: Dict, batch_idx: int): |
| loss, logs, non_logs = self.step(batch, "tst") |
| d = {f"tst/{k}": v for k, v in logs.items()} |
| self.log_dict(d, on_step=False, on_epoch=True, logger=True) |
| logs.update(non_logs) |
| return logs |
|
|
| def test_epoch_end(self, outputs: List[Any]): |
| pass |
|
|
| def predict_step(self, batch: Dict, batch_idx): |
| out = self.forward(batch) |
| return out |
|
|
| def on_epoch_end(self): |
| for m in self.metrics_dict: |
| self.metrics_trn[m].reset() |
| self.metrics_val[m].reset() |
| self.metrics_tst[m].reset() |
| for m in self.metrics_prob_dict: |
| self.metrics_trn_prob[m].reset() |
| self.metrics_val_prob[m].reset() |
| self.metrics_tst_prob[m].reset() |
|
|
| def get_feature_importance(self, data, features, method="shap_kernel"): |
|
|
| if method.startswith("shap"): |
|
|
| if self.hparams.task == "regression": |
|
|
| def predict_func(X): |
| batch = { |
| 'all': torch.from_numpy(np.float32(X[:, features['all_ids']])), |
| 'continuous': torch.from_numpy(np.float32(X[:, features['con_ids']])), |
| 'categorical': torch.from_numpy(np.int32(X[:, features['cat_ids']])), |
| } |
| tmp = self.forward(batch) |
| return tmp.cpu().detach().numpy() |
|
|
| if method == "shap_kernel": |
| explainer = shap.KernelExplainer(predict_func, data) |
| shap_values = explainer.shap_values(data) |
| if isinstance(shap_values, list): |
| shap_values = shap_values[0] |
| if method == "shap_sampling": |
| explainer = shap.SamplingExplainer(predict_func, data) |
| shap_values = explainer.shap_values(data) |
| if isinstance(shap_values, list): |
| shap_values = shap_values[0] |
| elif method == "shap_deep": |
| explainer = shap.DeepExplainer(self, torch.from_numpy(data)) |
| shap_values = explainer.shap_values(torch.from_numpy(data)) |
| else: |
| raise ValueError(f"Unsupported feature importance method: {method}") |
|
|
| importance_values = np.mean(np.abs(shap_values), axis=0) |
|
|
| elif self.hparams.task == "classification": |
|
|
| def predict_func(X): |
| self.produce_probabilities = True |
| batch = { |
| 'all': torch.from_numpy(np.float32(X[:, features['all_ids']])), |
| 'continuous': torch.from_numpy(np.float32(X[:, features['con_ids']])), |
| 'categorical': torch.from_numpy(np.int32(X[:, features['cat_ids']])), |
| } |
| tmp = self.forward(batch) |
| return tmp.cpu().detach().numpy() |
|
|
| if method == "shap_kernel": |
| explainer = shap.KernelExplainer(predict_func, data) |
| shap_values = explainer.shap_values(data) |
| if method == "shap_sampling": |
| explainer = shap.SamplingExplainer(predict_func, data) |
| shap_values = explainer.shap_values(data) |
| elif method == "shap_deep": |
| explainer = shap.DeepExplainer(self, torch.from_numpy(data)) |
| shap_values = explainer.shap_values(torch.from_numpy(data)) |
| else: |
| raise ValueError(f"Unsupported feature importance method: {method}") |
|
|
| importance_values = np.zeros(len(features['all'])) |
| for cl_id in range(len(shap_values)): |
| importance_values += np.mean(np.abs(shap_values[cl_id]), axis=0) |
|
|
| else: |
| raise ValueError("Unsupported task") |
|
|
| elif method == "none": |
| importance_values = np.zeros(len(features['all'])) |
| else: |
| raise ValueError(f"Unsupported feature importance method: {method}") |
|
|
| feature_importances = pd.DataFrame.from_dict( |
| { |
| 'feature': features['all'], |
| 'importance': importance_values |
| } |
| ) |
|
|
| return feature_importances |
|
|
| def configure_optimizers(self): |
| """Choose what optimizers and learning-rate schedulers to use in your optimization. |
| Normally you'd need one. But in the case of GANs or similar you might have multiple. |
| |
| See examples here: |
| https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers |
| """ |
| optimizer = torch.optim.Adam( |
| params=self.parameters(), |
| lr=self.hparams.optimizer_lr, |
| weight_decay=self.hparams.optimizer_weight_decay |
| ) |
| scheduler = torch.optim.lr_scheduler.StepLR( |
| optimizer=optimizer, |
| step_size=self.hparams.scheduler_step_size, |
| gamma=self.hparams.scheduler_gamma |
| ) |
|
|
| return ( |
| { |
| "optimizer": optimizer, |
| "lr_scheduler": scheduler |
| } |
| ) |
|
|