dynamic_alibi_pile_2layer / ops /layer_with_visualization.py
Lanni-ni's picture
add remote code + model files
c67d7da verified
import torch
import torch.nn
from typing import Dict, Any
class LayerWithVisualization(torch.nn.Module):
def __init__(self):
super().__init__()
self.visualization_enabled = False
def prepare(self):
# Should be called before the training step
pass
def plot(self, options: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError()
class LayerVisualizer:
def __init__(self, module: torch.nn.Module, options: Dict[str, Any] = {}):
self.modules = []
self.options = options
self.curr_options = None
for n, m in module.named_modules():
if isinstance(m, LayerWithVisualization):
self.modules.append((n, m))
def plot(self) -> Dict[str, Any]:
res = {}
for n, m in self.modules:
res.update({f"{n}/{k}": v for k, v in m.plot(self.curr_options).items()})
m.visualization_enabled = False
self.curr_options = None
return res
def prepare(self, options: Dict[str, Any] = {}):
self.curr_options = self.options.copy()
self.curr_options.update(options)
for _, m in self.modules:
m.prepare()
m.visualization_enabled = True