GirishaBuilds01 commited on
Commit
cd8cded
·
verified ·
1 Parent(s): 6245dc5

Update core/profiler.py

Browse files
Files changed (1) hide show
  1. core/profiler.py +28 -26
core/profiler.py CHANGED
@@ -8,33 +8,35 @@ class ActivationProfiler:
8
  self.hooks = []
9
 
10
 
11
- def hook(module, input, output):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # 🔥 FIX: Handle tuple outputs safely
14
- if isinstance(output, tuple):
15
- output = output[0]
16
-
17
- if not isinstance(output, torch.Tensor):
18
- return
19
-
20
- act = output.detach().float()
21
-
22
- std = act.std().item()
23
- mean = act.mean().item()
24
- min_val = act.min().item()
25
- max_val = act.max().item()
26
-
27
- outliers = (act.abs() > 3 * std).float().mean().item()
28
-
29
- self.stats[name] = {
30
- "mean": mean,
31
- "std": std,
32
- "min": min_val,
33
- "max": max_val,
34
- "outlier_ratio": outliers
35
- }
36
-
37
- return hook
38
 
39
 
40
  def register(self):
 
8
  self.hooks = []
9
 
10
 
11
+ def hook_fn(self, name):
12
+ def hook(module, input, output):
13
+
14
+ # 🔥 FIX: Handle tuple outputs safely
15
+ if isinstance(output, tuple):
16
+ output = output[0]
17
+
18
+ if not isinstance(output, torch.Tensor):
19
+ return
20
+
21
+ act = output.detach().float()
22
+
23
+ std = act.std().item()
24
+ mean = act.mean().item()
25
+ min_val = act.min().item()
26
+ max_val = act.max().item()
27
+
28
+ outliers = (act.abs() > 3 * std).float().mean().item()
29
+
30
+ self.stats[name] = {
31
+ "mean": mean,
32
+ "std": std,
33
+ "min": min_val,
34
+ "max": max_val,
35
+ "outlier_ratio": outliers
36
+ }
37
+
38
+ return hook
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def register(self):