flpelerin commited on
Commit
b0e0141
·
1 Parent(s): 9224c6a

Update 2 files

Browse files

- /util.py
- /model.py

Files changed (2) hide show
  1. model.py +26 -5
  2. util.py +21 -0
model.py CHANGED
@@ -1,26 +1,47 @@
1
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
2
  from mamba_ssm.models.config_mamba import MambaConfig
3
 
4
- from util import Config, GetDevice
5
 
6
 
7
  class Model:
8
  def __init__(self, config: Config):
9
  self.__dict__ = dict(config.__dict__)
10
 
11
- #print(f"params: {params}")
12
-
13
  self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
14
 
15
 
 
 
 
 
 
16
  def AutoRegressiveLossFunction(self, input_ids, labels=None, criterion=None):
17
  lm_logits = self.model(input_ids).logits
18
 
19
- labels = input_ids.to("cuda")
20
  shift_logits = lm_logits[:, :-1, :].contiguous()
21
  labels = labels[:, 1:].contiguous()
22
 
23
  loss_fct = criterion or torch.nn.CrossEntropyLoss()
24
  lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
25
 
26
- return lm_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
2
  from mamba_ssm.models.config_mamba import MambaConfig
3
 
4
+ from util import Config, GetDevice, GetNumParams
5
 
6
 
7
  class Model:
8
  def __init__(self, config: Config):
9
  self.__dict__ = dict(config.__dict__)
10
 
 
 
11
  self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
12
 
13
 
14
+ def Log(self):
15
+ model_size, rounded_model_size = GetNumParams(self.model)
16
+ print(f"Model has {model_size} ({rounded_model_size}) parameters")
17
+
18
+
19
  def AutoRegressiveLossFunction(self, input_ids, labels=None, criterion=None):
20
  lm_logits = self.model(input_ids).logits
21
 
22
+ labels = input_ids.to(self.model.device)
23
  shift_logits = lm_logits[:, :-1, :].contiguous()
24
  labels = labels[:, 1:].contiguous()
25
 
26
  loss_fct = criterion or torch.nn.CrossEntropyLoss()
27
  lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
28
 
29
+ return lm_loss
30
+
31
+
32
+ def GenerateText(self, tokenizer, seed_text, num_predict):
33
+ max_len = num_predict + len(seed_text)
34
+
35
+ with torch.no_grad():
36
+ encoded_ids = tokenizer.encode(seed_text)
37
+ input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(self.model.device)
38
+ output = self.model.generate(input_ids, max_length=max_len)
39
+
40
+ logits = output[0].tolist()
41
+ text = tokenizer.decode(logits)
42
+ return text
43
+
44
+
45
+ @staticmethod
46
+ def SavePretrained(self, path='./'):
47
+ self.model.save_pretrained(path)
util.py CHANGED
@@ -20,6 +20,27 @@ def RandomCode():
20
  return code
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class Config:
24
  def __init__(self, data):
25
  for key, value in data.items():
 
20
  return code
21
 
22
 
23
+ def RoundNumber(number):
24
+ suffixes = ['', 'k', 'm', 'b']
25
+
26
+ if number < 1000:
27
+ return str(number)
28
+
29
+ magnitude = 0
30
+ while abs(number) >= 1000:
31
+ magnitude += 1
32
+ number /= 1000.0
33
+
34
+ return '{:.0f}{}'.format(number, suffixes[magnitude])
35
+
36
+
37
+ def GetNumParams(model):
38
+ size = sum(p.numel() for p in model.parameters())
39
+ rounded_size = RoundNumber(size)
40
+
41
+ return size, rounded_size
42
+
43
+
44
  class Config:
45
  def __init__(self, data):
46
  for key, value in data.items():