chengyanwu
commited on
Commit
·
ccda2ec
1
Parent(s):
e6dee89
stuff
Browse files- .gitignore +1 -0
- README.md +75 -73
- __pycache__/configuration_olmoe.cpython-311.pyc +0 -0
- __pycache__/modeling_kvlatent.cpython-311.pyc +0 -0
- __pycache__/modeling_latent_attention.cpython-311.pyc +0 -0
- __pycache__/modeling_olmoe.cpython-311.pyc +0 -0
- __pycache__/random.cpython-311.pyc +0 -0
- __pycache__/train.cpython-311.pyc +0 -0
- config.json +20 -24
- generate.py +87 -0
- model-00001-of-00003.safetensors +0 -3
- model-00002-of-00003.safetensors +0 -3
- model-00003-of-00003.safetensors +0 -3
- modeling_olmoe.py +822 -0
- oldcmds.txt +3 -0
- output.txt +0 -0
- randommoe.py +1047 -0
- requirements.txt +9 -0
- shellcommands.txt +3 -0
- train.py +130 -0
- train_olmoe_adapter.py +404 -0
.gitignore
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
upload.py
|
|
|
|
|
|
| 1 |
upload.py
|
| 2 |
+
tester.py
|
README.md
CHANGED
|
@@ -12,86 +12,88 @@ datasets:
|
|
| 12 |
library_name: transformers
|
| 13 |
---
|
| 14 |
|
| 15 |
-
<img alt="OLMoE Logo." src="olmoe-logo.png" width="250px">
|
| 16 |
-
|
| 17 |
|
| 18 |
# Model Summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
This information and more can also be found on the [**OLMoE GitHub repository**](https://github.com/allenai/OLMoE).
|
| 23 |
-
- **Paper**: https://arxiv.org/abs/2409.02060
|
| 24 |
-
- **Pretraining** [Checkpoints](https://hf.co/allenai/OLMoE-1B-7B-0924), [Code](https://github.com/allenai/OLMo/tree/Muennighoff/MoE), [Data](https://huggingface.co/datasets/allenai/OLMoE-mix-0924) and [Logs](https://wandb.ai/ai2-llm/olmoe/reports/OLMoE-1B-7B-0924--Vmlldzo4OTcyMjU3).
|
| 25 |
-
- **SFT (Supervised Fine-Tuning)** [Checkpoints](https://huggingface.co/allenai/OLMoE-1B-7B-0924-SFT), [Code](https://github.com/allenai/open-instruct/tree/olmoe-sft), [Data](https://hf.co/datasets/allenai/tulu-v3.1-mix-preview-4096-OLMoE) and [Logs](https://github.com/allenai/OLMoE/blob/main/logs/olmoe-sft-logs.txt).
|
| 26 |
-
- **DPO/KTO (Direct Preference Optimization/Kahneman-Tversky Optimization)**, [Checkpoints](https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct), [Preference Data](https://hf.co/datasets/allenai/ultrafeedback_binarized_cleaned), [DPO code](https://github.com/allenai/open-instruct/tree/olmoe-sft), [KTO code](https://github.com/Muennighoff/kto/blob/master/kto.py) and [Logs](https://github.com/allenai/OLMoE/blob/main/logs/olmoe-dpo-logs.txt).
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
```python
|
| 33 |
-
from transformers import
|
| 34 |
-
import
|
| 35 |
-
|
| 36 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
-
|
| 38 |
-
# Load different ckpts via passing e.g. `revision=step10000-tokens41B`
|
| 39 |
-
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924").to(DEVICE)
|
| 40 |
-
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
|
| 41 |
-
inputs = tokenizer("Bitcoin is", return_tensors="pt")
|
| 42 |
-
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 43 |
-
out = model.generate(**inputs, max_length=64)
|
| 44 |
-
print(tokenizer.decode(out[0]))
|
| 45 |
-
# > # Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical
|
| 46 |
-
```
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 53 |
```
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
| **LMs with ~1B active parameters** | | | | | | | | |
|
| 65 |
-
| **OLMoE-1B-7B** | **1.3B** | **✅** | **54.1** | **80.0** | **62.1** | **84.2** | **79.8** | **70.2** |
|
| 66 |
-
| DCLM-1B | 1.4B | ✅ | 48.5 | 75.1 | 57.6 | 79.5 | 76.6 | 68.1 |
|
| 67 |
-
| TinyLlama-1B | 1.1B | ✅ | 33.6 | 60.8 | 38.1 | 69.5 | 71.7 | 60.1 |
|
| 68 |
-
| OLMo-1B (0724) | 1.3B | ✅ | 32.1 | 67.5 | 36.4 | 53.5 | 74.0 | 62.9 |
|
| 69 |
-
| Pythia-1B | 1.1B | ✅ | 31.1 | 48.0 | 31.4 | 63.4 | 68.9 | 52.7 |
|
| 70 |
-
| **LMs with ~2-3B active parameters** | | | | | | | | |
|
| 71 |
-
| Qwen1.5-3B-14B | 2.7B | ❌ | **62.4** | 80.0 | **77.4** | **91.6** | **81.0** | 72.3 |
|
| 72 |
-
| Gemma2-3B | 2.6B | ❌ | 53.3 | 74.6 | 67.5 | 84.3 | 78.5 | 71.8 |
|
| 73 |
-
| JetMoE-2B-9B | 2.2B | ❌ | 49.1 | **81.7** | 61.4 | 81.9 | 80.3 | 70.7 |
|
| 74 |
-
| DeepSeek-3B-16B | 2.9B | ❌ | 45.5 | 80.4 | 53.4 | 82.7 | 80.1 | **73.2** |
|
| 75 |
-
| StableLM-2B | 1.6B | ❌ | 40.4 | 70.3 | 50.6 | 75.3 | 75.6 | 65.8 |
|
| 76 |
-
| OpenMoE-3B-9B | 2.9B | ✅ | 27.4 | 44.4 | 29.3 | 50.6 | 63.3 | 51.9 |
|
| 77 |
-
| **LMs with ~7-9B active parameters** | | | | | | | | |
|
| 78 |
-
| Gemma2-9B | 9.2B | ❌ | **70.6** | **87.3** | **89.5** | **95.5** | **86.1** | **78.8** |
|
| 79 |
-
| Llama3.1-8B | 8.0B | ❌ | 66.9 | 81.6 | 79.5 | 91.7 | 81.1 | 76.6 |
|
| 80 |
-
| DCLM-7B | 6.9B | ✅ | 64.4 | 82.3 | 79.8 | 92.3 | 80.1 | 77.3 |
|
| 81 |
-
| Mistral-7B | 7.3B | ❌ | 64.0 | 83.0 | 78.6 | 90.8 | 82.8 | 77.9 |
|
| 82 |
-
| OLMo-7B (0724) | 6.9B | ✅ | 54.9 | 80.5 | 68.0 | 85.7 | 79.3 | 73.2 |
|
| 83 |
-
| Llama2-7B | 6.7B | ❌ | 46.2 | 78.9 | 54.2 | 84.0 | 77.5 | 71.7 |
|
| 84 |
-
|
| 85 |
-
# Citation
|
| 86 |
-
|
| 87 |
-
```bibtex
|
| 88 |
-
@misc{muennighoff2024olmoeopenmixtureofexpertslanguage,
|
| 89 |
-
title={OLMoE: Open Mixture-of-Experts Language Models},
|
| 90 |
-
author={Niklas Muennighoff and Luca Soldaini and Dirk Groeneveld and Kyle Lo and Jacob Morrison and Sewon Min and Weijia Shi and Pete Walsh and Oyvind Tafjord and Nathan Lambert and Yuling Gu and Shane Arora and Akshita Bhagia and Dustin Schwenk and David Wadden and Alexander Wettig and Binyuan Hui and Tim Dettmers and Douwe Kiela and Ali Farhadi and Noah A. Smith and Pang Wei Koh and Amanpreet Singh and Hannaneh Hajishirzi},
|
| 91 |
-
year={2024},
|
| 92 |
-
eprint={2409.02060},
|
| 93 |
-
archivePrefix={arXiv},
|
| 94 |
-
primaryClass={cs.CL},
|
| 95 |
-
url={https://arxiv.org/abs/2409.02060},
|
| 96 |
-
}
|
| 97 |
-
```
|
|
|
|
| 12 |
library_name: transformers
|
| 13 |
---
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# Model Summary
|
| 17 |
+
# OLMoE with Adapters
|
| 18 |
+
|
| 19 |
+
This repository contains an extension of the OLMo model with adapter layers for parameter-efficient fine-tuning. By adding small adapter modules to the model, we can fine-tune it on downstream tasks while freezing most of the original parameters, resulting in much more efficient training.
|
| 20 |
+
|
| 21 |
+
## Model Architecture
|
| 22 |
+
|
| 23 |
+
The `OlmoEWithAdaptersForCausalLM` model extends the original OLMo architecture by:
|
| 24 |
+
|
| 25 |
+
1. Adding small adapter layers (bottleneck layers) to each MLP block
|
| 26 |
+
2. Allowing selective freezing of the base model's parameters
|
| 27 |
+
3. Training only the adapter parameters (~0.1-1% of total parameters)
|
| 28 |
+
|
| 29 |
+
Key components:
|
| 30 |
+
- `OlmoEWithAdaptersMLP`: MLP layer with additional adapter modules
|
| 31 |
+
- `OlmoEWithAdaptersDecoderLayer`: Decoder layer incorporating adapter MLPs
|
| 32 |
+
- `OlmoEWithAdaptersModel`: Full model with adapter-based decoder layers
|
| 33 |
+
- `OlmoEWithAdaptersForCausalLM`: Causal language model with adapters
|
| 34 |
+
|
| 35 |
+
## Training Script
|
| 36 |
+
|
| 37 |
+
The `train_olmoe_adapters.py` script provides a complete workflow for fine-tuning the model:
|
| 38 |
+
|
| 39 |
+
### Features:
|
| 40 |
+
- Parameter-efficient fine-tuning using adapters
|
| 41 |
+
- Support for various datasets through Hugging Face datasets library
|
| 42 |
+
- Customizable adapter size
|
| 43 |
+
- Option to freeze/unfreeze different components
|
| 44 |
+
- Training with AdamW optimizer and learning rate scheduling
|
| 45 |
+
- Evaluation with perplexity metrics
|
| 46 |
+
- Model checkpointing and saving
|
| 47 |
+
|
| 48 |
+
### Usage:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
python train.py \
|
| 52 |
+
--model_name_or_path allenai/OLMo-7B \
|
| 53 |
+
--adapter_size 64 \
|
| 54 |
+
--freeze_base_model True \
|
| 55 |
+
--dataset_name wikitext \
|
| 56 |
+
--dataset_config_name wikitext-2-raw-v1 \
|
| 57 |
+
--output_dir ./olmoe-adapter-finetuned \
|
| 58 |
+
--num_train_epochs 3 \
|
| 59 |
+
--per_device_train_batch_size 4 \
|
| 60 |
+
--per_device_eval_batch_size 4 \
|
| 61 |
+
--learning_rate 5e-5 \
|
| 62 |
+
--warmup_steps 100 \
|
| 63 |
+
--logging_steps 100 \
|
| 64 |
+
--save_steps 1000 \
|
| 65 |
+
--seed 42
|
| 66 |
+
```
|
| 67 |
|
| 68 |
+
## Benefits of Adapter-Based Fine-Tuning
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
1. **Efficiency**: Train only ~0.1-1% of the parameters, dramatically reducing GPU memory requirements
|
| 71 |
+
2. **Storage**: Store only adapter weights rather than full fine-tuned models
|
| 72 |
+
3. **Composability**: Multiple adapters can be trained for different tasks and swapped at inference time
|
| 73 |
+
4. **Reduced Overfitting**: Lower parameter count helps prevent overfitting on small datasets
|
| 74 |
|
| 75 |
+
## How to Use the Fine-Tuned Model
|
| 76 |
|
| 77 |
```python
|
| 78 |
+
from transformers import OlmoTokenizer
|
| 79 |
+
from modeling_olmoe import OlmoEWithAdaptersForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
# Load the fine-tuned model
|
| 82 |
+
model = OlmoEWithAdaptersForCausalLM.from_pretrained("./olmoe-adapter-finetuned")
|
| 83 |
+
tokenizer = OlmoTokenizer.from_pretrained("./olmoe-adapter-finetuned")
|
| 84 |
+
|
| 85 |
+
# Generate text
|
| 86 |
+
inputs = tokenizer("Once upon a time", return_tensors="pt")
|
| 87 |
+
outputs = model.generate(**inputs, max_length=50)
|
| 88 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 89 |
```
|
| 90 |
|
| 91 |
+
## Adapter Size Recommendations
|
| 92 |
+
|
| 93 |
+
The adapter size determines the parameter efficiency vs. performance trade-off:
|
| 94 |
+
|
| 95 |
+
- **Small datasets**: 16-32 dimensions
|
| 96 |
+
- **Medium datasets**: 64-128 dimensions
|
| 97 |
+
- **Large datasets**: 128-256 dimensions
|
| 98 |
+
|
| 99 |
+
For most fine-tuning scenarios, an adapter size of 64 provides a good balance between efficiency and performance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__pycache__/configuration_olmoe.cpython-311.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
__pycache__/modeling_kvlatent.cpython-311.pyc
ADDED
|
Binary file (33.9 kB). View file
|
|
|
__pycache__/modeling_latent_attention.cpython-311.pyc
ADDED
|
Binary file (9.57 kB). View file
|
|
|
__pycache__/modeling_olmoe.cpython-311.pyc
ADDED
|
Binary file (43.7 kB). View file
|
|
|
__pycache__/random.cpython-311.pyc
ADDED
|
Binary file (54.4 kB). View file
|
|
|
__pycache__/train.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
config.json
CHANGED
|
@@ -1,31 +1,27 @@
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
-
"
|
| 4 |
],
|
| 5 |
-
"
|
| 6 |
-
"attention_dropout": 0.0,
|
| 7 |
-
"clip_qkv": null,
|
| 8 |
-
"eos_token_id": 50279,
|
| 9 |
-
"hidden_act": "silu",
|
| 10 |
"hidden_size": 2048,
|
| 11 |
-
"
|
| 12 |
-
"intermediate_size": 1024,
|
| 13 |
-
"max_position_embeddings": 4096,
|
| 14 |
-
"model_type": "olmoe",
|
| 15 |
-
"norm_topk_prob": false,
|
| 16 |
"num_attention_heads": 16,
|
| 17 |
-
"
|
| 18 |
-
"
|
| 19 |
-
"
|
| 20 |
-
"
|
| 21 |
-
"
|
|
|
|
|
|
|
| 22 |
"pad_token_id": 1,
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
-
"torch_dtype": "bfloat16",
|
| 28 |
-
"transformers_version": "4.43.0.dev0",
|
| 29 |
"use_cache": true,
|
| 30 |
-
"
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
+
"KVLatentForCausalLM"
|
| 4 |
],
|
| 5 |
+
"model_type": "kvlatent",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"hidden_size": 2048,
|
| 7 |
+
"num_hidden_layers": 24,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"num_attention_heads": 16,
|
| 9 |
+
"num_key_value_heads": 8,
|
| 10 |
+
"num_latents": 64,
|
| 11 |
+
"intermediate_size": 8192,
|
| 12 |
+
"hidden_act": "gelu",
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"rms_norm_eps": 1e-5,
|
| 15 |
+
"vocab_size": 50304,
|
| 16 |
"pad_token_id": 1,
|
| 17 |
+
"bos_token_id": 50256,
|
| 18 |
+
"eos_token_id": 50256,
|
| 19 |
+
"attention_dropout": 0.0,
|
| 20 |
+
"attention_bias": false,
|
|
|
|
|
|
|
| 21 |
"use_cache": true,
|
| 22 |
+
"tie_word_embeddings": false,
|
| 23 |
+
"rope_theta": 10000.0,
|
| 24 |
+
"rope_scaling": null,
|
| 25 |
+
"max_position_embeddings": 4096,
|
| 26 |
+
"torch_dtype": "bfloat16"
|
| 27 |
+
}
|
generate.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Example usage script to evaluate a fine-tuned OlmoE adapter model
|
| 4 |
+
and demonstrate generation with adapters.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
from modeling_olmoe import OlmoEWithAdaptersForCausalLM, OlmoConfig
|
| 11 |
+
|
| 12 |
+
def generate_text(
|
| 13 |
+
model_path: str,
|
| 14 |
+
prompt: str,
|
| 15 |
+
max_new_tokens: int = 128,
|
| 16 |
+
temperature: float = 0.7,
|
| 17 |
+
top_p: float = 0.9,
|
| 18 |
+
device: str = "auto",
|
| 19 |
+
):
|
| 20 |
+
"""Generate text using a fine-tuned OlmoE adapter model."""
|
| 21 |
+
# Determine device
|
| 22 |
+
if device == "auto":
|
| 23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
+
print(f"Using device: {device}")
|
| 25 |
+
|
| 26 |
+
# Load tokenizer and model
|
| 27 |
+
print(f"Loading model from {model_path}")
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 29 |
+
|
| 30 |
+
# Load config and update with adapter settings if needed
|
| 31 |
+
config = OlmoConfig.from_pretrained(model_path)
|
| 32 |
+
|
| 33 |
+
# Load adapter model
|
| 34 |
+
model = OlmoEWithAdaptersForCausalLM.from_pretrained(
|
| 35 |
+
model_path,
|
| 36 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 37 |
+
)
|
| 38 |
+
model = model.to(device)
|
| 39 |
+
model.eval()
|
| 40 |
+
|
| 41 |
+
# Tokenize input
|
| 42 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 43 |
+
|
| 44 |
+
# Generate text
|
| 45 |
+
print("\nGenerating text...\n")
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
outputs = model.generate(
|
| 48 |
+
input_ids,
|
| 49 |
+
max_new_tokens=max_new_tokens,
|
| 50 |
+
do_sample=True,
|
| 51 |
+
temperature=temperature,
|
| 52 |
+
top_p=top_p,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Decode the generated text
|
| 56 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 57 |
+
|
| 58 |
+
print(f"Prompt: {prompt}")
|
| 59 |
+
print("\nGenerated text:")
|
| 60 |
+
print("=" * 40)
|
| 61 |
+
print(generated_text)
|
| 62 |
+
print("=" * 40)
|
| 63 |
+
|
| 64 |
+
return generated_text
|
| 65 |
+
|
| 66 |
+
def main():
|
| 67 |
+
parser = argparse.ArgumentParser(description="Generate text with OlmoE adapter model")
|
| 68 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model")
|
| 69 |
+
parser.add_argument("--prompt", type=str, default="This is an example of", help="Prompt for text generation")
|
| 70 |
+
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens to generate")
|
| 71 |
+
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
|
| 72 |
+
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter")
|
| 73 |
+
parser.add_argument("--device", type=str, default="auto", help="Device to use (cuda, cpu, or auto)")
|
| 74 |
+
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
|
| 77 |
+
generate_text(
|
| 78 |
+
model_path=args.model_path,
|
| 79 |
+
prompt=args.prompt,
|
| 80 |
+
max_new_tokens=args.max_new_tokens,
|
| 81 |
+
temperature=args.temperature,
|
| 82 |
+
top_p=args.top_p,
|
| 83 |
+
device=args.device,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
main()
|
model-00001-of-00003.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5e3cff7e367794685c241169072c940d200918617d5e2813f1c387dff52d845e
|
| 3 |
-
size 4997744872
|
|
|
|
|
|
|
|
|
|
|
|
model-00002-of-00003.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:15ef5c730ee3cfed7199498788cd2faf337203fc74b529625e7502cdd759f4a7
|
| 3 |
-
size 4997235176
|
|
|
|
|
|
|
|
|
|
|
|
model-00003-of-00003.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a9abac4ac1b55c9adabac721a02fa39971f103eea9a65c310972b1246de76e04
|
| 3 |
-
size 3843741912
|
|
|
|
|
|
|
|
|
|
|
|
modeling_olmoe.py
ADDED
|
@@ -0,0 +1,822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modeling_olmoe.py - Extended version of OLMo for custom training
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Callable, Dict, Optional, Tuple, Union, Any
|
| 7 |
+
# Import necessary components from transformers
|
| 8 |
+
from transformers.activations import ACT2FN
|
| 9 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 10 |
+
from transformers.generation import GenerationMixin
|
| 11 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 12 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 13 |
+
# from transformers.modeling_layers import GradientCheckpointingLayer
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 16 |
+
# from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 17 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 18 |
+
from transformers.processing_utils import Unpack
|
| 19 |
+
from transformers.utils import LossKwargs, is_torch_flex_attn_available, logging
|
| 20 |
+
from transformers import OlmoConfig
|
| 21 |
+
|
| 22 |
+
# Import flex attention components if available
|
| 23 |
+
if is_torch_flex_attn_available():
|
| 24 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 25 |
+
# from transformers.integrations.flex_attention import make_flex_block_causal_mask
|
| 26 |
+
|
| 27 |
+
from functools import partial
|
| 28 |
+
# Define GradientCheckpointingLayer since it's missing
|
| 29 |
+
class GradientCheckpointingLayer(nn.Module):
|
| 30 |
+
gradient_checkpointing = False
|
| 31 |
+
def __call__(self, *args, **kwargs):
|
| 32 |
+
# Use checkpoint on `forward` when enabled
|
| 33 |
+
if self.gradient_checkpointing and self.training:
|
| 34 |
+
return checkpoint(self.forward, *args, **kwargs)
|
| 35 |
+
return super().__call__(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
def forward(self, *args, **kwargs):
|
| 38 |
+
# To be implemented by subclasses
|
| 39 |
+
raise NotImplementedError("Subclasses must implement `forward`")
|
| 40 |
+
|
| 41 |
+
import math
|
| 42 |
+
import functools
|
| 43 |
+
|
| 44 |
+
# Define our own dynamic_rope_update decorator and ROPE_INIT_FUNCTIONS
|
| 45 |
+
def dynamic_rope_update(func):
|
| 46 |
+
"""
|
| 47 |
+
Decorator for updating RoPE embeddings when using RoPE scaling strategies.
|
| 48 |
+
"""
|
| 49 |
+
@functools.wraps(func)
|
| 50 |
+
def wrapper(self, *args, **kwargs):
|
| 51 |
+
# Only dynamic scaling needs to modify the positional encodings
|
| 52 |
+
if self.rope_type == "dynamic" and hasattr(self, "original_max_seq_len"):
|
| 53 |
+
if self.config.rope_scaling is None:
|
| 54 |
+
return func(self, *args, **kwargs)
|
| 55 |
+
# Extract max_position_embeddings from the actual model
|
| 56 |
+
current_ctx_len = kwargs.get("position_ids", None)
|
| 57 |
+
if current_ctx_len is not None:
|
| 58 |
+
# position_ids shape is [batch_size, seq_len]
|
| 59 |
+
current_ctx_len = current_ctx_len.shape[-1]
|
| 60 |
+
|
| 61 |
+
# If we're inside a context window we've seen before, we don't have to change anything
|
| 62 |
+
if current_ctx_len is not None and current_ctx_len <= self.max_seq_len_cached:
|
| 63 |
+
return func(self, *args, **kwargs)
|
| 64 |
+
|
| 65 |
+
current_ctx_len = self.config.max_position_embeddings if current_ctx_len is None else current_ctx_len
|
| 66 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 67 |
+
|
| 68 |
+
self.max_seq_len_cached = min(
|
| 69 |
+
int(self.original_max_seq_len * scaling_factor),
|
| 70 |
+
self.config.rope_scaling.get("max_position_embeddings", float("inf"))
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Reset the cached maximum position embeddings to the new value
|
| 74 |
+
power = 0.0 if scaling_factor <= 1.0 else -0.5
|
| 75 |
+
self.inv_freq = self.original_inv_freq * (scaling_factor ** power)
|
| 76 |
+
|
| 77 |
+
return func(self, *args, **kwargs)
|
| 78 |
+
|
| 79 |
+
return wrapper
|
| 80 |
+
|
| 81 |
+
def get_default_rope_init(config, device=None):
|
| 82 |
+
"""
|
| 83 |
+
Default initialization for rotary position embeddings.
|
| 84 |
+
"""
|
| 85 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 86 |
+
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, head_dim, 2).float().to(device) / head_dim))
|
| 87 |
+
return inv_freq, None
|
| 88 |
+
|
| 89 |
+
def get_linear_rope_init(config, device=None):
|
| 90 |
+
"""
|
| 91 |
+
Linear initialization for dynamic scaling rotary position embeddings.
|
| 92 |
+
"""
|
| 93 |
+
base = get_default_rope_init(config, device)[0]
|
| 94 |
+
scaling_factor = config.rope_scaling["factor"]
|
| 95 |
+
|
| 96 |
+
# Scale the base frequencies
|
| 97 |
+
return base / scaling_factor, scaling_factor
|
| 98 |
+
|
| 99 |
+
def get_dynamic_rope_init(config, device=None):
|
| 100 |
+
"""
|
| 101 |
+
Dynamic initialization for dynamic scaling rotary position embeddings (NTK approach).
|
| 102 |
+
"""
|
| 103 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 104 |
+
scaling_factor = config.rope_scaling["factor"]
|
| 105 |
+
|
| 106 |
+
# Adjust the base frequencies by a power of the scaling factor
|
| 107 |
+
power = 0.0 if scaling_factor <= 1.0 else -0.5
|
| 108 |
+
inv_freq = 1.0 / (config.rope_theta **
|
| 109 |
+
(torch.arange(0, head_dim, 2).float().to(device) / head_dim))
|
| 110 |
+
inv_freq = inv_freq * (scaling_factor ** power)
|
| 111 |
+
|
| 112 |
+
return inv_freq, scaling_factor
|
| 113 |
+
|
| 114 |
+
# Define the dictionary of RoPE initialization functions
|
| 115 |
+
ROPE_INIT_FUNCTIONS = {
|
| 116 |
+
"default": get_default_rope_init,
|
| 117 |
+
"linear": get_linear_rope_init,
|
| 118 |
+
"dynamic": get_dynamic_rope_init,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def can_return_tuple(inputs):
|
| 122 |
+
# Copied logic from the original source
|
| 123 |
+
return getattr(inputs, "return_tuple", False) if hasattr(inputs, "return_tuple") else False
|
| 124 |
+
|
| 125 |
+
# Start Modeling Code
|
| 126 |
+
logger = logging.get_logger(__name__)
|
| 127 |
+
|
| 128 |
+
# Core OLMo components (reused from original implementation)
|
| 129 |
+
class OlmoLayerNorm(nn.Module):
|
| 130 |
+
"""LayerNorm but with no learnable weight or bias."""
|
| 131 |
+
|
| 132 |
+
def __init__(self, hidden_size: int) -> None:
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.normalized_shape = (hidden_size,)
|
| 135 |
+
|
| 136 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
orig_dtype = hidden_states.dtype
|
| 138 |
+
return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
|
| 139 |
+
orig_dtype
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class OlmoMLP(nn.Module):
|
| 144 |
+
def __init__(self, config):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.config = config
|
| 147 |
+
self.hidden_size = config.hidden_size
|
| 148 |
+
self.intermediate_size = config.intermediate_size
|
| 149 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 150 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 151 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 152 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 156 |
+
return down_proj
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# Helper functions for rotary position embeddings
|
| 160 |
+
def rotate_half(x):
|
| 161 |
+
"""Rotates half the hidden dims of the input."""
|
| 162 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 163 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 164 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 168 |
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
| 169 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 170 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 171 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 172 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 173 |
+
return q_embed, k_embed
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 177 |
+
"""
|
| 178 |
+
Repeats key/value states for grouped queries attention.
|
| 179 |
+
"""
|
| 180 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 181 |
+
if n_rep == 1:
|
| 182 |
+
return hidden_states
|
| 183 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 184 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def eager_attention_forward(
|
| 188 |
+
module: nn.Module,
|
| 189 |
+
query: torch.Tensor,
|
| 190 |
+
key: torch.Tensor,
|
| 191 |
+
value: torch.Tensor,
|
| 192 |
+
attention_mask: Optional[torch.Tensor],
|
| 193 |
+
scaling: float,
|
| 194 |
+
dropout: float = 0.0,
|
| 195 |
+
**kwargs,
|
| 196 |
+
):
|
| 197 |
+
"""Default eager implementation of multi-head attention"""
|
| 198 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 199 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 200 |
+
|
| 201 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 202 |
+
if attention_mask is not None:
|
| 203 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 204 |
+
attn_weights = attn_weights + causal_mask
|
| 205 |
+
|
| 206 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 207 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 208 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 209 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 210 |
+
|
| 211 |
+
return attn_output, attn_weights
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class OlmoAttention(nn.Module):
|
| 215 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 216 |
+
|
| 217 |
+
def __init__(self, config: OlmoConfig, layer_idx: int):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.config = config
|
| 220 |
+
self.layer_idx = layer_idx
|
| 221 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 222 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 223 |
+
self.scaling = self.head_dim**-0.5
|
| 224 |
+
self.attention_dropout = config.attention_dropout
|
| 225 |
+
self.is_causal = True
|
| 226 |
+
|
| 227 |
+
self.q_proj = nn.Linear(
|
| 228 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 229 |
+
)
|
| 230 |
+
self.k_proj = nn.Linear(
|
| 231 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 232 |
+
)
|
| 233 |
+
self.v_proj = nn.Linear(
|
| 234 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 235 |
+
)
|
| 236 |
+
self.o_proj = nn.Linear(
|
| 237 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def forward(
|
| 241 |
+
self,
|
| 242 |
+
hidden_states: torch.Tensor,
|
| 243 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 244 |
+
attention_mask: Optional[torch.Tensor],
|
| 245 |
+
past_key_value: Optional[Cache] = None,
|
| 246 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 247 |
+
**kwargs,
|
| 248 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 249 |
+
input_shape = hidden_states.shape[:-1]
|
| 250 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 251 |
+
|
| 252 |
+
query_states = self.q_proj(hidden_states)
|
| 253 |
+
key_states = self.k_proj(hidden_states)
|
| 254 |
+
value_states = self.v_proj(hidden_states)
|
| 255 |
+
|
| 256 |
+
if self.config.clip_qkv is not None:
|
| 257 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 258 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 259 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 260 |
+
|
| 261 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 262 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 263 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 264 |
+
|
| 265 |
+
cos, sin = position_embeddings
|
| 266 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 267 |
+
|
| 268 |
+
if past_key_value is not None:
|
| 269 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 270 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 271 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 272 |
+
|
| 273 |
+
attention_interface: Callable = eager_attention_forward
|
| 274 |
+
if self.config._attn_implementation != "eager":
|
| 275 |
+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
| 276 |
+
logger.warning_once(
|
| 277 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 278 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 282 |
+
|
| 283 |
+
attn_output, attn_weights = attention_interface(
|
| 284 |
+
self,
|
| 285 |
+
query_states,
|
| 286 |
+
key_states,
|
| 287 |
+
value_states,
|
| 288 |
+
attention_mask,
|
| 289 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 290 |
+
scaling=self.scaling,
|
| 291 |
+
**kwargs,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 295 |
+
attn_output = self.o_proj(attn_output)
|
| 296 |
+
return attn_output, attn_weights
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class OlmoDecoderLayer(GradientCheckpointingLayer):
|
| 300 |
+
def __init__(self, config: OlmoConfig, layer_idx: int):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.hidden_size = config.hidden_size
|
| 303 |
+
self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
|
| 304 |
+
|
| 305 |
+
self.mlp = OlmoMLP(config)
|
| 306 |
+
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 307 |
+
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 308 |
+
|
| 309 |
+
def forward(
|
| 310 |
+
self,
|
| 311 |
+
hidden_states: torch.Tensor,
|
| 312 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 313 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 314 |
+
past_key_value: Optional[Cache] = None,
|
| 315 |
+
output_attentions: Optional[bool] = False,
|
| 316 |
+
use_cache: Optional[bool] = False,
|
| 317 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 318 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 319 |
+
**kwargs,
|
| 320 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 321 |
+
residual = hidden_states
|
| 322 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 323 |
+
|
| 324 |
+
# Self Attention
|
| 325 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 326 |
+
hidden_states=hidden_states,
|
| 327 |
+
attention_mask=attention_mask,
|
| 328 |
+
position_ids=position_ids,
|
| 329 |
+
past_key_value=past_key_value,
|
| 330 |
+
output_attentions=output_attentions,
|
| 331 |
+
use_cache=use_cache,
|
| 332 |
+
cache_position=cache_position,
|
| 333 |
+
position_embeddings=position_embeddings,
|
| 334 |
+
**kwargs,
|
| 335 |
+
)
|
| 336 |
+
hidden_states = residual + hidden_states
|
| 337 |
+
|
| 338 |
+
# Fully Connected
|
| 339 |
+
residual = hidden_states
|
| 340 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 341 |
+
hidden_states = self.mlp(hidden_states)
|
| 342 |
+
hidden_states = residual + hidden_states
|
| 343 |
+
|
| 344 |
+
outputs = (hidden_states,)
|
| 345 |
+
if output_attentions:
|
| 346 |
+
outputs += (self_attn_weights,)
|
| 347 |
+
|
| 348 |
+
return outputs
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class OlmoRotaryEmbedding(nn.Module):
|
| 352 |
+
def __init__(self, config: OlmoConfig, device=None):
|
| 353 |
+
super().__init__()
|
| 354 |
+
# BC: "rope_type" was originally "type"
|
| 355 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 356 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 357 |
+
else:
|
| 358 |
+
self.rope_type = "default"
|
| 359 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 360 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 361 |
+
|
| 362 |
+
self.config = config
|
| 363 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 364 |
+
|
| 365 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 366 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 367 |
+
self.original_inv_freq = self.inv_freq
|
| 368 |
+
|
| 369 |
+
@torch.no_grad()
|
| 370 |
+
@dynamic_rope_update
|
| 371 |
+
def forward(self, x, position_ids):
|
| 372 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 373 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 374 |
+
|
| 375 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 376 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 377 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 378 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 379 |
+
cos = emb.cos() * self.attention_scaling
|
| 380 |
+
sin = emb.sin() * self.attention_scaling
|
| 381 |
+
|
| 382 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# Base model classes
|
| 386 |
+
class OlmoEPreTrainedModel(PreTrainedModel):
|
| 387 |
+
"""Base class for OlmoE models with additional extensibility features"""
|
| 388 |
+
|
| 389 |
+
config_class = OlmoConfig
|
| 390 |
+
base_model_prefix = "model"
|
| 391 |
+
supports_gradient_checkpointing = True
|
| 392 |
+
_no_split_modules = ["OlmoDecoderLayer"]
|
| 393 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 394 |
+
_supports_flash_attn_2 = True
|
| 395 |
+
_supports_sdpa = True
|
| 396 |
+
_supports_flex_attn = True
|
| 397 |
+
_supports_cache_class = True
|
| 398 |
+
_supports_quantized_cache = True
|
| 399 |
+
_supports_static_cache = True
|
| 400 |
+
_supports_attention_backend = True
|
| 401 |
+
|
| 402 |
+
def _init_weights(self, module):
|
| 403 |
+
std = self.config.initializer_range
|
| 404 |
+
if isinstance(module, nn.Linear):
|
| 405 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 406 |
+
if module.bias is not None:
|
| 407 |
+
module.bias.data.zero_()
|
| 408 |
+
elif isinstance(module, nn.Embedding):
|
| 409 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 410 |
+
if module.padding_idx is not None:
|
| 411 |
+
module.weight.data[module.padding_idx].zero_()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class OlmoEModel(OlmoEPreTrainedModel):
|
| 415 |
+
"""Extended OLMo base model with additional customization points"""
|
| 416 |
+
|
| 417 |
+
def __init__(self, config: OlmoConfig):
|
| 418 |
+
super().__init__(config)
|
| 419 |
+
self.padding_idx = config.pad_token_id
|
| 420 |
+
self.vocab_size = config.vocab_size
|
| 421 |
+
|
| 422 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 423 |
+
self.layers = nn.ModuleList(
|
| 424 |
+
[OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 425 |
+
)
|
| 426 |
+
self.norm = OlmoLayerNorm(config.hidden_size)
|
| 427 |
+
self.rotary_emb = OlmoRotaryEmbedding(config=config)
|
| 428 |
+
self.gradient_checkpointing = False
|
| 429 |
+
|
| 430 |
+
# Initialize weights and apply final processing
|
| 431 |
+
self.post_init()
|
| 432 |
+
|
| 433 |
+
def get_input_embeddings(self):
|
| 434 |
+
return self.embed_tokens
|
| 435 |
+
|
| 436 |
+
def set_input_embeddings(self, value):
|
| 437 |
+
self.embed_tokens = value
|
| 438 |
+
|
| 439 |
+
def _update_causal_mask(
|
| 440 |
+
self,
|
| 441 |
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
| 442 |
+
input_tensor: torch.Tensor,
|
| 443 |
+
cache_position: torch.Tensor,
|
| 444 |
+
past_key_values: Cache,
|
| 445 |
+
output_attentions: bool = False,
|
| 446 |
+
):
|
| 447 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 448 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 449 |
+
return attention_mask
|
| 450 |
+
return None
|
| 451 |
+
# if self.config._attn_implementation == "flex_attention":
|
| 452 |
+
# if isinstance(attention_mask, torch.Tensor):
|
| 453 |
+
# attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 454 |
+
# return attention_mask
|
| 455 |
+
|
| 456 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 457 |
+
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
| 458 |
+
|
| 459 |
+
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
| 460 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 461 |
+
attention_mask,
|
| 462 |
+
inputs_embeds=input_tensor,
|
| 463 |
+
past_key_values_length=past_seen_tokens,
|
| 464 |
+
is_training=self.training,
|
| 465 |
+
):
|
| 466 |
+
return None
|
| 467 |
+
|
| 468 |
+
dtype = input_tensor.dtype
|
| 469 |
+
sequence_length = input_tensor.shape[1]
|
| 470 |
+
if using_compilable_cache:
|
| 471 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 472 |
+
else:
|
| 473 |
+
target_length = (
|
| 474 |
+
attention_mask.shape[-1]
|
| 475 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 476 |
+
else past_seen_tokens + sequence_length + 1
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 480 |
+
attention_mask,
|
| 481 |
+
sequence_length=sequence_length,
|
| 482 |
+
target_length=target_length,
|
| 483 |
+
dtype=dtype,
|
| 484 |
+
cache_position=cache_position,
|
| 485 |
+
batch_size=input_tensor.shape[0],
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
if (
|
| 489 |
+
self.config._attn_implementation == "sdpa"
|
| 490 |
+
and attention_mask is not None
|
| 491 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 492 |
+
and not output_attentions
|
| 493 |
+
):
|
| 494 |
+
min_dtype = torch.finfo(dtype).min
|
| 495 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 496 |
+
|
| 497 |
+
return causal_mask
|
| 498 |
+
|
| 499 |
+
@staticmethod
|
| 500 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 501 |
+
attention_mask: torch.Tensor,
|
| 502 |
+
sequence_length: int,
|
| 503 |
+
target_length: int,
|
| 504 |
+
dtype: torch.dtype,
|
| 505 |
+
cache_position: torch.Tensor,
|
| 506 |
+
batch_size: int,
|
| 507 |
+
**kwargs,
|
| 508 |
+
):
|
| 509 |
+
"""Creates a causal 4D mask."""
|
| 510 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 511 |
+
causal_mask = attention_mask
|
| 512 |
+
else:
|
| 513 |
+
min_dtype = torch.finfo(dtype).min
|
| 514 |
+
causal_mask = torch.full(
|
| 515 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
| 516 |
+
)
|
| 517 |
+
if sequence_length != 1:
|
| 518 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 519 |
+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
| 520 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 521 |
+
if attention_mask is not None:
|
| 522 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 523 |
+
mask_length = attention_mask.shape[-1]
|
| 524 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 525 |
+
causal_mask.device
|
| 526 |
+
)
|
| 527 |
+
padding_mask = padding_mask == 0
|
| 528 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 529 |
+
padding_mask, min_dtype
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
return causal_mask
|
| 533 |
+
|
| 534 |
+
@can_return_tuple
|
| 535 |
+
def forward(
|
| 536 |
+
self,
|
| 537 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 538 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 539 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 540 |
+
past_key_values: Optional[Cache] = None,
|
| 541 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 542 |
+
use_cache: Optional[bool] = None,
|
| 543 |
+
output_attentions: Optional[bool] = None,
|
| 544 |
+
output_hidden_states: Optional[bool] = None,
|
| 545 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 546 |
+
**flash_attn_kwargs,
|
| 547 |
+
) -> BaseModelOutputWithPast:
|
| 548 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 549 |
+
output_hidden_states = (
|
| 550 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 551 |
+
)
|
| 552 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 553 |
+
|
| 554 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 555 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 556 |
+
|
| 557 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 558 |
+
logger.warning_once(
|
| 559 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 560 |
+
)
|
| 561 |
+
use_cache = False
|
| 562 |
+
|
| 563 |
+
if not isinstance(past_key_values, (type(None), Cache)):
|
| 564 |
+
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
| 565 |
+
|
| 566 |
+
if inputs_embeds is None:
|
| 567 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 568 |
+
|
| 569 |
+
if use_cache and past_key_values is None:
|
| 570 |
+
past_key_values = DynamicCache()
|
| 571 |
+
|
| 572 |
+
if cache_position is None:
|
| 573 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 574 |
+
cache_position = torch.arange(
|
| 575 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
if position_ids is None:
|
| 579 |
+
position_ids = cache_position.unsqueeze(0)
|
| 580 |
+
|
| 581 |
+
causal_mask = self._update_causal_mask(
|
| 582 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
hidden_states = inputs_embeds
|
| 586 |
+
|
| 587 |
+
# create position embeddings to be shared across the decoder layers
|
| 588 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 589 |
+
|
| 590 |
+
# decoder layers
|
| 591 |
+
all_hidden_states = () if output_hidden_states else None
|
| 592 |
+
all_self_attns = () if output_attentions else None
|
| 593 |
+
|
| 594 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 595 |
+
if output_hidden_states:
|
| 596 |
+
all_hidden_states += (hidden_states,)
|
| 597 |
+
|
| 598 |
+
layer_outputs = decoder_layer(
|
| 599 |
+
hidden_states,
|
| 600 |
+
attention_mask=causal_mask,
|
| 601 |
+
position_ids=position_ids,
|
| 602 |
+
past_key_value=past_key_values,
|
| 603 |
+
output_attentions=output_attentions,
|
| 604 |
+
use_cache=use_cache,
|
| 605 |
+
cache_position=cache_position,
|
| 606 |
+
position_embeddings=position_embeddings,
|
| 607 |
+
**flash_attn_kwargs,
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
hidden_states = layer_outputs[0]
|
| 611 |
+
|
| 612 |
+
if output_attentions:
|
| 613 |
+
all_self_attns += (layer_outputs[1],)
|
| 614 |
+
|
| 615 |
+
hidden_states = self.norm(hidden_states)
|
| 616 |
+
|
| 617 |
+
# add hidden states from the last decoder layer
|
| 618 |
+
if output_hidden_states:
|
| 619 |
+
all_hidden_states += (hidden_states,)
|
| 620 |
+
|
| 621 |
+
return BaseModelOutputWithPast(
|
| 622 |
+
last_hidden_state=hidden_states,
|
| 623 |
+
past_key_values=past_key_values if use_cache else None,
|
| 624 |
+
hidden_states=all_hidden_states,
|
| 625 |
+
attentions=all_self_attns,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class OlmoEForCausalLM(OlmoEPreTrainedModel, GenerationMixin):
|
| 633 |
+
"""OLMo Causal Language Model with extensions for custom training"""
|
| 634 |
+
|
| 635 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 636 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 637 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 638 |
+
|
| 639 |
+
def __init__(self, config):
|
| 640 |
+
super().__init__(config)
|
| 641 |
+
self.model = OlmoEModel(config)
|
| 642 |
+
self.vocab_size = config.vocab_size
|
| 643 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 644 |
+
|
| 645 |
+
# Initialize weights and apply final processing
|
| 646 |
+
self.post_init()
|
| 647 |
+
|
| 648 |
+
def get_input_embeddings(self):
|
| 649 |
+
return self.model.embed_tokens
|
| 650 |
+
|
| 651 |
+
def set_input_embeddings(self, value):
|
| 652 |
+
self.model.embed_tokens = value
|
| 653 |
+
|
| 654 |
+
def get_output_embeddings(self):
|
| 655 |
+
return self.lm_head
|
| 656 |
+
|
| 657 |
+
def set_output_embeddings(self, new_embeddings):
|
| 658 |
+
self.lm_head = new_embeddings
|
| 659 |
+
|
| 660 |
+
def set_decoder(self, decoder):
|
| 661 |
+
self.model = decoder
|
| 662 |
+
|
| 663 |
+
def get_decoder(self):
|
| 664 |
+
return self.model
|
| 665 |
+
|
| 666 |
+
@can_return_tuple
|
| 667 |
+
def forward(
|
| 668 |
+
self,
|
| 669 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 670 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 671 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 672 |
+
past_key_values: Optional[Cache] = None,
|
| 673 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 674 |
+
labels: Optional[torch.LongTensor] = None,
|
| 675 |
+
use_cache: Optional[bool] = None,
|
| 676 |
+
output_attentions: Optional[bool] = None,
|
| 677 |
+
output_hidden_states: Optional[bool] = None,
|
| 678 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 679 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 680 |
+
**kwargs,
|
| 681 |
+
) -> CausalLMOutputWithPast:
|
| 682 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 683 |
+
output_hidden_states = (
|
| 684 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Get model outputs
|
| 688 |
+
outputs = self.model(
|
| 689 |
+
input_ids=input_ids,
|
| 690 |
+
attention_mask=attention_mask,
|
| 691 |
+
position_ids=position_ids,
|
| 692 |
+
past_key_values=past_key_values,
|
| 693 |
+
inputs_embeds=inputs_embeds,
|
| 694 |
+
use_cache=use_cache,
|
| 695 |
+
output_attentions=output_attentions,
|
| 696 |
+
output_hidden_states=output_hidden_states,
|
| 697 |
+
cache_position=cache_position,
|
| 698 |
+
**kwargs,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
hidden_states = outputs.last_hidden_state
|
| 702 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 703 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 704 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 705 |
+
|
| 706 |
+
loss = None
|
| 707 |
+
if labels is not None:
|
| 708 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 709 |
+
|
| 710 |
+
return CausalLMOutputWithPast(
|
| 711 |
+
loss=loss,
|
| 712 |
+
logits=logits,
|
| 713 |
+
past_key_values=outputs.past_key_values,
|
| 714 |
+
hidden_states=outputs.hidden_states,
|
| 715 |
+
attentions=outputs.attentions,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
# Example of custom model extensions you can create:
|
| 720 |
+
|
| 721 |
+
class OlmoEWithAdaptersMLP(OlmoMLP):
|
| 722 |
+
"""An extended MLP with adapters for parameter-efficient fine-tuning"""
|
| 723 |
+
|
| 724 |
+
def __init__(self, config):
|
| 725 |
+
super().__init__(config)
|
| 726 |
+
# Example adapter dimensions (typically much smaller than original dims)
|
| 727 |
+
adapter_size = getattr(config, "adapter_size", 64)
|
| 728 |
+
|
| 729 |
+
# Add adapter layers
|
| 730 |
+
self.down_adapter = nn.Sequential(
|
| 731 |
+
nn.Linear(self.hidden_size, adapter_size, bias=False),
|
| 732 |
+
nn.ReLU(),
|
| 733 |
+
nn.Linear(adapter_size, self.hidden_size, bias=False),
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Initialize adapter layers with small weights
|
| 737 |
+
self.down_adapter[0].weight.data.normal_(mean=0.0, std=0.01)
|
| 738 |
+
self.down_adapter[2].weight.data.normal_(mean=0.0, std=0.01)
|
| 739 |
+
|
| 740 |
+
def forward(self, x):
|
| 741 |
+
# Original MLP computation
|
| 742 |
+
mlp_output = super().forward(x)
|
| 743 |
+
|
| 744 |
+
# Add adapter path with residual connection
|
| 745 |
+
adapter_output = self.down_adapter(x)
|
| 746 |
+
return mlp_output + adapter_output
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class OlmoEWithAdaptersDecoderLayer(OlmoDecoderLayer):
|
| 750 |
+
"""OLMo decoder layer with adapters for efficient fine-tuning"""
|
| 751 |
+
|
| 752 |
+
def __init__(self, config, layer_idx):
|
| 753 |
+
# Replace the standard MLP with an adapter-based MLP
|
| 754 |
+
super().__init__(config, layer_idx)
|
| 755 |
+
self.mlp = OlmoEWithAdaptersMLP(config)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class OlmoEWithAdaptersModel(OlmoEModel):
|
| 759 |
+
"""OLMo model with adapter layers"""
|
| 760 |
+
|
| 761 |
+
def __init__(self, config):
|
| 762 |
+
super().__init__(config)
|
| 763 |
+
# Replace all layers with adapter-based layers
|
| 764 |
+
self.layers = nn.ModuleList(
|
| 765 |
+
[OlmoEWithAdaptersDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
# Initialize weights
|
| 769 |
+
self.post_init()
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class OlmoEWithAdaptersForCausalLM(OlmoEForCausalLM):
|
| 773 |
+
"""OLMo for causal language modeling with adapters"""
|
| 774 |
+
|
| 775 |
+
def __init__(self, config, adapters_config: Optional[Dict[str, Any]] = None):
|
| 776 |
+
super().__init__(config)
|
| 777 |
+
self.adapters_config = adapters_config
|
| 778 |
+
|
| 779 |
+
# Initialize the model with adapters using the config
|
| 780 |
+
self.model = OlmoEWithAdaptersModel(config)
|
| 781 |
+
|
| 782 |
+
# Initialize weights
|
| 783 |
+
self.post_init()
|
| 784 |
+
|
| 785 |
+
def freeze_base_model(self):
|
| 786 |
+
"""Freeze all parameters except adapters for efficient fine-tuning"""
|
| 787 |
+
for param in self.model.embed_tokens.parameters():
|
| 788 |
+
param.requires_grad = False
|
| 789 |
+
|
| 790 |
+
for layer in self.model.layers:
|
| 791 |
+
for name, param in layer.self_attn.named_parameters():
|
| 792 |
+
param.requires_grad = False
|
| 793 |
+
|
| 794 |
+
for name, param in layer.mlp.named_parameters():
|
| 795 |
+
if "down_adapter" not in name:
|
| 796 |
+
param.requires_grad = False
|
| 797 |
+
|
| 798 |
+
for param in layer.input_layernorm.parameters():
|
| 799 |
+
param.requires_grad = False
|
| 800 |
+
for param in layer.post_attention_layernorm.parameters():
|
| 801 |
+
param.requires_grad = False
|
| 802 |
+
|
| 803 |
+
for param in self.model.norm.parameters():
|
| 804 |
+
param.requires_grad = False
|
| 805 |
+
|
| 806 |
+
# Uncomment to freeze LM head
|
| 807 |
+
# for param in self.lm_head.parameters():
|
| 808 |
+
# param.requires_grad = False
|
| 809 |
+
|
| 810 |
+
def get_trainable_parameters(self):
|
| 811 |
+
"""Return only trainable parameters for optimizer"""
|
| 812 |
+
return [p for p in self.parameters() if p.requires_grad]
|
| 813 |
+
|
| 814 |
+
@classmethod
|
| 815 |
+
def from_config_and_adapters(
|
| 816 |
+
cls,
|
| 817 |
+
config,
|
| 818 |
+
adapters_config: Optional[Dict[str, Any]] = None,
|
| 819 |
+
) -> "OlmoEWithAdaptersForCausalLM":
|
| 820 |
+
"""Optional factory method, if you want to keep this pattern."""
|
| 821 |
+
return cls(config=config, adapters_config=adapters_config)
|
| 822 |
+
|
oldcmds.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_HOME=$(dirname $(dirname $(which nvcc)))
|
| 2 |
+
export PATH=$CUDA_HOME/bin:$PATH
|
| 3 |
+
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
output.txt
ADDED
|
File without changes
|
randommoe.py
ADDED
|
@@ -0,0 +1,1047 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from ...activations import ACT2FN
|
| 8 |
+
from ...cache_utils import Cache, DynamicCache
|
| 9 |
+
from ...generation import GenerationMixin
|
| 10 |
+
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
| 11 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 12 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 13 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 14 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 15 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 16 |
+
from ...processing_utils import Unpack
|
| 17 |
+
from ...utils import (
|
| 18 |
+
LossKwargs,
|
| 19 |
+
add_start_docstrings,
|
| 20 |
+
add_start_docstrings_to_model_forward,
|
| 21 |
+
can_return_tuple,
|
| 22 |
+
is_torch_flex_attn_available,
|
| 23 |
+
logging,
|
| 24 |
+
replace_return_docstrings,
|
| 25 |
+
)
|
| 26 |
+
from .configuration_olmo import OlmoConfig
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if is_torch_flex_attn_available():
|
| 30 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 31 |
+
|
| 32 |
+
from ...integrations.flex_attention import make_flex_block_causal_mask
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
_CONFIG_FOR_DOC = "OlmoConfig"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class OlmoLayerNorm(nn.Module):
|
| 40 |
+
"""LayerNorm but with no learnable weight or bias."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, hidden_size: int) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.normalized_shape = (hidden_size,)
|
| 45 |
+
|
| 46 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
orig_dtype = hidden_states.dtype
|
| 48 |
+
return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
|
| 49 |
+
orig_dtype
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class OlmoMLP(nn.Module):
|
| 54 |
+
def __init__(self, config):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.config = config
|
| 57 |
+
self.hidden_size = config.hidden_size
|
| 58 |
+
self.intermediate_size = config.intermediate_size
|
| 59 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 60 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 61 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 62 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 66 |
+
return down_proj
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def rotate_half(x):
|
| 70 |
+
"""Rotates half the hidden dims of the input."""
|
| 71 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 72 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 73 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 77 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
q (`torch.Tensor`): The query tensor.
|
| 81 |
+
k (`torch.Tensor`): The key tensor.
|
| 82 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 83 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 84 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 85 |
+
Deprecated and unused.
|
| 86 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 87 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 88 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 89 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 90 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 91 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 92 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 93 |
+
Returns:
|
| 94 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 95 |
+
"""
|
| 96 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 97 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 98 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 99 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 100 |
+
return q_embed, k_embed
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 104 |
+
"""
|
| 105 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 106 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 107 |
+
"""
|
| 108 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 109 |
+
if n_rep == 1:
|
| 110 |
+
return hidden_states
|
| 111 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 112 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def eager_attention_forward(
|
| 116 |
+
module: nn.Module,
|
| 117 |
+
query: torch.Tensor,
|
| 118 |
+
key: torch.Tensor,
|
| 119 |
+
value: torch.Tensor,
|
| 120 |
+
attention_mask: Optional[torch.Tensor],
|
| 121 |
+
scaling: float,
|
| 122 |
+
dropout: float = 0.0,
|
| 123 |
+
**kwargs,
|
| 124 |
+
):
|
| 125 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 126 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 127 |
+
|
| 128 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 129 |
+
if attention_mask is not None:
|
| 130 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 131 |
+
attn_weights = attn_weights + causal_mask
|
| 132 |
+
|
| 133 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 134 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 135 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 136 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 137 |
+
|
| 138 |
+
return attn_output, attn_weights
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class OlmoAttention(nn.Module):
|
| 142 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 143 |
+
|
| 144 |
+
def __init__(self, config: OlmoConfig, layer_idx: int):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.config = config
|
| 147 |
+
self.layer_idx = layer_idx
|
| 148 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 149 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 150 |
+
self.scaling = self.head_dim**-0.5
|
| 151 |
+
self.attention_dropout = config.attention_dropout
|
| 152 |
+
self.is_causal = True
|
| 153 |
+
|
| 154 |
+
self.q_proj = nn.Linear(
|
| 155 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 156 |
+
)
|
| 157 |
+
self.k_proj = nn.Linear(
|
| 158 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 159 |
+
)
|
| 160 |
+
self.v_proj = nn.Linear(
|
| 161 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 162 |
+
)
|
| 163 |
+
self.o_proj = nn.Linear(
|
| 164 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def forward(
|
| 168 |
+
self,
|
| 169 |
+
hidden_states: torch.Tensor,
|
| 170 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 171 |
+
attention_mask: Optional[torch.Tensor],
|
| 172 |
+
past_key_value: Optional[Cache] = None,
|
| 173 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 174 |
+
**kwargs,
|
| 175 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 176 |
+
input_shape = hidden_states.shape[:-1]
|
| 177 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 178 |
+
|
| 179 |
+
query_states = self.q_proj(hidden_states)
|
| 180 |
+
key_states = self.k_proj(hidden_states)
|
| 181 |
+
value_states = self.v_proj(hidden_states)
|
| 182 |
+
|
| 183 |
+
if self.config.clip_qkv is not None:
|
| 184 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 185 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 186 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 187 |
+
|
| 188 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 189 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 190 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 191 |
+
|
| 192 |
+
cos, sin = position_embeddings
|
| 193 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 194 |
+
|
| 195 |
+
if past_key_value is not None:
|
| 196 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 197 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 198 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 199 |
+
|
| 200 |
+
attention_interface: Callable = eager_attention_forward
|
| 201 |
+
if self.config._attn_implementation != "eager":
|
| 202 |
+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
| 203 |
+
logger.warning_once(
|
| 204 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 205 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 209 |
+
|
| 210 |
+
attn_output, attn_weights = attention_interface(
|
| 211 |
+
self,
|
| 212 |
+
query_states,
|
| 213 |
+
key_states,
|
| 214 |
+
value_states,
|
| 215 |
+
attention_mask,
|
| 216 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 217 |
+
scaling=self.scaling,
|
| 218 |
+
**kwargs,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 222 |
+
attn_output = self.o_proj(attn_output)
|
| 223 |
+
return attn_output, attn_weights
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class OlmoDecoderLayer(GradientCheckpointingLayer):
|
| 227 |
+
def __init__(self, config: OlmoConfig, layer_idx: int):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.hidden_size = config.hidden_size
|
| 230 |
+
self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
|
| 231 |
+
|
| 232 |
+
self.mlp = OlmoMLP(config)
|
| 233 |
+
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 234 |
+
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 235 |
+
|
| 236 |
+
def forward(
|
| 237 |
+
self,
|
| 238 |
+
hidden_states: torch.Tensor,
|
| 239 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 240 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 241 |
+
past_key_value: Optional[Cache] = None,
|
| 242 |
+
output_attentions: Optional[bool] = False,
|
| 243 |
+
use_cache: Optional[bool] = False,
|
| 244 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 245 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 246 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 247 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 248 |
+
residual = hidden_states
|
| 249 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 250 |
+
|
| 251 |
+
# Self Attention
|
| 252 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 253 |
+
hidden_states=hidden_states,
|
| 254 |
+
attention_mask=attention_mask,
|
| 255 |
+
position_ids=position_ids,
|
| 256 |
+
past_key_value=past_key_value,
|
| 257 |
+
output_attentions=output_attentions,
|
| 258 |
+
use_cache=use_cache,
|
| 259 |
+
cache_position=cache_position,
|
| 260 |
+
position_embeddings=position_embeddings,
|
| 261 |
+
**kwargs,
|
| 262 |
+
)
|
| 263 |
+
hidden_states = residual + hidden_states
|
| 264 |
+
|
| 265 |
+
# Fully Connected
|
| 266 |
+
residual = hidden_states
|
| 267 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 268 |
+
hidden_states = self.mlp(hidden_states)
|
| 269 |
+
hidden_states = residual + hidden_states
|
| 270 |
+
|
| 271 |
+
outputs = (hidden_states,)
|
| 272 |
+
if output_attentions:
|
| 273 |
+
outputs += (self_attn_weights,)
|
| 274 |
+
|
| 275 |
+
return outputs
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class OlmoRotaryEmbedding(nn.Module):
|
| 279 |
+
def __init__(self, config: OlmoConfig, device=None):
|
| 280 |
+
super().__init__()
|
| 281 |
+
# BC: "rope_type" was originally "type"
|
| 282 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 283 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 284 |
+
else:
|
| 285 |
+
self.rope_type = "default"
|
| 286 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 287 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 288 |
+
|
| 289 |
+
self.config = config
|
| 290 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 291 |
+
|
| 292 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 293 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 294 |
+
self.original_inv_freq = self.inv_freq
|
| 295 |
+
|
| 296 |
+
@torch.no_grad()
|
| 297 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 298 |
+
def forward(self, x, position_ids):
|
| 299 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 300 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 301 |
+
|
| 302 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 303 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 304 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 305 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 306 |
+
cos = emb.cos() * self.attention_scaling
|
| 307 |
+
sin = emb.sin() * self.attention_scaling
|
| 308 |
+
|
| 309 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
OLMO_START_DOCSTRING = r"""
|
| 313 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 314 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 315 |
+
etc.)
|
| 316 |
+
|
| 317 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 318 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 319 |
+
and behavior.
|
| 320 |
+
|
| 321 |
+
Parameters:
|
| 322 |
+
config ([`OlmoConfig`]):
|
| 323 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 324 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 325 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@add_start_docstrings(
|
| 330 |
+
"The bare Olmo Model outputting raw hidden-states without any specific head on top.",
|
| 331 |
+
OLMO_START_DOCSTRING,
|
| 332 |
+
)
|
| 333 |
+
class OlmoPreTrainedModel(PreTrainedModel):
|
| 334 |
+
config_class = OlmoConfig
|
| 335 |
+
base_model_prefix = "model"
|
| 336 |
+
supports_gradient_checkpointing = True
|
| 337 |
+
_no_split_modules = ["OlmoDecoderLayer"]
|
| 338 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 339 |
+
_supports_flash_attn_2 = True
|
| 340 |
+
_supports_sdpa = True
|
| 341 |
+
_supports_flex_attn = True
|
| 342 |
+
_supports_cache_class = True
|
| 343 |
+
_supports_quantized_cache = True
|
| 344 |
+
_supports_static_cache = True
|
| 345 |
+
_supports_attention_backend = True
|
| 346 |
+
|
| 347 |
+
def _init_weights(self, module):
|
| 348 |
+
std = self.config.initializer_range
|
| 349 |
+
if isinstance(module, nn.Linear):
|
| 350 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 351 |
+
if module.bias is not None:
|
| 352 |
+
module.bias.data.zero_()
|
| 353 |
+
elif isinstance(module, nn.Embedding):
|
| 354 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 355 |
+
if module.padding_idx is not None:
|
| 356 |
+
module.weight.data[module.padding_idx].zero_()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
OLMO_INPUTS_DOCSTRING = r"""
|
| 360 |
+
Args:
|
| 361 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 362 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 363 |
+
it.
|
| 364 |
+
|
| 365 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 366 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 367 |
+
|
| 368 |
+
[What are input IDs?](../glossary#input-ids)
|
| 369 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
|
| 370 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 371 |
+
|
| 372 |
+
- 1 for tokens that are **not masked**,
|
| 373 |
+
- 0 for tokens that are **masked**.
|
| 374 |
+
|
| 375 |
+
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
|
| 376 |
+
but you can also pass a `BlockMask` object directly here.
|
| 377 |
+
|
| 378 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 379 |
+
|
| 380 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 381 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 382 |
+
|
| 383 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 384 |
+
`past_key_values`).
|
| 385 |
+
|
| 386 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 387 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 388 |
+
information on the default strategy.
|
| 389 |
+
|
| 390 |
+
- 1 indicates the head is **not masked**,
|
| 391 |
+
- 0 indicates the head is **masked**.
|
| 392 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 393 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 394 |
+
config.n_positions - 1]`.
|
| 395 |
+
|
| 396 |
+
[What are position IDs?](../glossary#position-ids)
|
| 397 |
+
past_key_values (`Cache`, *optional*):
|
| 398 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 399 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 400 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 401 |
+
|
| 402 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 403 |
+
|
| 404 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 405 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 406 |
+
of shape `(batch_size, sequence_length)`.
|
| 407 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 408 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 409 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 410 |
+
model's internal embedding lookup matrix.
|
| 411 |
+
use_cache (`bool`, *optional*):
|
| 412 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 413 |
+
`past_key_values`).
|
| 414 |
+
output_attentions (`bool`, *optional*):
|
| 415 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 416 |
+
tensors for more detail.
|
| 417 |
+
output_hidden_states (`bool`, *optional*):
|
| 418 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 419 |
+
more detail.
|
| 420 |
+
return_dict (`bool`, *optional*):
|
| 421 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 422 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 423 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 424 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 425 |
+
the complete sequence length.
|
| 426 |
+
"""
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
@add_start_docstrings(
|
| 430 |
+
"The bare Olmo Model outputting raw hidden-states without any specific head on top.",
|
| 431 |
+
OLMO_START_DOCSTRING,
|
| 432 |
+
)
|
| 433 |
+
class OlmoModel(OlmoPreTrainedModel):
|
| 434 |
+
"""
|
| 435 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoDecoderLayer`]
|
| 436 |
+
olmo's mapping in https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/modeling_auto.py
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
config: OlmoConfig
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
def __init__(self, config: OlmoConfig):
|
| 443 |
+
super().__init__(config)
|
| 444 |
+
self.padding_idx = config.pad_token_id
|
| 445 |
+
self.vocab_size = config.vocab_size
|
| 446 |
+
|
| 447 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 448 |
+
self.layers = nn.ModuleList(
|
| 449 |
+
[OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 450 |
+
)
|
| 451 |
+
self.norm = OlmoLayerNorm(config.hidden_size)
|
| 452 |
+
self.rotary_emb = OlmoRotaryEmbedding(config=config)
|
| 453 |
+
self.gradient_checkpointing = False
|
| 454 |
+
|
| 455 |
+
# Initialize weights and apply final processing
|
| 456 |
+
self.post_init()
|
| 457 |
+
|
| 458 |
+
def get_input_embeddings(self):
|
| 459 |
+
return self.embed_tokens
|
| 460 |
+
|
| 461 |
+
def set_input_embeddings(self, value):
|
| 462 |
+
self.embed_tokens = value
|
| 463 |
+
|
| 464 |
+
@can_return_tuple
|
| 465 |
+
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
|
| 466 |
+
def forward(
|
| 467 |
+
self,
|
| 468 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 469 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 470 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 471 |
+
past_key_values: Optional[Cache] = None,
|
| 472 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 473 |
+
use_cache: Optional[bool] = None,
|
| 474 |
+
output_attentions: Optional[bool] = None,
|
| 475 |
+
output_hidden_states: Optional[bool] = None,
|
| 476 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 477 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 478 |
+
) -> BaseModelOutputWithPast:
|
| 479 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 480 |
+
output_hidden_states = (
|
| 481 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 482 |
+
)
|
| 483 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 484 |
+
|
| 485 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 486 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 487 |
+
|
| 488 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 489 |
+
logger.warning_once(
|
| 490 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 491 |
+
)
|
| 492 |
+
use_cache = False
|
| 493 |
+
|
| 494 |
+
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
| 495 |
+
if not isinstance(past_key_values, (type(None), Cache)):
|
| 496 |
+
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
| 497 |
+
|
| 498 |
+
if inputs_embeds is None:
|
| 499 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 500 |
+
|
| 501 |
+
if use_cache and past_key_values is None:
|
| 502 |
+
past_key_values = DynamicCache()
|
| 503 |
+
|
| 504 |
+
if cache_position is None:
|
| 505 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 506 |
+
cache_position = torch.arange(
|
| 507 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if position_ids is None:
|
| 511 |
+
position_ids = cache_position.unsqueeze(0)
|
| 512 |
+
|
| 513 |
+
causal_mask = self._update_causal_mask(
|
| 514 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
hidden_states = inputs_embeds
|
| 518 |
+
|
| 519 |
+
# create position embeddings to be shared across the decoder layers
|
| 520 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 521 |
+
|
| 522 |
+
# decoder layers
|
| 523 |
+
all_hidden_states = () if output_hidden_states else None
|
| 524 |
+
all_self_attns = () if output_attentions else None
|
| 525 |
+
|
| 526 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 527 |
+
if output_hidden_states:
|
| 528 |
+
all_hidden_states += (hidden_states,)
|
| 529 |
+
|
| 530 |
+
layer_outputs = decoder_layer(
|
| 531 |
+
hidden_states,
|
| 532 |
+
attention_mask=causal_mask,
|
| 533 |
+
position_ids=position_ids,
|
| 534 |
+
past_key_value=past_key_values,
|
| 535 |
+
output_attentions=output_attentions,
|
| 536 |
+
use_cache=use_cache,
|
| 537 |
+
cache_position=cache_position,
|
| 538 |
+
position_embeddings=position_embeddings,
|
| 539 |
+
**flash_attn_kwargs,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
hidden_states = layer_outputs[0]
|
| 543 |
+
|
| 544 |
+
if output_attentions:
|
| 545 |
+
all_self_attns += (layer_outputs[1],)
|
| 546 |
+
|
| 547 |
+
hidden_states = self.norm(hidden_states)
|
| 548 |
+
|
| 549 |
+
# add hidden states from the last decoder layer
|
| 550 |
+
if output_hidden_states:
|
| 551 |
+
all_hidden_states += (hidden_states,)
|
| 552 |
+
|
| 553 |
+
return BaseModelOutputWithPast(
|
| 554 |
+
last_hidden_state=hidden_states,
|
| 555 |
+
past_key_values=past_key_values if use_cache else None,
|
| 556 |
+
hidden_states=all_hidden_states,
|
| 557 |
+
attentions=all_self_attns,
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
def _update_causal_mask(
|
| 561 |
+
self,
|
| 562 |
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
| 563 |
+
input_tensor: torch.Tensor,
|
| 564 |
+
cache_position: torch.Tensor,
|
| 565 |
+
past_key_values: Cache,
|
| 566 |
+
output_attentions: bool = False,
|
| 567 |
+
):
|
| 568 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 569 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 570 |
+
return attention_mask
|
| 571 |
+
return None
|
| 572 |
+
if self.config._attn_implementation == "flex_attention":
|
| 573 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 574 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 575 |
+
return attention_mask
|
| 576 |
+
|
| 577 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 578 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 579 |
+
# to infer the attention mask.
|
| 580 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 581 |
+
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
| 582 |
+
|
| 583 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 584 |
+
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
| 585 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 586 |
+
attention_mask,
|
| 587 |
+
inputs_embeds=input_tensor,
|
| 588 |
+
past_key_values_length=past_seen_tokens,
|
| 589 |
+
is_training=self.training,
|
| 590 |
+
):
|
| 591 |
+
return None
|
| 592 |
+
|
| 593 |
+
dtype = input_tensor.dtype
|
| 594 |
+
sequence_length = input_tensor.shape[1]
|
| 595 |
+
if using_compilable_cache:
|
| 596 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 597 |
+
else:
|
| 598 |
+
target_length = (
|
| 599 |
+
attention_mask.shape[-1]
|
| 600 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 601 |
+
else past_seen_tokens + sequence_length + 1
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 605 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 606 |
+
attention_mask,
|
| 607 |
+
sequence_length=sequence_length,
|
| 608 |
+
target_length=target_length,
|
| 609 |
+
dtype=dtype,
|
| 610 |
+
cache_position=cache_position,
|
| 611 |
+
batch_size=input_tensor.shape[0],
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
if (
|
| 615 |
+
self.config._attn_implementation == "sdpa"
|
| 616 |
+
and attention_mask is not None
|
| 617 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 618 |
+
and not output_attentions
|
| 619 |
+
):
|
| 620 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 621 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 622 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 623 |
+
min_dtype = torch.finfo(dtype).min
|
| 624 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 625 |
+
|
| 626 |
+
return causal_mask
|
| 627 |
+
|
| 628 |
+
@staticmethod
|
| 629 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 630 |
+
attention_mask: torch.Tensor,
|
| 631 |
+
sequence_length: int,
|
| 632 |
+
target_length: int,
|
| 633 |
+
dtype: torch.dtype,
|
| 634 |
+
cache_position: torch.Tensor,
|
| 635 |
+
batch_size: int,
|
| 636 |
+
**kwargs,
|
| 637 |
+
):
|
| 638 |
+
"""
|
| 639 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 640 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
attention_mask (`torch.Tensor`):
|
| 644 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 645 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 646 |
+
sequence_length (`int`):
|
| 647 |
+
The sequence length being processed.
|
| 648 |
+
target_length (`int`):
|
| 649 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 650 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 651 |
+
dtype (`torch.dtype`):
|
| 652 |
+
The dtype to use for the 4D attention mask.
|
| 653 |
+
cache_position (`torch.Tensor`):
|
| 654 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 655 |
+
batch_size (`torch.Tensor`):
|
| 656 |
+
Batch size.
|
| 657 |
+
"""
|
| 658 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 659 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 660 |
+
causal_mask = attention_mask
|
| 661 |
+
else:
|
| 662 |
+
min_dtype = torch.finfo(dtype).min
|
| 663 |
+
causal_mask = torch.full(
|
| 664 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
| 665 |
+
)
|
| 666 |
+
if sequence_length != 1:
|
| 667 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 668 |
+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
| 669 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 670 |
+
if attention_mask is not None:
|
| 671 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 672 |
+
mask_length = attention_mask.shape[-1]
|
| 673 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 674 |
+
causal_mask.device
|
| 675 |
+
)
|
| 676 |
+
padding_mask = padding_mask == 0
|
| 677 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 678 |
+
padding_mask, min_dtype
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
return causal_mask
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
| 688 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 689 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 690 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 691 |
+
|
| 692 |
+
def __init__(self, config):
|
| 693 |
+
super().__init__(config)
|
| 694 |
+
self.model = OlmoModel(config)
|
| 695 |
+
self.vocab_size = config.vocab_size
|
| 696 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 697 |
+
|
| 698 |
+
# Initialize weights and apply final processing
|
| 699 |
+
self.post_init()
|
| 700 |
+
|
| 701 |
+
def get_input_embeddings(self):
|
| 702 |
+
return self.model.embed_tokens
|
| 703 |
+
|
| 704 |
+
def set_input_embeddings(self, value):
|
| 705 |
+
self.model.embed_tokens = value
|
| 706 |
+
|
| 707 |
+
def get_output_embeddings(self):
|
| 708 |
+
return self.lm_head
|
| 709 |
+
|
| 710 |
+
def set_output_embeddings(self, new_embeddings):
|
| 711 |
+
self.lm_head = new_embeddings
|
| 712 |
+
|
| 713 |
+
def set_decoder(self, decoder):
|
| 714 |
+
self.model = decoder
|
| 715 |
+
|
| 716 |
+
def get_decoder(self):
|
| 717 |
+
return self.model
|
| 718 |
+
|
| 719 |
+
@can_return_tuple
|
| 720 |
+
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
|
| 721 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 722 |
+
def forward(
|
| 723 |
+
self,
|
| 724 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 725 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 726 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 727 |
+
past_key_values: Optional[Cache] = None,
|
| 728 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 729 |
+
labels: Optional[torch.LongTensor] = None,
|
| 730 |
+
use_cache: Optional[bool] = None,
|
| 731 |
+
output_attentions: Optional[bool] = None,
|
| 732 |
+
output_hidden_states: Optional[bool] = None,
|
| 733 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 734 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 735 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 736 |
+
) -> CausalLMOutputWithPast:
|
| 737 |
+
r"""
|
| 738 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 739 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 740 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 741 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 742 |
+
|
| 743 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 744 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 745 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 746 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 747 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 748 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 749 |
+
|
| 750 |
+
Returns:
|
| 751 |
+
|
| 752 |
+
Example:
|
| 753 |
+
|
| 754 |
+
```python
|
| 755 |
+
>>> from transformers import AutoTokenizer, OlmoForCausalLM
|
| 756 |
+
|
| 757 |
+
>>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf")
|
| 758 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf")
|
| 759 |
+
|
| 760 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 761 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 762 |
+
|
| 763 |
+
>>> # Generate
|
| 764 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 765 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 766 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 767 |
+
```"""
|
| 768 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 769 |
+
output_hidden_states = (
|
| 770 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 774 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 775 |
+
input_ids=input_ids,
|
| 776 |
+
attention_mask=attention_mask,
|
| 777 |
+
position_ids=position_ids,
|
| 778 |
+
past_key_values=past_key_values,
|
| 779 |
+
inputs_embeds=inputs_embeds,
|
| 780 |
+
use_cache=use_cache,
|
| 781 |
+
output_attentions=output_attentions,
|
| 782 |
+
output_hidden_states=output_hidden_states,
|
| 783 |
+
cache_position=cache_position,
|
| 784 |
+
**kwargs,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
hidden_states = outputs.last_hidden_state
|
| 788 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 789 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 790 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 791 |
+
|
| 792 |
+
loss = None
|
| 793 |
+
if labels is not None:
|
| 794 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 795 |
+
|
| 796 |
+
return CausalLMOutputWithPast(
|
| 797 |
+
loss=loss,
|
| 798 |
+
logits=logits,
|
| 799 |
+
past_key_values=outputs.past_key_values,
|
| 800 |
+
hidden_states=outputs.hidden_states,
|
| 801 |
+
attentions=outputs.attentions,
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
import torch
|
| 805 |
+
import torch.nn as nn
|
| 806 |
+
import torch.nn.functional as F
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
class OlmoMoERouter(nn.Module):
|
| 810 |
+
"""
|
| 811 |
+
Router module that uses random importance sampling instead of deterministic top-k.
|
| 812 |
+
|
| 813 |
+
This router computes logits for each expert, converts them to probabilities,
|
| 814 |
+
and then randomly samples experts based on these probabilities.
|
| 815 |
+
"""
|
| 816 |
+
def __init__(self, config):
|
| 817 |
+
super().__init__()
|
| 818 |
+
self.hidden_size = config.hidden_size
|
| 819 |
+
self.num_experts = config.num_experts
|
| 820 |
+
self.router = nn.Linear(self.hidden_size, self.num_experts, bias=False)
|
| 821 |
+
self.top_k = config.num_selected_experts
|
| 822 |
+
self.temperature = config.router_temperature if hasattr(config, "router_temperature") else 1.0
|
| 823 |
+
|
| 824 |
+
def forward(self, hidden_states):
|
| 825 |
+
"""
|
| 826 |
+
Args:
|
| 827 |
+
hidden_states: [batch_size, sequence_length, hidden_size]
|
| 828 |
+
|
| 829 |
+
Returns:
|
| 830 |
+
routing_weights: [batch_size, sequence_length, top_k]
|
| 831 |
+
routing_indices: [batch_size, sequence_length, top_k]
|
| 832 |
+
"""
|
| 833 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 834 |
+
|
| 835 |
+
# Compute router logits and apply temperature
|
| 836 |
+
router_logits = self.router(hidden_states) / self.temperature # [batch_size, sequence_length, num_experts]
|
| 837 |
+
|
| 838 |
+
# Convert to probabilities using softmax
|
| 839 |
+
router_probs = F.softmax(router_logits, dim=-1) # [batch_size, sequence_length, num_experts]
|
| 840 |
+
|
| 841 |
+
# For random importance sampling, we'll:
|
| 842 |
+
# 1. Add Gumbel noise to the log probabilities to induce randomness
|
| 843 |
+
# 2. Sample top-k experts using the perturbed probabilities
|
| 844 |
+
|
| 845 |
+
# Add Gumbel noise
|
| 846 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand_like(router_probs) + 1e-10) + 1e-10)
|
| 847 |
+
perturbed_logits = torch.log(router_probs + 1e-10) + gumbel_noise
|
| 848 |
+
|
| 849 |
+
# Sample top-k experts based on perturbed probabilities
|
| 850 |
+
routing_weights, routing_indices = torch.topk(perturbed_logits, self.top_k, dim=-1)
|
| 851 |
+
|
| 852 |
+
# Re-normalize the selected probabilities
|
| 853 |
+
routing_weights = router_probs.gather(-1, routing_indices)
|
| 854 |
+
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
| 855 |
+
|
| 856 |
+
return routing_weights, routing_indices
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class OlmoExpertMLP(nn.Module):
|
| 860 |
+
"""
|
| 861 |
+
Expert MLP module similar to OlmoMLP but used in the MoE architecture.
|
| 862 |
+
"""
|
| 863 |
+
def __init__(self, config):
|
| 864 |
+
super().__init__()
|
| 865 |
+
self.config = config
|
| 866 |
+
self.hidden_size = config.hidden_size
|
| 867 |
+
self.intermediate_size = config.intermediate_size
|
| 868 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 869 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 870 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 871 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 872 |
+
|
| 873 |
+
def forward(self, x):
|
| 874 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 875 |
+
return down_proj
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
class OlmoMixtureOfExperts(nn.Module):
|
| 879 |
+
"""
|
| 880 |
+
Mixture of Experts layer that replaces the standard MLP in OLMo.
|
| 881 |
+
"""
|
| 882 |
+
def __init__(self, config):
|
| 883 |
+
super().__init__()
|
| 884 |
+
self.config = config
|
| 885 |
+
self.num_experts = config.num_experts
|
| 886 |
+
self.num_selected_experts = config.num_selected_experts # top_k
|
| 887 |
+
|
| 888 |
+
# Create router
|
| 889 |
+
self.router = OlmoMoERouter(config)
|
| 890 |
+
|
| 891 |
+
# Create experts
|
| 892 |
+
self.experts = nn.ModuleList([OlmoExpertMLP(config) for _ in range(self.num_experts)])
|
| 893 |
+
|
| 894 |
+
# Expert capacity factor (to avoid load balancing issues)
|
| 895 |
+
self.capacity_factor = config.expert_capacity_factor if hasattr(config, "expert_capacity_factor") else 1.0
|
| 896 |
+
|
| 897 |
+
def forward(self, hidden_states):
|
| 898 |
+
"""
|
| 899 |
+
Args:
|
| 900 |
+
hidden_states: [batch_size, sequence_length, hidden_size]
|
| 901 |
+
|
| 902 |
+
Returns:
|
| 903 |
+
outputs: [batch_size, sequence_length, hidden_size]
|
| 904 |
+
"""
|
| 905 |
+
batch_size, sequence_length, hidden_size = hidden_states.shape
|
| 906 |
+
|
| 907 |
+
# Get routing weights and indices
|
| 908 |
+
routing_weights, routing_indices = self.router(hidden_states)
|
| 909 |
+
|
| 910 |
+
# Reshape tensors for processing
|
| 911 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_size) # [batch_size * sequence_length, hidden_size]
|
| 912 |
+
|
| 913 |
+
# Initialize expert outputs
|
| 914 |
+
final_output = torch.zeros_like(flat_hidden_states)
|
| 915 |
+
|
| 916 |
+
# For each expert, compute its contribution
|
| 917 |
+
for expert_idx in range(self.num_experts):
|
| 918 |
+
# Create a mask to identify which tokens use this expert
|
| 919 |
+
expert_mask = (routing_indices == expert_idx).any(dim=-1).reshape(-1)
|
| 920 |
+
|
| 921 |
+
if not expert_mask.any():
|
| 922 |
+
continue # Skip if no token routes to this expert
|
| 923 |
+
|
| 924 |
+
# Get the hidden states for tokens routed to this expert
|
| 925 |
+
expert_inputs = flat_hidden_states[expert_mask]
|
| 926 |
+
|
| 927 |
+
# Process these hidden states through the expert
|
| 928 |
+
expert_outputs = self.experts[expert_idx](expert_inputs)
|
| 929 |
+
|
| 930 |
+
# Find weights for this expert
|
| 931 |
+
expert_weights = routing_weights[routing_indices == expert_idx].reshape(-1, 1)
|
| 932 |
+
|
| 933 |
+
# Multiply outputs by the routing weights
|
| 934 |
+
weighted_outputs = expert_outputs * expert_weights
|
| 935 |
+
|
| 936 |
+
# Combine the expert outputs into the final output tensor
|
| 937 |
+
final_output[expert_mask] += weighted_outputs
|
| 938 |
+
|
| 939 |
+
# Reshape back to original dimensions
|
| 940 |
+
final_output = final_output.reshape(batch_size, sequence_length, hidden_size)
|
| 941 |
+
|
| 942 |
+
return final_output
|
| 943 |
+
|
| 944 |
+
|
| 945 |
+
# Modified OlmoDecoderLayer to use MoE instead of standard MLP
|
| 946 |
+
class OlmoMoEDecoderLayer(GradientCheckpointingLayer):
|
| 947 |
+
def __init__(self, config: OlmoConfig, layer_idx: int):
|
| 948 |
+
super().__init__()
|
| 949 |
+
self.hidden_size = config.hidden_size
|
| 950 |
+
self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
|
| 951 |
+
|
| 952 |
+
# Use MoE instead of standard MLP
|
| 953 |
+
self.mlp = OlmoMixtureOfExperts(config)
|
| 954 |
+
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 955 |
+
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 956 |
+
|
| 957 |
+
def forward(
|
| 958 |
+
self,
|
| 959 |
+
hidden_states: torch.Tensor,
|
| 960 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 961 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 962 |
+
past_key_value: Optional[Cache] = None,
|
| 963 |
+
output_attentions: Optional[bool] = False,
|
| 964 |
+
use_cache: Optional[bool] = False,
|
| 965 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 966 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 967 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 968 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 969 |
+
residual = hidden_states
|
| 970 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 971 |
+
|
| 972 |
+
# Self Attention
|
| 973 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 974 |
+
hidden_states=hidden_states,
|
| 975 |
+
attention_mask=attention_mask,
|
| 976 |
+
position_ids=position_ids,
|
| 977 |
+
past_key_value=past_key_value,
|
| 978 |
+
output_attentions=output_attentions,
|
| 979 |
+
use_cache=use_cache,
|
| 980 |
+
cache_position=cache_position,
|
| 981 |
+
position_embeddings=position_embeddings,
|
| 982 |
+
**kwargs,
|
| 983 |
+
)
|
| 984 |
+
hidden_states = residual + hidden_states
|
| 985 |
+
|
| 986 |
+
# MoE instead of Fully Connected
|
| 987 |
+
residual = hidden_states
|
| 988 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 989 |
+
hidden_states = self.mlp(hidden_states)
|
| 990 |
+
hidden_states = residual + hidden_states
|
| 991 |
+
|
| 992 |
+
outputs = (hidden_states,)
|
| 993 |
+
if output_attentions:
|
| 994 |
+
outputs += (self_attn_weights,)
|
| 995 |
+
|
| 996 |
+
return outputs
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
# Modified OlmoConfig to include MoE-specific parameters
|
| 1000 |
+
class OlmoMoEConfig(OlmoConfig):
|
| 1001 |
+
def __init__(
|
| 1002 |
+
self,
|
| 1003 |
+
num_experts=8,
|
| 1004 |
+
num_selected_experts=2,
|
| 1005 |
+
expert_capacity_factor=1.0,
|
| 1006 |
+
router_temperature=0.1,
|
| 1007 |
+
**kwargs
|
| 1008 |
+
):
|
| 1009 |
+
super().__init__(**kwargs)
|
| 1010 |
+
self.num_experts = num_experts
|
| 1011 |
+
self.num_selected_experts = num_selected_experts
|
| 1012 |
+
self.expert_capacity_factor = expert_capacity_factor
|
| 1013 |
+
self.router_temperature = router_temperature
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
# Modified OlmoModel to use MoE decoder layers
|
| 1017 |
+
class OlmoMoEModel(OlmoModel):
|
| 1018 |
+
def __init__(self, config: OlmoMoEConfig):
|
| 1019 |
+
OlmoPreTrainedModel.__init__(self, config)
|
| 1020 |
+
self.padding_idx = config.pad_token_id
|
| 1021 |
+
self.vocab_size = config.vocab_size
|
| 1022 |
+
|
| 1023 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1024 |
+
# Use MoE decoder layers
|
| 1025 |
+
self.layers = nn.ModuleList(
|
| 1026 |
+
[OlmoMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 1027 |
+
)
|
| 1028 |
+
self.norm = OlmoLayerNorm(config.hidden_size)
|
| 1029 |
+
self.rotary_emb = OlmoRotaryEmbedding(config=config)
|
| 1030 |
+
self.gradient_checkpointing = False
|
| 1031 |
+
|
| 1032 |
+
# Initialize weights and apply final processing
|
| 1033 |
+
self.post_init()
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
# Modified OlmoForCausalLM to use MoE model
|
| 1037 |
+
class OlmoMoEForCausalLM(OlmoForCausalLM):
|
| 1038 |
+
def __init__(self, config):
|
| 1039 |
+
OlmoPreTrainedModel.__init__(self, config)
|
| 1040 |
+
self.model = OlmoMoEModel(config)
|
| 1041 |
+
self.vocab_size = config.vocab_size
|
| 1042 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1043 |
+
|
| 1044 |
+
# Initialize weights and apply final processing
|
| 1045 |
+
self.post_init()
|
| 1046 |
+
|
| 1047 |
+
__all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers>=4.34.0
|
| 3 |
+
accelerate>=0.25.0
|
| 4 |
+
datasets>=2.14.0
|
| 5 |
+
tqdm>=4.66.0
|
| 6 |
+
bitsandbytes>=0.41.0 # For 8-bit training if needed
|
| 7 |
+
sentencepiece>=0.1.99 # For tokenization
|
| 8 |
+
protobuf>=4.23.4 # For datasets loading
|
| 9 |
+
tensorboard>=2.13.0 # For training monitoring
|
shellcommands.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
conda activate rlmoe
|
| 2 |
+
cd SkipMoE
|
| 3 |
+
python train.py
|
train.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train.py
|
| 2 |
+
# runs train_olmoe_adapter.py with parameters when called
|
| 3 |
+
# #!/usr/bin/env python
|
| 4 |
+
"""
|
| 5 |
+
Run script for fine-tuning OlmoE with adapters on specific text domains.
|
| 6 |
+
Handles argument parsing and configuration.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from transformers import (
|
| 16 |
+
HfArgumentParser,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class ScriptArguments:
|
| 23 |
+
"""
|
| 24 |
+
Arguments for the run script that aren't covered by TrainingArguments.
|
| 25 |
+
"""
|
| 26 |
+
model_path: str = field(
|
| 27 |
+
default="allenai/OLMo-7B-Instruct",
|
| 28 |
+
metadata={"help": "Path to the model to fine-tune"}
|
| 29 |
+
)
|
| 30 |
+
output_dir: str = field(
|
| 31 |
+
default="./output_olmoe_adapter",
|
| 32 |
+
metadata={"help": "Directory to save the model and logs"}
|
| 33 |
+
)
|
| 34 |
+
adapter_size: int = field(
|
| 35 |
+
default=64,
|
| 36 |
+
metadata={"help": "Size of the adapter layers"}
|
| 37 |
+
)
|
| 38 |
+
dataset_name: str = field(
|
| 39 |
+
default="mlfoundations/dclm-baseline-1.0",
|
| 40 |
+
metadata={"help": "Name of the dataset to use"}
|
| 41 |
+
)
|
| 42 |
+
max_steps: int = field(
|
| 43 |
+
default=10000,
|
| 44 |
+
metadata={"help": "Maximum number of training steps"}
|
| 45 |
+
)
|
| 46 |
+
learning_rate: float = field(
|
| 47 |
+
default=5e-5,
|
| 48 |
+
metadata={"help": "Learning rate for fine-tuning"}
|
| 49 |
+
)
|
| 50 |
+
per_device_batch_size: int = field(
|
| 51 |
+
default=8,
|
| 52 |
+
metadata={"help": "Batch size per device"}
|
| 53 |
+
)
|
| 54 |
+
gradient_accumulation_steps: int = field(
|
| 55 |
+
default=1,
|
| 56 |
+
metadata={"help": "Number of steps to accumulate gradients"}
|
| 57 |
+
)
|
| 58 |
+
# use_8bit: bool = field(
|
| 59 |
+
# default=False,
|
| 60 |
+
# metadata={"help": "Whether to use 8-bit precision"}
|
| 61 |
+
# )
|
| 62 |
+
# use_4bit: bool = field(
|
| 63 |
+
# default=False,
|
| 64 |
+
# metadata={"help": "Whether to use 4-bit precision"}
|
| 65 |
+
# )
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
# Parse command-line arguments
|
| 70 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 71 |
+
args = parser.parse_args_into_dataclasses()[0]
|
| 72 |
+
|
| 73 |
+
# Create output directory
|
| 74 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
# Prepare command for training
|
| 77 |
+
cmd = [
|
| 78 |
+
"python",
|
| 79 |
+
"train_olmoe_adapter.py",
|
| 80 |
+
|
| 81 |
+
# Model arguments
|
| 82 |
+
f"--model_name_or_path={args.model_path}",
|
| 83 |
+
f"--adapter_size={args.adapter_size}",
|
| 84 |
+
"--freeze_base_model=True", # Always freeze the base model
|
| 85 |
+
f"--checkpoint_dir={args.output_dir}",
|
| 86 |
+
|
| 87 |
+
# Data arguments
|
| 88 |
+
f"--dataset_name={args.dataset_name}",
|
| 89 |
+
"--streaming=True", # Always stream for large datasets
|
| 90 |
+
"--streaming_buffer_size=8192",
|
| 91 |
+
"--max_seq_length=1024",
|
| 92 |
+
|
| 93 |
+
# Training arguments
|
| 94 |
+
f"--output_dir={args.output_dir}",
|
| 95 |
+
f"--per_device_train_batch_size={args.per_device_batch_size}",
|
| 96 |
+
f"--gradient_accumulation_steps={args.gradient_accumulation_steps}",
|
| 97 |
+
f"--learning_rate={args.learning_rate}",
|
| 98 |
+
f"--max_steps={args.max_steps}",
|
| 99 |
+
"--warmup_steps=500",
|
| 100 |
+
"--logging_steps=10",
|
| 101 |
+
"--save_steps=1000",
|
| 102 |
+
"--save_total_limit=2",
|
| 103 |
+
"--dataloader_num_workers=4",
|
| 104 |
+
"--seed=42",
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
# Add precision flags if needed
|
| 108 |
+
# if args.use_8bit:
|
| 109 |
+
# cmd.append("--load_in_8bit")
|
| 110 |
+
|
| 111 |
+
# if args.use_4bit:
|
| 112 |
+
# cmd.append("--load_in_4bit")
|
| 113 |
+
|
| 114 |
+
# Print the command for logging
|
| 115 |
+
cmd_str = " ".join(cmd)
|
| 116 |
+
print(f"Running command: {cmd_str}")
|
| 117 |
+
|
| 118 |
+
# Execute the training script
|
| 119 |
+
os.environ["PYTHONPATH"] = os.getcwd()
|
| 120 |
+
ret = os.system(cmd_str)
|
| 121 |
+
|
| 122 |
+
if ret != 0:
|
| 123 |
+
print(f"Training failed with exit code {ret}")
|
| 124 |
+
sys.exit(ret)
|
| 125 |
+
|
| 126 |
+
print("Training completed successfully!")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
main()
|
train_olmoe_adapter.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#train_olmoe_adapter.py
|
| 2 |
+
#!/usr/bin/env python
|
| 3 |
+
"""
|
| 4 |
+
Training script for OlmoE model with adapters on the mlfoundations/dclm-baseline-1.0 dataset.
|
| 5 |
+
This script demonstrates parameter-efficient fine-tuning using adapters.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import math
|
| 10 |
+
import logging
|
| 11 |
+
import argparse
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from typing import Dict, List, Optional, Tuple, Any, Union
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 19 |
+
from torch.optim import AdamW
|
| 20 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 21 |
+
|
| 22 |
+
from datasets import load_dataset
|
| 23 |
+
from transformers import (
|
| 24 |
+
OlmoConfig,
|
| 25 |
+
OlmoForCausalLM,
|
| 26 |
+
AutoTokenizer,
|
| 27 |
+
DataCollatorForLanguageModeling,
|
| 28 |
+
HfArgumentParser,
|
| 29 |
+
TrainingArguments,
|
| 30 |
+
set_seed,
|
| 31 |
+
get_scheduler,
|
| 32 |
+
)
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
from accelerate import Accelerator, DistributedType
|
| 35 |
+
from accelerate.utils import find_batch_size
|
| 36 |
+
|
| 37 |
+
from modeling_olmoe import (
|
| 38 |
+
OlmoEWithAdaptersForCausalLM,
|
| 39 |
+
OlmoEForCausalLM,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Set up logging
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 46 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 47 |
+
level=logging.INFO,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class ModelArguments:
|
| 52 |
+
"""Arguments for model configuration."""
|
| 53 |
+
model_name_or_path: str = field(
|
| 54 |
+
default="allenai/OLMo-7B-Instruct",
|
| 55 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
| 56 |
+
)
|
| 57 |
+
adapter_size: int = field(
|
| 58 |
+
default=64,
|
| 59 |
+
metadata={"help": "Size of the adapter layers"}
|
| 60 |
+
)
|
| 61 |
+
freeze_base_model: bool = field(
|
| 62 |
+
default=True,
|
| 63 |
+
metadata={"help": "Whether to freeze all parameters except the adapters"}
|
| 64 |
+
)
|
| 65 |
+
checkpoint_dir: Optional[str] = field(
|
| 66 |
+
default=None,
|
| 67 |
+
metadata={"help": "Path to save model checkpoints"}
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class DataArguments:
|
| 73 |
+
"""Arguments for dataset configuration."""
|
| 74 |
+
dataset_name: str = field(
|
| 75 |
+
default="mlfoundations/dclm-baseline-1.0",
|
| 76 |
+
metadata={"help": "Dataset name or path for training"}
|
| 77 |
+
)
|
| 78 |
+
streaming: bool = field(
|
| 79 |
+
default=True,
|
| 80 |
+
metadata={"help": "Whether to stream the dataset"}
|
| 81 |
+
)
|
| 82 |
+
streaming_buffer_size: int = field(
|
| 83 |
+
default=8192,
|
| 84 |
+
metadata={"help": "Buffer size for streaming dataset"}
|
| 85 |
+
)
|
| 86 |
+
max_seq_length: int = field(
|
| 87 |
+
default=1024,
|
| 88 |
+
metadata={"help": "Maximum sequence length for training"}
|
| 89 |
+
)
|
| 90 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 91 |
+
default=None,
|
| 92 |
+
metadata={"help": "Number of workers for preprocessing"}
|
| 93 |
+
)
|
| 94 |
+
text_column_name: str = field(
|
| 95 |
+
default="text",
|
| 96 |
+
metadata={"help": "Column name for text data"}
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class StreamingTextDataset(IterableDataset):
|
| 101 |
+
"""Dataset for streaming text data."""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
dataset_name: str,
|
| 106 |
+
tokenizer,
|
| 107 |
+
max_seq_length: int,
|
| 108 |
+
streaming: bool = True,
|
| 109 |
+
text_column_name: str = "text",
|
| 110 |
+
buffer_size: int = 8192,
|
| 111 |
+
split: str = "train",
|
| 112 |
+
):
|
| 113 |
+
self.tokenizer = tokenizer
|
| 114 |
+
self.max_seq_length = max_seq_length
|
| 115 |
+
self.text_column_name = text_column_name
|
| 116 |
+
|
| 117 |
+
# Load dataset in streaming mode
|
| 118 |
+
self.dataset = load_dataset(
|
| 119 |
+
dataset_name,
|
| 120 |
+
split=split,
|
| 121 |
+
streaming=streaming,
|
| 122 |
+
)
|
| 123 |
+
if streaming:
|
| 124 |
+
# Buffer for streaming
|
| 125 |
+
self.dataset = self.dataset.shuffle(buffer_size=buffer_size)
|
| 126 |
+
|
| 127 |
+
def __iter__(self):
|
| 128 |
+
buffer = []
|
| 129 |
+
current_length = 0
|
| 130 |
+
|
| 131 |
+
for example in self.dataset:
|
| 132 |
+
text = example[self.text_column_name]
|
| 133 |
+
if not text or len(text.strip()) == 0:
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
tokenized = self.tokenizer(
|
| 137 |
+
text,
|
| 138 |
+
truncation=False,
|
| 139 |
+
return_attention_mask=False,
|
| 140 |
+
return_token_type_ids=False,
|
| 141 |
+
add_special_tokens=False,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
ids = tokenized["input_ids"]
|
| 145 |
+
buffer.extend(ids)
|
| 146 |
+
|
| 147 |
+
# Yield complete sequences and update buffer
|
| 148 |
+
while len(buffer) >= self.max_seq_length:
|
| 149 |
+
yield {
|
| 150 |
+
"input_ids": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long),
|
| 151 |
+
"labels": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long),
|
| 152 |
+
}
|
| 153 |
+
buffer = buffer[self.max_seq_length:]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def create_optimizer_and_scheduler(
|
| 157 |
+
model: nn.Module,
|
| 158 |
+
args: TrainingArguments,
|
| 159 |
+
num_training_steps: int
|
| 160 |
+
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
|
| 161 |
+
"""Create optimizer and learning rate scheduler."""
|
| 162 |
+
|
| 163 |
+
# Get only trainable parameters if using adapters with frozen base model
|
| 164 |
+
if hasattr(model, "get_trainable_parameters"):
|
| 165 |
+
optimizer_params = model.get_trainable_parameters()
|
| 166 |
+
logger.info(f"Training with {len(optimizer_params)} trainable parameters")
|
| 167 |
+
else:
|
| 168 |
+
# No parameter filtering - get all parameters that require grad
|
| 169 |
+
optimizer_params = [p for p in model.parameters() if p.requires_grad]
|
| 170 |
+
logger.info(f"Training with {len(optimizer_params)} parameters")
|
| 171 |
+
|
| 172 |
+
# Create optimizer
|
| 173 |
+
optimizer = AdamW(
|
| 174 |
+
optimizer_params,
|
| 175 |
+
lr=args.learning_rate,
|
| 176 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 177 |
+
eps=args.adam_epsilon,
|
| 178 |
+
weight_decay=args.weight_decay,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Create scheduler
|
| 182 |
+
scheduler = get_scheduler(
|
| 183 |
+
name=args.lr_scheduler_type,
|
| 184 |
+
optimizer=optimizer,
|
| 185 |
+
num_warmup_steps=args.warmup_steps,
|
| 186 |
+
num_training_steps=num_training_steps,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return optimizer, scheduler
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def train(
|
| 193 |
+
model_args: ModelArguments,
|
| 194 |
+
data_args: DataArguments,
|
| 195 |
+
training_args: TrainingArguments,
|
| 196 |
+
):
|
| 197 |
+
"""Main training function."""
|
| 198 |
+
|
| 199 |
+
# Set up accelerator
|
| 200 |
+
accelerator = Accelerator(
|
| 201 |
+
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
| 202 |
+
mixed_precision=training_args.fp16 and "fp16" or training_args.bf16 and "bf16" or "no",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Log information about the training setup
|
| 206 |
+
logger.info(accelerator.state)
|
| 207 |
+
if accelerator.is_local_main_process:
|
| 208 |
+
logger.info(f"Model arguments: {model_args}")
|
| 209 |
+
logger.info(f"Data arguments: {data_args}")
|
| 210 |
+
logger.info(f"Training arguments: {training_args}")
|
| 211 |
+
|
| 212 |
+
# Set seed for reproducibility
|
| 213 |
+
set_seed(training_args.seed)
|
| 214 |
+
|
| 215 |
+
# Load tokenizer and model
|
| 216 |
+
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
| 217 |
+
|
| 218 |
+
# Ensure the tokenizer has padding token and EOS token set
|
| 219 |
+
if tokenizer.pad_token is None:
|
| 220 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 221 |
+
|
| 222 |
+
# Load model config and update with adapter size
|
| 223 |
+
config = OlmoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
| 224 |
+
config.adapter_size = model_args.adapter_size
|
| 225 |
+
|
| 226 |
+
# Load model with adapters
|
| 227 |
+
logger.info(f"Loading OlmoE model with adapters from {model_args.model_name_or_path}")
|
| 228 |
+
base_model = OlmoForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
| 229 |
+
|
| 230 |
+
# Create adapter model from base model weights
|
| 231 |
+
model = OlmoEWithAdaptersForCausalLM(config)
|
| 232 |
+
|
| 233 |
+
# Copy weights from base model to adapter model
|
| 234 |
+
# This is needed because we're using a custom model class
|
| 235 |
+
model.load_state_dict(base_model.state_dict(), strict=False)
|
| 236 |
+
|
| 237 |
+
# Freeze base model parameters if requested
|
| 238 |
+
if model_args.freeze_base_model:
|
| 239 |
+
logger.info("Freezing base model parameters")
|
| 240 |
+
model.freeze_base_model()
|
| 241 |
+
|
| 242 |
+
# Set up streaming dataset
|
| 243 |
+
logger.info(f"Loading dataset: {data_args.dataset_name}")
|
| 244 |
+
train_dataset = StreamingTextDataset(
|
| 245 |
+
dataset_name=data_args.dataset_name,
|
| 246 |
+
tokenizer=tokenizer,
|
| 247 |
+
max_seq_length=data_args.max_seq_length,
|
| 248 |
+
streaming=data_args.streaming,
|
| 249 |
+
buffer_size=data_args.streaming_buffer_size,
|
| 250 |
+
text_column_name=data_args.text_column_name,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Data collator to handle batching
|
| 254 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 255 |
+
tokenizer=tokenizer,
|
| 256 |
+
mlm=False,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Create data loader
|
| 260 |
+
train_dataloader = DataLoader(
|
| 261 |
+
train_dataset,
|
| 262 |
+
batch_size=training_args.per_device_train_batch_size,
|
| 263 |
+
collate_fn=data_collator,
|
| 264 |
+
num_workers=data_args.preprocessing_num_workers or 0,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Estimate number of update steps
|
| 268 |
+
# For streaming datasets, we'll use a fixed number of steps
|
| 269 |
+
num_update_steps_per_epoch = training_args.max_steps
|
| 270 |
+
num_training_steps = training_args.max_steps
|
| 271 |
+
|
| 272 |
+
# Create optimizer and scheduler
|
| 273 |
+
optimizer, lr_scheduler = create_optimizer_and_scheduler(
|
| 274 |
+
model=model,
|
| 275 |
+
args=training_args,
|
| 276 |
+
num_training_steps=num_training_steps,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Prepare for distributed training with accelerator
|
| 280 |
+
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 281 |
+
model, optimizer, train_dataloader, lr_scheduler
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Get total batch size for logging
|
| 285 |
+
total_batch_size = (
|
| 286 |
+
training_args.per_device_train_batch_size
|
| 287 |
+
* accelerator.num_processes
|
| 288 |
+
* training_args.gradient_accumulation_steps
|
| 289 |
+
)
|
| 290 |
+
logger.info(f"Total batch size (with parallel & accumulation): {total_batch_size}")
|
| 291 |
+
|
| 292 |
+
# Log estimated number of steps
|
| 293 |
+
logger.info(f"Number of training steps: {num_training_steps}")
|
| 294 |
+
logger.info(f"Number of warmup steps: {training_args.warmup_steps}")
|
| 295 |
+
|
| 296 |
+
# Keep track of training progress
|
| 297 |
+
progress_bar = tqdm(
|
| 298 |
+
range(num_training_steps),
|
| 299 |
+
disable=not accelerator.is_local_main_process,
|
| 300 |
+
desc="Training",
|
| 301 |
+
)
|
| 302 |
+
completed_steps = 0
|
| 303 |
+
starting_epoch = 0
|
| 304 |
+
global_step = 0
|
| 305 |
+
|
| 306 |
+
# Training loop
|
| 307 |
+
logger.info("Starting training...")
|
| 308 |
+
model.train()
|
| 309 |
+
|
| 310 |
+
for step, batch in enumerate(train_dataloader):
|
| 311 |
+
# Skip steps for resuming
|
| 312 |
+
if starting_epoch > 0 and step < starting_epoch * num_update_steps_per_epoch:
|
| 313 |
+
progress_bar.update(1)
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
with accelerator.accumulate(model):
|
| 317 |
+
# Forward pass
|
| 318 |
+
outputs = model(**batch)
|
| 319 |
+
loss = outputs.loss
|
| 320 |
+
|
| 321 |
+
# Backward pass
|
| 322 |
+
accelerator.backward(loss)
|
| 323 |
+
|
| 324 |
+
# Update weights
|
| 325 |
+
optimizer.step()
|
| 326 |
+
lr_scheduler.step()
|
| 327 |
+
optimizer.zero_grad()
|
| 328 |
+
|
| 329 |
+
# Update progress bar
|
| 330 |
+
progress_bar.update(1)
|
| 331 |
+
completed_steps += 1
|
| 332 |
+
global_step += 1
|
| 333 |
+
|
| 334 |
+
# Log metrics
|
| 335 |
+
if global_step % training_args.logging_steps == 0:
|
| 336 |
+
# Gather loss from all processes
|
| 337 |
+
loss_value = accelerator.gather(loss).mean().item()
|
| 338 |
+
logger.info(f"Step {global_step}: loss = {loss_value:.4f}, lr = {lr_scheduler.get_last_lr()[0]:.8f}")
|
| 339 |
+
|
| 340 |
+
# Log to tensorboard if available
|
| 341 |
+
if hasattr(accelerator.trackers[0], "store"):
|
| 342 |
+
accelerator.trackers[0].store({
|
| 343 |
+
"loss": loss_value,
|
| 344 |
+
"learning_rate": lr_scheduler.get_last_lr()[0],
|
| 345 |
+
"step": global_step,
|
| 346 |
+
})
|
| 347 |
+
|
| 348 |
+
# Save checkpoint
|
| 349 |
+
if training_args.save_steps > 0 and global_step % training_args.save_steps == 0:
|
| 350 |
+
if model_args.checkpoint_dir is not None:
|
| 351 |
+
output_dir = os.path.join(model_args.checkpoint_dir, f"checkpoint-{global_step}")
|
| 352 |
+
accelerator.save_state(output_dir)
|
| 353 |
+
logger.info(f"Saved checkpoint to {output_dir}")
|
| 354 |
+
|
| 355 |
+
# Save the model separately
|
| 356 |
+
if accelerator.is_main_process:
|
| 357 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 358 |
+
unwrapped_model.save_pretrained(
|
| 359 |
+
output_dir,
|
| 360 |
+
is_main_process=accelerator.is_main_process,
|
| 361 |
+
save_function=accelerator.save,
|
| 362 |
+
)
|
| 363 |
+
tokenizer.save_pretrained(output_dir)
|
| 364 |
+
|
| 365 |
+
# Check if we've reached max steps
|
| 366 |
+
if completed_steps >= num_training_steps:
|
| 367 |
+
break
|
| 368 |
+
|
| 369 |
+
# Save final model
|
| 370 |
+
if model_args.checkpoint_dir is not None:
|
| 371 |
+
output_dir = os.path.join(model_args.checkpoint_dir, "final-model")
|
| 372 |
+
accelerator.save_state(output_dir)
|
| 373 |
+
|
| 374 |
+
# Save the model separately
|
| 375 |
+
if accelerator.is_main_process:
|
| 376 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 377 |
+
unwrapped_model.save_pretrained(
|
| 378 |
+
output_dir,
|
| 379 |
+
is_main_process=accelerator.is_main_process,
|
| 380 |
+
save_function=accelerator.save,
|
| 381 |
+
)
|
| 382 |
+
tokenizer.save_pretrained(output_dir)
|
| 383 |
+
|
| 384 |
+
logger.info(f"Saved final model to {output_dir}")
|
| 385 |
+
|
| 386 |
+
logger.info("Training complete!")
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def main():
|
| 390 |
+
"""Main entry point."""
|
| 391 |
+
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
| 392 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 393 |
+
|
| 394 |
+
# Set up output directory
|
| 395 |
+
if model_args.checkpoint_dir is None:
|
| 396 |
+
model_args.checkpoint_dir = training_args.output_dir
|
| 397 |
+
os.makedirs(model_args.checkpoint_dir, exist_ok=True)
|
| 398 |
+
|
| 399 |
+
# Run training
|
| 400 |
+
train(model_args, data_args, training_args)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
if __name__ == "__main__":
|
| 404 |
+
main()
|