flpelerin commited on
Commit
6f405dd
·
1 Parent(s): 961a0d5

Update 3 files

Browse files

- /util.py
- /trainer.py
- /model.py

Files changed (3) hide show
  1. model.py +24 -0
  2. trainer.py +3 -0
  3. util.py +5 -0
model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
2
+ from mamba_ssm.models.config_mamba import MambaConfig
3
+
4
+ from util import Config, GetDevice
5
+
6
+
7
+ class Model:
8
+ def __init__(self, config: Config):
9
+ self.__dict__ = dict(config.__dict__)
10
+
11
+ self.model = MambaLMHeadModel(MambaConfig(params)).to(GetDevice())
12
+
13
+
14
+ def AutoRegressiveLossFunction(self, input_ids, labels=None, criterion=None):
15
+ lm_logits = self.model(input_ids).logits
16
+
17
+ labels = input_ids.to("cuda")
18
+ shift_logits = lm_logits[:, :-1, :].contiguous()
19
+ labels = labels[:, 1:].contiguous()
20
+
21
+ loss_fct = criterion or torch.nn.CrossEntropyLoss()
22
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
23
+
24
+ return lm_loss
trainer.py CHANGED
@@ -1,5 +1,7 @@
1
  from util import Config
 
2
  from logger import Wandb
 
3
 
4
 
5
  class Trainer:
@@ -7,4 +9,5 @@ class Trainer:
7
  self.__dict__ = dict(config.__dict__)
8
 
9
  self.wandb = Wandb(config.wandb)
 
10
 
 
1
  from util import Config
2
+
3
  from logger import Wandb
4
+ from model import Model
5
 
6
 
7
  class Trainer:
 
9
  self.__dict__ = dict(config.__dict__)
10
 
11
  self.wandb = Wandb(config.wandb)
12
+ self.model = Model(config.model)
13
 
util.py CHANGED
@@ -3,6 +3,11 @@ import math
3
  import random
4
 
5
 
 
 
 
 
 
6
  def RandomCode():
7
  code = '';
8
  chars = '0123456789abcdef'
 
3
  import random
4
 
5
 
6
+
7
+ def GetDevice():
8
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+
11
  def RandomCode():
12
  code = '';
13
  chars = '0123456789abcdef'