GirishaBuilds01 commited on
Commit
6245dc5
·
verified ·
1 Parent(s): 60fe91b

Update core/profiler.py

Browse files
Files changed (1) hide show
  1. core/profiler.py +29 -17
core/profiler.py CHANGED
@@ -7,23 +7,35 @@ class ActivationProfiler:
7
  self.stats = defaultdict(dict)
8
  self.hooks = []
9
 
10
- def hook_fn(self, name):
11
- def hook(module, input, output):
12
- act = output.detach().float()
13
- std = act.std().item()
14
- mean = act.mean().item()
15
- min_val = act.min().item()
16
- max_val = act.max().item()
17
- outliers = (act.abs() > 3 * std).float().mean().item()
18
-
19
- self.stats[name] = {
20
- "mean": mean,
21
- "std": std,
22
- "min": min_val,
23
- "max": max_val,
24
- "outlier_ratio": outliers
25
- }
26
- return hook
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def register(self):
29
  for name, module in self.model.named_modules():
 
7
  self.stats = defaultdict(dict)
8
  self.hooks = []
9
 
10
+
11
+ def hook(module, input, output):
12
+
13
+ # 🔥 FIX: Handle tuple outputs safely
14
+ if isinstance(output, tuple):
15
+ output = output[0]
16
+
17
+ if not isinstance(output, torch.Tensor):
18
+ return
19
+
20
+ act = output.detach().float()
21
+
22
+ std = act.std().item()
23
+ mean = act.mean().item()
24
+ min_val = act.min().item()
25
+ max_val = act.max().item()
26
+
27
+ outliers = (act.abs() > 3 * std).float().mean().item()
28
+
29
+ self.stats[name] = {
30
+ "mean": mean,
31
+ "std": std,
32
+ "min": min_val,
33
+ "max": max_val,
34
+ "outlier_ratio": outliers
35
+ }
36
+
37
+ return hook
38
+
39
 
40
  def register(self):
41
  for name, module in self.model.named_modules():