Charlie81 commited on
Commit
61a401c
·
1 Parent(s): 870d3db

attribute error

Browse files
Files changed (1) hide show
  1. scripts/evalexperts.py +56 -1
scripts/evalexperts.py CHANGED
@@ -77,7 +77,7 @@ class ExpertTrackingHFLM(HFLM):
77
  )
78
 
79
  # Update statistics
80
- self._update_expert_stats(
81
  layer_idx=layer_idx,
82
  topk_experts=topk_experts,
83
  topk_probs=topk_probs,
@@ -89,6 +89,61 @@ class ExpertTrackingHFLM(HFLM):
89
 
90
  return expert_hook
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
93
  topk_probs: torch.Tensor, num_regular_experts: int,
94
  num_small_experts: int, batch_size: int, seq_len: int):
 
77
  )
78
 
79
  # Update statistics
80
+ self.update_expert_stats( # Changed from _update_expert_stats to update_expert_stats
81
  layer_idx=layer_idx,
82
  topk_experts=topk_experts,
83
  topk_probs=topk_probs,
 
89
 
90
  return expert_hook
91
 
92
+ def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor, # Renamed from _update_expert_stats
93
+ topk_probs: torch.Tensor, num_regular_experts: int,
94
+ num_small_experts: int, batch_size: int, seq_len: int):
95
+ """Update expert usage statistics."""
96
+ # Flatten the batch and sequence dimensions
97
+ topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
98
+ topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
99
+
100
+ # Initialize layer stats if not present
101
+ if layer_idx not in self.expert_stats['layer_stats']:
102
+ self.expert_stats['layer_stats'][layer_idx] = {
103
+ 'total_tokens': 0,
104
+ 'regular_expert_counts': [0] * num_regular_experts,
105
+ 'small_expert_counts': [0] * num_small_experts if num_small_experts > 0 else None,
106
+ 'regular_expert_load': [0.0] * num_regular_experts,
107
+ 'small_expert_load': [0.0] * num_small_experts if num_small_experts > 0 else None
108
+ }
109
+
110
+ layer_stats = self.expert_stats['layer_stats'][layer_idx]
111
+ num_tokens = topk_experts_flat.size(0)
112
+
113
+ # Update global stats
114
+ self.expert_stats['total_tokens'] += num_tokens
115
+
116
+ # Update layer stats
117
+ layer_stats['total_tokens'] += num_tokens
118
+
119
+ # Track regular experts
120
+ for expert_idx in range(num_regular_experts):
121
+ mask = (topk_experts_flat == expert_idx)
122
+ count = mask.sum().item()
123
+ load = topk_probs_flat[mask].sum().item()
124
+
125
+ layer_stats['regular_expert_counts'][expert_idx] += count
126
+ layer_stats['regular_expert_load'][expert_idx] += load
127
+
128
+ if expert_idx not in self.expert_stats['regular_expert_usage']:
129
+ self.expert_stats['regular_expert_usage'][expert_idx] = 0
130
+ self.expert_stats['regular_expert_usage'][expert_idx] += count
131
+
132
+ # Track small experts if they exist
133
+ if num_small_experts > 0:
134
+ for expert_idx in range(num_small_experts):
135
+ small_expert_num = expert_idx + num_regular_experts
136
+ mask = (topk_experts_flat == small_expert_num)
137
+ count = mask.sum().item()
138
+ load = topk_probs_flat[mask].sum().item()
139
+
140
+ layer_stats['small_expert_counts'][expert_idx] += count
141
+ layer_stats['small_expert_load'][expert_idx] += load
142
+
143
+ if expert_idx not in self.expert_stats['small_expert_usage']:
144
+ self.expert_stats['small_expert_usage'][expert_idx] = 0
145
+ self.expert_stats['small_expert_usage'][expert_idx] += count
146
+
147
  def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
148
  topk_probs: torch.Tensor, num_regular_experts: int,
149
  num_small_experts: int, batch_size: int, seq_len: int):