robinfaro commited on
Commit
b5d4337
·
verified ·
1 Parent(s): 7c83237

Adding modeling.py file

Browse files
Files changed (1) hide show
  1. modeling.py +4 -2
modeling.py CHANGED
@@ -8,6 +8,7 @@ from .gpt import GPTBase
8
  from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss
9
  from typing import Optional, List
10
  from dataclasses import dataclass
 
11
 
12
 
13
  @dataclass
@@ -37,8 +38,8 @@ class MoLM(PreTrainedModel):
37
 
38
  # Number of experts
39
  self.num_experts = config.num_experts
40
- #print(f"Number of experts: {self.num_experts}")
41
- #print(f"Expert configurations: {config.expert_configs}")
42
  assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
43
  self.expert_configs = config.expert_configs
44
 
@@ -52,6 +53,7 @@ class MoLM(PreTrainedModel):
52
 
53
  # Initialize experts using the provided configurations
54
  self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
 
55
 
56
  # Load pre-trained weights if provided
57
  if expert_weights is not None:
 
8
  from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss
9
  from typing import Optional, List
10
  from dataclasses import dataclass
11
+ import tiktoken
12
 
13
 
14
  @dataclass
 
38
 
39
  # Number of experts
40
  self.num_experts = config.num_experts
41
+ print(f"Number of experts: {self.num_experts}")
42
+ print(f"Expert configurations: {config.expert_configs}")
43
  assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
44
  self.expert_configs = config.expert_configs
45
 
 
53
 
54
  # Initialize experts using the provided configurations
55
  self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
56
+ self.tokenizer = tiktoken.get_encoding("gpt2")
57
 
58
  # Load pre-trained weights if provided
59
  if expert_weights is not None: