flpelerin commited on
Commit
be26ed0
·
1 Parent(s): b0f1591

Update 2 files

Browse files

- /model.py
- /trainer.py

Files changed (2) hide show
  1. model.py +2 -2
  2. trainer.py +5 -5
model.py CHANGED
@@ -8,9 +8,9 @@ class Model:
8
  def __init__(self, config: Config):
9
  self.__dict__ = dict(config.__dict__)
10
 
11
- print(f"params: {params}")
12
 
13
- #self.model = MambaLMHeadModel(MambaConfig(params)).to(GetDevice())
14
 
15
 
16
  def AutoRegressiveLossFunction(self, input_ids, labels=None, criterion=None):
 
8
  def __init__(self, config: Config):
9
  self.__dict__ = dict(config.__dict__)
10
 
11
+ #print(f"params: {params}")
12
 
13
+ self.model = MambaLMHeadModel(MambaConfig(self.params)).to(GetDevice())
14
 
15
 
16
  def AutoRegressiveLossFunction(self, input_ids, labels=None, criterion=None):
trainer.py CHANGED
@@ -8,12 +8,12 @@ class Trainer:
8
  def __init__(self, config: Config):
9
  self.__dict__ = dict(config.__dict__)
10
 
11
- print(f"self.dict: {self.__dict__}")
12
- print(f"locals: {locals()}")
13
 
14
  #self.wandb = Wandb(config.wandb)
15
 
16
- print(f"model config: {model}")
17
- print(f"config.params: {config.model.params}")
18
- #self.model = Model(config.model)
19
 
 
8
  def __init__(self, config: Config):
9
  self.__dict__ = dict(config.__dict__)
10
 
11
+ #print(f"self.dict: {self.__dict__}")
12
+ #print(f"locals: {locals()}")
13
 
14
  #self.wandb = Wandb(config.wandb)
15
 
16
+ #print(f"model config: {self.model}")
17
+ #print(f"config.params: {self.model.params}")
18
+ self.model = Model(config.model)
19