Charlie81 commited on
Commit
870d3db
·
1 Parent(s): 3e8c5b1

handle JSON serialization

Browse files
Files changed (1) hide show
  1. scripts/evalexperts.py +71 -66
scripts/evalexperts.py CHANGED
@@ -89,61 +89,61 @@ class ExpertTrackingHFLM(HFLM):
89
 
90
  return expert_hook
91
 
92
- def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
93
- topk_probs: torch.Tensor, num_regular_experts: int,
94
- num_small_experts: int, batch_size: int, seq_len: int):
95
- """Update expert usage statistics."""
96
- # Flatten the batch and sequence dimensions
97
- topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
98
- topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
99
-
100
- # Initialize layer stats if not present
101
- if layer_idx not in self.expert_stats['layer_stats']:
102
- self.expert_stats['layer_stats'][layer_idx] = {
103
- 'total_tokens': 0,
104
- 'regular_expert_counts': torch.zeros(num_regular_experts, dtype=torch.long),
105
- 'small_expert_counts': torch.zeros(num_small_experts, dtype=torch.long) if num_small_experts > 0 else None,
106
- 'regular_expert_load': torch.zeros(num_regular_experts, dtype=torch.float),
107
- 'small_expert_load': torch.zeros(num_small_experts, dtype=torch.float) if num_small_experts > 0 else None
108
- }
109
-
110
- layer_stats = self.expert_stats['layer_stats'][layer_idx]
111
- num_tokens = topk_experts_flat.size(0)
112
-
113
- # Update global stats
114
- self.expert_stats['total_tokens'] += num_tokens
 
 
 
 
 
 
 
 
 
115
 
116
- # Update layer stats
117
- layer_stats['total_tokens'] += num_tokens
118
 
119
- # Track regular experts
120
- for expert_idx in range(num_regular_experts):
121
- mask = (topk_experts_flat == expert_idx)
 
 
 
 
 
 
122
  count = mask.sum().item()
123
  load = topk_probs_flat[mask].sum().item()
124
 
125
- layer_stats['regular_expert_counts'][expert_idx] += count
126
- layer_stats['regular_expert_load'][expert_idx] += load
127
 
128
- if expert_idx not in self.expert_stats['regular_expert_usage']:
129
- self.expert_stats['regular_expert_usage'][expert_idx] = 0
130
- self.expert_stats['regular_expert_usage'][expert_idx] += count
131
-
132
- # Track small experts if they exist
133
- if num_small_experts > 0:
134
- for expert_idx in range(num_small_experts):
135
- small_expert_num = expert_idx + num_regular_experts
136
- mask = (topk_experts_flat == small_expert_num)
137
- count = mask.sum().item()
138
- load = topk_probs_flat[mask].sum().item()
139
-
140
- layer_stats['small_expert_counts'][expert_idx] += count
141
- layer_stats['small_expert_load'][expert_idx] += load
142
-
143
- if expert_idx not in self.expert_stats['small_expert_usage']:
144
- self.expert_stats['small_expert_usage'][expert_idx] = 0
145
- self.expert_stats['small_expert_usage'][expert_idx] += count
146
-
147
  def get_expert_stats(self) -> Dict[str, Any]:
148
  """Return expert usage statistics in a serializable format."""
149
  stats = {
@@ -495,15 +495,7 @@ def run_evaluation(args) -> Tuple[Dict[str, Any], Dict[str, Any]]:
495
 
496
  def save_results(results: Dict[str, Any], expert_stats: Dict[str, Any], args) -> str:
497
  """
498
- Save evaluation results and expert statistics to file.
499
-
500
- Args:
501
- results: Evaluation results
502
- expert_stats: Expert usage statistics
503
- args: Parsed command line arguments
504
-
505
- Returns:
506
- str: Path to saved results file
507
  """
508
  os.makedirs(args.output_dir, exist_ok=True)
509
 
@@ -514,10 +506,7 @@ def save_results(results: Dict[str, Any], expert_stats: Dict[str, Any], args) ->
514
  if len(args.tasks) > 3:
515
  tasks_str += f"_and_{len(args.tasks)-3}_more"
516
 
517
- if args.model_type == "custom":
518
- filename = f"{model_name}_custom_{tasks_str}_results_with_expert_stats.json"
519
- else:
520
- filename = f"{model_name}_transformers_{tasks_str}_results_with_expert_stats.json"
521
  else:
522
  filename = args.output_filename
523
 
@@ -534,7 +523,7 @@ def save_results(results: Dict[str, Any], expert_stats: Dict[str, Any], args) ->
534
  "num_fewshot": args.num_fewshot,
535
  "batch_size": args.batch_size,
536
  "device": args.device,
537
- "dtype": args.dtype,
538
  "limit": args.limit,
539
  }
540
 
@@ -542,15 +531,31 @@ def save_results(results: Dict[str, Any], expert_stats: Dict[str, Any], args) ->
542
  if args.model_type == "custom":
543
  metadata["routing_type"] = "top-k (default)"
544
 
545
- combined_results = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  "metadata": metadata,
547
  "task_results": results,
548
- "expert_statistics": expert_stats
 
 
549
  }
550
 
551
  # Save to file
552
  with open(output_path, 'w') as f:
553
- json.dump(combined_results, f, indent=2)
554
 
555
  logger.info(f"Results saved to {output_path}")
556
  return output_path
 
89
 
90
  return expert_hook
91
 
92
+ def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
93
+ topk_probs: torch.Tensor, num_regular_experts: int,
94
+ num_small_experts: int, batch_size: int, seq_len: int):
95
+ """Update expert usage statistics with serializable data types."""
96
+ # Flatten the batch and sequence dimensions
97
+ topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
98
+ topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
99
+
100
+ # Initialize layer stats if not present
101
+ if layer_idx not in self.expert_stats['layer_stats']:
102
+ self.expert_stats['layer_stats'][layer_idx] = {
103
+ 'total_tokens': 0,
104
+ 'regular_expert_counts': [0] * num_regular_experts, # Use list instead of tensor
105
+ 'small_expert_counts': [0] * num_small_experts if num_small_experts > 0 else None,
106
+ 'regular_expert_load': [0.0] * num_regular_experts,
107
+ 'small_expert_load': [0.0] * num_small_experts if num_small_experts > 0 else None
108
+ }
109
+
110
+ layer_stats = self.expert_stats['layer_stats'][layer_idx]
111
+ num_tokens = topk_experts_flat.size(0)
112
+
113
+ # Update global stats
114
+ self.expert_stats['total_tokens'] += num_tokens
115
+
116
+ # Update layer stats
117
+ layer_stats['total_tokens'] += num_tokens
118
+
119
+ # Track regular experts
120
+ for expert_idx in range(num_regular_experts):
121
+ mask = (topk_experts_flat == expert_idx)
122
+ count = mask.sum().item()
123
+ load = topk_probs_flat[mask].sum().item()
124
 
125
+ layer_stats['regular_expert_counts'][expert_idx] += count
126
+ layer_stats['regular_expert_load'][expert_idx] += load
127
 
128
+ if expert_idx not in self.expert_stats['regular_expert_usage']:
129
+ self.expert_stats['regular_expert_usage'][expert_idx] = 0
130
+ self.expert_stats['regular_expert_usage'][expert_idx] += count
131
+
132
+ # Track small experts if they exist
133
+ if num_small_experts > 0:
134
+ for expert_idx in range(num_small_experts):
135
+ small_expert_num = expert_idx + num_regular_experts
136
+ mask = (topk_experts_flat == small_expert_num)
137
  count = mask.sum().item()
138
  load = topk_probs_flat[mask].sum().item()
139
 
140
+ layer_stats['small_expert_counts'][expert_idx] += count
141
+ layer_stats['small_expert_load'][expert_idx] += load
142
 
143
+ if expert_idx not in self.expert_stats['small_expert_usage']:
144
+ self.expert_stats['small_expert_usage'][expert_idx] = 0
145
+ self.expert_stats['small_expert_usage'][expert_idx] += count
146
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def get_expert_stats(self) -> Dict[str, Any]:
148
  """Return expert usage statistics in a serializable format."""
149
  stats = {
 
495
 
496
  def save_results(results: Dict[str, Any], expert_stats: Dict[str, Any], args) -> str:
497
  """
498
+ Save evaluation results and expert statistics to file with proper serialization.
 
 
 
 
 
 
 
 
499
  """
500
  os.makedirs(args.output_dir, exist_ok=True)
501
 
 
506
  if len(args.tasks) > 3:
507
  tasks_str += f"_and_{len(args.tasks)-3}_more"
508
 
509
+ filename = f"{model_name}_{args.model_type}_{tasks_str}_results.json"
 
 
 
510
  else:
511
  filename = args.output_filename
512
 
 
523
  "num_fewshot": args.num_fewshot,
524
  "batch_size": args.batch_size,
525
  "device": args.device,
526
+ "dtype": str(args.dtype), # Convert dtype to string
527
  "limit": args.limit,
528
  }
529
 
 
531
  if args.model_type == "custom":
532
  metadata["routing_type"] = "top-k (default)"
533
 
534
+ def convert_for_json(obj):
535
+ """Recursively convert objects to JSON-serializable formats."""
536
+ if isinstance(obj, (np.integer, np.floating)):
537
+ return int(obj) if isinstance(obj, np.integer) else float(obj)
538
+ elif isinstance(obj, np.ndarray):
539
+ return obj.tolist()
540
+ elif isinstance(obj, (torch.Tensor, torch.dtype)):
541
+ return str(obj) if isinstance(obj, torch.dtype) else obj.tolist()
542
+ elif isinstance(obj, (dict, list, tuple, str, int, float, bool, type(None))):
543
+ return obj
544
+ else:
545
+ return str(obj)
546
+
547
+ # Convert all data to JSON-serializable format
548
+ serializable_results = {
549
  "metadata": metadata,
550
  "task_results": results,
551
+ "expert_statistics": {
552
+ k: convert_for_json(v) for k, v in expert_stats.items()
553
+ }
554
  }
555
 
556
  # Save to file
557
  with open(output_path, 'w') as f:
558
+ json.dump(serializable_results, f, indent=2)
559
 
560
  logger.info(f"Results saved to {output_path}")
561
  return output_path