| | --- |
| | license: apache-2.0 |
| | metrics: |
| | - perplexity |
| | pipeline_tag: text-generation |
| | --- |
| | |
| | Train in 30B Byte. Mode size 353M. Table 2 in [MambaByte](https://arxiv.org/abs/2401.13660) |
| |
|
| | To use |
| |
|
| | ``` |
| | import torch |
| | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel |
| | |
| | import numpy as np |
| | |
| | model=MambaLMHeadModel.from_pretrained("JunxiongWang/MambaByte_Code", device='cuda', dtype=torch.float32) |
| | |
| | text = "import torch" |
| | text_byte = np.frombuffer(text.encode('utf-8'), dtype=np.uint8) |
| | input_ids = torch.from_numpy(text_byte[None, :]).long().cuda() |
| | |
| | sample = model.generate( |
| | input_ids=input_ids, |
| | max_length=2048, |
| | cg=True, |
| | return_dict_in_generate=True, |
| | output_scores=True, |
| | enable_timing=True, |
| | temperature=1, |
| | top_k=256, |
| | top_p=0.9, |
| | ) |
| | |
| | print(bytes(sample.sequences[0].tolist()).decode('utf-8')) |
| | ``` |
| |
|
| | Output |
| |
|
| | ``` |
| | import torch |
| | import numpy as np |
| | import torch.nn.functional as F |
| | from torch.autograd import Variable |
| | |
| | from networkx.states import TransientState |
| | |
| | def extract_data(num_epochs, epochs, is_last_epoch): |
| | |
| | def get_data(num_features, num_classes): |
| | data_features = num_features |
| | data_classes = num_classes |
| | data_labels = num_epochs |
| | |
| | if num_features == 0 or num_classes == 0: |
| | return data_features, data_classes |
| | if is_last_epoch: |
| | data_features = num_features |
| | data_classes = num_classes |
| | data_labels = num_epochs |
| | return data_features, data_classes |
| | |
| | data_features, data_classes = get_data(num_epochs, epochs, is_last_epoch) |
| | data_labels = num_epochs * 2 |
| | return data_features, data_classes |
| | |
| | |
| | class NumChannel: |
| | def __init__(self, x, y, dx=1, dy=1, idx=1, data_size=2, epoch=None): |
| | """idx is the channel index with data feature in the first epoch. |
| | x is the channel of the input data. |
| | y is the element of the input data. |
| | dx is the element of the data feature of the input data. |
| | data_size is the size of the element of the data. |
| | epoch is the channel of the element of the data. |
| | """ |
| | self.x = x |
| | self.y = y |
| | self.dx = dx |
| | self.data_size = data_size |
| | self.epoch = epoch |
| | self.reference_count = 0 |
| | self.data_features = {} |
| | self.data_classes = {} |
| | |
| | self._initialize() |
| | if idx is not None: |
| | self._start_time = time.time() |
| | |
| | def _initialize(self): |
| | """idx is the channel index with data feature in the first epoch. |
| | x is the channel of the input data. |
| | y is the element of the input data. |
| | dx is the element of the data feature of the input data. |
| | data_size is the size of the element of the data. |
| | epoch is the channel of the element of the data. |
| | """ |
| | self.idx = idx |
| | ``` |