Update meshrouter.py
Browse files- meshrouter.py +26 -2
meshrouter.py
CHANGED
|
@@ -1,3 +1,27 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math
|
| 6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
|
| 7 |
|
| 8 |
+
# Define the Router for dynamic routing
|
| 9 |
+
class MeshRouter(nn.Module):
|
| 10 |
+
def __init__(self, config: MeshConfig):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
|
| 13 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 14 |
+
self.routing_k = config.routing_k
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
# x shape: (batch_size, sequence_length, hidden_size)
|
| 18 |
+
gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
|
| 19 |
+
gate_weights = self.softmax(gate_scores)
|
| 20 |
+
|
| 21 |
+
# Select top-k experts
|
| 22 |
+
topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
|
| 23 |
+
|
| 24 |
+
# Normalize top-k weights
|
| 25 |
+
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)
|
| 26 |
+
|
| 27 |
+
return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)
|