Adding modeling.py file
Browse files- modeling.py +4 -2
modeling.py
CHANGED
|
@@ -8,6 +8,7 @@ from .gpt import GPTBase
|
|
| 8 |
from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss
|
| 9 |
from typing import Optional, List
|
| 10 |
from dataclasses import dataclass
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@dataclass
|
|
@@ -37,8 +38,8 @@ class MoLM(PreTrainedModel):
|
|
| 37 |
|
| 38 |
# Number of experts
|
| 39 |
self.num_experts = config.num_experts
|
| 40 |
-
|
| 41 |
-
|
| 42 |
assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
|
| 43 |
self.expert_configs = config.expert_configs
|
| 44 |
|
|
@@ -52,6 +53,7 @@ class MoLM(PreTrainedModel):
|
|
| 52 |
|
| 53 |
# Initialize experts using the provided configurations
|
| 54 |
self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
|
|
|
|
| 55 |
|
| 56 |
# Load pre-trained weights if provided
|
| 57 |
if expert_weights is not None:
|
|
|
|
| 8 |
from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss
|
| 9 |
from typing import Optional, List
|
| 10 |
from dataclasses import dataclass
|
| 11 |
+
import tiktoken
|
| 12 |
|
| 13 |
|
| 14 |
@dataclass
|
|
|
|
| 38 |
|
| 39 |
# Number of experts
|
| 40 |
self.num_experts = config.num_experts
|
| 41 |
+
print(f"Number of experts: {self.num_experts}")
|
| 42 |
+
print(f"Expert configurations: {config.expert_configs}")
|
| 43 |
assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
|
| 44 |
self.expert_configs = config.expert_configs
|
| 45 |
|
|
|
|
| 53 |
|
| 54 |
# Initialize experts using the provided configurations
|
| 55 |
self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
|
| 56 |
+
self.tokenizer = tiktoken.get_encoding("gpt2")
|
| 57 |
|
| 58 |
# Load pre-trained weights if provided
|
| 59 |
if expert_weights is not None:
|