Charlie81 commited on
Commit
84d0925
·
1 Parent(s): 842be01

push debugging info for evalexperts

Browse files
Files changed (1) hide show
  1. scripts/evalexperts.py +52 -25
scripts/evalexperts.py CHANGED
@@ -25,7 +25,7 @@ from lm_eval.models.huggingface import HFLM
25
 
26
  # Set up logging
27
  logging.basicConfig(
28
- level=logging.INFO,
29
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
  )
31
  logger = logging.getLogger(__name__)
@@ -90,15 +90,22 @@ class ExpertTrackingHFLM(HFLM):
90
  return expert_hook
91
 
92
  def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
93
- topk_probs: torch.Tensor, num_regular_experts: int,
94
- num_small_experts: int, batch_size: int, seq_len: int):
95
- """Update expert usage statistics."""
 
 
 
 
 
 
96
  # Flatten the batch and sequence dimensions
97
  topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
98
  topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
99
 
100
  # Initialize layer stats if not present
101
  if layer_idx not in self.expert_stats['layer_stats']:
 
102
  self.expert_stats['layer_stats'][layer_idx] = {
103
  'total_tokens': 0,
104
  'regular_expert_counts': [0] * num_regular_experts,
@@ -110,40 +117,60 @@ class ExpertTrackingHFLM(HFLM):
110
  layer_stats = self.expert_stats['layer_stats'][layer_idx]
111
  num_tokens = topk_experts_flat.size(0)
112
 
113
- # Update global stats
114
- self.expert_stats['total_tokens'] += num_tokens
115
-
116
- # Update layer stats
117
- layer_stats['total_tokens'] += num_tokens
 
118
 
119
  # Track regular experts
 
120
  for expert_idx in range(num_regular_experts):
121
  mask = (topk_experts_flat == expert_idx)
122
  count = mask.sum().item()
123
- load = topk_probs_flat[mask].sum().item()
124
-
125
- layer_stats['regular_expert_counts'][expert_idx] += count
126
- layer_stats['regular_expert_load'][expert_idx] += load
127
-
128
- if expert_idx not in self.expert_stats['regular_expert_usage']:
129
- self.expert_stats['regular_expert_usage'][expert_idx] = 0
130
- self.expert_stats['regular_expert_usage'][expert_idx] += count
 
 
 
131
 
132
  # Track small experts if they exist
133
  if num_small_experts > 0:
 
134
  for expert_idx in range(num_small_experts):
135
  small_expert_num = expert_idx + num_regular_experts
136
  mask = (topk_experts_flat == small_expert_num)
137
  count = mask.sum().item()
138
- load = topk_probs_flat[mask].sum().item()
139
 
140
- layer_stats['small_expert_counts'][expert_idx] += count
141
- layer_stats['small_expert_load'][expert_idx] += load
142
-
143
- if expert_idx not in self.expert_stats['small_expert_usage']:
144
- self.expert_stats['small_expert_usage'][expert_idx] = 0
145
- self.expert_stats['small_expert_usage'][expert_idx] += count
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def get_expert_stats(self) -> Dict[str, Any]:
148
  """Return expert usage statistics in a serializable format."""
149
  def convert(obj):
 
25
 
26
  # Set up logging
27
  logging.basicConfig(
28
+ level=logging.DEBUG, # Changed from INFO to DEBUG
29
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
  )
31
  logger = logging.getLogger(__name__)
 
90
  return expert_hook
91
 
92
  def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
93
+ topk_probs: torch.Tensor, num_regular_experts: int,
94
+ num_small_experts: int, batch_size: int, seq_len: int):
95
+ """Update expert usage statistics with debug logging."""
96
+ # Debug: Print input parameters
97
+ logger.debug(f"\n{'='*40}")
98
+ logger.debug(f"Updating stats for layer {layer_idx}")
99
+ logger.debug(f"Input shapes - experts: {topk_experts.shape}, probs: {topk_probs.shape}")
100
+ logger.debug(f"Num experts - regular: {num_regular_experts}, small: {num_small_experts}")
101
+
102
  # Flatten the batch and sequence dimensions
103
  topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
104
  topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
105
 
106
  # Initialize layer stats if not present
107
  if layer_idx not in self.expert_stats['layer_stats']:
108
+ logger.debug(f"Initializing new layer stats with {num_regular_experts} regular and {num_small_experts} small experts")
109
  self.expert_stats['layer_stats'][layer_idx] = {
110
  'total_tokens': 0,
111
  'regular_expert_counts': [0] * num_regular_experts,
 
117
  layer_stats = self.expert_stats['layer_stats'][layer_idx]
118
  num_tokens = topk_experts_flat.size(0)
119
 
120
+ # Debug: Print current layer stats structure
121
+ logger.debug(f"Current layer stats structure: {layer_stats.keys()}")
122
+ if layer_stats['small_expert_counts'] is None:
123
+ logger.debug("Small expert counts is None - no small experts initialized")
124
+ else:
125
+ logger.debug(f"Small expert counts length: {len(layer_stats['small_expert_counts'])}")
126
 
127
  # Track regular experts
128
+ regular_expert_used = False
129
  for expert_idx in range(num_regular_experts):
130
  mask = (topk_experts_flat == expert_idx)
131
  count = mask.sum().item()
132
+ if count > 0:
133
+ regular_expert_used = True
134
+ layer_stats['regular_expert_counts'][expert_idx] += count
135
+ layer_stats['regular_expert_load'][expert_idx] += topk_probs_flat[mask].sum().item()
136
+
137
+ if expert_idx not in self.expert_stats['regular_expert_usage']:
138
+ self.expert_stats['regular_expert_usage'][expert_idx] = 0
139
+ self.expert_stats['regular_expert_usage'][expert_idx] += count
140
+
141
+ # Debug: Regular expert usage
142
+ logger.debug(f"Regular experts used this batch: {regular_expert_used}")
143
 
144
  # Track small experts if they exist
145
  if num_small_experts > 0:
146
+ small_expert_used = False
147
  for expert_idx in range(num_small_experts):
148
  small_expert_num = expert_idx + num_regular_experts
149
  mask = (topk_experts_flat == small_expert_num)
150
  count = mask.sum().item()
 
151
 
152
+ if count > 0:
153
+ small_expert_used = True
154
+ layer_stats['small_expert_counts'][expert_idx] += count
155
+ layer_stats['small_expert_load'][expert_idx] += topk_probs_flat[mask].sum().item()
156
+
157
+ if expert_idx not in self.expert_stats['small_expert_usage']:
158
+ self.expert_stats['small_expert_usage'][expert_idx] = 0
159
+ self.expert_stats['small_expert_usage'][expert_idx] += count
160
+
161
+ # Debug: Small expert usage
162
+ logger.debug(f"Small experts used this batch: {small_expert_used}")
163
+ if not small_expert_used:
164
+ logger.debug(f"Top-k experts sample: {topk_experts_flat[:5].tolist()}")
165
+ logger.debug(f"Num regular experts: {num_regular_experts}, looking for experts >= this number")
166
+ else:
167
+ logger.debug("No small experts configured for this layer")
168
+
169
+ # Update token counts
170
+ self.expert_stats['total_tokens'] += num_tokens
171
+ layer_stats['total_tokens'] += num_tokens
172
+ logger.debug(f"Updated token counts - layer: {layer_stats['total_tokens']}, total: {self.expert_stats['total_tokens']}")
173
+
174
  def get_expert_stats(self) -> Dict[str, Any]:
175
  """Return expert usage statistics in a serializable format."""
176
  def convert(obj):