chengyanwu commited on
Commit
ccda2ec
·
1 Parent(s): e6dee89
.gitignore CHANGED
@@ -1 +1,2 @@
1
  upload.py
 
 
1
  upload.py
2
+ tester.py
README.md CHANGED
@@ -12,86 +12,88 @@ datasets:
12
  library_name: transformers
13
  ---
14
 
15
- <img alt="OLMoE Logo." src="olmoe-logo.png" width="250px">
16
-
17
 
18
  # Model Summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- > OLMoE-1B-7B is a Mixture-of-Experts LLM with 1B active and 7B total parameters released in September 2024 (0924). It yields state-of-the-art performance among models with a similar cost (1B) and is competitive with much larger models like Llama2-13B. OLMoE is 100% open-source.
21
-
22
- This information and more can also be found on the [**OLMoE GitHub repository**](https://github.com/allenai/OLMoE).
23
- - **Paper**: https://arxiv.org/abs/2409.02060
24
- - **Pretraining** [Checkpoints](https://hf.co/allenai/OLMoE-1B-7B-0924), [Code](https://github.com/allenai/OLMo/tree/Muennighoff/MoE), [Data](https://huggingface.co/datasets/allenai/OLMoE-mix-0924) and [Logs](https://wandb.ai/ai2-llm/olmoe/reports/OLMoE-1B-7B-0924--Vmlldzo4OTcyMjU3).
25
- - **SFT (Supervised Fine-Tuning)** [Checkpoints](https://huggingface.co/allenai/OLMoE-1B-7B-0924-SFT), [Code](https://github.com/allenai/open-instruct/tree/olmoe-sft), [Data](https://hf.co/datasets/allenai/tulu-v3.1-mix-preview-4096-OLMoE) and [Logs](https://github.com/allenai/OLMoE/blob/main/logs/olmoe-sft-logs.txt).
26
- - **DPO/KTO (Direct Preference Optimization/Kahneman-Tversky Optimization)**, [Checkpoints](https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct), [Preference Data](https://hf.co/datasets/allenai/ultrafeedback_binarized_cleaned), [DPO code](https://github.com/allenai/open-instruct/tree/olmoe-sft), [KTO code](https://github.com/Muennighoff/kto/blob/master/kto.py) and [Logs](https://github.com/allenai/OLMoE/blob/main/logs/olmoe-dpo-logs.txt).
27
 
28
- # Use
 
 
 
29
 
30
- Install `transformers` **from source** until a release after [this PR](https://github.com/huggingface/transformers/pull/32406) & `torch` and run:
31
 
32
  ```python
33
- from transformers import OlmoeForCausalLM, AutoTokenizer
34
- import torch
35
-
36
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
-
38
- # Load different ckpts via passing e.g. `revision=step10000-tokens41B`
39
- model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924").to(DEVICE)
40
- tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
41
- inputs = tokenizer("Bitcoin is", return_tensors="pt")
42
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
43
- out = model.generate(**inputs, max_length=64)
44
- print(tokenizer.decode(out[0]))
45
- # > # Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical
46
- ```
47
 
48
- You can list all revisions/branches by installing `huggingface-hub` & running:
49
- ```python
50
- from huggingface_hub import list_repo_refs
51
- out = list_repo_refs("allenai/OLMoE-1B-7B-0924")
52
- branches = [b.name for b in out.branches]
 
 
 
53
  ```
54
 
55
- Important branches:
56
- - `step1200000-tokens5033B`: Pretraining checkpoint used for annealing. There are a few more checkpoints after this one but we did not use them.
57
- - `main`: Checkpoint annealed from `step1200000-tokens5033B` for an additional 100B tokens (23,842 steps). We use this checkpoint for our adaptation (https://huggingface.co/allenai/OLMoE-1B-7B-0924-SFT & https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct).
58
- - `fp32`: FP32 version of `main`. The model weights were stored in FP32 during training but we did not observe any performance drop from casting them to BF16 after training so we upload all weights in BF16. If you want the original FP32 checkpoint for `main` you can use this one. You will find that it yields slightly different results but should perform around the same on benchmarks.
59
-
60
- # Evaluation Snapshot
61
-
62
- | Model | Active Params | Open Data | MMLU | HellaSwag | ARC-Chall. | ARC-Easy | PIQA | WinoGrande |
63
- |-----------------------------|---------------|-----------|------|-----------|------------|----------|------|------------|
64
- | **LMs with ~1B active parameters** | | | | | | | | |
65
- | **OLMoE-1B-7B** | **1.3B** | **✅** | **54.1** | **80.0** | **62.1** | **84.2** | **79.8** | **70.2** |
66
- | DCLM-1B | 1.4B | ✅ | 48.5 | 75.1 | 57.6 | 79.5 | 76.6 | 68.1 |
67
- | TinyLlama-1B | 1.1B | ✅ | 33.6 | 60.8 | 38.1 | 69.5 | 71.7 | 60.1 |
68
- | OLMo-1B (0724) | 1.3B | ✅ | 32.1 | 67.5 | 36.4 | 53.5 | 74.0 | 62.9 |
69
- | Pythia-1B | 1.1B | ✅ | 31.1 | 48.0 | 31.4 | 63.4 | 68.9 | 52.7 |
70
- | **LMs with ~2-3B active parameters** | | | | | | | | |
71
- | Qwen1.5-3B-14B | 2.7B | ❌ | **62.4** | 80.0 | **77.4** | **91.6** | **81.0** | 72.3 |
72
- | Gemma2-3B | 2.6B | ❌ | 53.3 | 74.6 | 67.5 | 84.3 | 78.5 | 71.8 |
73
- | JetMoE-2B-9B | 2.2B | ❌ | 49.1 | **81.7** | 61.4 | 81.9 | 80.3 | 70.7 |
74
- | DeepSeek-3B-16B | 2.9B | ❌ | 45.5 | 80.4 | 53.4 | 82.7 | 80.1 | **73.2** |
75
- | StableLM-2B | 1.6B | ❌ | 40.4 | 70.3 | 50.6 | 75.3 | 75.6 | 65.8 |
76
- | OpenMoE-3B-9B | 2.9B | ✅ | 27.4 | 44.4 | 29.3 | 50.6 | 63.3 | 51.9 |
77
- | **LMs with ~7-9B active parameters** | | | | | | | | |
78
- | Gemma2-9B | 9.2B | ❌ | **70.6** | **87.3** | **89.5** | **95.5** | **86.1** | **78.8** |
79
- | Llama3.1-8B | 8.0B | ❌ | 66.9 | 81.6 | 79.5 | 91.7 | 81.1 | 76.6 |
80
- | DCLM-7B | 6.9B | ✅ | 64.4 | 82.3 | 79.8 | 92.3 | 80.1 | 77.3 |
81
- | Mistral-7B | 7.3B | ❌ | 64.0 | 83.0 | 78.6 | 90.8 | 82.8 | 77.9 |
82
- | OLMo-7B (0724) | 6.9B | ✅ | 54.9 | 80.5 | 68.0 | 85.7 | 79.3 | 73.2 |
83
- | Llama2-7B | 6.7B | ❌ | 46.2 | 78.9 | 54.2 | 84.0 | 77.5 | 71.7 |
84
-
85
- # Citation
86
-
87
- ```bibtex
88
- @misc{muennighoff2024olmoeopenmixtureofexpertslanguage,
89
- title={OLMoE: Open Mixture-of-Experts Language Models},
90
- author={Niklas Muennighoff and Luca Soldaini and Dirk Groeneveld and Kyle Lo and Jacob Morrison and Sewon Min and Weijia Shi and Pete Walsh and Oyvind Tafjord and Nathan Lambert and Yuling Gu and Shane Arora and Akshita Bhagia and Dustin Schwenk and David Wadden and Alexander Wettig and Binyuan Hui and Tim Dettmers and Douwe Kiela and Ali Farhadi and Noah A. Smith and Pang Wei Koh and Amanpreet Singh and Hannaneh Hajishirzi},
91
- year={2024},
92
- eprint={2409.02060},
93
- archivePrefix={arXiv},
94
- primaryClass={cs.CL},
95
- url={https://arxiv.org/abs/2409.02060},
96
- }
97
- ```
 
12
  library_name: transformers
13
  ---
14
 
 
 
15
 
16
  # Model Summary
17
+ # OLMoE with Adapters
18
+
19
+ This repository contains an extension of the OLMo model with adapter layers for parameter-efficient fine-tuning. By adding small adapter modules to the model, we can fine-tune it on downstream tasks while freezing most of the original parameters, resulting in much more efficient training.
20
+
21
+ ## Model Architecture
22
+
23
+ The `OlmoEWithAdaptersForCausalLM` model extends the original OLMo architecture by:
24
+
25
+ 1. Adding small adapter layers (bottleneck layers) to each MLP block
26
+ 2. Allowing selective freezing of the base model's parameters
27
+ 3. Training only the adapter parameters (~0.1-1% of total parameters)
28
+
29
+ Key components:
30
+ - `OlmoEWithAdaptersMLP`: MLP layer with additional adapter modules
31
+ - `OlmoEWithAdaptersDecoderLayer`: Decoder layer incorporating adapter MLPs
32
+ - `OlmoEWithAdaptersModel`: Full model with adapter-based decoder layers
33
+ - `OlmoEWithAdaptersForCausalLM`: Causal language model with adapters
34
+
35
+ ## Training Script
36
+
37
+ The `train_olmoe_adapters.py` script provides a complete workflow for fine-tuning the model:
38
+
39
+ ### Features:
40
+ - Parameter-efficient fine-tuning using adapters
41
+ - Support for various datasets through Hugging Face datasets library
42
+ - Customizable adapter size
43
+ - Option to freeze/unfreeze different components
44
+ - Training with AdamW optimizer and learning rate scheduling
45
+ - Evaluation with perplexity metrics
46
+ - Model checkpointing and saving
47
+
48
+ ### Usage:
49
+
50
+ ```bash
51
+ python train.py \
52
+ --model_name_or_path allenai/OLMo-7B \
53
+ --adapter_size 64 \
54
+ --freeze_base_model True \
55
+ --dataset_name wikitext \
56
+ --dataset_config_name wikitext-2-raw-v1 \
57
+ --output_dir ./olmoe-adapter-finetuned \
58
+ --num_train_epochs 3 \
59
+ --per_device_train_batch_size 4 \
60
+ --per_device_eval_batch_size 4 \
61
+ --learning_rate 5e-5 \
62
+ --warmup_steps 100 \
63
+ --logging_steps 100 \
64
+ --save_steps 1000 \
65
+ --seed 42
66
+ ```
67
 
68
+ ## Benefits of Adapter-Based Fine-Tuning
 
 
 
 
 
 
69
 
70
+ 1. **Efficiency**: Train only ~0.1-1% of the parameters, dramatically reducing GPU memory requirements
71
+ 2. **Storage**: Store only adapter weights rather than full fine-tuned models
72
+ 3. **Composability**: Multiple adapters can be trained for different tasks and swapped at inference time
73
+ 4. **Reduced Overfitting**: Lower parameter count helps prevent overfitting on small datasets
74
 
75
+ ## How to Use the Fine-Tuned Model
76
 
77
  ```python
78
+ from transformers import OlmoTokenizer
79
+ from modeling_olmoe import OlmoEWithAdaptersForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # Load the fine-tuned model
82
+ model = OlmoEWithAdaptersForCausalLM.from_pretrained("./olmoe-adapter-finetuned")
83
+ tokenizer = OlmoTokenizer.from_pretrained("./olmoe-adapter-finetuned")
84
+
85
+ # Generate text
86
+ inputs = tokenizer("Once upon a time", return_tensors="pt")
87
+ outputs = model.generate(**inputs, max_length=50)
88
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
89
  ```
90
 
91
+ ## Adapter Size Recommendations
92
+
93
+ The adapter size determines the parameter efficiency vs. performance trade-off:
94
+
95
+ - **Small datasets**: 16-32 dimensions
96
+ - **Medium datasets**: 64-128 dimensions
97
+ - **Large datasets**: 128-256 dimensions
98
+
99
+ For most fine-tuning scenarios, an adapter size of 64 provides a good balance between efficiency and performance.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__pycache__/configuration_olmoe.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
__pycache__/modeling_kvlatent.cpython-311.pyc ADDED
Binary file (33.9 kB). View file
 
__pycache__/modeling_latent_attention.cpython-311.pyc ADDED
Binary file (9.57 kB). View file
 
__pycache__/modeling_olmoe.cpython-311.pyc ADDED
Binary file (43.7 kB). View file
 
__pycache__/random.cpython-311.pyc ADDED
Binary file (54.4 kB). View file
 
__pycache__/train.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
config.json CHANGED
@@ -1,31 +1,27 @@
1
  {
2
  "architectures": [
3
- "OlmoeForCausalLM"
4
  ],
5
- "attention_bias": false,
6
- "attention_dropout": 0.0,
7
- "clip_qkv": null,
8
- "eos_token_id": 50279,
9
- "hidden_act": "silu",
10
  "hidden_size": 2048,
11
- "initializer_range": 0.02,
12
- "intermediate_size": 1024,
13
- "max_position_embeddings": 4096,
14
- "model_type": "olmoe",
15
- "norm_topk_prob": false,
16
  "num_attention_heads": 16,
17
- "num_experts": 64,
18
- "num_experts_per_tok": 8,
19
- "num_hidden_layers": 16,
20
- "num_key_value_heads": 16,
21
- "output_router_logits": false,
 
 
22
  "pad_token_id": 1,
23
- "rope_scaling": null,
24
- "rope_theta": 10000.0,
25
- "router_aux_loss_coef": 0.01,
26
- "tie_word_embeddings": false,
27
- "torch_dtype": "bfloat16",
28
- "transformers_version": "4.43.0.dev0",
29
  "use_cache": true,
30
- "vocab_size": 50304
31
- }
 
 
 
 
 
1
  {
2
  "architectures": [
3
+ "KVLatentForCausalLM"
4
  ],
5
+ "model_type": "kvlatent",
 
 
 
 
6
  "hidden_size": 2048,
7
+ "num_hidden_layers": 24,
 
 
 
 
8
  "num_attention_heads": 16,
9
+ "num_key_value_heads": 8,
10
+ "num_latents": 64,
11
+ "intermediate_size": 8192,
12
+ "hidden_act": "gelu",
13
+ "initializer_range": 0.02,
14
+ "rms_norm_eps": 1e-5,
15
+ "vocab_size": 50304,
16
  "pad_token_id": 1,
17
+ "bos_token_id": 50256,
18
+ "eos_token_id": 50256,
19
+ "attention_dropout": 0.0,
20
+ "attention_bias": false,
 
 
21
  "use_cache": true,
22
+ "tie_word_embeddings": false,
23
+ "rope_theta": 10000.0,
24
+ "rope_scaling": null,
25
+ "max_position_embeddings": 4096,
26
+ "torch_dtype": "bfloat16"
27
+ }
generate.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Example usage script to evaluate a fine-tuned OlmoE adapter model
4
+ and demonstrate generation with adapters.
5
+ """
6
+
7
+ import argparse
8
+ import torch
9
+ from transformers import AutoTokenizer
10
+ from modeling_olmoe import OlmoEWithAdaptersForCausalLM, OlmoConfig
11
+
12
+ def generate_text(
13
+ model_path: str,
14
+ prompt: str,
15
+ max_new_tokens: int = 128,
16
+ temperature: float = 0.7,
17
+ top_p: float = 0.9,
18
+ device: str = "auto",
19
+ ):
20
+ """Generate text using a fine-tuned OlmoE adapter model."""
21
+ # Determine device
22
+ if device == "auto":
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ print(f"Using device: {device}")
25
+
26
+ # Load tokenizer and model
27
+ print(f"Loading model from {model_path}")
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
29
+
30
+ # Load config and update with adapter settings if needed
31
+ config = OlmoConfig.from_pretrained(model_path)
32
+
33
+ # Load adapter model
34
+ model = OlmoEWithAdaptersForCausalLM.from_pretrained(
35
+ model_path,
36
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
37
+ )
38
+ model = model.to(device)
39
+ model.eval()
40
+
41
+ # Tokenize input
42
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
43
+
44
+ # Generate text
45
+ print("\nGenerating text...\n")
46
+ with torch.no_grad():
47
+ outputs = model.generate(
48
+ input_ids,
49
+ max_new_tokens=max_new_tokens,
50
+ do_sample=True,
51
+ temperature=temperature,
52
+ top_p=top_p,
53
+ )
54
+
55
+ # Decode the generated text
56
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+
58
+ print(f"Prompt: {prompt}")
59
+ print("\nGenerated text:")
60
+ print("=" * 40)
61
+ print(generated_text)
62
+ print("=" * 40)
63
+
64
+ return generated_text
65
+
66
+ def main():
67
+ parser = argparse.ArgumentParser(description="Generate text with OlmoE adapter model")
68
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model")
69
+ parser.add_argument("--prompt", type=str, default="This is an example of", help="Prompt for text generation")
70
+ parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens to generate")
71
+ parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
72
+ parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter")
73
+ parser.add_argument("--device", type=str, default="auto", help="Device to use (cuda, cpu, or auto)")
74
+
75
+ args = parser.parse_args()
76
+
77
+ generate_text(
78
+ model_path=args.model_path,
79
+ prompt=args.prompt,
80
+ max_new_tokens=args.max_new_tokens,
81
+ temperature=args.temperature,
82
+ top_p=args.top_p,
83
+ device=args.device,
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ main()
model-00001-of-00003.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5e3cff7e367794685c241169072c940d200918617d5e2813f1c387dff52d845e
3
- size 4997744872
 
 
 
 
model-00002-of-00003.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:15ef5c730ee3cfed7199498788cd2faf337203fc74b529625e7502cdd759f4a7
3
- size 4997235176
 
 
 
 
model-00003-of-00003.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a9abac4ac1b55c9adabac721a02fa39971f103eea9a65c310972b1246de76e04
3
- size 3843741912
 
 
 
 
modeling_olmoe.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_olmoe.py - Extended version of OLMo for custom training
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Callable, Dict, Optional, Tuple, Union, Any
7
+ # Import necessary components from transformers
8
+ from transformers.activations import ACT2FN
9
+ from transformers.cache_utils import Cache, DynamicCache
10
+ from transformers.generation import GenerationMixin
11
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
12
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
13
+ # from transformers.modeling_layers import GradientCheckpointingLayer
14
+ from torch.utils.checkpoint import checkpoint
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+ # from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
17
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
18
+ from transformers.processing_utils import Unpack
19
+ from transformers.utils import LossKwargs, is_torch_flex_attn_available, logging
20
+ from transformers import OlmoConfig
21
+
22
+ # Import flex attention components if available
23
+ if is_torch_flex_attn_available():
24
+ from torch.nn.attention.flex_attention import BlockMask
25
+ # from transformers.integrations.flex_attention import make_flex_block_causal_mask
26
+
27
+ from functools import partial
28
+ # Define GradientCheckpointingLayer since it's missing
29
+ class GradientCheckpointingLayer(nn.Module):
30
+ gradient_checkpointing = False
31
+ def __call__(self, *args, **kwargs):
32
+ # Use checkpoint on `forward` when enabled
33
+ if self.gradient_checkpointing and self.training:
34
+ return checkpoint(self.forward, *args, **kwargs)
35
+ return super().__call__(*args, **kwargs)
36
+
37
+ def forward(self, *args, **kwargs):
38
+ # To be implemented by subclasses
39
+ raise NotImplementedError("Subclasses must implement `forward`")
40
+
41
+ import math
42
+ import functools
43
+
44
+ # Define our own dynamic_rope_update decorator and ROPE_INIT_FUNCTIONS
45
+ def dynamic_rope_update(func):
46
+ """
47
+ Decorator for updating RoPE embeddings when using RoPE scaling strategies.
48
+ """
49
+ @functools.wraps(func)
50
+ def wrapper(self, *args, **kwargs):
51
+ # Only dynamic scaling needs to modify the positional encodings
52
+ if self.rope_type == "dynamic" and hasattr(self, "original_max_seq_len"):
53
+ if self.config.rope_scaling is None:
54
+ return func(self, *args, **kwargs)
55
+ # Extract max_position_embeddings from the actual model
56
+ current_ctx_len = kwargs.get("position_ids", None)
57
+ if current_ctx_len is not None:
58
+ # position_ids shape is [batch_size, seq_len]
59
+ current_ctx_len = current_ctx_len.shape[-1]
60
+
61
+ # If we're inside a context window we've seen before, we don't have to change anything
62
+ if current_ctx_len is not None and current_ctx_len <= self.max_seq_len_cached:
63
+ return func(self, *args, **kwargs)
64
+
65
+ current_ctx_len = self.config.max_position_embeddings if current_ctx_len is None else current_ctx_len
66
+ scaling_factor = self.config.rope_scaling["factor"]
67
+
68
+ self.max_seq_len_cached = min(
69
+ int(self.original_max_seq_len * scaling_factor),
70
+ self.config.rope_scaling.get("max_position_embeddings", float("inf"))
71
+ )
72
+
73
+ # Reset the cached maximum position embeddings to the new value
74
+ power = 0.0 if scaling_factor <= 1.0 else -0.5
75
+ self.inv_freq = self.original_inv_freq * (scaling_factor ** power)
76
+
77
+ return func(self, *args, **kwargs)
78
+
79
+ return wrapper
80
+
81
+ def get_default_rope_init(config, device=None):
82
+ """
83
+ Default initialization for rotary position embeddings.
84
+ """
85
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
86
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, head_dim, 2).float().to(device) / head_dim))
87
+ return inv_freq, None
88
+
89
+ def get_linear_rope_init(config, device=None):
90
+ """
91
+ Linear initialization for dynamic scaling rotary position embeddings.
92
+ """
93
+ base = get_default_rope_init(config, device)[0]
94
+ scaling_factor = config.rope_scaling["factor"]
95
+
96
+ # Scale the base frequencies
97
+ return base / scaling_factor, scaling_factor
98
+
99
+ def get_dynamic_rope_init(config, device=None):
100
+ """
101
+ Dynamic initialization for dynamic scaling rotary position embeddings (NTK approach).
102
+ """
103
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
104
+ scaling_factor = config.rope_scaling["factor"]
105
+
106
+ # Adjust the base frequencies by a power of the scaling factor
107
+ power = 0.0 if scaling_factor <= 1.0 else -0.5
108
+ inv_freq = 1.0 / (config.rope_theta **
109
+ (torch.arange(0, head_dim, 2).float().to(device) / head_dim))
110
+ inv_freq = inv_freq * (scaling_factor ** power)
111
+
112
+ return inv_freq, scaling_factor
113
+
114
+ # Define the dictionary of RoPE initialization functions
115
+ ROPE_INIT_FUNCTIONS = {
116
+ "default": get_default_rope_init,
117
+ "linear": get_linear_rope_init,
118
+ "dynamic": get_dynamic_rope_init,
119
+ }
120
+
121
+ def can_return_tuple(inputs):
122
+ # Copied logic from the original source
123
+ return getattr(inputs, "return_tuple", False) if hasattr(inputs, "return_tuple") else False
124
+
125
+ # Start Modeling Code
126
+ logger = logging.get_logger(__name__)
127
+
128
+ # Core OLMo components (reused from original implementation)
129
+ class OlmoLayerNorm(nn.Module):
130
+ """LayerNorm but with no learnable weight or bias."""
131
+
132
+ def __init__(self, hidden_size: int) -> None:
133
+ super().__init__()
134
+ self.normalized_shape = (hidden_size,)
135
+
136
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
137
+ orig_dtype = hidden_states.dtype
138
+ return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
139
+ orig_dtype
140
+ )
141
+
142
+
143
+ class OlmoMLP(nn.Module):
144
+ def __init__(self, config):
145
+ super().__init__()
146
+ self.config = config
147
+ self.hidden_size = config.hidden_size
148
+ self.intermediate_size = config.intermediate_size
149
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
150
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
151
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
152
+ self.act_fn = ACT2FN[config.hidden_act]
153
+
154
+ def forward(self, x):
155
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
156
+ return down_proj
157
+
158
+
159
+ # Helper functions for rotary position embeddings
160
+ def rotate_half(x):
161
+ """Rotates half the hidden dims of the input."""
162
+ x1 = x[..., : x.shape[-1] // 2]
163
+ x2 = x[..., x.shape[-1] // 2 :]
164
+ return torch.cat((-x2, x1), dim=-1)
165
+
166
+
167
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
168
+ """Applies Rotary Position Embedding to the query and key tensors."""
169
+ cos = cos.unsqueeze(unsqueeze_dim)
170
+ sin = sin.unsqueeze(unsqueeze_dim)
171
+ q_embed = (q * cos) + (rotate_half(q) * sin)
172
+ k_embed = (k * cos) + (rotate_half(k) * sin)
173
+ return q_embed, k_embed
174
+
175
+
176
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
177
+ """
178
+ Repeats key/value states for grouped queries attention.
179
+ """
180
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
181
+ if n_rep == 1:
182
+ return hidden_states
183
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
184
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
185
+
186
+
187
+ def eager_attention_forward(
188
+ module: nn.Module,
189
+ query: torch.Tensor,
190
+ key: torch.Tensor,
191
+ value: torch.Tensor,
192
+ attention_mask: Optional[torch.Tensor],
193
+ scaling: float,
194
+ dropout: float = 0.0,
195
+ **kwargs,
196
+ ):
197
+ """Default eager implementation of multi-head attention"""
198
+ key_states = repeat_kv(key, module.num_key_value_groups)
199
+ value_states = repeat_kv(value, module.num_key_value_groups)
200
+
201
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
202
+ if attention_mask is not None:
203
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
204
+ attn_weights = attn_weights + causal_mask
205
+
206
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
207
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
208
+ attn_output = torch.matmul(attn_weights, value_states)
209
+ attn_output = attn_output.transpose(1, 2).contiguous()
210
+
211
+ return attn_output, attn_weights
212
+
213
+
214
+ class OlmoAttention(nn.Module):
215
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
216
+
217
+ def __init__(self, config: OlmoConfig, layer_idx: int):
218
+ super().__init__()
219
+ self.config = config
220
+ self.layer_idx = layer_idx
221
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
222
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
223
+ self.scaling = self.head_dim**-0.5
224
+ self.attention_dropout = config.attention_dropout
225
+ self.is_causal = True
226
+
227
+ self.q_proj = nn.Linear(
228
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
229
+ )
230
+ self.k_proj = nn.Linear(
231
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
232
+ )
233
+ self.v_proj = nn.Linear(
234
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
235
+ )
236
+ self.o_proj = nn.Linear(
237
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
238
+ )
239
+
240
+ def forward(
241
+ self,
242
+ hidden_states: torch.Tensor,
243
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
244
+ attention_mask: Optional[torch.Tensor],
245
+ past_key_value: Optional[Cache] = None,
246
+ cache_position: Optional[torch.LongTensor] = None,
247
+ **kwargs,
248
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
249
+ input_shape = hidden_states.shape[:-1]
250
+ hidden_shape = (*input_shape, -1, self.head_dim)
251
+
252
+ query_states = self.q_proj(hidden_states)
253
+ key_states = self.k_proj(hidden_states)
254
+ value_states = self.v_proj(hidden_states)
255
+
256
+ if self.config.clip_qkv is not None:
257
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
258
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
259
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
260
+
261
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
262
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
263
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
264
+
265
+ cos, sin = position_embeddings
266
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
267
+
268
+ if past_key_value is not None:
269
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
270
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
271
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
272
+
273
+ attention_interface: Callable = eager_attention_forward
274
+ if self.config._attn_implementation != "eager":
275
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
276
+ logger.warning_once(
277
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
278
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
279
+ )
280
+ else:
281
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
282
+
283
+ attn_output, attn_weights = attention_interface(
284
+ self,
285
+ query_states,
286
+ key_states,
287
+ value_states,
288
+ attention_mask,
289
+ dropout=0.0 if not self.training else self.attention_dropout,
290
+ scaling=self.scaling,
291
+ **kwargs,
292
+ )
293
+
294
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
295
+ attn_output = self.o_proj(attn_output)
296
+ return attn_output, attn_weights
297
+
298
+
299
+ class OlmoDecoderLayer(GradientCheckpointingLayer):
300
+ def __init__(self, config: OlmoConfig, layer_idx: int):
301
+ super().__init__()
302
+ self.hidden_size = config.hidden_size
303
+ self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
304
+
305
+ self.mlp = OlmoMLP(config)
306
+ self.input_layernorm = OlmoLayerNorm(config.hidden_size)
307
+ self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
308
+
309
+ def forward(
310
+ self,
311
+ hidden_states: torch.Tensor,
312
+ attention_mask: Optional[torch.Tensor] = None,
313
+ position_ids: Optional[torch.LongTensor] = None,
314
+ past_key_value: Optional[Cache] = None,
315
+ output_attentions: Optional[bool] = False,
316
+ use_cache: Optional[bool] = False,
317
+ cache_position: Optional[torch.LongTensor] = None,
318
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
319
+ **kwargs,
320
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
321
+ residual = hidden_states
322
+ hidden_states = self.input_layernorm(hidden_states)
323
+
324
+ # Self Attention
325
+ hidden_states, self_attn_weights = self.self_attn(
326
+ hidden_states=hidden_states,
327
+ attention_mask=attention_mask,
328
+ position_ids=position_ids,
329
+ past_key_value=past_key_value,
330
+ output_attentions=output_attentions,
331
+ use_cache=use_cache,
332
+ cache_position=cache_position,
333
+ position_embeddings=position_embeddings,
334
+ **kwargs,
335
+ )
336
+ hidden_states = residual + hidden_states
337
+
338
+ # Fully Connected
339
+ residual = hidden_states
340
+ hidden_states = self.post_attention_layernorm(hidden_states)
341
+ hidden_states = self.mlp(hidden_states)
342
+ hidden_states = residual + hidden_states
343
+
344
+ outputs = (hidden_states,)
345
+ if output_attentions:
346
+ outputs += (self_attn_weights,)
347
+
348
+ return outputs
349
+
350
+
351
+ class OlmoRotaryEmbedding(nn.Module):
352
+ def __init__(self, config: OlmoConfig, device=None):
353
+ super().__init__()
354
+ # BC: "rope_type" was originally "type"
355
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
356
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
357
+ else:
358
+ self.rope_type = "default"
359
+ self.max_seq_len_cached = config.max_position_embeddings
360
+ self.original_max_seq_len = config.max_position_embeddings
361
+
362
+ self.config = config
363
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
364
+
365
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
366
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
367
+ self.original_inv_freq = self.inv_freq
368
+
369
+ @torch.no_grad()
370
+ @dynamic_rope_update
371
+ def forward(self, x, position_ids):
372
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
373
+ position_ids_expanded = position_ids[:, None, :].float()
374
+
375
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
376
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
377
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
378
+ emb = torch.cat((freqs, freqs), dim=-1)
379
+ cos = emb.cos() * self.attention_scaling
380
+ sin = emb.sin() * self.attention_scaling
381
+
382
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
383
+
384
+
385
+ # Base model classes
386
+ class OlmoEPreTrainedModel(PreTrainedModel):
387
+ """Base class for OlmoE models with additional extensibility features"""
388
+
389
+ config_class = OlmoConfig
390
+ base_model_prefix = "model"
391
+ supports_gradient_checkpointing = True
392
+ _no_split_modules = ["OlmoDecoderLayer"]
393
+ _skip_keys_device_placement = ["past_key_values"]
394
+ _supports_flash_attn_2 = True
395
+ _supports_sdpa = True
396
+ _supports_flex_attn = True
397
+ _supports_cache_class = True
398
+ _supports_quantized_cache = True
399
+ _supports_static_cache = True
400
+ _supports_attention_backend = True
401
+
402
+ def _init_weights(self, module):
403
+ std = self.config.initializer_range
404
+ if isinstance(module, nn.Linear):
405
+ module.weight.data.normal_(mean=0.0, std=std)
406
+ if module.bias is not None:
407
+ module.bias.data.zero_()
408
+ elif isinstance(module, nn.Embedding):
409
+ module.weight.data.normal_(mean=0.0, std=std)
410
+ if module.padding_idx is not None:
411
+ module.weight.data[module.padding_idx].zero_()
412
+
413
+
414
+ class OlmoEModel(OlmoEPreTrainedModel):
415
+ """Extended OLMo base model with additional customization points"""
416
+
417
+ def __init__(self, config: OlmoConfig):
418
+ super().__init__(config)
419
+ self.padding_idx = config.pad_token_id
420
+ self.vocab_size = config.vocab_size
421
+
422
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
423
+ self.layers = nn.ModuleList(
424
+ [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
425
+ )
426
+ self.norm = OlmoLayerNorm(config.hidden_size)
427
+ self.rotary_emb = OlmoRotaryEmbedding(config=config)
428
+ self.gradient_checkpointing = False
429
+
430
+ # Initialize weights and apply final processing
431
+ self.post_init()
432
+
433
+ def get_input_embeddings(self):
434
+ return self.embed_tokens
435
+
436
+ def set_input_embeddings(self, value):
437
+ self.embed_tokens = value
438
+
439
+ def _update_causal_mask(
440
+ self,
441
+ attention_mask: Union[torch.Tensor, "BlockMask"],
442
+ input_tensor: torch.Tensor,
443
+ cache_position: torch.Tensor,
444
+ past_key_values: Cache,
445
+ output_attentions: bool = False,
446
+ ):
447
+ if self.config._attn_implementation == "flash_attention_2":
448
+ if attention_mask is not None and (attention_mask == 0.0).any():
449
+ return attention_mask
450
+ return None
451
+ # if self.config._attn_implementation == "flex_attention":
452
+ # if isinstance(attention_mask, torch.Tensor):
453
+ # attention_mask = make_flex_block_causal_mask(attention_mask)
454
+ # return attention_mask
455
+
456
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
457
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
458
+
459
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
460
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
461
+ attention_mask,
462
+ inputs_embeds=input_tensor,
463
+ past_key_values_length=past_seen_tokens,
464
+ is_training=self.training,
465
+ ):
466
+ return None
467
+
468
+ dtype = input_tensor.dtype
469
+ sequence_length = input_tensor.shape[1]
470
+ if using_compilable_cache:
471
+ target_length = past_key_values.get_max_cache_shape()
472
+ else:
473
+ target_length = (
474
+ attention_mask.shape[-1]
475
+ if isinstance(attention_mask, torch.Tensor)
476
+ else past_seen_tokens + sequence_length + 1
477
+ )
478
+
479
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
480
+ attention_mask,
481
+ sequence_length=sequence_length,
482
+ target_length=target_length,
483
+ dtype=dtype,
484
+ cache_position=cache_position,
485
+ batch_size=input_tensor.shape[0],
486
+ )
487
+
488
+ if (
489
+ self.config._attn_implementation == "sdpa"
490
+ and attention_mask is not None
491
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
492
+ and not output_attentions
493
+ ):
494
+ min_dtype = torch.finfo(dtype).min
495
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
496
+
497
+ return causal_mask
498
+
499
+ @staticmethod
500
+ def _prepare_4d_causal_attention_mask_with_cache_position(
501
+ attention_mask: torch.Tensor,
502
+ sequence_length: int,
503
+ target_length: int,
504
+ dtype: torch.dtype,
505
+ cache_position: torch.Tensor,
506
+ batch_size: int,
507
+ **kwargs,
508
+ ):
509
+ """Creates a causal 4D mask."""
510
+ if attention_mask is not None and attention_mask.dim() == 4:
511
+ causal_mask = attention_mask
512
+ else:
513
+ min_dtype = torch.finfo(dtype).min
514
+ causal_mask = torch.full(
515
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
516
+ )
517
+ if sequence_length != 1:
518
+ causal_mask = torch.triu(causal_mask, diagonal=1)
519
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
520
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
521
+ if attention_mask is not None:
522
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
523
+ mask_length = attention_mask.shape[-1]
524
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
525
+ causal_mask.device
526
+ )
527
+ padding_mask = padding_mask == 0
528
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
529
+ padding_mask, min_dtype
530
+ )
531
+
532
+ return causal_mask
533
+
534
+ @can_return_tuple
535
+ def forward(
536
+ self,
537
+ input_ids: Optional[torch.LongTensor] = None,
538
+ attention_mask: Optional[torch.Tensor] = None,
539
+ position_ids: Optional[torch.LongTensor] = None,
540
+ past_key_values: Optional[Cache] = None,
541
+ inputs_embeds: Optional[torch.FloatTensor] = None,
542
+ use_cache: Optional[bool] = None,
543
+ output_attentions: Optional[bool] = None,
544
+ output_hidden_states: Optional[bool] = None,
545
+ cache_position: Optional[torch.LongTensor] = None,
546
+ **flash_attn_kwargs,
547
+ ) -> BaseModelOutputWithPast:
548
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
549
+ output_hidden_states = (
550
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
551
+ )
552
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
553
+
554
+ if (input_ids is None) ^ (inputs_embeds is not None):
555
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
556
+
557
+ if self.gradient_checkpointing and self.training and use_cache:
558
+ logger.warning_once(
559
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
560
+ )
561
+ use_cache = False
562
+
563
+ if not isinstance(past_key_values, (type(None), Cache)):
564
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
565
+
566
+ if inputs_embeds is None:
567
+ inputs_embeds = self.embed_tokens(input_ids)
568
+
569
+ if use_cache and past_key_values is None:
570
+ past_key_values = DynamicCache()
571
+
572
+ if cache_position is None:
573
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
574
+ cache_position = torch.arange(
575
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
576
+ )
577
+
578
+ if position_ids is None:
579
+ position_ids = cache_position.unsqueeze(0)
580
+
581
+ causal_mask = self._update_causal_mask(
582
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
583
+ )
584
+
585
+ hidden_states = inputs_embeds
586
+
587
+ # create position embeddings to be shared across the decoder layers
588
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
589
+
590
+ # decoder layers
591
+ all_hidden_states = () if output_hidden_states else None
592
+ all_self_attns = () if output_attentions else None
593
+
594
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
595
+ if output_hidden_states:
596
+ all_hidden_states += (hidden_states,)
597
+
598
+ layer_outputs = decoder_layer(
599
+ hidden_states,
600
+ attention_mask=causal_mask,
601
+ position_ids=position_ids,
602
+ past_key_value=past_key_values,
603
+ output_attentions=output_attentions,
604
+ use_cache=use_cache,
605
+ cache_position=cache_position,
606
+ position_embeddings=position_embeddings,
607
+ **flash_attn_kwargs,
608
+ )
609
+
610
+ hidden_states = layer_outputs[0]
611
+
612
+ if output_attentions:
613
+ all_self_attns += (layer_outputs[1],)
614
+
615
+ hidden_states = self.norm(hidden_states)
616
+
617
+ # add hidden states from the last decoder layer
618
+ if output_hidden_states:
619
+ all_hidden_states += (hidden_states,)
620
+
621
+ return BaseModelOutputWithPast(
622
+ last_hidden_state=hidden_states,
623
+ past_key_values=past_key_values if use_cache else None,
624
+ hidden_states=all_hidden_states,
625
+ attentions=all_self_attns,
626
+ )
627
+
628
+
629
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
630
+
631
+
632
+ class OlmoEForCausalLM(OlmoEPreTrainedModel, GenerationMixin):
633
+ """OLMo Causal Language Model with extensions for custom training"""
634
+
635
+ _tied_weights_keys = ["lm_head.weight"]
636
+ _tp_plan = {"lm_head": "colwise_rep"}
637
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
638
+
639
+ def __init__(self, config):
640
+ super().__init__(config)
641
+ self.model = OlmoEModel(config)
642
+ self.vocab_size = config.vocab_size
643
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
644
+
645
+ # Initialize weights and apply final processing
646
+ self.post_init()
647
+
648
+ def get_input_embeddings(self):
649
+ return self.model.embed_tokens
650
+
651
+ def set_input_embeddings(self, value):
652
+ self.model.embed_tokens = value
653
+
654
+ def get_output_embeddings(self):
655
+ return self.lm_head
656
+
657
+ def set_output_embeddings(self, new_embeddings):
658
+ self.lm_head = new_embeddings
659
+
660
+ def set_decoder(self, decoder):
661
+ self.model = decoder
662
+
663
+ def get_decoder(self):
664
+ return self.model
665
+
666
+ @can_return_tuple
667
+ def forward(
668
+ self,
669
+ input_ids: Optional[torch.LongTensor] = None,
670
+ attention_mask: Optional[torch.Tensor] = None,
671
+ position_ids: Optional[torch.LongTensor] = None,
672
+ past_key_values: Optional[Cache] = None,
673
+ inputs_embeds: Optional[torch.FloatTensor] = None,
674
+ labels: Optional[torch.LongTensor] = None,
675
+ use_cache: Optional[bool] = None,
676
+ output_attentions: Optional[bool] = None,
677
+ output_hidden_states: Optional[bool] = None,
678
+ cache_position: Optional[torch.LongTensor] = None,
679
+ logits_to_keep: Union[int, torch.Tensor] = 0,
680
+ **kwargs,
681
+ ) -> CausalLMOutputWithPast:
682
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
683
+ output_hidden_states = (
684
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
685
+ )
686
+
687
+ # Get model outputs
688
+ outputs = self.model(
689
+ input_ids=input_ids,
690
+ attention_mask=attention_mask,
691
+ position_ids=position_ids,
692
+ past_key_values=past_key_values,
693
+ inputs_embeds=inputs_embeds,
694
+ use_cache=use_cache,
695
+ output_attentions=output_attentions,
696
+ output_hidden_states=output_hidden_states,
697
+ cache_position=cache_position,
698
+ **kwargs,
699
+ )
700
+
701
+ hidden_states = outputs.last_hidden_state
702
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
703
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
704
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
705
+
706
+ loss = None
707
+ if labels is not None:
708
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
709
+
710
+ return CausalLMOutputWithPast(
711
+ loss=loss,
712
+ logits=logits,
713
+ past_key_values=outputs.past_key_values,
714
+ hidden_states=outputs.hidden_states,
715
+ attentions=outputs.attentions,
716
+ )
717
+
718
+
719
+ # Example of custom model extensions you can create:
720
+
721
+ class OlmoEWithAdaptersMLP(OlmoMLP):
722
+ """An extended MLP with adapters for parameter-efficient fine-tuning"""
723
+
724
+ def __init__(self, config):
725
+ super().__init__(config)
726
+ # Example adapter dimensions (typically much smaller than original dims)
727
+ adapter_size = getattr(config, "adapter_size", 64)
728
+
729
+ # Add adapter layers
730
+ self.down_adapter = nn.Sequential(
731
+ nn.Linear(self.hidden_size, adapter_size, bias=False),
732
+ nn.ReLU(),
733
+ nn.Linear(adapter_size, self.hidden_size, bias=False),
734
+ )
735
+
736
+ # Initialize adapter layers with small weights
737
+ self.down_adapter[0].weight.data.normal_(mean=0.0, std=0.01)
738
+ self.down_adapter[2].weight.data.normal_(mean=0.0, std=0.01)
739
+
740
+ def forward(self, x):
741
+ # Original MLP computation
742
+ mlp_output = super().forward(x)
743
+
744
+ # Add adapter path with residual connection
745
+ adapter_output = self.down_adapter(x)
746
+ return mlp_output + adapter_output
747
+
748
+
749
+ class OlmoEWithAdaptersDecoderLayer(OlmoDecoderLayer):
750
+ """OLMo decoder layer with adapters for efficient fine-tuning"""
751
+
752
+ def __init__(self, config, layer_idx):
753
+ # Replace the standard MLP with an adapter-based MLP
754
+ super().__init__(config, layer_idx)
755
+ self.mlp = OlmoEWithAdaptersMLP(config)
756
+
757
+
758
+ class OlmoEWithAdaptersModel(OlmoEModel):
759
+ """OLMo model with adapter layers"""
760
+
761
+ def __init__(self, config):
762
+ super().__init__(config)
763
+ # Replace all layers with adapter-based layers
764
+ self.layers = nn.ModuleList(
765
+ [OlmoEWithAdaptersDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
766
+ )
767
+
768
+ # Initialize weights
769
+ self.post_init()
770
+
771
+
772
+ class OlmoEWithAdaptersForCausalLM(OlmoEForCausalLM):
773
+ """OLMo for causal language modeling with adapters"""
774
+
775
+ def __init__(self, config, adapters_config: Optional[Dict[str, Any]] = None):
776
+ super().__init__(config)
777
+ self.adapters_config = adapters_config
778
+
779
+ # Initialize the model with adapters using the config
780
+ self.model = OlmoEWithAdaptersModel(config)
781
+
782
+ # Initialize weights
783
+ self.post_init()
784
+
785
+ def freeze_base_model(self):
786
+ """Freeze all parameters except adapters for efficient fine-tuning"""
787
+ for param in self.model.embed_tokens.parameters():
788
+ param.requires_grad = False
789
+
790
+ for layer in self.model.layers:
791
+ for name, param in layer.self_attn.named_parameters():
792
+ param.requires_grad = False
793
+
794
+ for name, param in layer.mlp.named_parameters():
795
+ if "down_adapter" not in name:
796
+ param.requires_grad = False
797
+
798
+ for param in layer.input_layernorm.parameters():
799
+ param.requires_grad = False
800
+ for param in layer.post_attention_layernorm.parameters():
801
+ param.requires_grad = False
802
+
803
+ for param in self.model.norm.parameters():
804
+ param.requires_grad = False
805
+
806
+ # Uncomment to freeze LM head
807
+ # for param in self.lm_head.parameters():
808
+ # param.requires_grad = False
809
+
810
+ def get_trainable_parameters(self):
811
+ """Return only trainable parameters for optimizer"""
812
+ return [p for p in self.parameters() if p.requires_grad]
813
+
814
+ @classmethod
815
+ def from_config_and_adapters(
816
+ cls,
817
+ config,
818
+ adapters_config: Optional[Dict[str, Any]] = None,
819
+ ) -> "OlmoEWithAdaptersForCausalLM":
820
+ """Optional factory method, if you want to keep this pattern."""
821
+ return cls(config=config, adapters_config=adapters_config)
822
+
oldcmds.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ export CUDA_HOME=$(dirname $(dirname $(which nvcc)))
2
+ export PATH=$CUDA_HOME/bin:$PATH
3
+ export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
output.txt ADDED
File without changes
randommoe.py ADDED
@@ -0,0 +1,1047 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from ...activations import ACT2FN
8
+ from ...cache_utils import Cache, DynamicCache
9
+ from ...generation import GenerationMixin
10
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
11
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
12
+ from ...modeling_layers import GradientCheckpointingLayer
13
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
15
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
16
+ from ...processing_utils import Unpack
17
+ from ...utils import (
18
+ LossKwargs,
19
+ add_start_docstrings,
20
+ add_start_docstrings_to_model_forward,
21
+ can_return_tuple,
22
+ is_torch_flex_attn_available,
23
+ logging,
24
+ replace_return_docstrings,
25
+ )
26
+ from .configuration_olmo import OlmoConfig
27
+
28
+
29
+ if is_torch_flex_attn_available():
30
+ from torch.nn.attention.flex_attention import BlockMask
31
+
32
+ from ...integrations.flex_attention import make_flex_block_causal_mask
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+ _CONFIG_FOR_DOC = "OlmoConfig"
37
+
38
+
39
+ class OlmoLayerNorm(nn.Module):
40
+ """LayerNorm but with no learnable weight or bias."""
41
+
42
+ def __init__(self, hidden_size: int) -> None:
43
+ super().__init__()
44
+ self.normalized_shape = (hidden_size,)
45
+
46
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
47
+ orig_dtype = hidden_states.dtype
48
+ return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
49
+ orig_dtype
50
+ )
51
+
52
+
53
+ class OlmoMLP(nn.Module):
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.config = config
57
+ self.hidden_size = config.hidden_size
58
+ self.intermediate_size = config.intermediate_size
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+ self.act_fn = ACT2FN[config.hidden_act]
63
+
64
+ def forward(self, x):
65
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
66
+ return down_proj
67
+
68
+
69
+ def rotate_half(x):
70
+ """Rotates half the hidden dims of the input."""
71
+ x1 = x[..., : x.shape[-1] // 2]
72
+ x2 = x[..., x.shape[-1] // 2 :]
73
+ return torch.cat((-x2, x1), dim=-1)
74
+
75
+
76
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
77
+ """Applies Rotary Position Embedding to the query and key tensors.
78
+
79
+ Args:
80
+ q (`torch.Tensor`): The query tensor.
81
+ k (`torch.Tensor`): The key tensor.
82
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
83
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
84
+ position_ids (`torch.Tensor`, *optional*):
85
+ Deprecated and unused.
86
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
87
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
88
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
89
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
90
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
91
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
92
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
93
+ Returns:
94
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
95
+ """
96
+ cos = cos.unsqueeze(unsqueeze_dim)
97
+ sin = sin.unsqueeze(unsqueeze_dim)
98
+ q_embed = (q * cos) + (rotate_half(q) * sin)
99
+ k_embed = (k * cos) + (rotate_half(k) * sin)
100
+ return q_embed, k_embed
101
+
102
+
103
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
104
+ """
105
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
106
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
107
+ """
108
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
109
+ if n_rep == 1:
110
+ return hidden_states
111
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
112
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
113
+
114
+
115
+ def eager_attention_forward(
116
+ module: nn.Module,
117
+ query: torch.Tensor,
118
+ key: torch.Tensor,
119
+ value: torch.Tensor,
120
+ attention_mask: Optional[torch.Tensor],
121
+ scaling: float,
122
+ dropout: float = 0.0,
123
+ **kwargs,
124
+ ):
125
+ key_states = repeat_kv(key, module.num_key_value_groups)
126
+ value_states = repeat_kv(value, module.num_key_value_groups)
127
+
128
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
129
+ if attention_mask is not None:
130
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
131
+ attn_weights = attn_weights + causal_mask
132
+
133
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
134
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
135
+ attn_output = torch.matmul(attn_weights, value_states)
136
+ attn_output = attn_output.transpose(1, 2).contiguous()
137
+
138
+ return attn_output, attn_weights
139
+
140
+
141
+ class OlmoAttention(nn.Module):
142
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
143
+
144
+ def __init__(self, config: OlmoConfig, layer_idx: int):
145
+ super().__init__()
146
+ self.config = config
147
+ self.layer_idx = layer_idx
148
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
149
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
150
+ self.scaling = self.head_dim**-0.5
151
+ self.attention_dropout = config.attention_dropout
152
+ self.is_causal = True
153
+
154
+ self.q_proj = nn.Linear(
155
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
156
+ )
157
+ self.k_proj = nn.Linear(
158
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
159
+ )
160
+ self.v_proj = nn.Linear(
161
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
162
+ )
163
+ self.o_proj = nn.Linear(
164
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
165
+ )
166
+
167
+ def forward(
168
+ self,
169
+ hidden_states: torch.Tensor,
170
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
171
+ attention_mask: Optional[torch.Tensor],
172
+ past_key_value: Optional[Cache] = None,
173
+ cache_position: Optional[torch.LongTensor] = None,
174
+ **kwargs,
175
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
176
+ input_shape = hidden_states.shape[:-1]
177
+ hidden_shape = (*input_shape, -1, self.head_dim)
178
+
179
+ query_states = self.q_proj(hidden_states)
180
+ key_states = self.k_proj(hidden_states)
181
+ value_states = self.v_proj(hidden_states)
182
+
183
+ if self.config.clip_qkv is not None:
184
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
185
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
186
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
187
+
188
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
189
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
190
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
191
+
192
+ cos, sin = position_embeddings
193
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
194
+
195
+ if past_key_value is not None:
196
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
197
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
198
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
199
+
200
+ attention_interface: Callable = eager_attention_forward
201
+ if self.config._attn_implementation != "eager":
202
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
203
+ logger.warning_once(
204
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
205
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
206
+ )
207
+ else:
208
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
209
+
210
+ attn_output, attn_weights = attention_interface(
211
+ self,
212
+ query_states,
213
+ key_states,
214
+ value_states,
215
+ attention_mask,
216
+ dropout=0.0 if not self.training else self.attention_dropout,
217
+ scaling=self.scaling,
218
+ **kwargs,
219
+ )
220
+
221
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
222
+ attn_output = self.o_proj(attn_output)
223
+ return attn_output, attn_weights
224
+
225
+
226
+ class OlmoDecoderLayer(GradientCheckpointingLayer):
227
+ def __init__(self, config: OlmoConfig, layer_idx: int):
228
+ super().__init__()
229
+ self.hidden_size = config.hidden_size
230
+ self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
231
+
232
+ self.mlp = OlmoMLP(config)
233
+ self.input_layernorm = OlmoLayerNorm(config.hidden_size)
234
+ self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states: torch.Tensor,
239
+ attention_mask: Optional[torch.Tensor] = None,
240
+ position_ids: Optional[torch.LongTensor] = None,
241
+ past_key_value: Optional[Cache] = None,
242
+ output_attentions: Optional[bool] = False,
243
+ use_cache: Optional[bool] = False,
244
+ cache_position: Optional[torch.LongTensor] = None,
245
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
246
+ **kwargs: Unpack[FlashAttentionKwargs],
247
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
248
+ residual = hidden_states
249
+ hidden_states = self.input_layernorm(hidden_states)
250
+
251
+ # Self Attention
252
+ hidden_states, self_attn_weights = self.self_attn(
253
+ hidden_states=hidden_states,
254
+ attention_mask=attention_mask,
255
+ position_ids=position_ids,
256
+ past_key_value=past_key_value,
257
+ output_attentions=output_attentions,
258
+ use_cache=use_cache,
259
+ cache_position=cache_position,
260
+ position_embeddings=position_embeddings,
261
+ **kwargs,
262
+ )
263
+ hidden_states = residual + hidden_states
264
+
265
+ # Fully Connected
266
+ residual = hidden_states
267
+ hidden_states = self.post_attention_layernorm(hidden_states)
268
+ hidden_states = self.mlp(hidden_states)
269
+ hidden_states = residual + hidden_states
270
+
271
+ outputs = (hidden_states,)
272
+ if output_attentions:
273
+ outputs += (self_attn_weights,)
274
+
275
+ return outputs
276
+
277
+
278
+ class OlmoRotaryEmbedding(nn.Module):
279
+ def __init__(self, config: OlmoConfig, device=None):
280
+ super().__init__()
281
+ # BC: "rope_type" was originally "type"
282
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
283
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
284
+ else:
285
+ self.rope_type = "default"
286
+ self.max_seq_len_cached = config.max_position_embeddings
287
+ self.original_max_seq_len = config.max_position_embeddings
288
+
289
+ self.config = config
290
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
291
+
292
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
293
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
294
+ self.original_inv_freq = self.inv_freq
295
+
296
+ @torch.no_grad()
297
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
298
+ def forward(self, x, position_ids):
299
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
300
+ position_ids_expanded = position_ids[:, None, :].float()
301
+
302
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
303
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
304
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
305
+ emb = torch.cat((freqs, freqs), dim=-1)
306
+ cos = emb.cos() * self.attention_scaling
307
+ sin = emb.sin() * self.attention_scaling
308
+
309
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
310
+
311
+
312
+ OLMO_START_DOCSTRING = r"""
313
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
314
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
315
+ etc.)
316
+
317
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
318
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
319
+ and behavior.
320
+
321
+ Parameters:
322
+ config ([`OlmoConfig`]):
323
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
324
+ load the weights associated with the model, only the configuration. Check out the
325
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
326
+ """
327
+
328
+
329
+ @add_start_docstrings(
330
+ "The bare Olmo Model outputting raw hidden-states without any specific head on top.",
331
+ OLMO_START_DOCSTRING,
332
+ )
333
+ class OlmoPreTrainedModel(PreTrainedModel):
334
+ config_class = OlmoConfig
335
+ base_model_prefix = "model"
336
+ supports_gradient_checkpointing = True
337
+ _no_split_modules = ["OlmoDecoderLayer"]
338
+ _skip_keys_device_placement = ["past_key_values"]
339
+ _supports_flash_attn_2 = True
340
+ _supports_sdpa = True
341
+ _supports_flex_attn = True
342
+ _supports_cache_class = True
343
+ _supports_quantized_cache = True
344
+ _supports_static_cache = True
345
+ _supports_attention_backend = True
346
+
347
+ def _init_weights(self, module):
348
+ std = self.config.initializer_range
349
+ if isinstance(module, nn.Linear):
350
+ module.weight.data.normal_(mean=0.0, std=std)
351
+ if module.bias is not None:
352
+ module.bias.data.zero_()
353
+ elif isinstance(module, nn.Embedding):
354
+ module.weight.data.normal_(mean=0.0, std=std)
355
+ if module.padding_idx is not None:
356
+ module.weight.data[module.padding_idx].zero_()
357
+
358
+
359
+ OLMO_INPUTS_DOCSTRING = r"""
360
+ Args:
361
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
362
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
363
+ it.
364
+
365
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
366
+ [`PreTrainedTokenizer.__call__`] for details.
367
+
368
+ [What are input IDs?](../glossary#input-ids)
369
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
370
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
371
+
372
+ - 1 for tokens that are **not masked**,
373
+ - 0 for tokens that are **masked**.
374
+
375
+ If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
376
+ but you can also pass a `BlockMask` object directly here.
377
+
378
+ [What are attention masks?](../glossary#attention-mask)
379
+
380
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
381
+ [`PreTrainedTokenizer.__call__`] for details.
382
+
383
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
384
+ `past_key_values`).
385
+
386
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
387
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
388
+ information on the default strategy.
389
+
390
+ - 1 indicates the head is **not masked**,
391
+ - 0 indicates the head is **masked**.
392
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
393
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
394
+ config.n_positions - 1]`.
395
+
396
+ [What are position IDs?](../glossary#position-ids)
397
+ past_key_values (`Cache`, *optional*):
398
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
399
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
400
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
401
+
402
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
403
+
404
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
405
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
406
+ of shape `(batch_size, sequence_length)`.
407
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
408
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
409
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
410
+ model's internal embedding lookup matrix.
411
+ use_cache (`bool`, *optional*):
412
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
413
+ `past_key_values`).
414
+ output_attentions (`bool`, *optional*):
415
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
416
+ tensors for more detail.
417
+ output_hidden_states (`bool`, *optional*):
418
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
419
+ more detail.
420
+ return_dict (`bool`, *optional*):
421
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
422
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
423
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
424
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
425
+ the complete sequence length.
426
+ """
427
+
428
+
429
+ @add_start_docstrings(
430
+ "The bare Olmo Model outputting raw hidden-states without any specific head on top.",
431
+ OLMO_START_DOCSTRING,
432
+ )
433
+ class OlmoModel(OlmoPreTrainedModel):
434
+ """
435
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoDecoderLayer`]
436
+ olmo's mapping in https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/modeling_auto.py
437
+
438
+ Args:
439
+ config: OlmoConfig
440
+ """
441
+
442
+ def __init__(self, config: OlmoConfig):
443
+ super().__init__(config)
444
+ self.padding_idx = config.pad_token_id
445
+ self.vocab_size = config.vocab_size
446
+
447
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
448
+ self.layers = nn.ModuleList(
449
+ [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
450
+ )
451
+ self.norm = OlmoLayerNorm(config.hidden_size)
452
+ self.rotary_emb = OlmoRotaryEmbedding(config=config)
453
+ self.gradient_checkpointing = False
454
+
455
+ # Initialize weights and apply final processing
456
+ self.post_init()
457
+
458
+ def get_input_embeddings(self):
459
+ return self.embed_tokens
460
+
461
+ def set_input_embeddings(self, value):
462
+ self.embed_tokens = value
463
+
464
+ @can_return_tuple
465
+ @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
466
+ def forward(
467
+ self,
468
+ input_ids: Optional[torch.LongTensor] = None,
469
+ attention_mask: Optional[torch.Tensor] = None,
470
+ position_ids: Optional[torch.LongTensor] = None,
471
+ past_key_values: Optional[Cache] = None,
472
+ inputs_embeds: Optional[torch.FloatTensor] = None,
473
+ use_cache: Optional[bool] = None,
474
+ output_attentions: Optional[bool] = None,
475
+ output_hidden_states: Optional[bool] = None,
476
+ cache_position: Optional[torch.LongTensor] = None,
477
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
478
+ ) -> BaseModelOutputWithPast:
479
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
480
+ output_hidden_states = (
481
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
482
+ )
483
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
484
+
485
+ if (input_ids is None) ^ (inputs_embeds is not None):
486
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
487
+
488
+ if self.gradient_checkpointing and self.training and use_cache:
489
+ logger.warning_once(
490
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
491
+ )
492
+ use_cache = False
493
+
494
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
495
+ if not isinstance(past_key_values, (type(None), Cache)):
496
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
497
+
498
+ if inputs_embeds is None:
499
+ inputs_embeds = self.embed_tokens(input_ids)
500
+
501
+ if use_cache and past_key_values is None:
502
+ past_key_values = DynamicCache()
503
+
504
+ if cache_position is None:
505
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
506
+ cache_position = torch.arange(
507
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
508
+ )
509
+
510
+ if position_ids is None:
511
+ position_ids = cache_position.unsqueeze(0)
512
+
513
+ causal_mask = self._update_causal_mask(
514
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
515
+ )
516
+
517
+ hidden_states = inputs_embeds
518
+
519
+ # create position embeddings to be shared across the decoder layers
520
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
521
+
522
+ # decoder layers
523
+ all_hidden_states = () if output_hidden_states else None
524
+ all_self_attns = () if output_attentions else None
525
+
526
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
527
+ if output_hidden_states:
528
+ all_hidden_states += (hidden_states,)
529
+
530
+ layer_outputs = decoder_layer(
531
+ hidden_states,
532
+ attention_mask=causal_mask,
533
+ position_ids=position_ids,
534
+ past_key_value=past_key_values,
535
+ output_attentions=output_attentions,
536
+ use_cache=use_cache,
537
+ cache_position=cache_position,
538
+ position_embeddings=position_embeddings,
539
+ **flash_attn_kwargs,
540
+ )
541
+
542
+ hidden_states = layer_outputs[0]
543
+
544
+ if output_attentions:
545
+ all_self_attns += (layer_outputs[1],)
546
+
547
+ hidden_states = self.norm(hidden_states)
548
+
549
+ # add hidden states from the last decoder layer
550
+ if output_hidden_states:
551
+ all_hidden_states += (hidden_states,)
552
+
553
+ return BaseModelOutputWithPast(
554
+ last_hidden_state=hidden_states,
555
+ past_key_values=past_key_values if use_cache else None,
556
+ hidden_states=all_hidden_states,
557
+ attentions=all_self_attns,
558
+ )
559
+
560
+ def _update_causal_mask(
561
+ self,
562
+ attention_mask: Union[torch.Tensor, "BlockMask"],
563
+ input_tensor: torch.Tensor,
564
+ cache_position: torch.Tensor,
565
+ past_key_values: Cache,
566
+ output_attentions: bool = False,
567
+ ):
568
+ if self.config._attn_implementation == "flash_attention_2":
569
+ if attention_mask is not None and (attention_mask == 0.0).any():
570
+ return attention_mask
571
+ return None
572
+ if self.config._attn_implementation == "flex_attention":
573
+ if isinstance(attention_mask, torch.Tensor):
574
+ attention_mask = make_flex_block_causal_mask(attention_mask)
575
+ return attention_mask
576
+
577
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
578
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
579
+ # to infer the attention mask.
580
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
581
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
582
+
583
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
584
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
585
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
586
+ attention_mask,
587
+ inputs_embeds=input_tensor,
588
+ past_key_values_length=past_seen_tokens,
589
+ is_training=self.training,
590
+ ):
591
+ return None
592
+
593
+ dtype = input_tensor.dtype
594
+ sequence_length = input_tensor.shape[1]
595
+ if using_compilable_cache:
596
+ target_length = past_key_values.get_max_cache_shape()
597
+ else:
598
+ target_length = (
599
+ attention_mask.shape[-1]
600
+ if isinstance(attention_mask, torch.Tensor)
601
+ else past_seen_tokens + sequence_length + 1
602
+ )
603
+
604
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
605
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
606
+ attention_mask,
607
+ sequence_length=sequence_length,
608
+ target_length=target_length,
609
+ dtype=dtype,
610
+ cache_position=cache_position,
611
+ batch_size=input_tensor.shape[0],
612
+ )
613
+
614
+ if (
615
+ self.config._attn_implementation == "sdpa"
616
+ and attention_mask is not None
617
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
618
+ and not output_attentions
619
+ ):
620
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
621
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
622
+ # Details: https://github.com/pytorch/pytorch/issues/110213
623
+ min_dtype = torch.finfo(dtype).min
624
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
625
+
626
+ return causal_mask
627
+
628
+ @staticmethod
629
+ def _prepare_4d_causal_attention_mask_with_cache_position(
630
+ attention_mask: torch.Tensor,
631
+ sequence_length: int,
632
+ target_length: int,
633
+ dtype: torch.dtype,
634
+ cache_position: torch.Tensor,
635
+ batch_size: int,
636
+ **kwargs,
637
+ ):
638
+ """
639
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
640
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
641
+
642
+ Args:
643
+ attention_mask (`torch.Tensor`):
644
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
645
+ `(batch_size, 1, query_length, key_value_length)`.
646
+ sequence_length (`int`):
647
+ The sequence length being processed.
648
+ target_length (`int`):
649
+ The target length: when generating with static cache, the mask should be as long as the static cache,
650
+ to account for the 0 padding, the part of the cache that is not filled yet.
651
+ dtype (`torch.dtype`):
652
+ The dtype to use for the 4D attention mask.
653
+ cache_position (`torch.Tensor`):
654
+ Indices depicting the position of the input sequence tokens in the sequence.
655
+ batch_size (`torch.Tensor`):
656
+ Batch size.
657
+ """
658
+ if attention_mask is not None and attention_mask.dim() == 4:
659
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
660
+ causal_mask = attention_mask
661
+ else:
662
+ min_dtype = torch.finfo(dtype).min
663
+ causal_mask = torch.full(
664
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
665
+ )
666
+ if sequence_length != 1:
667
+ causal_mask = torch.triu(causal_mask, diagonal=1)
668
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
669
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
670
+ if attention_mask is not None:
671
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
672
+ mask_length = attention_mask.shape[-1]
673
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
674
+ causal_mask.device
675
+ )
676
+ padding_mask = padding_mask == 0
677
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
678
+ padding_mask, min_dtype
679
+ )
680
+
681
+ return causal_mask
682
+
683
+
684
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
685
+
686
+
687
+ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
688
+ _tied_weights_keys = ["lm_head.weight"]
689
+ _tp_plan = {"lm_head": "colwise_rep"}
690
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
691
+
692
+ def __init__(self, config):
693
+ super().__init__(config)
694
+ self.model = OlmoModel(config)
695
+ self.vocab_size = config.vocab_size
696
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
697
+
698
+ # Initialize weights and apply final processing
699
+ self.post_init()
700
+
701
+ def get_input_embeddings(self):
702
+ return self.model.embed_tokens
703
+
704
+ def set_input_embeddings(self, value):
705
+ self.model.embed_tokens = value
706
+
707
+ def get_output_embeddings(self):
708
+ return self.lm_head
709
+
710
+ def set_output_embeddings(self, new_embeddings):
711
+ self.lm_head = new_embeddings
712
+
713
+ def set_decoder(self, decoder):
714
+ self.model = decoder
715
+
716
+ def get_decoder(self):
717
+ return self.model
718
+
719
+ @can_return_tuple
720
+ @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
721
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
722
+ def forward(
723
+ self,
724
+ input_ids: Optional[torch.LongTensor] = None,
725
+ attention_mask: Optional[torch.Tensor] = None,
726
+ position_ids: Optional[torch.LongTensor] = None,
727
+ past_key_values: Optional[Cache] = None,
728
+ inputs_embeds: Optional[torch.FloatTensor] = None,
729
+ labels: Optional[torch.LongTensor] = None,
730
+ use_cache: Optional[bool] = None,
731
+ output_attentions: Optional[bool] = None,
732
+ output_hidden_states: Optional[bool] = None,
733
+ cache_position: Optional[torch.LongTensor] = None,
734
+ logits_to_keep: Union[int, torch.Tensor] = 0,
735
+ **kwargs: Unpack[KwargsForCausalLM],
736
+ ) -> CausalLMOutputWithPast:
737
+ r"""
738
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
739
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
740
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
741
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
742
+
743
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
744
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
745
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
746
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
747
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
748
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
749
+
750
+ Returns:
751
+
752
+ Example:
753
+
754
+ ```python
755
+ >>> from transformers import AutoTokenizer, OlmoForCausalLM
756
+
757
+ >>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf")
758
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf")
759
+
760
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
761
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
762
+
763
+ >>> # Generate
764
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
765
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
766
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
767
+ ```"""
768
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
769
+ output_hidden_states = (
770
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
771
+ )
772
+
773
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
774
+ outputs: BaseModelOutputWithPast = self.model(
775
+ input_ids=input_ids,
776
+ attention_mask=attention_mask,
777
+ position_ids=position_ids,
778
+ past_key_values=past_key_values,
779
+ inputs_embeds=inputs_embeds,
780
+ use_cache=use_cache,
781
+ output_attentions=output_attentions,
782
+ output_hidden_states=output_hidden_states,
783
+ cache_position=cache_position,
784
+ **kwargs,
785
+ )
786
+
787
+ hidden_states = outputs.last_hidden_state
788
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
789
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
790
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
791
+
792
+ loss = None
793
+ if labels is not None:
794
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
795
+
796
+ return CausalLMOutputWithPast(
797
+ loss=loss,
798
+ logits=logits,
799
+ past_key_values=outputs.past_key_values,
800
+ hidden_states=outputs.hidden_states,
801
+ attentions=outputs.attentions,
802
+ )
803
+
804
+ import torch
805
+ import torch.nn as nn
806
+ import torch.nn.functional as F
807
+
808
+
809
+ class OlmoMoERouter(nn.Module):
810
+ """
811
+ Router module that uses random importance sampling instead of deterministic top-k.
812
+
813
+ This router computes logits for each expert, converts them to probabilities,
814
+ and then randomly samples experts based on these probabilities.
815
+ """
816
+ def __init__(self, config):
817
+ super().__init__()
818
+ self.hidden_size = config.hidden_size
819
+ self.num_experts = config.num_experts
820
+ self.router = nn.Linear(self.hidden_size, self.num_experts, bias=False)
821
+ self.top_k = config.num_selected_experts
822
+ self.temperature = config.router_temperature if hasattr(config, "router_temperature") else 1.0
823
+
824
+ def forward(self, hidden_states):
825
+ """
826
+ Args:
827
+ hidden_states: [batch_size, sequence_length, hidden_size]
828
+
829
+ Returns:
830
+ routing_weights: [batch_size, sequence_length, top_k]
831
+ routing_indices: [batch_size, sequence_length, top_k]
832
+ """
833
+ batch_size, sequence_length, _ = hidden_states.shape
834
+
835
+ # Compute router logits and apply temperature
836
+ router_logits = self.router(hidden_states) / self.temperature # [batch_size, sequence_length, num_experts]
837
+
838
+ # Convert to probabilities using softmax
839
+ router_probs = F.softmax(router_logits, dim=-1) # [batch_size, sequence_length, num_experts]
840
+
841
+ # For random importance sampling, we'll:
842
+ # 1. Add Gumbel noise to the log probabilities to induce randomness
843
+ # 2. Sample top-k experts using the perturbed probabilities
844
+
845
+ # Add Gumbel noise
846
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(router_probs) + 1e-10) + 1e-10)
847
+ perturbed_logits = torch.log(router_probs + 1e-10) + gumbel_noise
848
+
849
+ # Sample top-k experts based on perturbed probabilities
850
+ routing_weights, routing_indices = torch.topk(perturbed_logits, self.top_k, dim=-1)
851
+
852
+ # Re-normalize the selected probabilities
853
+ routing_weights = router_probs.gather(-1, routing_indices)
854
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
855
+
856
+ return routing_weights, routing_indices
857
+
858
+
859
+ class OlmoExpertMLP(nn.Module):
860
+ """
861
+ Expert MLP module similar to OlmoMLP but used in the MoE architecture.
862
+ """
863
+ def __init__(self, config):
864
+ super().__init__()
865
+ self.config = config
866
+ self.hidden_size = config.hidden_size
867
+ self.intermediate_size = config.intermediate_size
868
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
869
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
870
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
871
+ self.act_fn = ACT2FN[config.hidden_act]
872
+
873
+ def forward(self, x):
874
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
875
+ return down_proj
876
+
877
+
878
+ class OlmoMixtureOfExperts(nn.Module):
879
+ """
880
+ Mixture of Experts layer that replaces the standard MLP in OLMo.
881
+ """
882
+ def __init__(self, config):
883
+ super().__init__()
884
+ self.config = config
885
+ self.num_experts = config.num_experts
886
+ self.num_selected_experts = config.num_selected_experts # top_k
887
+
888
+ # Create router
889
+ self.router = OlmoMoERouter(config)
890
+
891
+ # Create experts
892
+ self.experts = nn.ModuleList([OlmoExpertMLP(config) for _ in range(self.num_experts)])
893
+
894
+ # Expert capacity factor (to avoid load balancing issues)
895
+ self.capacity_factor = config.expert_capacity_factor if hasattr(config, "expert_capacity_factor") else 1.0
896
+
897
+ def forward(self, hidden_states):
898
+ """
899
+ Args:
900
+ hidden_states: [batch_size, sequence_length, hidden_size]
901
+
902
+ Returns:
903
+ outputs: [batch_size, sequence_length, hidden_size]
904
+ """
905
+ batch_size, sequence_length, hidden_size = hidden_states.shape
906
+
907
+ # Get routing weights and indices
908
+ routing_weights, routing_indices = self.router(hidden_states)
909
+
910
+ # Reshape tensors for processing
911
+ flat_hidden_states = hidden_states.reshape(-1, hidden_size) # [batch_size * sequence_length, hidden_size]
912
+
913
+ # Initialize expert outputs
914
+ final_output = torch.zeros_like(flat_hidden_states)
915
+
916
+ # For each expert, compute its contribution
917
+ for expert_idx in range(self.num_experts):
918
+ # Create a mask to identify which tokens use this expert
919
+ expert_mask = (routing_indices == expert_idx).any(dim=-1).reshape(-1)
920
+
921
+ if not expert_mask.any():
922
+ continue # Skip if no token routes to this expert
923
+
924
+ # Get the hidden states for tokens routed to this expert
925
+ expert_inputs = flat_hidden_states[expert_mask]
926
+
927
+ # Process these hidden states through the expert
928
+ expert_outputs = self.experts[expert_idx](expert_inputs)
929
+
930
+ # Find weights for this expert
931
+ expert_weights = routing_weights[routing_indices == expert_idx].reshape(-1, 1)
932
+
933
+ # Multiply outputs by the routing weights
934
+ weighted_outputs = expert_outputs * expert_weights
935
+
936
+ # Combine the expert outputs into the final output tensor
937
+ final_output[expert_mask] += weighted_outputs
938
+
939
+ # Reshape back to original dimensions
940
+ final_output = final_output.reshape(batch_size, sequence_length, hidden_size)
941
+
942
+ return final_output
943
+
944
+
945
+ # Modified OlmoDecoderLayer to use MoE instead of standard MLP
946
+ class OlmoMoEDecoderLayer(GradientCheckpointingLayer):
947
+ def __init__(self, config: OlmoConfig, layer_idx: int):
948
+ super().__init__()
949
+ self.hidden_size = config.hidden_size
950
+ self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
951
+
952
+ # Use MoE instead of standard MLP
953
+ self.mlp = OlmoMixtureOfExperts(config)
954
+ self.input_layernorm = OlmoLayerNorm(config.hidden_size)
955
+ self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
956
+
957
+ def forward(
958
+ self,
959
+ hidden_states: torch.Tensor,
960
+ attention_mask: Optional[torch.Tensor] = None,
961
+ position_ids: Optional[torch.LongTensor] = None,
962
+ past_key_value: Optional[Cache] = None,
963
+ output_attentions: Optional[bool] = False,
964
+ use_cache: Optional[bool] = False,
965
+ cache_position: Optional[torch.LongTensor] = None,
966
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
967
+ **kwargs: Unpack[FlashAttentionKwargs],
968
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
969
+ residual = hidden_states
970
+ hidden_states = self.input_layernorm(hidden_states)
971
+
972
+ # Self Attention
973
+ hidden_states, self_attn_weights = self.self_attn(
974
+ hidden_states=hidden_states,
975
+ attention_mask=attention_mask,
976
+ position_ids=position_ids,
977
+ past_key_value=past_key_value,
978
+ output_attentions=output_attentions,
979
+ use_cache=use_cache,
980
+ cache_position=cache_position,
981
+ position_embeddings=position_embeddings,
982
+ **kwargs,
983
+ )
984
+ hidden_states = residual + hidden_states
985
+
986
+ # MoE instead of Fully Connected
987
+ residual = hidden_states
988
+ hidden_states = self.post_attention_layernorm(hidden_states)
989
+ hidden_states = self.mlp(hidden_states)
990
+ hidden_states = residual + hidden_states
991
+
992
+ outputs = (hidden_states,)
993
+ if output_attentions:
994
+ outputs += (self_attn_weights,)
995
+
996
+ return outputs
997
+
998
+
999
+ # Modified OlmoConfig to include MoE-specific parameters
1000
+ class OlmoMoEConfig(OlmoConfig):
1001
+ def __init__(
1002
+ self,
1003
+ num_experts=8,
1004
+ num_selected_experts=2,
1005
+ expert_capacity_factor=1.0,
1006
+ router_temperature=0.1,
1007
+ **kwargs
1008
+ ):
1009
+ super().__init__(**kwargs)
1010
+ self.num_experts = num_experts
1011
+ self.num_selected_experts = num_selected_experts
1012
+ self.expert_capacity_factor = expert_capacity_factor
1013
+ self.router_temperature = router_temperature
1014
+
1015
+
1016
+ # Modified OlmoModel to use MoE decoder layers
1017
+ class OlmoMoEModel(OlmoModel):
1018
+ def __init__(self, config: OlmoMoEConfig):
1019
+ OlmoPreTrainedModel.__init__(self, config)
1020
+ self.padding_idx = config.pad_token_id
1021
+ self.vocab_size = config.vocab_size
1022
+
1023
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1024
+ # Use MoE decoder layers
1025
+ self.layers = nn.ModuleList(
1026
+ [OlmoMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1027
+ )
1028
+ self.norm = OlmoLayerNorm(config.hidden_size)
1029
+ self.rotary_emb = OlmoRotaryEmbedding(config=config)
1030
+ self.gradient_checkpointing = False
1031
+
1032
+ # Initialize weights and apply final processing
1033
+ self.post_init()
1034
+
1035
+
1036
+ # Modified OlmoForCausalLM to use MoE model
1037
+ class OlmoMoEForCausalLM(OlmoForCausalLM):
1038
+ def __init__(self, config):
1039
+ OlmoPreTrainedModel.__init__(self, config)
1040
+ self.model = OlmoMoEModel(config)
1041
+ self.vocab_size = config.vocab_size
1042
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1043
+
1044
+ # Initialize weights and apply final processing
1045
+ self.post_init()
1046
+
1047
+ __all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.34.0
3
+ accelerate>=0.25.0
4
+ datasets>=2.14.0
5
+ tqdm>=4.66.0
6
+ bitsandbytes>=0.41.0 # For 8-bit training if needed
7
+ sentencepiece>=0.1.99 # For tokenization
8
+ protobuf>=4.23.4 # For datasets loading
9
+ tensorboard>=2.13.0 # For training monitoring
shellcommands.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ conda activate rlmoe
2
+ cd SkipMoE
3
+ python train.py
train.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ # runs train_olmoe_adapter.py with parameters when called
3
+ # #!/usr/bin/env python
4
+ """
5
+ Run script for fine-tuning OlmoE with adapters on specific text domains.
6
+ Handles argument parsing and configuration.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import sys
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+
15
+ from transformers import (
16
+ HfArgumentParser,
17
+ TrainingArguments,
18
+ )
19
+
20
+
21
+ @dataclass
22
+ class ScriptArguments:
23
+ """
24
+ Arguments for the run script that aren't covered by TrainingArguments.
25
+ """
26
+ model_path: str = field(
27
+ default="allenai/OLMo-7B-Instruct",
28
+ metadata={"help": "Path to the model to fine-tune"}
29
+ )
30
+ output_dir: str = field(
31
+ default="./output_olmoe_adapter",
32
+ metadata={"help": "Directory to save the model and logs"}
33
+ )
34
+ adapter_size: int = field(
35
+ default=64,
36
+ metadata={"help": "Size of the adapter layers"}
37
+ )
38
+ dataset_name: str = field(
39
+ default="mlfoundations/dclm-baseline-1.0",
40
+ metadata={"help": "Name of the dataset to use"}
41
+ )
42
+ max_steps: int = field(
43
+ default=10000,
44
+ metadata={"help": "Maximum number of training steps"}
45
+ )
46
+ learning_rate: float = field(
47
+ default=5e-5,
48
+ metadata={"help": "Learning rate for fine-tuning"}
49
+ )
50
+ per_device_batch_size: int = field(
51
+ default=8,
52
+ metadata={"help": "Batch size per device"}
53
+ )
54
+ gradient_accumulation_steps: int = field(
55
+ default=1,
56
+ metadata={"help": "Number of steps to accumulate gradients"}
57
+ )
58
+ # use_8bit: bool = field(
59
+ # default=False,
60
+ # metadata={"help": "Whether to use 8-bit precision"}
61
+ # )
62
+ # use_4bit: bool = field(
63
+ # default=False,
64
+ # metadata={"help": "Whether to use 4-bit precision"}
65
+ # )
66
+
67
+
68
+ def main():
69
+ # Parse command-line arguments
70
+ parser = HfArgumentParser(ScriptArguments)
71
+ args = parser.parse_args_into_dataclasses()[0]
72
+
73
+ # Create output directory
74
+ os.makedirs(args.output_dir, exist_ok=True)
75
+
76
+ # Prepare command for training
77
+ cmd = [
78
+ "python",
79
+ "train_olmoe_adapter.py",
80
+
81
+ # Model arguments
82
+ f"--model_name_or_path={args.model_path}",
83
+ f"--adapter_size={args.adapter_size}",
84
+ "--freeze_base_model=True", # Always freeze the base model
85
+ f"--checkpoint_dir={args.output_dir}",
86
+
87
+ # Data arguments
88
+ f"--dataset_name={args.dataset_name}",
89
+ "--streaming=True", # Always stream for large datasets
90
+ "--streaming_buffer_size=8192",
91
+ "--max_seq_length=1024",
92
+
93
+ # Training arguments
94
+ f"--output_dir={args.output_dir}",
95
+ f"--per_device_train_batch_size={args.per_device_batch_size}",
96
+ f"--gradient_accumulation_steps={args.gradient_accumulation_steps}",
97
+ f"--learning_rate={args.learning_rate}",
98
+ f"--max_steps={args.max_steps}",
99
+ "--warmup_steps=500",
100
+ "--logging_steps=10",
101
+ "--save_steps=1000",
102
+ "--save_total_limit=2",
103
+ "--dataloader_num_workers=4",
104
+ "--seed=42",
105
+ ]
106
+
107
+ # Add precision flags if needed
108
+ # if args.use_8bit:
109
+ # cmd.append("--load_in_8bit")
110
+
111
+ # if args.use_4bit:
112
+ # cmd.append("--load_in_4bit")
113
+
114
+ # Print the command for logging
115
+ cmd_str = " ".join(cmd)
116
+ print(f"Running command: {cmd_str}")
117
+
118
+ # Execute the training script
119
+ os.environ["PYTHONPATH"] = os.getcwd()
120
+ ret = os.system(cmd_str)
121
+
122
+ if ret != 0:
123
+ print(f"Training failed with exit code {ret}")
124
+ sys.exit(ret)
125
+
126
+ print("Training completed successfully!")
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
train_olmoe_adapter.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #train_olmoe_adapter.py
2
+ #!/usr/bin/env python
3
+ """
4
+ Training script for OlmoE model with adapters on the mlfoundations/dclm-baseline-1.0 dataset.
5
+ This script demonstrates parameter-efficient fine-tuning using adapters.
6
+ """
7
+
8
+ import os
9
+ import math
10
+ import logging
11
+ import argparse
12
+ from dataclasses import dataclass, field
13
+ from typing import Dict, List, Optional, Tuple, Any, Union
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import DataLoader, IterableDataset
19
+ from torch.optim import AdamW
20
+ from torch.optim.lr_scheduler import LambdaLR
21
+
22
+ from datasets import load_dataset
23
+ from transformers import (
24
+ OlmoConfig,
25
+ OlmoForCausalLM,
26
+ AutoTokenizer,
27
+ DataCollatorForLanguageModeling,
28
+ HfArgumentParser,
29
+ TrainingArguments,
30
+ set_seed,
31
+ get_scheduler,
32
+ )
33
+ from tqdm import tqdm
34
+ from accelerate import Accelerator, DistributedType
35
+ from accelerate.utils import find_batch_size
36
+
37
+ from modeling_olmoe import (
38
+ OlmoEWithAdaptersForCausalLM,
39
+ OlmoEForCausalLM,
40
+ )
41
+
42
+ # Set up logging
43
+ logger = logging.getLogger(__name__)
44
+ logging.basicConfig(
45
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
46
+ datefmt="%m/%d/%Y %H:%M:%S",
47
+ level=logging.INFO,
48
+ )
49
+
50
+ @dataclass
51
+ class ModelArguments:
52
+ """Arguments for model configuration."""
53
+ model_name_or_path: str = field(
54
+ default="allenai/OLMo-7B-Instruct",
55
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
56
+ )
57
+ adapter_size: int = field(
58
+ default=64,
59
+ metadata={"help": "Size of the adapter layers"}
60
+ )
61
+ freeze_base_model: bool = field(
62
+ default=True,
63
+ metadata={"help": "Whether to freeze all parameters except the adapters"}
64
+ )
65
+ checkpoint_dir: Optional[str] = field(
66
+ default=None,
67
+ metadata={"help": "Path to save model checkpoints"}
68
+ )
69
+
70
+
71
+ @dataclass
72
+ class DataArguments:
73
+ """Arguments for dataset configuration."""
74
+ dataset_name: str = field(
75
+ default="mlfoundations/dclm-baseline-1.0",
76
+ metadata={"help": "Dataset name or path for training"}
77
+ )
78
+ streaming: bool = field(
79
+ default=True,
80
+ metadata={"help": "Whether to stream the dataset"}
81
+ )
82
+ streaming_buffer_size: int = field(
83
+ default=8192,
84
+ metadata={"help": "Buffer size for streaming dataset"}
85
+ )
86
+ max_seq_length: int = field(
87
+ default=1024,
88
+ metadata={"help": "Maximum sequence length for training"}
89
+ )
90
+ preprocessing_num_workers: Optional[int] = field(
91
+ default=None,
92
+ metadata={"help": "Number of workers for preprocessing"}
93
+ )
94
+ text_column_name: str = field(
95
+ default="text",
96
+ metadata={"help": "Column name for text data"}
97
+ )
98
+
99
+
100
+ class StreamingTextDataset(IterableDataset):
101
+ """Dataset for streaming text data."""
102
+
103
+ def __init__(
104
+ self,
105
+ dataset_name: str,
106
+ tokenizer,
107
+ max_seq_length: int,
108
+ streaming: bool = True,
109
+ text_column_name: str = "text",
110
+ buffer_size: int = 8192,
111
+ split: str = "train",
112
+ ):
113
+ self.tokenizer = tokenizer
114
+ self.max_seq_length = max_seq_length
115
+ self.text_column_name = text_column_name
116
+
117
+ # Load dataset in streaming mode
118
+ self.dataset = load_dataset(
119
+ dataset_name,
120
+ split=split,
121
+ streaming=streaming,
122
+ )
123
+ if streaming:
124
+ # Buffer for streaming
125
+ self.dataset = self.dataset.shuffle(buffer_size=buffer_size)
126
+
127
+ def __iter__(self):
128
+ buffer = []
129
+ current_length = 0
130
+
131
+ for example in self.dataset:
132
+ text = example[self.text_column_name]
133
+ if not text or len(text.strip()) == 0:
134
+ continue
135
+
136
+ tokenized = self.tokenizer(
137
+ text,
138
+ truncation=False,
139
+ return_attention_mask=False,
140
+ return_token_type_ids=False,
141
+ add_special_tokens=False,
142
+ )
143
+
144
+ ids = tokenized["input_ids"]
145
+ buffer.extend(ids)
146
+
147
+ # Yield complete sequences and update buffer
148
+ while len(buffer) >= self.max_seq_length:
149
+ yield {
150
+ "input_ids": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long),
151
+ "labels": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long),
152
+ }
153
+ buffer = buffer[self.max_seq_length:]
154
+
155
+
156
+ def create_optimizer_and_scheduler(
157
+ model: nn.Module,
158
+ args: TrainingArguments,
159
+ num_training_steps: int
160
+ ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
161
+ """Create optimizer and learning rate scheduler."""
162
+
163
+ # Get only trainable parameters if using adapters with frozen base model
164
+ if hasattr(model, "get_trainable_parameters"):
165
+ optimizer_params = model.get_trainable_parameters()
166
+ logger.info(f"Training with {len(optimizer_params)} trainable parameters")
167
+ else:
168
+ # No parameter filtering - get all parameters that require grad
169
+ optimizer_params = [p for p in model.parameters() if p.requires_grad]
170
+ logger.info(f"Training with {len(optimizer_params)} parameters")
171
+
172
+ # Create optimizer
173
+ optimizer = AdamW(
174
+ optimizer_params,
175
+ lr=args.learning_rate,
176
+ betas=(args.adam_beta1, args.adam_beta2),
177
+ eps=args.adam_epsilon,
178
+ weight_decay=args.weight_decay,
179
+ )
180
+
181
+ # Create scheduler
182
+ scheduler = get_scheduler(
183
+ name=args.lr_scheduler_type,
184
+ optimizer=optimizer,
185
+ num_warmup_steps=args.warmup_steps,
186
+ num_training_steps=num_training_steps,
187
+ )
188
+
189
+ return optimizer, scheduler
190
+
191
+
192
+ def train(
193
+ model_args: ModelArguments,
194
+ data_args: DataArguments,
195
+ training_args: TrainingArguments,
196
+ ):
197
+ """Main training function."""
198
+
199
+ # Set up accelerator
200
+ accelerator = Accelerator(
201
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
202
+ mixed_precision=training_args.fp16 and "fp16" or training_args.bf16 and "bf16" or "no",
203
+ )
204
+
205
+ # Log information about the training setup
206
+ logger.info(accelerator.state)
207
+ if accelerator.is_local_main_process:
208
+ logger.info(f"Model arguments: {model_args}")
209
+ logger.info(f"Data arguments: {data_args}")
210
+ logger.info(f"Training arguments: {training_args}")
211
+
212
+ # Set seed for reproducibility
213
+ set_seed(training_args.seed)
214
+
215
+ # Load tokenizer and model
216
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
217
+
218
+ # Ensure the tokenizer has padding token and EOS token set
219
+ if tokenizer.pad_token is None:
220
+ tokenizer.pad_token = tokenizer.eos_token
221
+
222
+ # Load model config and update with adapter size
223
+ config = OlmoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
224
+ config.adapter_size = model_args.adapter_size
225
+
226
+ # Load model with adapters
227
+ logger.info(f"Loading OlmoE model with adapters from {model_args.model_name_or_path}")
228
+ base_model = OlmoForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
229
+
230
+ # Create adapter model from base model weights
231
+ model = OlmoEWithAdaptersForCausalLM(config)
232
+
233
+ # Copy weights from base model to adapter model
234
+ # This is needed because we're using a custom model class
235
+ model.load_state_dict(base_model.state_dict(), strict=False)
236
+
237
+ # Freeze base model parameters if requested
238
+ if model_args.freeze_base_model:
239
+ logger.info("Freezing base model parameters")
240
+ model.freeze_base_model()
241
+
242
+ # Set up streaming dataset
243
+ logger.info(f"Loading dataset: {data_args.dataset_name}")
244
+ train_dataset = StreamingTextDataset(
245
+ dataset_name=data_args.dataset_name,
246
+ tokenizer=tokenizer,
247
+ max_seq_length=data_args.max_seq_length,
248
+ streaming=data_args.streaming,
249
+ buffer_size=data_args.streaming_buffer_size,
250
+ text_column_name=data_args.text_column_name,
251
+ )
252
+
253
+ # Data collator to handle batching
254
+ data_collator = DataCollatorForLanguageModeling(
255
+ tokenizer=tokenizer,
256
+ mlm=False,
257
+ )
258
+
259
+ # Create data loader
260
+ train_dataloader = DataLoader(
261
+ train_dataset,
262
+ batch_size=training_args.per_device_train_batch_size,
263
+ collate_fn=data_collator,
264
+ num_workers=data_args.preprocessing_num_workers or 0,
265
+ )
266
+
267
+ # Estimate number of update steps
268
+ # For streaming datasets, we'll use a fixed number of steps
269
+ num_update_steps_per_epoch = training_args.max_steps
270
+ num_training_steps = training_args.max_steps
271
+
272
+ # Create optimizer and scheduler
273
+ optimizer, lr_scheduler = create_optimizer_and_scheduler(
274
+ model=model,
275
+ args=training_args,
276
+ num_training_steps=num_training_steps,
277
+ )
278
+
279
+ # Prepare for distributed training with accelerator
280
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
281
+ model, optimizer, train_dataloader, lr_scheduler
282
+ )
283
+
284
+ # Get total batch size for logging
285
+ total_batch_size = (
286
+ training_args.per_device_train_batch_size
287
+ * accelerator.num_processes
288
+ * training_args.gradient_accumulation_steps
289
+ )
290
+ logger.info(f"Total batch size (with parallel & accumulation): {total_batch_size}")
291
+
292
+ # Log estimated number of steps
293
+ logger.info(f"Number of training steps: {num_training_steps}")
294
+ logger.info(f"Number of warmup steps: {training_args.warmup_steps}")
295
+
296
+ # Keep track of training progress
297
+ progress_bar = tqdm(
298
+ range(num_training_steps),
299
+ disable=not accelerator.is_local_main_process,
300
+ desc="Training",
301
+ )
302
+ completed_steps = 0
303
+ starting_epoch = 0
304
+ global_step = 0
305
+
306
+ # Training loop
307
+ logger.info("Starting training...")
308
+ model.train()
309
+
310
+ for step, batch in enumerate(train_dataloader):
311
+ # Skip steps for resuming
312
+ if starting_epoch > 0 and step < starting_epoch * num_update_steps_per_epoch:
313
+ progress_bar.update(1)
314
+ continue
315
+
316
+ with accelerator.accumulate(model):
317
+ # Forward pass
318
+ outputs = model(**batch)
319
+ loss = outputs.loss
320
+
321
+ # Backward pass
322
+ accelerator.backward(loss)
323
+
324
+ # Update weights
325
+ optimizer.step()
326
+ lr_scheduler.step()
327
+ optimizer.zero_grad()
328
+
329
+ # Update progress bar
330
+ progress_bar.update(1)
331
+ completed_steps += 1
332
+ global_step += 1
333
+
334
+ # Log metrics
335
+ if global_step % training_args.logging_steps == 0:
336
+ # Gather loss from all processes
337
+ loss_value = accelerator.gather(loss).mean().item()
338
+ logger.info(f"Step {global_step}: loss = {loss_value:.4f}, lr = {lr_scheduler.get_last_lr()[0]:.8f}")
339
+
340
+ # Log to tensorboard if available
341
+ if hasattr(accelerator.trackers[0], "store"):
342
+ accelerator.trackers[0].store({
343
+ "loss": loss_value,
344
+ "learning_rate": lr_scheduler.get_last_lr()[0],
345
+ "step": global_step,
346
+ })
347
+
348
+ # Save checkpoint
349
+ if training_args.save_steps > 0 and global_step % training_args.save_steps == 0:
350
+ if model_args.checkpoint_dir is not None:
351
+ output_dir = os.path.join(model_args.checkpoint_dir, f"checkpoint-{global_step}")
352
+ accelerator.save_state(output_dir)
353
+ logger.info(f"Saved checkpoint to {output_dir}")
354
+
355
+ # Save the model separately
356
+ if accelerator.is_main_process:
357
+ unwrapped_model = accelerator.unwrap_model(model)
358
+ unwrapped_model.save_pretrained(
359
+ output_dir,
360
+ is_main_process=accelerator.is_main_process,
361
+ save_function=accelerator.save,
362
+ )
363
+ tokenizer.save_pretrained(output_dir)
364
+
365
+ # Check if we've reached max steps
366
+ if completed_steps >= num_training_steps:
367
+ break
368
+
369
+ # Save final model
370
+ if model_args.checkpoint_dir is not None:
371
+ output_dir = os.path.join(model_args.checkpoint_dir, "final-model")
372
+ accelerator.save_state(output_dir)
373
+
374
+ # Save the model separately
375
+ if accelerator.is_main_process:
376
+ unwrapped_model = accelerator.unwrap_model(model)
377
+ unwrapped_model.save_pretrained(
378
+ output_dir,
379
+ is_main_process=accelerator.is_main_process,
380
+ save_function=accelerator.save,
381
+ )
382
+ tokenizer.save_pretrained(output_dir)
383
+
384
+ logger.info(f"Saved final model to {output_dir}")
385
+
386
+ logger.info("Training complete!")
387
+
388
+
389
+ def main():
390
+ """Main entry point."""
391
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
392
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
393
+
394
+ # Set up output directory
395
+ if model_args.checkpoint_dir is None:
396
+ model_args.checkpoint_dir = training_args.output_dir
397
+ os.makedirs(model_args.checkpoint_dir, exist_ok=True)
398
+
399
+ # Run training
400
+ train(model_args, data_args, training_args)
401
+
402
+
403
+ if __name__ == "__main__":
404
+ main()