Charlie81 commited on
Commit
842be01
·
1 Parent(s): a62f0f3

recursive json convert chatgpt

Browse files
Files changed (1) hide show
  1. scripts/evalexperts.py +17 -12
scripts/evalexperts.py CHANGED
@@ -547,27 +547,32 @@ def save_results(results: Dict[str, Any], expert_stats: Dict[str, Any], args) ->
547
  if args.model_type == "custom":
548
  metadata["routing_type"] = "top-k (default)"
549
 
550
- def convert_for_json(obj):
551
- """Recursively convert objects to JSON-serializable formats."""
552
  if isinstance(obj, (np.integer, np.floating)):
553
  return int(obj) if isinstance(obj, np.integer) else float(obj)
554
  elif isinstance(obj, np.ndarray):
555
  return obj.tolist()
556
- elif isinstance(obj, (torch.Tensor, torch.dtype)):
557
- return str(obj) if isinstance(obj, torch.dtype) else obj.tolist()
558
- elif isinstance(obj, (dict, list, tuple, str, int, float, bool, type(None))):
 
 
 
 
 
 
559
  return obj
560
  else:
561
  return str(obj)
562
-
563
- # Convert all data to JSON-serializable format
564
- serializable_results = {
565
  "metadata": metadata,
566
  "task_results": results,
567
- "expert_statistics": {
568
- k: convert_for_json(v) for k, v in expert_stats.items()
569
- }
570
- }
571
 
572
  # Save to file
573
  with open(output_path, 'w') as f:
 
547
  if args.model_type == "custom":
548
  metadata["routing_type"] = "top-k (default)"
549
 
550
+ # Recursive conversion function
551
+ def recursive_convert(obj):
552
  if isinstance(obj, (np.integer, np.floating)):
553
  return int(obj) if isinstance(obj, np.integer) else float(obj)
554
  elif isinstance(obj, np.ndarray):
555
  return obj.tolist()
556
+ elif isinstance(obj, torch.Tensor):
557
+ return obj.cpu().tolist()
558
+ elif isinstance(obj, torch.dtype):
559
+ return str(obj)
560
+ elif isinstance(obj, dict):
561
+ return {k: recursive_convert(v) for k, v in obj.items()}
562
+ elif isinstance(obj, (list, tuple)):
563
+ return [recursive_convert(v) for v in obj]
564
+ elif isinstance(obj, (int, float, str, bool)) or obj is None:
565
  return obj
566
  else:
567
  return str(obj)
568
+
569
+ # Convert everything
570
+ serializable_results = recursive_convert({
571
  "metadata": metadata,
572
  "task_results": results,
573
+ "expert_statistics": expert_stats
574
+ })
575
+
 
576
 
577
  # Save to file
578
  with open(output_path, 'w') as f: