SImAge / models /tabular /base.py
AaronBlare
Initial commit
87da3f2
from typing import Any, List, Dict
from torchmetrics import MetricCollection
import wandb
import pytorch_lightning as pl
import torch
from models.metrics import get_cls_pred_metrics, get_cls_prob_metrics, get_reg_metrics
import numpy as np
import shap
import pandas as pd
def get_model_framework_dict():
model_framework = {
"elastic_net": "stand_alone",
"logistic_regression": "stand_alone",
"svm": "stand_alone",
"xgboost": "stand_alone",
"catboost": "stand_alone",
"lightgbm": "stand_alone",
"widedeep_tab_mlp": "pytorch",
"widedeep_tab_resnet": "pytorch",
"widedeep_tab_net": "pytorch",
"widedeep_tab_transformer": "pytorch",
"widedeep_ft_transformer": "pytorch",
"widedeep_saint": "pytorch",
"widedeep_tab_fastformer": "pytorch",
"widedeep_tab_perceiver": "pytorch",
"pytorch_tabular_autoint": "pytorch",
"pytorch_tabular_tabnet": "pytorch",
"pytorch_tabular_node": "pytorch",
"pytorch_tabular_category_embedding": "pytorch",
"pytorch_tabular_ft_transformer": "pytorch",
"pytorch_tabular_tab_transformer": "pytorch",
"nbm_spam_spam": "pytorch",
"nbm_spam_nam": "pytorch",
"nbm_spam_nbm": "pytorch",
"arm_net_models": "pytorch",
"danet": "pytorch",
"nam": "pytorch",
"stg": "pytorch",
"coxnet": "stand_alone"
}
return model_framework
class BaseModel(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters(logger=False)
self.produce_probabilities = False
self.produce_importance = False
if self.hparams.task == "classification":
self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
if self.hparams.output_dim < 2:
raise ValueError(f"Classification with {self.hparams.output_dim} classes")
self.metrics = get_cls_pred_metrics(self.hparams.output_dim)
self.metrics = {f'{k}_pl': v for k, v in self.metrics.items()}
self.metrics_dict = {k:v[0] for k,v in self.metrics.items()}
self.metrics_prob = get_cls_prob_metrics(self.hparams.output_dim)
self.metrics_prob = {f'{k}_pl': v for k, v in self.metrics_prob.items()}
self.metrics_prob_dict = {k:v[0] for k,v in self.metrics_prob.items()}
elif self.hparams.task == "regression":
if self.hparams.loss_type == "MSE":
self.loss_fn = torch.nn.MSELoss(reduction='mean')
elif self.hparams.loss_type == "L1Loss":
self.loss_fn = torch.nn.L1Loss(reduction='mean')
else:
raise ValueError("Unsupported loss_type")
self.metrics = get_reg_metrics()
self.metrics = {f'{k}_pl': v for k, v in self.metrics.items()}
self.metrics_dict = {k: v[0] for k, v in self.metrics.items()}
self.metrics_prob_dict = {}
self.metrics_trn = MetricCollection(self.metrics_dict)
self.metrics_trn_prob = MetricCollection(self.metrics_prob_dict)
self.metrics_val = self.metrics_trn.clone()
self.metrics_val_prob = self.metrics_trn_prob.clone()
self.metrics_tst = self.metrics_trn.clone()
self.metrics_tst_prob = self.metrics_trn_prob.clone()
def on_train_start(self):
# by default lightning executes validation step sanity checks before training starts,
# so we need to make sure all MaxMetric doesn't store accuracy from these checks
# self.max_metric.reset()
pass
def on_fit_start(self) -> None:
if wandb.run is not None:
for stage_type in ['trn', 'val', 'tst']:
for m in self.metrics:
wandb.define_metric(f"{stage_type}/{m}", summary=self.metrics[m][1])
if self.hparams.task == "classification":
for m in self.metrics_prob:
wandb.define_metric(f"{stage_type}/{m}", summary=self.metrics_prob[m][1])
wandb.define_metric(f"{stage_type}/loss", summary='min')
def calc_out_and_loss(self, out, y, stage):
loss = self.loss_fn(out, y)
return out, loss
def forward(self, batch):
pass
def forward_train(self, batch):
return self.forward(batch)
def forward_eval(self, batch):
return self.forward(batch)
def step(self, batch: Dict, stage:str):
y = batch["target"]
batch_size = y.size(0)
if self.hparams.task == "regression":
y = y.view(batch_size, -1)
if stage == "trn":
out = self.forward_train(batch)
else:
out = self.forward_eval(batch)
out, loss = self.calc_out_and_loss(out, y, stage)
logs = {"loss": loss}
non_logs = {}
if self.hparams.task == "classification":
probs = torch.softmax(out, dim=1)
preds = torch.argmax(out, dim=1)
non_logs["preds"] = preds
non_logs["targets"] = y
if stage == "trn":
logs.update(self.metrics_trn(preds, y))
try:
logs.update(self.metrics_trn_prob(probs, y))
except ValueError:
pass
elif stage == "val":
logs.update(self.metrics_val(preds, y))
try:
logs.update(self.metrics_val_prob(probs, y))
except ValueError:
pass
elif stage == "tst":
logs.update(self.metrics_tst(preds, y))
try:
logs.update(self.metrics_tst_prob(probs, y))
except ValueError:
pass
elif self.hparams.task == "regression":
if stage == "trn":
logs.update(self.metrics_trn(out, y))
elif stage == "val":
logs.update(self.metrics_val(out, y))
elif stage == "tst":
logs.update(self.metrics_tst(out, y))
return loss, logs, non_logs
def training_step(self, batch: Dict, batch_idx: int):
loss, logs, non_logs = self.step(batch, "trn")
d = {f"trn/{k}": v for k, v in logs.items()}
self.log_dict(d, on_step=False, on_epoch=True, logger=True)
logs.update(non_logs)
return logs
def training_epoch_end(self, outputs: List[Any]):
pass
def validation_step(self, batch: Dict, batch_idx: int):
loss, logs, non_logs = self.step(batch, "val")
d = {f"val/{k}": v for k, v in logs.items()}
self.log_dict(d, on_step=False, on_epoch=True, logger=True)
logs.update(non_logs)
return logs
def validation_epoch_end(self, outputs: List[Any]):
pass
def test_step(self, batch: Dict, batch_idx: int):
loss, logs, non_logs = self.step(batch, "tst")
d = {f"tst/{k}": v for k, v in logs.items()}
self.log_dict(d, on_step=False, on_epoch=True, logger=True)
logs.update(non_logs)
return logs
def test_epoch_end(self, outputs: List[Any]):
pass
def predict_step(self, batch: Dict, batch_idx):
out = self.forward(batch)
return out
def on_epoch_end(self):
for m in self.metrics_dict:
self.metrics_trn[m].reset()
self.metrics_val[m].reset()
self.metrics_tst[m].reset()
for m in self.metrics_prob_dict:
self.metrics_trn_prob[m].reset()
self.metrics_val_prob[m].reset()
self.metrics_tst_prob[m].reset()
def get_feature_importance(self, data, features, method="shap_kernel"):
if method.startswith("shap"):
if self.hparams.task == "regression":
def predict_func(X):
batch = {
'all': torch.from_numpy(np.float32(X[:, features['all_ids']])),
'continuous': torch.from_numpy(np.float32(X[:, features['con_ids']])),
'categorical': torch.from_numpy(np.int32(X[:, features['cat_ids']])),
}
tmp = self.forward(batch)
return tmp.cpu().detach().numpy()
if method == "shap_kernel":
explainer = shap.KernelExplainer(predict_func, data)
shap_values = explainer.shap_values(data)
if isinstance(shap_values, list):
shap_values = shap_values[0]
if method == "shap_sampling":
explainer = shap.SamplingExplainer(predict_func, data)
shap_values = explainer.shap_values(data)
if isinstance(shap_values, list):
shap_values = shap_values[0]
elif method == "shap_deep":
explainer = shap.DeepExplainer(self, torch.from_numpy(data))
shap_values = explainer.shap_values(torch.from_numpy(data))
else:
raise ValueError(f"Unsupported feature importance method: {method}")
importance_values = np.mean(np.abs(shap_values), axis=0)
elif self.hparams.task == "classification":
def predict_func(X):
self.produce_probabilities = True
batch = {
'all': torch.from_numpy(np.float32(X[:, features['all_ids']])),
'continuous': torch.from_numpy(np.float32(X[:, features['con_ids']])),
'categorical': torch.from_numpy(np.int32(X[:, features['cat_ids']])),
}
tmp = self.forward(batch)
return tmp.cpu().detach().numpy()
if method == "shap_kernel":
explainer = shap.KernelExplainer(predict_func, data)
shap_values = explainer.shap_values(data)
if method == "shap_sampling":
explainer = shap.SamplingExplainer(predict_func, data)
shap_values = explainer.shap_values(data)
elif method == "shap_deep":
explainer = shap.DeepExplainer(self, torch.from_numpy(data))
shap_values = explainer.shap_values(torch.from_numpy(data))
else:
raise ValueError(f"Unsupported feature importance method: {method}")
importance_values = np.zeros(len(features['all']))
for cl_id in range(len(shap_values)):
importance_values += np.mean(np.abs(shap_values[cl_id]), axis=0)
else:
raise ValueError("Unsupported task")
elif method == "none":
importance_values = np.zeros(len(features['all']))
else:
raise ValueError(f"Unsupported feature importance method: {method}")
feature_importances = pd.DataFrame.from_dict(
{
'feature': features['all'],
'importance': importance_values
}
)
return feature_importances
def configure_optimizers(self):
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
See examples here:
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
optimizer = torch.optim.Adam(
params=self.parameters(),
lr=self.hparams.optimizer_lr,
weight_decay=self.hparams.optimizer_weight_decay
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer,
step_size=self.hparams.scheduler_step_size,
gamma=self.hparams.scheduler_gamma
)
return (
{
"optimizer": optimizer,
"lr_scheduler": scheduler
}
)