Update modeling_super_linear.py
Browse files- modeling_super_linear.py +50 -12
modeling_super_linear.py
CHANGED
|
@@ -355,28 +355,66 @@ class SparseNoisyMoE(nn.Module):
|
|
| 355 |
|
| 356 |
load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
|
| 357 |
|
| 358 |
-
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
| 359 |
expert_probs = expert_probs[1,:]
|
| 360 |
# Plot the expert probabilities
|
| 361 |
import matplotlib.pyplot as plt
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
plt.
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
if hasattr(self, 'experts') and isinstance(getattr(self, 'experts', None), dict):
|
| 371 |
# If experts are stored in a dictionary with meaningful keys
|
| 372 |
expert_names = list(self.experts.keys())
|
| 373 |
-
plt.xticks(range(len(
|
| 374 |
else:
|
| 375 |
# Default numbering if expert names aren't available
|
| 376 |
-
plt.xticks(range(len(
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
plt.tight_layout()
|
| 379 |
-
plt.savefig('expert_probabilities.png')
|
| 380 |
plt.close()
|
| 381 |
print(expert_probs.shape)
|
| 382 |
|
|
|
|
| 355 |
|
| 356 |
load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
|
| 357 |
|
|
|
|
| 358 |
expert_probs = expert_probs[1,:]
|
| 359 |
# Plot the expert probabilities
|
| 360 |
import matplotlib.pyplot as plt
|
| 361 |
|
| 362 |
+
# Get expert probabilities and convert to numpy
|
| 363 |
+
probs_np = expert_probs.detach().cpu().numpy()
|
| 364 |
+
|
| 365 |
+
# Create a nicer figure with a modern style
|
| 366 |
+
plt.style.use('ggplot')
|
| 367 |
+
fig, ax = plt.figure(figsize=(12, 8), dpi=120)
|
| 368 |
+
ax = plt.subplot(111)
|
| 369 |
+
|
| 370 |
+
# Create color gradient based on probability values
|
| 371 |
+
colors = plt.cm.viridis(probs_np)
|
| 372 |
+
|
| 373 |
+
# Plot bars with more attractive styling
|
| 374 |
+
bars = plt.bar(range(len(probs_np)), probs_np, color=colors, width=0.6,
|
| 375 |
+
edgecolor='black', linewidth=0.5, alpha=0.85)
|
| 376 |
+
|
| 377 |
+
# Add value annotations on top of each bar
|
| 378 |
+
for i, bar in enumerate(bars):
|
| 379 |
+
height = bar.get_height()
|
| 380 |
+
plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
| 381 |
+
f'{height:.3f}', ha='center', va='bottom', fontsize=9,
|
| 382 |
+
rotation=0, fontweight='bold')
|
| 383 |
+
|
| 384 |
+
# Add expert names to x-axis
|
| 385 |
if hasattr(self, 'experts') and isinstance(getattr(self, 'experts', None), dict):
|
| 386 |
# If experts are stored in a dictionary with meaningful keys
|
| 387 |
expert_names = list(self.experts.keys())
|
| 388 |
+
plt.xticks(range(len(probs_np)), expert_names, rotation=45, ha='right')
|
| 389 |
else:
|
| 390 |
# Default numbering if expert names aren't available
|
| 391 |
+
plt.xticks(range(len(probs_np)), [f'Expert {i}' for i in range(len(probs_np))])
|
| 392 |
+
|
| 393 |
+
# Add grid for better readability
|
| 394 |
+
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
| 395 |
+
|
| 396 |
+
# Add timestamp to title
|
| 397 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 398 |
+
plt.title(f'Expert Selection Probabilities\n{timestamp}', fontsize=14, fontweight='bold')
|
| 399 |
+
plt.xlabel('Expert Models', fontsize=12)
|
| 400 |
+
plt.ylabel('Selection Probability', fontsize=12)
|
| 401 |
+
|
| 402 |
+
# Highlight the most probable expert
|
| 403 |
+
max_idx = np.argmax(probs_np)
|
| 404 |
+
bars[max_idx].set_color('orangered')
|
| 405 |
+
bars[max_idx].set_edgecolor('black')
|
| 406 |
+
bars[max_idx].set_linewidth(1.5)
|
| 407 |
+
|
| 408 |
+
# Add stats in a text box
|
| 409 |
+
textstr = f'Max: {probs_np.max():.4f} (Expert {max_idx})\n'
|
| 410 |
+
textstr += f'Min: {probs_np.min():.4f}\n'
|
| 411 |
+
textstr += f'Mean: {probs_np.mean():.4f}'
|
| 412 |
+
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
| 413 |
+
plt.text(0.02, 0.97, textstr, transform=plt.gca().transAxes, fontsize=10,
|
| 414 |
+
verticalalignment='top', bbox=props)
|
| 415 |
+
|
| 416 |
plt.tight_layout()
|
| 417 |
+
plt.savefig('expert_probabilities.png', bbox_inches='tight')
|
| 418 |
plt.close()
|
| 419 |
print(expert_probs.shape)
|
| 420 |
|