| | ---
|
| | license: other
|
| | ---
|
| |
|
| | # xLSTM-7B
|
| | This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the `xlstm-jax` framework.
|
| |
|
| |
|
| | ## How to use it
|
| | First, install `xlstm`, which now uses the `mlstm_kernels` package for triton kernels:
|
| |
|
| | ```bash
|
| | pip install xlstm
|
| | pip install mlstm_kernels
|
| | ```
|
| |
|
| | For now, install the transformers repositiory fork from NX-AI (until it is merged):
|
| | ```bash
|
| | pip install 'transformers @ git+ssh://git@github.com/NX-AI/transformers.git@integrate_xlstm'
|
| | ```
|
| |
|
| | Use this model as:
|
| | ```python
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer
|
| |
|
| | xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")
|
| |
|
| | # this is a fork of EleutherAI/gpt-neox-20b
|
| | tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")
|
| |
|
| | tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda")
|
| |
|
| | out = xlstm.generate(tokens, max_new_tokens=20)
|
| |
|
| | print(tokenizer.decode(out[0]))
|
| | ```
|
| |
|
| | ## Speed results
|
| | Generation Speed using `torch.cuda.graph` and `torch.compile` optimizations on one NVIDIA H100:
|
| | 
|
| |
|
| | ## Performance
|
| | 
|
| |
|
| | Using HuggingFace's `lm_eval`:
|
| |
|
| | | BBH | MMLU-Pro | Math | MUSR | GPQA | IfEval |
|
| | |-------|----------|--------|------|------|--------|
|
| | | 0.381 | 0.242 | 0.036 | 0.379|0.280 | 0.244 |
|
| |
|
| | Using HuggingFace's `lighteval` in the Leaderboard-v1 settings:
|
| |
|
| | |Arc-Challenge (25-shot) |MMLU (5-shot) |Hellaswag (10-shot)|Winogrande (5-shot) |TruthfulQA (0-shot) |GSM8k (5-shot) |OpenbookQA (5-shot) | PiQA (5-shot)|
|
| | |------------------------|--------------|-------------------|--------------------|--------------------|---------------|--------------------|--------------|
|
| | | 0.584 |0.589 | 0.710 |0.742 | 0.420 | 0.004 | 0.443 | 0.817 |
|
| |
|
| | ## License
|
| | NXAI Community License (see `LICENSE` file)
|
| |
|