| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class BaseExpert(nn.Module): | |
| def __init__(self, dim: int, expert_dim: int, role: str, specialization: str): | |
| super().__init__() | |
| self.role = role | |
| self.specialization = specialization | |
| self.w1 = nn.Linear(dim, expert_dim, bias=False) | |
| self.w2 = nn.Linear(expert_dim, dim, bias=False) | |
| self.w3 = nn.Linear(dim, expert_dim, bias=False) | |
| self.role_bias = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| gate = F.silu(self.w1(x)) | |
| value = self.w3(x) | |
| hidden = gate * value | |
| output = self.w2(hidden) | |
| output = output + self.role_bias * output.mean() | |
| return output | |
| def get_role(self) -> str: | |
| return self.role | |
| def get_specialization(self) -> str: | |
| return self.specialization | |