razmars commited on
Commit
525f838
·
verified ·
1 Parent(s): f773080

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +50 -12
modeling_super_linear.py CHANGED
@@ -355,28 +355,66 @@ class SparseNoisyMoE(nn.Module):
355
 
356
  load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
357
 
358
- expert_probs = F.softmax(self.gate_outputs, dim=1)
359
  expert_probs = expert_probs[1,:]
360
  # Plot the expert probabilities
361
  import matplotlib.pyplot as plt
362
 
363
- plt.figure(figsize=(10, 6))
364
- plt.bar(range(len(expert_probs)), expert_probs.detach().cpu().numpy())
365
- plt.xlabel('Expert Index')
366
- plt.ylabel('Probability')
367
- plt.title('Expert Selection Probabilities')
368
-
369
- # Create more descriptive x-axis labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  if hasattr(self, 'experts') and isinstance(getattr(self, 'experts', None), dict):
371
  # If experts are stored in a dictionary with meaningful keys
372
  expert_names = list(self.experts.keys())
373
- plt.xticks(range(len(expert_probs)), expert_names, rotation=45)
374
  else:
375
  # Default numbering if expert names aren't available
376
- plt.xticks(range(len(expert_probs)), [f'Expert {i}' for i in range(len(expert_probs))])
377
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  plt.tight_layout()
379
- plt.savefig('expert_probabilities.png')
380
  plt.close()
381
  print(expert_probs.shape)
382
 
 
355
 
356
  load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
357
 
 
358
  expert_probs = expert_probs[1,:]
359
  # Plot the expert probabilities
360
  import matplotlib.pyplot as plt
361
 
362
+ # Get expert probabilities and convert to numpy
363
+ probs_np = expert_probs.detach().cpu().numpy()
364
+
365
+ # Create a nicer figure with a modern style
366
+ plt.style.use('ggplot')
367
+ fig, ax = plt.figure(figsize=(12, 8), dpi=120)
368
+ ax = plt.subplot(111)
369
+
370
+ # Create color gradient based on probability values
371
+ colors = plt.cm.viridis(probs_np)
372
+
373
+ # Plot bars with more attractive styling
374
+ bars = plt.bar(range(len(probs_np)), probs_np, color=colors, width=0.6,
375
+ edgecolor='black', linewidth=0.5, alpha=0.85)
376
+
377
+ # Add value annotations on top of each bar
378
+ for i, bar in enumerate(bars):
379
+ height = bar.get_height()
380
+ plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
381
+ f'{height:.3f}', ha='center', va='bottom', fontsize=9,
382
+ rotation=0, fontweight='bold')
383
+
384
+ # Add expert names to x-axis
385
  if hasattr(self, 'experts') and isinstance(getattr(self, 'experts', None), dict):
386
  # If experts are stored in a dictionary with meaningful keys
387
  expert_names = list(self.experts.keys())
388
+ plt.xticks(range(len(probs_np)), expert_names, rotation=45, ha='right')
389
  else:
390
  # Default numbering if expert names aren't available
391
+ plt.xticks(range(len(probs_np)), [f'Expert {i}' for i in range(len(probs_np))])
392
+
393
+ # Add grid for better readability
394
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
395
+
396
+ # Add timestamp to title
397
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
398
+ plt.title(f'Expert Selection Probabilities\n{timestamp}', fontsize=14, fontweight='bold')
399
+ plt.xlabel('Expert Models', fontsize=12)
400
+ plt.ylabel('Selection Probability', fontsize=12)
401
+
402
+ # Highlight the most probable expert
403
+ max_idx = np.argmax(probs_np)
404
+ bars[max_idx].set_color('orangered')
405
+ bars[max_idx].set_edgecolor('black')
406
+ bars[max_idx].set_linewidth(1.5)
407
+
408
+ # Add stats in a text box
409
+ textstr = f'Max: {probs_np.max():.4f} (Expert {max_idx})\n'
410
+ textstr += f'Min: {probs_np.min():.4f}\n'
411
+ textstr += f'Mean: {probs_np.mean():.4f}'
412
+ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
413
+ plt.text(0.02, 0.97, textstr, transform=plt.gca().transAxes, fontsize=10,
414
+ verticalalignment='top', bbox=props)
415
+
416
  plt.tight_layout()
417
+ plt.savefig('expert_probabilities.png', bbox_inches='tight')
418
  plt.close()
419
  print(expert_probs.shape)
420