Update modeling_super_linear.py
Browse files- modeling_super_linear.py +0 -68
modeling_super_linear.py
CHANGED
|
@@ -371,71 +371,6 @@ class SparseNoisyMoE(nn.Module):
|
|
| 371 |
|
| 372 |
load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
|
| 373 |
|
| 374 |
-
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
| 375 |
-
expert_probs = expert_probs[1,:]
|
| 376 |
-
# Plot the expert probabilities
|
| 377 |
-
import matplotlib.pyplot as plt
|
| 378 |
-
|
| 379 |
-
# Get expert probabilities and convert to numpy
|
| 380 |
-
probs_np = expert_probs.detach().cpu().numpy()
|
| 381 |
-
# Create a nicer figure with a modern style
|
| 382 |
-
plt.style.use('ggplot')
|
| 383 |
-
plt.figure(figsize=(12, 8), dpi=120)
|
| 384 |
-
plt.subplot(111)
|
| 385 |
-
ax = plt.subplot(111)
|
| 386 |
-
|
| 387 |
-
# Create color gradient based on probability values
|
| 388 |
-
colors = plt.cm.viridis(probs_np)
|
| 389 |
-
|
| 390 |
-
# Plot bars with more attractive styling
|
| 391 |
-
bars = plt.bar(range(len(probs_np)), probs_np, color=colors, width=0.6,
|
| 392 |
-
edgecolor='black', linewidth=0.5, alpha=0.85)
|
| 393 |
-
# Add value annotations on top of each bar
|
| 394 |
-
for i, bar in enumerate(bars):
|
| 395 |
-
height = bar.get_height()
|
| 396 |
-
plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
| 397 |
-
f'{height:.3f}', ha='center', va='bottom', fontsize=9,
|
| 398 |
-
rotation=0, fontweight='bold')
|
| 399 |
-
|
| 400 |
-
# Add expert names to x-axis
|
| 401 |
-
if hasattr(self, 'experts') and isinstance(getattr(self, 'experts', None), dict):
|
| 402 |
-
# If experts are stored in a dictionary with meaningful keys
|
| 403 |
-
expert_names = list(self.experts.keys())
|
| 404 |
-
plt.xticks(range(len(probs_np)), expert_names, rotation=45, ha='right')
|
| 405 |
-
else:
|
| 406 |
-
# Default numbering if expert names aren't available
|
| 407 |
-
plt.xticks(range(len(probs_np)), [f'Expert {i}' for i in range(len(probs_np))])
|
| 408 |
-
|
| 409 |
-
# Add grid for better readability
|
| 410 |
-
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
| 411 |
-
|
| 412 |
-
# Add timestamp to title
|
| 413 |
-
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 414 |
-
plt.title(f'Expert Selection Probabilities\n{timestamp}', fontsize=14, fontweight='bold')
|
| 415 |
-
plt.xlabel('Expert Models', fontsize=12)
|
| 416 |
-
plt.ylabel('Selection Probability', fontsize=12)
|
| 417 |
-
|
| 418 |
-
# Highlight the most probable expert
|
| 419 |
-
max_idx = np.argmax(probs_np)
|
| 420 |
-
bars[max_idx].set_color('orangered')
|
| 421 |
-
bars[max_idx].set_edgecolor('black')
|
| 422 |
-
bars[max_idx].set_linewidth(1.5)
|
| 423 |
-
|
| 424 |
-
# Add stats in a text box
|
| 425 |
-
textstr = f'Max: {probs_np.max():.4f} (Expert {max_idx})\n'
|
| 426 |
-
textstr += f'Min: {probs_np.min():.4f}\n'
|
| 427 |
-
textstr += f'Mean: {probs_np.mean():.4f}'
|
| 428 |
-
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
| 429 |
-
plt.text(0.02, 0.97, textstr, transform=plt.gca().transAxes, fontsize=10,
|
| 430 |
-
verticalalignment='top', bbox=props)
|
| 431 |
-
|
| 432 |
-
plt.tight_layout()
|
| 433 |
-
plt.savefig(F"expert_probabilities_{self.i}.png", bbox_inches='tight')
|
| 434 |
-
self.i+=1
|
| 435 |
-
plt.close()
|
| 436 |
-
print(expert_probs.shape)
|
| 437 |
-
|
| 438 |
-
|
| 439 |
if get_prob:
|
| 440 |
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
| 441 |
print(expert_probs.shape)
|
|
@@ -684,9 +619,6 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 684 |
# 3. Inverse FFT to the shorter grid
|
| 685 |
y = torch.fft.irfft(X_crop, n=target_len, dim=1)
|
| 686 |
|
| 687 |
-
# 4. Renormalise amplitudes:
|
| 688 |
-
# irfft divides by `target_len`, whereas the forward rfft used length `L`.
|
| 689 |
-
# Multiply by (target_len / L) so DC and low-freq amplitudes match input.
|
| 690 |
|
| 691 |
return y
|
| 692 |
|
|
|
|
| 371 |
|
| 372 |
load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
|
| 373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
if get_prob:
|
| 375 |
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
| 376 |
print(expert_probs.shape)
|
|
|
|
| 619 |
# 3. Inverse FFT to the shorter grid
|
| 620 |
y = torch.fft.irfft(X_crop, n=target_len, dim=1)
|
| 621 |
|
|
|
|
|
|
|
|
|
|
| 622 |
|
| 623 |
return y
|
| 624 |
|