flpelerin commited on
Commit
4c50abe
·
1 Parent(s): f2ce311

Update 2 files

Browse files

- /trainer.py
- /model.py

Files changed (2) hide show
  1. model.py +7 -47
  2. trainer.py +2 -2
model.py CHANGED
@@ -1,7 +1,6 @@
1
- #from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
2
- #from mamba_ssm.models.config_mamba import MambaConfig
3
 
4
- from mamba import Mamba, ModelArgs
5
 
6
 
7
  import torch
@@ -13,8 +12,7 @@ class Model:
13
  def __init__(self, config: Config):
14
  self.__dict__ = dict(config.__dict__)
15
 
16
- #self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
17
- self.model = Mamba(ModelArgs(**self.params.__dict__)).to(GetDevice())
18
  self.log()
19
 
20
 
@@ -45,57 +43,19 @@ class Model:
45
 
46
 
47
 
48
- def generate_text(self, model,
49
- tokenizer,
50
- prompt: str,
51
- n_tokens_to_gen: int = 50,
52
- sample: bool = True,
53
- top_k: int = 40):
54
-
55
- model = self.model
56
- model.eval()
57
-
58
- input_ids = tokenizer.encode(prompt)
59
-
60
- for token_n in range(n_tokens_to_gen):
61
- with torch.no_grad():
62
- indices_to_input = input_ids
63
- next_token_logits = model(indices_to_input)[:, -1]
64
-
65
- probs = F.softmax(next_token_logits, dim=-1)
66
- (batch, vocab_size) = probs.shape
67
-
68
- if top_k is not None:
69
- (values, indices) = torch.topk(probs, k=top_k)
70
- probs[probs < values[:, -1, None]] = 0
71
- probs = probs / probs.sum(axis=1, keepdims=True)
72
-
73
- if sample:
74
- next_indices = torch.multinomial(probs, num_samples=1)
75
- else:
76
- next_indices = torch.argmax(probs, dim=-1)[:, None]
77
-
78
- input_ids = torch.cat([input_ids, next_indices], dim=1)
79
-
80
- output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
81
-
82
- return output_completions
83
-
84
-
85
- """
86
- def generate_text(self, tokenizer, seed_text, num_predict):
87
  max_len = num_predict + len(seed_text)
88
 
89
  with torch.no_grad():
90
- encoded_ids = tokenizer.encode(seed_text)
91
  input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
92
  output = self.model.generate(input_ids, max_length=max_len)
93
 
94
  logits = output[0].tolist()
95
- text = tokenizer.decode(logits)
96
 
97
  return text
98
- """
99
 
100
 
101
  @staticmethod
 
1
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
2
+ from mamba_ssm.models.config_mamba import MambaConfig
3
 
 
4
 
5
 
6
  import torch
 
12
  def __init__(self, config: Config):
13
  self.__dict__ = dict(config.__dict__)
14
 
15
+ self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
 
16
  self.log()
17
 
18
 
 
43
 
44
 
45
 
46
+ def generate_text(self, seed_text, num_predict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  max_len = num_predict + len(seed_text)
48
 
49
  with torch.no_grad():
50
+ encoded_ids = self.tokenizer.encode(seed_text)
51
  input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
52
  output = self.model.generate(input_ids, max_length=max_len)
53
 
54
  logits = output[0].tolist()
55
+ text = self.tokenizer.decode(logits)
56
 
57
  return text
58
+
59
 
60
 
61
  @staticmethod
trainer.py CHANGED
@@ -15,8 +15,8 @@ class Trainer:
15
  args = {'epoch': self.epoch, 'batch': self.batch, 'loss': loss}
16
  self.wandb(args)
17
 
18
- #if self.batch % 200 == 0:
19
- # print(f'{self.model.generate_text(self.model.tokenizer, self.inference.seed_text, self.inference.n_predict)}')
20
 
21
 
22
  def train(self, batches):
 
15
  args = {'epoch': self.epoch, 'batch': self.batch, 'loss': loss}
16
  self.wandb(args)
17
 
18
+ if self.batch % 200 == 0:
19
+ print(f'{self.model.generate_text(self.inference.seed_text, self.inference.n_predict)}')
20
 
21
 
22
  def train(self, batches):