razmars commited on
Commit
82634ad
·
verified ·
1 Parent(s): ef55773

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +30 -30
modeling_super_linear.py CHANGED
@@ -390,37 +390,37 @@ class superLinear(nn.Module):
390
  cycle = cp.split("/")
391
 
392
  self.experts = {}
393
- # if self.freq_experts is not None:
394
- # for expert_freq in self.freq_experts:
395
- # if expert_freq == "naive" or expert_freq == "Naive":
396
- # self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
397
- # elif expert_freq == "mean" or expert_freq == "Mean":
398
- # self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
399
- # else:
400
- # self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
401
- # if configs.load_linear:
402
- # cycle = self.map_to_cycle(expert_freq)
403
- # cycle_str = f'cycle_{cycle}/'
404
- # cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
405
- # if len(cycle_checkpoint_path) > 0:
406
- # print()
407
- # print(cycle_str)
408
- # cycle_checkpoint_path = cycle_checkpoint_path[0]
409
- # #print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
410
- # print(cycle_checkpoint_path)
411
- # self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
412
- # else:
413
- # print(f"Checkpoint for {cycle_str} not found in {path}")
414
- # raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
415
- # if configs.freeze_experts:
416
- # for param in self.experts[expert_freq].parameters():
417
- # param.requires_grad = False
418
 
419
- # self.n_experts = len(self.experts)
420
- # else:
421
- # for i in range(self.n_experts):
422
- # print(f"creating expert {i}")
423
- # self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
424
 
425
 
426
  if configs.misc_moe>0:
 
390
  cycle = cp.split("/")
391
 
392
  self.experts = {}
393
+ if self.freq_experts is not None:
394
+ for expert_freq in self.freq_experts:
395
+ if expert_freq == "naive" or expert_freq == "Naive":
396
+ self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
397
+ elif expert_freq == "mean" or expert_freq == "Mean":
398
+ self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
399
+ else:
400
+ self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
401
+ # if configs.load_linear:
402
+ # cycle = self.map_to_cycle(expert_freq)
403
+ # cycle_str = f'cycle_{cycle}/'
404
+ # cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
405
+ # if len(cycle_checkpoint_path) > 0:
406
+ # print()
407
+ # print(cycle_str)
408
+ # cycle_checkpoint_path = cycle_checkpoint_path[0]
409
+ # #print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
410
+ # print(cycle_checkpoint_path)
411
+ # self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
412
+ # else:
413
+ # print(f"Checkpoint for {cycle_str} not found in {path}")
414
+ # raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
415
+ # if configs.freeze_experts:
416
+ # for param in self.experts[expert_freq].parameters():
417
+ # param.requires_grad = False
418
 
419
+ self.n_experts = len(self.experts)
420
+ else:
421
+ for i in range(self.n_experts):
422
+ print(f"creating expert {i}")
423
+ self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
424
 
425
 
426
  if configs.misc_moe>0: