flpelerin commited on
Commit
7b037bb
·
1 Parent(s): 920fe2d

Update 2 files

Browse files

- /model.py
- /trainer.py

Files changed (2) hide show
  1. model.py +3 -1
  2. trainer.py +1 -1
model.py CHANGED
@@ -8,7 +8,9 @@ class Model:
8
  def __init__(self, config: Config):
9
  self.__dict__ = dict(config.__dict__)
10
 
11
- self.model = MambaLMHeadModel(MambaConfig(params)).to(GetDevice())
 
 
12
 
13
 
14
  def AutoRegressiveLossFunction(self, input_ids, labels=None, criterion=None):
 
8
  def __init__(self, config: Config):
9
  self.__dict__ = dict(config.__dict__)
10
 
11
+ print(f"params: {params}")
12
+
13
+ #self.model = MambaLMHeadModel(MambaConfig(params)).to(GetDevice())
14
 
15
 
16
  def AutoRegressiveLossFunction(self, input_ids, labels=None, criterion=None):
trainer.py CHANGED
@@ -10,7 +10,7 @@ class Trainer:
10
 
11
  #self.wandb = Wandb(config.wandb)
12
 
13
- print(f"model config: {config.model}")
14
  print(f"config.params: {config.model.params}")
15
  #self.model = Model(config.model)
16
 
 
10
 
11
  #self.wandb = Wandb(config.wandb)
12
 
13
+ print(f"model config: {model}")
14
  print(f"config.params: {config.model.params}")
15
  #self.model = Model(config.model)
16