Spaces:
Sleeping
Sleeping
| import subprocess | |
| from pathlib import Path | |
| from typing import List | |
| import matplotlib.pyplot as plt | |
| import seaborn as sn | |
| import torch | |
| import wandb | |
| from pytorch_lightning import Callback, Trainer | |
| from pytorch_lightning.loggers import LoggerCollection, WandbLogger | |
| from pytorch_lightning.utilities import rank_zero_only | |
| from sklearn import metrics | |
| from sklearn.metrics import f1_score, precision_score, recall_score | |
| def get_wandb_logger(trainer: Trainer) -> WandbLogger: | |
| """Safely get Weights&Biases logger from Trainer.""" | |
| if trainer.fast_dev_run: | |
| raise Exception( | |
| "Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode." | |
| ) | |
| if isinstance(trainer.logger, WandbLogger): | |
| return trainer.logger | |
| if isinstance(trainer.logger, LoggerCollection): | |
| for logger in trainer.logger: | |
| if isinstance(logger, WandbLogger): | |
| return logger | |
| raise Exception( | |
| "You are using wandb related callback, but WandbLogger was not found for some reason..." | |
| ) | |
| class WatchModel(Callback): | |
| """Make wandb watch model at the beginning of the run.""" | |
| def __init__(self, log: str = "gradients", log_freq: int = 100): | |
| self.log = log | |
| self.log_freq = log_freq | |
| def on_train_start(self, trainer, pl_module): | |
| logger = get_wandb_logger(trainer=trainer) | |
| logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) | |
| class UploadCodeAsArtifact(Callback): | |
| """Upload all code files to wandb as an artifact, at the beginning of the run.""" | |
| def __init__(self, code_dir: str, use_git: bool = True): | |
| """ | |
| Args: | |
| code_dir: the code directory | |
| use_git: if using git, then upload all files that are not ignored by git. | |
| if not using git, then upload all '*.py' file | |
| """ | |
| self.code_dir = code_dir | |
| self.use_git = use_git | |
| def on_train_start(self, trainer, pl_module): | |
| logger = get_wandb_logger(trainer=trainer) | |
| experiment = logger.experiment | |
| code = wandb.Artifact("project-source", type="code") | |
| if self.use_git: | |
| # get .git folder | |
| # https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/ | |
| git_dir_path = Path( | |
| subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8") | |
| ).resolve() | |
| for path in Path(self.code_dir).resolve().rglob("*"): | |
| if ( | |
| path.is_file() | |
| # ignore files in .git | |
| and not str(path).startswith(str(git_dir_path)) # noqa: W503 | |
| # ignore files ignored by git | |
| and ( # noqa: W503 | |
| subprocess.run(["git", "check-ignore", "-q", str(path)]).returncode == 1 | |
| ) | |
| ): | |
| code.add_file(str(path), name=str(path.relative_to(self.code_dir))) | |
| else: | |
| for path in Path(self.code_dir).resolve().rglob("*.py"): | |
| code.add_file(str(path), name=str(path.relative_to(self.code_dir))) | |
| experiment.log_artifact(code) | |
| class UploadCheckpointsAsArtifact(Callback): | |
| """Upload checkpoints to wandb as an artifact, at the end of run.""" | |
| def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): | |
| self.ckpt_dir = ckpt_dir | |
| self.upload_best_only = upload_best_only | |
| def on_keyboard_interrupt(self, trainer, pl_module): | |
| self.on_train_end(trainer, pl_module) | |
| def on_train_end(self, trainer, pl_module): | |
| logger = get_wandb_logger(trainer=trainer) | |
| experiment = logger.experiment | |
| ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") | |
| if self.upload_best_only: | |
| ckpts.add_file(trainer.checkpoint_callback.best_model_path) | |
| else: | |
| for path in Path(self.ckpt_dir).rglob("*.ckpt"): | |
| ckpts.add_file(str(path)) | |
| experiment.log_artifact(ckpts) | |
| class LogConfusionMatrix(Callback): | |
| """Generate confusion matrix every epoch and send it to wandb. | |
| Expects validation step to return predictions and targets. | |
| """ | |
| def __init__(self): | |
| self.preds = [] | |
| self.targets = [] | |
| self.ready = True | |
| def on_sanity_check_start(self, trainer, pl_module) -> None: | |
| self.ready = False | |
| def on_sanity_check_end(self, trainer, pl_module): | |
| """Start executing this callback only after all validation sanity checks end.""" | |
| self.ready = True | |
| def on_validation_batch_end( | |
| self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx | |
| ): | |
| """Gather data from single batch.""" | |
| if self.ready: | |
| self.preds.append(outputs["preds"]) | |
| self.targets.append(outputs["targets"]) | |
| def on_validation_epoch_end(self, trainer, pl_module): | |
| """Generate confusion matrix.""" | |
| if self.ready: | |
| logger = get_wandb_logger(trainer) | |
| experiment = logger.experiment | |
| preds = torch.cat(self.preds).cpu().numpy() | |
| targets = torch.cat(self.targets).cpu().numpy() | |
| confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds) | |
| # set figure size | |
| plt.figure(figsize=(14, 8)) | |
| # set labels size | |
| sn.set(font_scale=1.4) | |
| # set font size | |
| sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g") | |
| # names should be uniqe or else charts from different experiments in wandb will overlap | |
| experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False) | |
| # according to wandb docs this should also work but it crashes | |
| # experiment.log(f{"confusion_matrix/{experiment.name}": plt}) | |
| # reset plot | |
| plt.clf() | |
| self.preds.clear() | |
| self.targets.clear() | |
| class LogF1PrecRecHeatmap(Callback): | |
| """Generate f1, precision, recall heatmap every epoch and send it to wandb. | |
| Expects validation step to return predictions and targets. | |
| """ | |
| def __init__(self, class_names: List[str] = None): | |
| self.preds = [] | |
| self.targets = [] | |
| self.ready = True | |
| def on_sanity_check_start(self, trainer, pl_module): | |
| self.ready = False | |
| def on_sanity_check_end(self, trainer, pl_module): | |
| """Start executing this callback only after all validation sanity checks end.""" | |
| self.ready = True | |
| def on_validation_batch_end( | |
| self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx | |
| ): | |
| """Gather data from single batch.""" | |
| if self.ready: | |
| self.preds.append(outputs["preds"]) | |
| self.targets.append(outputs["targets"]) | |
| def on_validation_epoch_end(self, trainer, pl_module): | |
| """Generate f1, precision and recall heatmap.""" | |
| if self.ready: | |
| logger = get_wandb_logger(trainer=trainer) | |
| experiment = logger.experiment | |
| preds = torch.cat(self.preds).cpu().numpy() | |
| targets = torch.cat(self.targets).cpu().numpy() | |
| f1 = f1_score(targets, preds, average=None) | |
| r = recall_score(targets, preds, average=None) | |
| p = precision_score(targets, preds, average=None) | |
| data = [f1, p, r] | |
| # set figure size | |
| plt.figure(figsize=(14, 3)) | |
| # set labels size | |
| sn.set(font_scale=1.2) | |
| # set font size | |
| sn.heatmap( | |
| data, | |
| annot=True, | |
| annot_kws={"size": 10}, | |
| fmt=".3f", | |
| yticklabels=["F1", "Precision", "Recall"], | |
| ) | |
| # names should be uniqe or else charts from different experiments in wandb will overlap | |
| experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False) | |
| # reset plot | |
| plt.clf() | |
| self.preds.clear() | |
| self.targets.clear() | |
| class LogImagePredictions(Callback): | |
| """Logs a validation batch and their predictions to wandb. | |
| Example adapted from: | |
| https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY | |
| """ | |
| def __init__(self, num_samples: int = 8): | |
| super().__init__() | |
| self.num_samples = num_samples | |
| self.ready = True | |
| def on_sanity_check_start(self, trainer, pl_module): | |
| self.ready = False | |
| def on_sanity_check_end(self, trainer, pl_module): | |
| """Start executing this callback only after all validation sanity checks end.""" | |
| self.ready = True | |
| def on_validation_epoch_end(self, trainer, pl_module): | |
| if self.ready: | |
| logger = get_wandb_logger(trainer=trainer) | |
| experiment = logger.experiment | |
| # get a validation batch from the validation dat loader | |
| val_samples = next(iter(trainer.datamodule.val_dataloader())) | |
| val_imgs, val_labels = val_samples | |
| # run the batch through the network | |
| val_imgs = val_imgs.to(device=pl_module.device) | |
| logits = pl_module(val_imgs) | |
| preds = torch.argmax(logits, dim=-1) | |
| # log the images as wandb Image | |
| experiment.log( | |
| { | |
| f"Images/{experiment.name}": [ | |
| wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") | |
| for x, pred, y in zip( | |
| val_imgs[: self.num_samples], | |
| preds[: self.num_samples], | |
| val_labels[: self.num_samples], | |
| ) | |
| ] | |
| } | |
| ) | |