File size: 1,763 Bytes
4c50abe
 
258e237
 
6f405dd
6d999e4
 
b0e0141
6f405dd
 
 
 
 
 
4c50abe
8f0b92e
6f405dd
 
8f0b92e
b0e0141
 
4d5a396
b0e0141
 
140c0cf
 
 
8f0b92e
 
 
8b5543e
6f405dd
bab1439
6f405dd
 
 
 
 
 
b0e0141
 
 
4c50abe
b0e0141
 
 
4c50abe
6a41ff7
b0e0141
 
 
4c50abe
6a41ff7
b0e0141
4c50abe
b0e0141
8f0b92e
b0e0141
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig



import torch

from util import Config, GetDevice, GetNumParams


class Model:
    def __init__(self, config: Config):
        self.__dict__ = dict(config.__dict__)

        self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
        self.log()


    def log(self):
        model_size, rounded_model_size = GetNumParams(self.model)
        print(f"Model has {model_size} ({rounded_model_size}) parameters")
        print(f"Model's embedding size is {self.params.vocab_size}")


    def parameters(self): return self.model.parameters()
    def unfreeze(self):   self.model.train()
    def freeze(self):     self.model.eval()


    def compute_loss(self, input_ids, labels=None, criterion=None):
        lm_logits = self.model(input_ids).logits

        labels = input_ids.to(GetDevice())
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = criterion or torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss


    def generate_text(self, seed_text, num_predict):
        max_len = num_predict + len(seed_text)

        with torch.no_grad():
            encoded_ids = self.tokenizer.encode(seed_text)
            input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
            output = self.model.generate(input_ids, max_length=max_len)

            logits = output[0].tolist()
            text = self.tokenizer.decode(logits)
            
        return text


    def save_pretrained(self, path='./'):
        self.model.save_pretrained(path)