File size: 1,763 Bytes
4c50abe 258e237 6f405dd 6d999e4 b0e0141 6f405dd 4c50abe 8f0b92e 6f405dd 8f0b92e b0e0141 4d5a396 b0e0141 140c0cf 8f0b92e 8b5543e 6f405dd bab1439 6f405dd b0e0141 4c50abe b0e0141 4c50abe 6a41ff7 b0e0141 4c50abe 6a41ff7 b0e0141 4c50abe b0e0141 8f0b92e b0e0141 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
import torch
from util import Config, GetDevice, GetNumParams
class Model:
def __init__(self, config: Config):
self.__dict__ = dict(config.__dict__)
self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
self.log()
def log(self):
model_size, rounded_model_size = GetNumParams(self.model)
print(f"Model has {model_size} ({rounded_model_size}) parameters")
print(f"Model's embedding size is {self.params.vocab_size}")
def parameters(self): return self.model.parameters()
def unfreeze(self): self.model.train()
def freeze(self): self.model.eval()
def compute_loss(self, input_ids, labels=None, criterion=None):
lm_logits = self.model(input_ids).logits
labels = input_ids.to(GetDevice())
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = criterion or torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return lm_loss
def generate_text(self, seed_text, num_predict):
max_len = num_predict + len(seed_text)
with torch.no_grad():
encoded_ids = self.tokenizer.encode(seed_text)
input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
output = self.model.generate(input_ids, max_length=max_len)
logits = output[0].tolist()
text = self.tokenizer.decode(logits)
return text
def save_pretrained(self, path='./'):
self.model.save_pretrained(path) |