flpelerin commited on
Commit
8f0b92e
·
1 Parent(s): 4d5a396

Update 2 files

Browse files

- /model.py
- /trainer.py

Files changed (2) hide show
  1. model.py +17 -5
  2. trainer.py +26 -5
model.py CHANGED
@@ -9,16 +9,28 @@ class Model:
9
  self.__dict__ = dict(config.__dict__)
10
 
11
  self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
12
- self.Log()
13
 
14
 
15
- def Log(self):
16
  model_size, rounded_model_size = GetNumParams(self.model)
17
  print(f"Model has {model_size} ({rounded_model_size}) parameters")
18
  print(f"Model's embedding size is {self.params.vocab_size}")
19
 
20
 
21
- def AutoRegressiveLossFunction(self, input_ids, labels=None, criterion=None):
 
 
 
 
 
 
 
 
 
 
 
 
22
  lm_logits = self.model(input_ids).logits
23
 
24
  labels = input_ids.to(self.model.device)
@@ -31,7 +43,7 @@ class Model:
31
  return lm_loss
32
 
33
 
34
- def GenerateText(self, tokenizer, seed_text, num_predict):
35
  max_len = num_predict + len(seed_text)
36
 
37
  with torch.no_grad():
@@ -45,5 +57,5 @@ class Model:
45
 
46
 
47
  @staticmethod
48
- def SavePretrained(self, path='./'):
49
  self.model.save_pretrained(path)
 
9
  self.__dict__ = dict(config.__dict__)
10
 
11
  self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
12
+ self.log()
13
 
14
 
15
+ def log(self):
16
  model_size, rounded_model_size = GetNumParams(self.model)
17
  print(f"Model has {model_size} ({rounded_model_size}) parameters")
18
  print(f"Model's embedding size is {self.params.vocab_size}")
19
 
20
 
21
+ def parameters():
22
+ return self.model.parameters()
23
+
24
+
25
+ def unfreeze():
26
+ self.model.train()
27
+
28
+
29
+ def freeze():
30
+ self.model.eval()
31
+
32
+
33
+ def compute_loss(self, input_ids, labels=None, criterion=None):
34
  lm_logits = self.model(input_ids).logits
35
 
36
  labels = input_ids.to(self.model.device)
 
43
  return lm_loss
44
 
45
 
46
+ def generate_text(self, tokenizer, seed_text, num_predict):
47
  max_len = num_predict + len(seed_text)
48
 
49
  with torch.no_grad():
 
57
 
58
 
59
  @staticmethod
60
+ def save_pretrained(self, path='./'):
61
  self.model.save_pretrained(path)
trainer.py CHANGED
@@ -6,13 +6,34 @@ from model import Model
6
 
7
  class Trainer:
8
  def __init__(self, config: Config):
9
- self.__dict__ = dict(config.__dict__)
10
-
11
- #self.wandb = Wandb(config.wandb)
12
 
13
 
 
14
  self.model = Model(config.model)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def train(self, dataset): # TODO: Implement
18
- pass
 
6
 
7
  class Trainer:
8
  def __init__(self, config: Config):
9
+ self.__dict__ = dict(config.trainer.__dict__)
 
 
10
 
11
 
12
+ #self.wandb = Wandb(config.wandb)
13
  self.model = Model(config.model)
14
 
15
+ self.optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
16
+
17
+
18
+ def log(self, loss: float):
19
+ print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
20
+
21
+
22
+
23
+ def train(self, batches):
24
+ #pass
25
+
26
+ model.unfreeze()
27
+
28
+ for self.epoch in range(self.num_epochs):
29
+ for self.batch in range(self.num_batches):
30
+ ids = batches[batch]
31
+
32
+ loss = model.compute_loss(ids)
33
+
34
+ self.optimizer.zero_grad()
35
+ loss.backward()
36
+ self.optimizer.step()
37
 
38
+ self.log(loss.item())
39
+ #Train.LogStep(infer_config, log_config, epoch, num_epochs, batch, num_batches, loss)