razmars commited on
Commit
4b63879
·
verified ·
1 Parent(s): 0898dee

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +0 -68
modeling_super_linear.py CHANGED
@@ -371,71 +371,6 @@ class SparseNoisyMoE(nn.Module):
371
 
372
  load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
373
 
374
- expert_probs = F.softmax(self.gate_outputs, dim=1)
375
- expert_probs = expert_probs[1,:]
376
- # Plot the expert probabilities
377
- import matplotlib.pyplot as plt
378
-
379
- # Get expert probabilities and convert to numpy
380
- probs_np = expert_probs.detach().cpu().numpy()
381
- # Create a nicer figure with a modern style
382
- plt.style.use('ggplot')
383
- plt.figure(figsize=(12, 8), dpi=120)
384
- plt.subplot(111)
385
- ax = plt.subplot(111)
386
-
387
- # Create color gradient based on probability values
388
- colors = plt.cm.viridis(probs_np)
389
-
390
- # Plot bars with more attractive styling
391
- bars = plt.bar(range(len(probs_np)), probs_np, color=colors, width=0.6,
392
- edgecolor='black', linewidth=0.5, alpha=0.85)
393
- # Add value annotations on top of each bar
394
- for i, bar in enumerate(bars):
395
- height = bar.get_height()
396
- plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
397
- f'{height:.3f}', ha='center', va='bottom', fontsize=9,
398
- rotation=0, fontweight='bold')
399
-
400
- # Add expert names to x-axis
401
- if hasattr(self, 'experts') and isinstance(getattr(self, 'experts', None), dict):
402
- # If experts are stored in a dictionary with meaningful keys
403
- expert_names = list(self.experts.keys())
404
- plt.xticks(range(len(probs_np)), expert_names, rotation=45, ha='right')
405
- else:
406
- # Default numbering if expert names aren't available
407
- plt.xticks(range(len(probs_np)), [f'Expert {i}' for i in range(len(probs_np))])
408
-
409
- # Add grid for better readability
410
- plt.grid(axis='y', linestyle='--', alpha=0.7)
411
-
412
- # Add timestamp to title
413
- timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
414
- plt.title(f'Expert Selection Probabilities\n{timestamp}', fontsize=14, fontweight='bold')
415
- plt.xlabel('Expert Models', fontsize=12)
416
- plt.ylabel('Selection Probability', fontsize=12)
417
-
418
- # Highlight the most probable expert
419
- max_idx = np.argmax(probs_np)
420
- bars[max_idx].set_color('orangered')
421
- bars[max_idx].set_edgecolor('black')
422
- bars[max_idx].set_linewidth(1.5)
423
-
424
- # Add stats in a text box
425
- textstr = f'Max: {probs_np.max():.4f} (Expert {max_idx})\n'
426
- textstr += f'Min: {probs_np.min():.4f}\n'
427
- textstr += f'Mean: {probs_np.mean():.4f}'
428
- props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
429
- plt.text(0.02, 0.97, textstr, transform=plt.gca().transAxes, fontsize=10,
430
- verticalalignment='top', bbox=props)
431
-
432
- plt.tight_layout()
433
- plt.savefig(F"expert_probabilities_{self.i}.png", bbox_inches='tight')
434
- self.i+=1
435
- plt.close()
436
- print(expert_probs.shape)
437
-
438
-
439
  if get_prob:
440
  expert_probs = F.softmax(self.gate_outputs, dim=1)
441
  print(expert_probs.shape)
@@ -684,9 +619,6 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
684
  # 3. Inverse FFT to the shorter grid
685
  y = torch.fft.irfft(X_crop, n=target_len, dim=1)
686
 
687
- # 4. Renormalise amplitudes:
688
- # irfft divides by `target_len`, whereas the forward rfft used length `L`.
689
- # Multiply by (target_len / L) so DC and low-freq amplitudes match input.
690
 
691
  return y
692
 
 
371
 
372
  load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  if get_prob:
375
  expert_probs = F.softmax(self.gate_outputs, dim=1)
376
  print(expert_probs.shape)
 
619
  # 3. Inverse FFT to the shorter grid
620
  y = torch.fft.irfft(X_crop, n=target_len, dim=1)
621
 
 
 
 
622
 
623
  return y
624