File size: 1,257 Bytes
15063d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
import torch.nn
from typing import Dict, Any
class LayerWithVisualization(torch.nn.Module):
def __init__(self):
super().__init__()
self.visualization_enabled = False
def prepare(self):
# Should be called before the training step
pass
def plot(self, options: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError()
class LayerVisualizer:
def __init__(self, module: torch.nn.Module, options: Dict[str, Any] = {}):
self.modules = []
self.options = options
self.curr_options = None
for n, m in module.named_modules():
if isinstance(m, LayerWithVisualization):
self.modules.append((n, m))
def plot(self) -> Dict[str, Any]:
res = {}
for n, m in self.modules:
res.update({f"{n}/{k}": v for k, v in m.plot(self.curr_options).items()})
m.visualization_enabled = False
self.curr_options = None
return res
def prepare(self, options: Dict[str, Any] = {}):
self.curr_options = self.options.copy()
self.curr_options.update(options)
for _, m in self.modules:
m.prepare()
m.visualization_enabled = True
|