SirajRLX commited on
Commit
e527a65
·
verified ·
1 Parent(s): 89ec0cd

Add Training Scripts

Browse files
trainer-kit/.gitignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+ env.bak/
11
+ venv.bak/
12
+ pythonenv*
13
+ .pytest_cache/
14
+ ipynb_checkpoints/
15
+
16
+ # Virtualenv
17
+ .venv
18
+ venv/
19
+ virtualenv/
20
+ env/
21
+
22
+ # IDE
23
+ .vscode/
24
+ .idea/
25
+ *.sublime-workspace
26
+ *.sublime-project
27
+ *.swp
28
+ *.swo
29
+
30
+ # Build
31
+ build/
32
+ dist/
33
+ *.egg-info/
34
+
35
+ # Data and logs
36
+ data/
37
+ logs/
38
+ *.log
39
+ runs/**
40
+ output/**
41
+
42
+ # Jupyter
43
+ .ipynb_checkpoints/
44
+
45
+ # Environment
46
+ .env
47
+ .ENV
48
+ .env.bak
49
+ .venv
50
+ venv
51
+ venv.bak
52
+
53
+ # OS generated files
54
+ .DS_Store
55
+ Thumbs.db
trainer-kit/CPT-14b/README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Trainer‑Kit : Config‑Driven CPT (LoRA / QLoRA) with Packing, Logging, Resume, and Merge
2
+
3
+ Trainer‑Kit is a small, config‑driven training runner for **continued pretraining (CPT)** on causal LMs.
4
+ It supports **LoRA** and **QLoRA**, data **packing** (strict or padding‑masked), **checkpointing + resume**, **JSONL logging**, periodic **eval with perplexity**, and an optional **merge** step to export a final merged model.
5
+
6
+ ---
7
+
8
+ ## What we built
9
+
10
+ ### ✅ Core goals implemented
11
+
12
+ * **CPT training loop** controlled entirely via a **YAML config**
13
+ * **Local model support** (load from filesystem) and optional **HF download** (if `repo_id` is a hub id)
14
+ * **JSONL datasets** for train (+ optional eval split)
15
+ * **CPT‑style token stream packing** into fixed‑length blocks
16
+ * **Two packing modes**
17
+
18
+ * `drop`: strict CPT, drop remainder tokens (preferred for real CPT)
19
+ * `pad`: pad the remainder to `block_size` and **mask loss** on padding (useful for small datasets / debugging)
20
+ * **Checkpointing + resume**
21
+
22
+ * `resume_from_checkpoint: "auto"` resumes from the latest checkpoint under `run_dir/checkpoints`
23
+ * **JSONL logs** written locally
24
+
25
+ * training logs: `run_dir/logs/train.jsonl`
26
+ * eval logs: `run_dir/logs/eval.jsonl`
27
+ * **Evaluation**
28
+
29
+ * logs `eval_loss` and computed `perplexity = exp(eval_loss)` (with safe overflow guard)
30
+ * **Adapter output**
31
+
32
+ * saves the final/best adapter to `run_dir/best_adapter`
33
+ * **Merge workflow**
34
+
35
+ * `--merge-only` merges an existing adapter later
36
+ * merge is done **on CPU** to avoid GPU OOM
37
+ * merged model is stored under the configured merge output directory (relative to `run_dir` if a relative path)
38
+
39
+ ---
40
+
41
+ ## Repository layout (outputs)
42
+
43
+ A run produces the following structure under `run.run_dir`:
44
+
45
+ ```
46
+ runs/<run_name>/
47
+ ├─ checkpoints/ # trainer checkpoints (for resume)
48
+ ├─ best_adapter/ # saved LoRA adapter
49
+ ├─ logs/
50
+ │ ├─ train.jsonl # step-wise training logs
51
+ │ └─ eval.jsonl # eval logs (eval_loss + perplexity)
52
+ ├─ eval_final.json # final eval metrics summary (if eval is enabled)
53
+ └─ config_resolved.yaml # exact config used for the run
54
+ ```
55
+
56
+ If merge is used, the merged model is written to:
57
+
58
+ * `run_dir/<merge.output_dir>` if `merge.output_dir` is relative (e.g. `./merged_model`)
59
+ * or the absolute path if it is absolute.
60
+
61
+ ---
62
+
63
+ ## Supported training modes
64
+
65
+ ### 1) LoRA vs QLoRA (same script)
66
+
67
+ * **QLoRA** happens when `model.use_4bit: true`
68
+
69
+ * base weights are loaded in 4‑bit using bitsandbytes
70
+ * training updates only LoRA parameters
71
+ * **LoRA** happens when `model.use_4bit: false`
72
+
73
+ * base weights are loaded in fp16/bf16 (as configured)
74
+ * training updates only LoRA parameters
75
+
76
+ No “full finetune” mode is enabled by default in this runner.
77
+
78
+ ---
79
+
80
+ ## Data pipeline (CPT behavior)
81
+
82
+ ### Input format
83
+
84
+ * JSONL file where each line contains a text field (default `"text"`).
85
+ * Example:
86
+
87
+ * `{"text": "some training text..."}`
88
+
89
+ ### Packing (token stream → fixed blocks)
90
+
91
+ * Each sample is tokenized without truncation.
92
+ * An **EOS token is appended** per document to preserve boundaries.
93
+ * Token lists are concatenated and converted into **fixed‑length blocks** of `data.block_size`.
94
+
95
+ Two modes:
96
+
97
+ * **`drop` (strict CPT):** remainder tokens that don’t fill a full block are discarded.
98
+ * **`pad` (debug/small data):** remainder is padded to block_size:
99
+
100
+ * `attention_mask = 0` for padded positions
101
+ * `labels = -100` for padded positions (loss masking)
102
+
103
+ This is what allowed training to proceed even with tiny dummy datasets at `block_size=1024`.
104
+
105
+ ---
106
+
107
+ ## Logging
108
+
109
+ Trainer‑Kit writes **machine‑readable logs** in JSONL.
110
+
111
+ ### Training logs (`logs/train.jsonl`)
112
+
113
+ Includes entries with:
114
+
115
+ * `step`
116
+ * `loss`
117
+ * `grad_norm`
118
+ * `learning_rate`
119
+ * `progress_pct` (step progress when `max_steps` is active)
120
+ * ETA estimation
121
+
122
+ ### Eval logs (`logs/eval.jsonl`)
123
+
124
+ Includes:
125
+
126
+ * `eval_loss`
127
+ * `perplexity`
128
+
129
+ Notes:
130
+
131
+ * When using `max_steps`, the Trainer’s internal `epoch` counter can grow unexpectedly on tiny datasets (because steps/epoch becomes ~1).
132
+ **Use `progress_pct` as the reliable indicator** for step‑based runs.
133
+
134
+ ---
135
+
136
+ ## Checkpointing and resume
137
+
138
+ The trainer saves checkpoints under:
139
+
140
+ * `run_dir/checkpoints/`
141
+
142
+ Resume options:
143
+
144
+ * `resume_from_checkpoint: "auto"` → picks the latest checkpoint automatically
145
+ * `resume_from_checkpoint: "/path/to/checkpoint"` → resumes from a specific checkpoint
146
+ * `resume_from_checkpoint: null` → fresh run
147
+
148
+ ---
149
+
150
+ ## Merging adapters into a final model
151
+
152
+ Trainer‑Kit supports exporting a merged model:
153
+
154
+ ### Merge after training
155
+
156
+ * Enable merge in config (`merge.enabled: true`)
157
+ * The script will:
158
+
159
+ 1. save the adapter
160
+ 2. free GPU memory
161
+ 3. reload base model on **CPU**
162
+ 4. load adapter
163
+ 5. `merge_and_unload()`
164
+ 6. save final merged model
165
+
166
+ ### Merge later
167
+
168
+ Run:
169
+
170
+ ```
171
+ python run_cpt.py --config config.yaml --merge-only
172
+ ```
173
+
174
+ This skips training and merges `run_dir/best_adapter` into the base model.
175
+
176
+ ---
177
+
178
+ ## How to run
179
+
180
+ ### Train
181
+
182
+ ```
183
+ python run_cpt.py --config config.yaml
184
+ ```
185
+
186
+ ### Merge only
187
+
188
+ ```
189
+ python run_cpt.py --config config.yaml --merge-only
trainer-kit/CPT-14b/README_instruct.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Instruction Fine-Tuning Script
2
+
3
+ This script (`run_instruct.py`) is designed for fine-tuning language models on instruction-following tasks. It's based on the original CPT script but adapted specifically for instruction input/output pairs.
4
+
5
+ ## Key Differences from CPT
6
+
7
+ 1. **Data Format**: Handles structured instruction data with separate fields for instruction, input, and output
8
+ 2. **Formatting Options**: Supports multiple instruction formats (ChatML, Alpaca, custom templates)
9
+ 3. **No Text Packing**: Each example is treated as a complete instruction-response pair
10
+ 4. **Proper Loss Masking**: Loss is only computed on the response/output portion, not on the instruction and input
11
+ 5. **Automatic Label Creation**: Labels are automatically created with -100 masking for instruction tokens
12
+
13
+ ## Supported Data Formats
14
+
15
+ ### JSONL Structure
16
+ Each line should be a JSON object with the following fields:
17
+ ```json
18
+ {
19
+ "instruction": "Your instruction here",
20
+ "input": "Optional input context (can be empty string)",
21
+ "output": "Expected response"
22
+ }
23
+ ```
24
+
25
+ ### Formatting Options
26
+
27
+ #### 1. ChatML Format (Default)
28
+ Uses the model's chat template with system/user/assistant roles:
29
+ ```yaml
30
+ data:
31
+ format_type: "chatml"
32
+ system_prompt: "You are a helpful assistant."
33
+ ```
34
+
35
+ #### 2. Alpaca Format
36
+ Uses the classic Alpaca instruction format:
37
+ ```yaml
38
+ data:
39
+ format_type: "alpaca"
40
+ ```
41
+
42
+ #### 3. Custom Format
43
+ Define your own template:
44
+ ```yaml
45
+ data:
46
+ format_type: "custom"
47
+ custom_template: "Instruction: {instruction}\nInput: {input}\nOutput: {output}"
48
+ ```
49
+
50
+ ## Configuration
51
+
52
+ Key configuration options in `config_instruct.yaml`:
53
+
54
+ ### Data Configuration
55
+ ```yaml
56
+ data:
57
+ train_jsonl: "path/to/your/train.jsonl"
58
+ eval_jsonl: "path/to/your/eval.jsonl" # optional
59
+ eval_split_ratio: 0.1 # if no eval file provided
60
+
61
+ # Field names in your data
62
+ instruction_field: "instruction"
63
+ input_field: "input"
64
+ output_field: "output"
65
+
66
+ # Formatting
67
+ format_type: "chatml" # "chatml" | "alpaca" | "custom"
68
+ system_prompt: "You are a helpful assistant."
69
+
70
+ # Tokenization
71
+ max_length: 2048
72
+ ```
73
+
74
+ ### Training Configuration
75
+ ```yaml
76
+ train:
77
+ max_steps: 100
78
+ num_train_epochs: 3
79
+ per_device_train_batch_size: 1
80
+ gradient_accumulation_steps: 16
81
+ learning_rate: 5e-5
82
+ # ... other training parameters
83
+ ```
84
+
85
+ ## Usage
86
+
87
+ ### Basic Usage
88
+ ```bash
89
+ python run_instruct.py --config config_instruct.yaml
90
+ ```
91
+
92
+ ### Merge Only (after training)
93
+ ```bash
94
+ python run_instruct.py --config config_instruct.yaml --merge-only
95
+ ```
96
+
97
+ ## Example Data Format
98
+
99
+ See `instruct_data.jsonl` for examples of the expected data format. Here are a few examples:
100
+
101
+ ```json
102
+ {"instruction": "What is the capital of France?", "input": "", "output": "The capital of France is Paris."}
103
+
104
+ {"instruction": "Translate the following English text to French.", "input": "Hello, how are you today?", "output": "Bonjour, comment allez-vous aujourd'hui?"}
105
+
106
+ {"instruction": "Write a Python function that calculates factorial.", "input": "", "output": "def factorial(n):\n if n < 0:\n raise ValueError(...)"}
107
+ ```
108
+
109
+ ## Key Features
110
+
111
+ 1. **Multiple Format Support**: ChatML, Alpaca, and custom templates
112
+ 2. **Flexible Field Mapping**: Configure custom field names for your data
113
+ 3. **Proper Loss Masking**: Only computes loss on the response portion
114
+ 4. **PEFT/LoRA Support**: Efficient fine-tuning with LoRA
115
+ 5. **Evaluation Support**: Automatic evaluation split or separate eval file
116
+ 6. **Checkpointing**: Resume training from checkpoints
117
+ 7. **Model Merging**: Merge trained adapters with base model
118
+
119
+ ## Best Practices
120
+
121
+ 1. **Data Quality**: Ensure your instruction-response pairs are high-quality and consistent
122
+ 2. **Format Consistency**: Use the same format for training and inference
123
+ 3. **System Prompts**: Choose appropriate system prompts for your use case
124
+ 4. **Token Length**: Set appropriate `max_length` based on your model and data
125
+ 5. **Batch Size**: Adjust batch size and gradient accumulation based on your GPU memory
126
+
127
+ ## Troubleshooting
128
+
129
+ ### Common Issues
130
+
131
+ 1. **CUDA Out of Memory**: Reduce batch size or enable 4-bit quantization
132
+ 2. **Slow Training**: Increase `gradient_accumulation_steps` or reduce `max_length`
133
+ 3. **Poor Quality**: Check data format consistency and quality
134
+ 4. **Tokenizer Issues**: Ensure your model has proper chat template support
135
+
136
+ ### Debug Mode
137
+ Add logging to see formatted examples:
138
+ ```python
139
+ # In format_instruction function, add:
140
+ print(f"Formatted: {formatted_text}")
141
+ ```
142
+
143
+ ## File Structure
144
+
145
+ ```
146
+ CPT/
147
+ ├── run_instruct.py # Main instruction fine-tuning script
148
+ ├── config_instruct.yaml # Configuration file
149
+ ├── instruct_data.jsonl # Example instruction data
150
+ ├── README_instruct.md # This documentation
151
+ └── runs/ # Training outputs
152
+ └── instruct_run_v1/
153
+ ├── logs/
154
+ ├── checkpoints/
155
+ ├── best_adapter/
156
+ └── final_model/
157
+ ```
158
+
159
+ ## Migration from CPT
160
+
161
+ To migrate from the original CPT script:
162
+
163
+ 1. Convert your text data to instruction format
164
+ 2. Update your configuration file
165
+ 3. Choose appropriate formatting options
166
+ 4. Adjust training parameters (instruction fine-tuning typically needs fewer steps)
167
+
168
+ The script maintains the same CLI interface and most configuration options for easy migration.
trainer-kit/CPT-14b/commands.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Commands
2
+
3
+ Train (no merge):
4
+
5
+ ```bash
6
+ python run_cpt.py --config config.yaml
7
+ ```
8
+
9
+ Merge later:
10
+
11
+ ```bash
12
+ python run_cpt.py --config config.yaml --merge-only
13
+ ```
14
+
15
+ ---
trainer-kit/CPT-14b/config.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ run_dir: "./runs/cpt_run_14b"
3
+ seed: 42
4
+
5
+ # WandB integration for experiment tracking
6
+ wandb:
7
+ enabled: true # Set to true to enable wandb logging
8
+ project: "cpt-training" # WandB project name
9
+ entity: null # WandB entity/team (optional)
10
+ name: null # Run name (optional, will auto-generate if null)
11
+ tags: ["cpt-lora","sft-14b"] # List of tags for the run (e.g., ["lora", "qlora", "experiment-1"])
12
+ notes: null # Run description/notes (optional)
13
+
14
+ model:
15
+ # Local model path (no download)
16
+ repo_id: "/workspace/Models/Qwen2.5-Coder-14B"
17
+ revision: null
18
+
19
+ # Used only when repo_id is a HF repo (not a local path)
20
+ base_local_dir: "base_model"
21
+
22
+ trust_remote_code: true
23
+ tokenizer_use_fast: true
24
+ device_map: "auto"
25
+
26
+ torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
27
+
28
+ # QLoRA
29
+ use_4bit: false
30
+ bnb_4bit_quant_type: "nf4"
31
+ bnb_4bit_use_double_quant: false
32
+ bnb_4bit_compute_dtype: "bfloat16"
33
+
34
+ # optional: "flash_attention_2" | "sdpa" | null
35
+ attn_implementation: null
36
+
37
+ data:
38
+ train_jsonl: "all_data_with_descriptions.jsonl"
39
+ eval_jsonl: null
40
+ eval_split_ratio: 0.1
41
+ text_field: "text"
42
+ block_size: 4096
43
+ shuffle: true
44
+ num_proc: 4
45
+
46
+ # ✅ NEW: packing behavior
47
+ # "drop" = strict CPT (drop remainder)
48
+ # "pad" = pad remainder to block_size + loss mask (-100) + attention_mask=0
49
+ pack_mode: "pad"
50
+
51
+ peft:
52
+ enabled: true
53
+ r: 32
54
+ lora_alpha: 64
55
+ lora_dropout: 0.05
56
+ bias: "none"
57
+ target_modules: "auto"
58
+
59
+ train:
60
+ # max_steps: 1000
61
+ num_train_epochs: 2
62
+
63
+ per_device_train_batch_size: 1
64
+ per_device_eval_batch_size: 1
65
+ gradient_accumulation_steps: 16
66
+
67
+ learning_rate: 2e-5
68
+ weight_decay: 0.0
69
+ warmup_ratio: 0.1
70
+ lr_scheduler_type: "cosine"
71
+
72
+ optim: "paged_adamw_8bit"
73
+ max_grad_norm: 1.0
74
+ gradient_checkpointing: true
75
+
76
+ logging_steps: 1
77
+ save_strategy: "steps"
78
+ save_steps: 100
79
+ save_total_limit: 7
80
+
81
+ evaluation_strategy: "steps"
82
+ eval_steps: 50
83
+ load_best_model_at_end: true
84
+
85
+ resume_from_checkpoint: "auto"
86
+
87
+ merge:
88
+ enabled: true
89
+ merged_dtype: "float16"
90
+ max_shard_size: "2GB"
91
+ output_dir: "./merged_14b_cpt_lora"
trainer-kit/CPT-14b/dummy_data.jsonl ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"text": "This is a test sentence for the dummy dataset."}
2
+ {"text": "Another sentence to check if training works."}
3
+ {"text": "We need enough data to form a batch."}
4
+ {"text": "FSDP and LoRA are cool technologies."}
5
+ {"text": "Fine-tuning LLMs is fun and useful."}
6
+ {"text": "This is the end of the dummy dataset."}
trainer-kit/CPT-14b/requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ torch>=2.1.0
3
+ transformers>=4.41.0
4
+ datasets>=2.18.0
5
+ accelerate>=0.30.0
6
+
7
+ # PEFT / QLoRA
8
+ peft>=0.11.1
9
+ bitsandbytes>=0.43.1
10
+
11
+ # Hugging Face Hub (local + download support)
12
+ huggingface_hub>=0.23.0
13
+
14
+ # Config + utilities
15
+ pyyaml>=6.0
16
+ tqdm>=4.66.0
17
+
18
+ # Optional but recommended (tokenizers speed)
19
+ tokenizers>=0.15.0
20
+ safetensors>=0.4.2
21
+ # Optional (for eval)
22
+ rouge-score>=0.1.2
23
+
24
+ # Experiment tracking
25
+ wandb>=0.16.0
trainer-kit/CPT-14b/run_cpt.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import inspect # Added for Transformers version compatibility
4
+ import math
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple, List
8
+
9
+ import torch
10
+ import yaml
11
+ from datasets import load_dataset, DatasetDict
12
+ from huggingface_hub import snapshot_download
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ PreTrainedTokenizerFast,
17
+ TrainingArguments,
18
+ Trainer,
19
+ TrainerCallback,
20
+ default_data_collator,
21
+ set_seed,
22
+ )
23
+ from transformers.trainer_utils import get_last_checkpoint
24
+ from peft import (
25
+ LoraConfig,
26
+ get_peft_model,
27
+ prepare_model_for_kbit_training,
28
+ PeftModel,
29
+ )
30
+
31
+ try:
32
+ from transformers import BitsAndBytesConfig
33
+ except ImportError: # older transformers
34
+ BitsAndBytesConfig = None
35
+
36
+ try:
37
+ import wandb
38
+ WANDB_AVAILABLE = True
39
+ except ImportError:
40
+ WANDB_AVAILABLE = False
41
+ wandb = None
42
+
43
+
44
+ # --------------------------
45
+ # Helpers
46
+ # --------------------------
47
+
48
+ def _dtype_from_str(s: str) -> torch.dtype:
49
+ s = (s or "").lower()
50
+ if s in ("float16", "fp16"):
51
+ return torch.float16
52
+ if s in ("bfloat16", "bf16"):
53
+ return torch.bfloat16
54
+ if s in ("float32", "fp32"):
55
+ return torch.float32
56
+ raise ValueError(f"Unknown torch_dtype: {s}")
57
+
58
+ def _now_iso() -> str:
59
+ return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
60
+
61
+ def _safe_exp(x: float) -> float:
62
+ x = min(float(x), 50.0)
63
+ return float(math.exp(x))
64
+
65
+ def _ensure_dir(p: Path) -> Path:
66
+ p.mkdir(parents=True, exist_ok=True)
67
+ return p
68
+
69
+ def _looks_like_model_dir(p: Path) -> bool:
70
+ if not p.exists() or not p.is_dir():
71
+ return False
72
+ if (p / "config.json").exists():
73
+ return True
74
+ if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
75
+ return True
76
+ return False
77
+
78
+ def _detect_text_field(example: Dict[str, Any]) -> Optional[str]:
79
+ for k, v in example.items():
80
+ if isinstance(v, str) and v.strip():
81
+ return k
82
+ return None
83
+
84
+ def _load_tokenizer(base_dir: Path, use_fast: bool, trust_remote_code: bool):
85
+ try:
86
+ return AutoTokenizer.from_pretrained(
87
+ str(base_dir),
88
+ use_fast=use_fast,
89
+ trust_remote_code=trust_remote_code,
90
+ )
91
+ except ValueError as e:
92
+ if "TokenizersBackend" not in str(e):
93
+ raise
94
+ tok_file = base_dir / "tokenizer.json"
95
+ tok_cfg_path = base_dir / "tokenizer_config.json"
96
+ if not tok_file.exists():
97
+ raise
98
+
99
+ tok_kwargs: Dict[str, Any] = {}
100
+ if tok_cfg_path.exists():
101
+ with tok_cfg_path.open("r", encoding="utf-8") as f:
102
+ tok_cfg = json.load(f)
103
+ for key in ("bos_token", "eos_token", "pad_token", "unk_token", "model_max_length"):
104
+ if tok_cfg.get(key) is not None:
105
+ tok_kwargs[key] = tok_cfg[key]
106
+ extra = tok_cfg.get("additional_special_tokens") or tok_cfg.get("extra_special_tokens")
107
+ if extra:
108
+ tok_kwargs["additional_special_tokens"] = extra
109
+
110
+ return PreTrainedTokenizerFast(tokenizer_file=str(tok_file), **tok_kwargs)
111
+
112
+ def _infer_target_modules(model) -> List[str]:
113
+ names = set()
114
+ for n, _ in model.named_modules():
115
+ names.add(n.split(".")[-1])
116
+
117
+ for group in [
118
+ ["q_proj", "k_proj", "v_proj", "o_proj"],
119
+ ["Wqkv", "out_proj"],
120
+ ["query_key_value", "dense"],
121
+ ["c_attn", "c_proj"],
122
+ ]:
123
+ if all(x in names for x in group):
124
+ return group
125
+
126
+ fallback = [x for x in ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn", "c_proj", "out_proj", "dense"] if x in names]
127
+ if fallback:
128
+ return fallback
129
+
130
+ raise ValueError("Could not auto-infer target_modules. Set peft.target_modules explicitly.")
131
+
132
+ def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
133
+ return cfg.get("model", {}).get("attn_implementation", None)
134
+
135
+
136
+ # --------------------------
137
+ # Wandb Integration
138
+ # --------------------------
139
+
140
+ def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
141
+ """Initialize Wandb if enabled in configuration."""
142
+ wandb_cfg = cfg.get("wandb", {})
143
+
144
+ if not wandb_cfg.get("enabled", False):
145
+ print("Wandb logging disabled")
146
+ return None
147
+
148
+ if not WANDB_AVAILABLE:
149
+ print("Wandb not available. Install with: pip install wandb")
150
+ return None
151
+
152
+ # Extract wandb configuration
153
+ project = wandb_cfg.get("project", "cpt-training")
154
+ entity = wandb_cfg.get("entity", None)
155
+ name = wandb_cfg.get("name", None)
156
+ tags = wandb_cfg.get("tags", [])
157
+ notes = wandb_cfg.get("notes", None)
158
+
159
+ # Initialize wandb
160
+ try:
161
+ wandb.init(
162
+ project=project,
163
+ entity=entity,
164
+ name=name,
165
+ tags=tags,
166
+ notes=notes,
167
+ dir=str(run_dir),
168
+ config={
169
+ "model": cfg.get("model", {}),
170
+ "data": cfg.get("data", {}),
171
+ "peft": cfg.get("peft", {}),
172
+ "train": cfg.get("train", {}),
173
+ "run_dir": str(run_dir),
174
+ }
175
+ )
176
+ print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
177
+ return wandb
178
+ except Exception as e:
179
+ print(f"Failed to initialize Wandb: {e}")
180
+ return None
181
+
182
+
183
+ def finish_wandb():
184
+ """Finish Wandb run if active."""
185
+ if WANDB_AVAILABLE and wandb.run is not None:
186
+ wandb.finish()
187
+ print("Wandb run finished")
188
+
189
+
190
+ # --------------------------
191
+ # JSONL Logger Callback
192
+ # --------------------------
193
+
194
+ class JsonlLoggerCallback(TrainerCallback):
195
+ def __init__(self, run_dir: Path):
196
+ self.run_dir = run_dir
197
+ self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
198
+ self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
199
+ self.start_time = None
200
+
201
+ def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
202
+ if self.start_time is None or global_step <= 0 or max_steps <= 0:
203
+ return None
204
+ elapsed = time.time() - self.start_time
205
+ sec_per_step = elapsed / global_step
206
+ remaining = max(0, max_steps - global_step) * sec_per_step
207
+ h = int(remaining // 3600)
208
+ m = int((remaining % 3600) // 60)
209
+ s = int(remaining % 60)
210
+ return f"{h:02d}:{m:02d}:{s:02d}"
211
+
212
+ def on_train_begin(self, args, state, control, **kwargs):
213
+ self.start_time = time.time()
214
+
215
+ def on_log(self, args, state, control, logs=None, **kwargs):
216
+ if not logs:
217
+ return
218
+
219
+ max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
220
+ progress_pct = (100.0 * state.global_step / max_steps) if max_steps > 0 else None
221
+ epoch_pct = None
222
+ if state.epoch is not None and args.num_train_epochs and args.num_train_epochs > 0:
223
+ epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
224
+
225
+ payload = {
226
+ "ts": _now_iso(),
227
+ "event": "train_log",
228
+ "step": int(state.global_step),
229
+ "epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
230
+ "progress_pct": round(progress_pct, 2) if progress_pct is not None else None,
231
+ "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
232
+ "eta": self._eta(int(state.global_step), max_steps),
233
+ "max_grad_norm": getattr(args, "max_grad_norm", None),
234
+ **logs,
235
+ }
236
+
237
+ with self.train_log_path.open("a", encoding="utf-8") as f:
238
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
239
+
240
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
241
+ if not metrics:
242
+ return
243
+ eval_loss = metrics.get("eval_loss", None)
244
+ ppl = _safe_exp(eval_loss) if eval_loss is not None else None
245
+
246
+ payload = {
247
+ "ts": _now_iso(),
248
+ "event": "eval",
249
+ "step": int(state.global_step),
250
+ "epoch": float(state.epoch) if state.epoch is not None else None,
251
+ **metrics,
252
+ "perplexity": ppl,
253
+ }
254
+ with self.eval_log_path.open("a", encoding="utf-8") as f:
255
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
256
+
257
+
258
+ # --------------------------
259
+ # Data Pipeline (EOS + Packing)
260
+ # --------------------------
261
+
262
+ def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
263
+ data_cfg = cfg["data"]
264
+ train_path = data_cfg["train_jsonl"]
265
+ eval_path = data_cfg.get("eval_jsonl", None)
266
+ split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
267
+ text_field = data_cfg.get("text_field", "text")
268
+ block_size = int(data_cfg.get("block_size", 2048))
269
+ shuffle = bool(data_cfg.get("shuffle", True))
270
+ num_proc = int(data_cfg.get("num_proc", 4))
271
+
272
+ pack_mode = str(data_cfg.get("pack_mode", "drop")).lower().strip()
273
+ if pack_mode not in ("drop", "pad"):
274
+ raise ValueError(f"data.pack_mode must be 'drop' or 'pad', got: {pack_mode}")
275
+
276
+ eos_id = tokenizer.eos_token_id
277
+ if eos_id is None:
278
+ raise ValueError("Tokenizer has no eos_token_id; CPT packing needs an EOS delimiter.")
279
+
280
+ if tokenizer.pad_token_id is None:
281
+ # safe default for many causal LMs
282
+ tokenizer.pad_token = tokenizer.eos_token
283
+ pad_id = tokenizer.pad_token_id
284
+
285
+ ds = load_dataset("json", data_files={"train": train_path})
286
+
287
+ if eval_path:
288
+ ds_eval = load_dataset("json", data_files={"eval": eval_path})
289
+ dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
290
+ else:
291
+ if 0.0 < split_ratio < 1.0:
292
+ split = ds["train"].train_test_split(test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)))
293
+ dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
294
+ else:
295
+ dsd = DatasetDict({"train": ds["train"], "eval": None})
296
+
297
+ if text_field not in dsd["train"].column_names:
298
+ auto_field = _detect_text_field(dsd["train"][0])
299
+ if not auto_field:
300
+ raise ValueError(f"Could not find text field. Columns: {dsd['train'].column_names}")
301
+ text_field = auto_field
302
+
303
+ def tokenize_fn(examples):
304
+ out = tokenizer(
305
+ examples[text_field],
306
+ add_special_tokens=False,
307
+ truncation=False,
308
+ padding=False,
309
+ )
310
+ if "token_type_ids" in out:
311
+ del out["token_type_ids"]
312
+ # Add EOS between docs
313
+ out["input_ids"] = [ids + [eos_id] for ids in out["input_ids"]]
314
+ out["attention_mask"] = [m + [1] for m in out["attention_mask"]]
315
+ return out
316
+
317
+ tokenized_train = dsd["train"].map(
318
+ tokenize_fn,
319
+ batched=True,
320
+ num_proc=num_proc,
321
+ remove_columns=dsd["train"].column_names,
322
+ desc="Tokenizing train",
323
+ )
324
+
325
+ tokenized_eval = None
326
+ if dsd["eval"] is not None:
327
+ tokenized_eval = dsd["eval"].map(
328
+ tokenize_fn,
329
+ batched=True,
330
+ num_proc=num_proc,
331
+ remove_columns=dsd["eval"].column_names,
332
+ desc="Tokenizing eval",
333
+ )
334
+
335
+ def group_texts(examples):
336
+ concatenated = {k: sum(examples[k], []) for k in examples.keys()}
337
+ total_length = len(concatenated["input_ids"])
338
+
339
+ if total_length == 0:
340
+ return {"input_ids": [], "attention_mask": [], "labels": []}
341
+
342
+ full_len = (total_length // block_size) * block_size
343
+ blocks_input, blocks_attn, blocks_labels = [], [], []
344
+
345
+ # full blocks
346
+ for i in range(0, full_len, block_size):
347
+ chunk = concatenated["input_ids"][i:i + block_size]
348
+ attn = concatenated["attention_mask"][i:i + block_size]
349
+ blocks_input.append(chunk)
350
+ blocks_attn.append(attn)
351
+ blocks_labels.append(chunk.copy())
352
+
353
+ # remainder
354
+ remainder = total_length - full_len
355
+ if remainder > 0 and pack_mode == "pad":
356
+ chunk = concatenated["input_ids"][full_len:full_len + remainder]
357
+ attn = concatenated["attention_mask"][full_len:full_len + remainder]
358
+
359
+ pad_len = block_size - remainder
360
+ chunk_padded = chunk + [pad_id] * pad_len
361
+ attn_padded = attn + [0] * pad_len
362
+
363
+ labels = chunk_padded.copy()
364
+ labels[-pad_len:] = [-100] * pad_len # loss mask
365
+
366
+ blocks_input.append(chunk_padded)
367
+ blocks_attn.append(attn_padded)
368
+ blocks_labels.append(labels)
369
+
370
+ return {
371
+ "input_ids": blocks_input,
372
+ "attention_mask": blocks_attn,
373
+ "labels": blocks_labels,
374
+ }
375
+
376
+ tokenized_train = tokenized_train.map(
377
+ group_texts,
378
+ batched=True,
379
+ num_proc=num_proc,
380
+ desc=f"Packing train blocks (mode={pack_mode})",
381
+ )
382
+ if tokenized_eval is not None:
383
+ tokenized_eval = tokenized_eval.map(
384
+ group_texts,
385
+ batched=True,
386
+ num_proc=num_proc,
387
+ desc=f"Packing eval blocks (mode={pack_mode})",
388
+ )
389
+
390
+ if len(tokenized_train) == 0:
391
+ raise ValueError(
392
+ "Train dataset is empty after packing. "
393
+ "Either increase data, reduce block_size, or set data.pack_mode='pad'."
394
+ )
395
+
396
+ if shuffle:
397
+ tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
398
+
399
+ return tokenized_train, tokenized_eval
400
+
401
+
402
+ # --------------------------
403
+ # Model Loading + PEFT
404
+ # --------------------------
405
+
406
+ def _select_model_loader(base_dir: Path):
407
+ cfg_path = base_dir / "config.json"
408
+ if not cfg_path.exists():
409
+ return {"kind": "causal", "arch": None}
410
+ with cfg_path.open("r", encoding="utf-8") as f:
411
+ cfg = json.load(f)
412
+ arch = cfg.get("architectures") or []
413
+ arch_name = arch[0] if arch else None
414
+ if any("ForConditionalGeneration" in a for a in arch):
415
+ return {"kind": "conditional", "arch": arch_name}
416
+ return {"kind": "causal", "arch": arch_name}
417
+
418
+ def _resolve_model_class(arch_name: str):
419
+ import transformers
420
+ cls = getattr(transformers, arch_name, None)
421
+ if cls is None:
422
+ raise ValueError(f"Model class '{arch_name}' is not available in installed transformers.")
423
+ return cls
424
+
425
+
426
+ def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
427
+ model_cfg = cfg["model"]
428
+ trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
429
+ use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
430
+ device_map = model_cfg.get("device_map", "auto")
431
+
432
+ tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
433
+ if tokenizer.pad_token is None:
434
+ tokenizer.pad_token = tokenizer.eos_token
435
+
436
+ torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
437
+ use_4bit = bool(model_cfg.get("use_4bit", False))
438
+
439
+ quant_cfg = None
440
+ if use_4bit:
441
+ if BitsAndBytesConfig is None:
442
+ raise ImportError("BitsAndBytesConfig is not available in this transformers version.")
443
+ quant_cfg = BitsAndBytesConfig(
444
+ load_in_4bit=True,
445
+ bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
446
+ bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)),
447
+ bnb_4bit_compute_dtype=_dtype_from_str(model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")),
448
+ )
449
+
450
+ attn_impl = _choose_attn_impl(cfg)
451
+ model_meta = _select_model_loader(base_dir)
452
+
453
+ try:
454
+ if model_meta["kind"] == "conditional":
455
+ model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
456
+ if model_cls is None:
457
+ raise ValueError("Conditional model architecture not specified in config.json.")
458
+ model = model_cls.from_pretrained(
459
+ str(base_dir),
460
+ device_map=device_map,
461
+ trust_remote_code=trust_remote_code,
462
+ low_cpu_mem_usage=True,
463
+ torch_dtype=(torch_dtype if not use_4bit else None),
464
+ quantization_config=quant_cfg,
465
+ attn_implementation=attn_impl,
466
+ )
467
+ else:
468
+ model = AutoModelForCausalLM.from_pretrained(
469
+ str(base_dir),
470
+ device_map=device_map,
471
+ trust_remote_code=trust_remote_code,
472
+ low_cpu_mem_usage=True,
473
+ torch_dtype=(torch_dtype if not use_4bit else None),
474
+ quantization_config=quant_cfg,
475
+ attn_implementation=attn_impl,
476
+ )
477
+ except Exception as e:
478
+ if attn_impl is not None:
479
+ print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
480
+ print("[warn] Falling back to default attention implementation.")
481
+ if model_meta["kind"] == "conditional":
482
+ model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
483
+ if model_cls is None:
484
+ raise ValueError("Conditional model architecture not specified in config.json.")
485
+ model = model_cls.from_pretrained(
486
+ str(base_dir),
487
+ device_map=device_map,
488
+ trust_remote_code=trust_remote_code,
489
+ low_cpu_mem_usage=True,
490
+ torch_dtype=(torch_dtype if not use_4bit else None),
491
+ quantization_config=quant_cfg,
492
+ )
493
+ else:
494
+ model = AutoModelForCausalLM.from_pretrained(
495
+ str(base_dir),
496
+ device_map=device_map,
497
+ trust_remote_code=trust_remote_code,
498
+ low_cpu_mem_usage=True,
499
+ torch_dtype=(torch_dtype if not use_4bit else None),
500
+ quantization_config=quant_cfg,
501
+ )
502
+
503
+ return model, tokenizer
504
+
505
+
506
+ def apply_peft(cfg: Dict[str, Any], model):
507
+ peft_cfg = cfg["peft"]
508
+ model_cfg = cfg["model"]
509
+ tr_cfg = cfg["train"]
510
+
511
+ if not bool(peft_cfg.get("enabled", True)):
512
+ return model, None
513
+
514
+ use_4bit = bool(model_cfg.get("use_4bit", False))
515
+ gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
516
+
517
+ if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
518
+ model.gradient_checkpointing_enable()
519
+ if hasattr(model, "config"):
520
+ model.config.use_cache = False
521
+
522
+ if use_4bit:
523
+ model = prepare_model_for_kbit_training(
524
+ model,
525
+ use_gradient_checkpointing=gradient_checkpointing,
526
+ )
527
+
528
+ target_modules = peft_cfg.get("target_modules", "auto")
529
+ if target_modules == "auto":
530
+ target_modules = _infer_target_modules(model)
531
+
532
+ lora_config = LoraConfig(
533
+ r=int(peft_cfg.get("r", 16)),
534
+ lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
535
+ lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
536
+ bias=str(peft_cfg.get("bias", "none")),
537
+ task_type="CAUSAL_LM",
538
+ target_modules=target_modules,
539
+ )
540
+ model = get_peft_model(model, lora_config)
541
+ return model, lora_config
542
+
543
+
544
+ # --------------------------
545
+ # Merge Logic
546
+ # --------------------------
547
+
548
+ def merge_adapter(cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path):
549
+ print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
550
+
551
+ model_cfg = cfg["model"]
552
+ merge_cfg = cfg.get("merge", {})
553
+ trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
554
+ use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
555
+
556
+ merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
557
+ max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
558
+
559
+ model_meta = _select_model_loader(base_dir)
560
+ if model_meta["kind"] == "conditional":
561
+ base_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
562
+ if base_cls is None:
563
+ raise ValueError("Conditional model architecture not specified in config.json.")
564
+ base = base_cls.from_pretrained(
565
+ str(base_dir),
566
+ torch_dtype=merged_dtype,
567
+ device_map="cpu",
568
+ low_cpu_mem_usage=True,
569
+ trust_remote_code=trust_remote_code,
570
+ )
571
+ else:
572
+ base = AutoModelForCausalLM.from_pretrained(
573
+ str(base_dir),
574
+ torch_dtype=merged_dtype,
575
+ device_map="cpu",
576
+ low_cpu_mem_usage=True,
577
+ trust_remote_code=trust_remote_code,
578
+ )
579
+
580
+ merged = PeftModel.from_pretrained(base, str(adapter_dir))
581
+ merged = merged.merge_and_unload()
582
+
583
+ _ensure_dir(final_dir)
584
+ merged.save_pretrained(str(final_dir), safe_serialization=True, max_shard_size=max_shard_size)
585
+
586
+ tok = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
587
+ if tok.pad_token is None:
588
+ tok.pad_token = tok.eos_token
589
+ tok.save_pretrained(str(final_dir))
590
+
591
+ print("--- Merge complete ---")
592
+
593
+
594
+ # --------------------------
595
+ # Main
596
+ # --------------------------
597
+
598
+ def main():
599
+ ap = argparse.ArgumentParser()
600
+ ap.add_argument("--config", required=True, help="Path to YAML config")
601
+ ap.add_argument("--merge-only", action="store_true", help="Skip training, just merge adapter")
602
+ args = ap.parse_args()
603
+
604
+ with open(args.config, "r", encoding="utf-8") as f:
605
+ cfg = yaml.safe_load(f)
606
+
607
+ run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
608
+ _ensure_dir(run_dir / "logs")
609
+
610
+ with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
611
+ yaml.safe_dump(cfg, f, sort_keys=False)
612
+
613
+ model_cfg = cfg["model"]
614
+ repo_id = str(model_cfg["repo_id"]).strip()
615
+ repo_path = Path(repo_id)
616
+
617
+ # ✅ Local model path -> load directly; no download
618
+ if repo_path.exists() and repo_path.is_dir():
619
+ base_dir = repo_path
620
+ if not _looks_like_model_dir(base_dir):
621
+ raise ValueError(f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}")
622
+ else:
623
+ # HF repo_id -> download into run_dir/base_local_dir
624
+ base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
625
+ if not _looks_like_model_dir(base_dir):
626
+ print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
627
+ snapshot_download(
628
+ repo_id=repo_id,
629
+ revision=model_cfg.get("revision", None),
630
+ local_dir=str(base_dir),
631
+ local_dir_use_symlinks=False,
632
+ )
633
+
634
+ ckpt_dir = _ensure_dir(run_dir / "checkpoints")
635
+ best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
636
+
637
+ merge_cfg = cfg.get("merge", {}) or {}
638
+ if merge_cfg.get("output_dir"):
639
+ od = Path(str(merge_cfg["output_dir"]))
640
+ final_dir = od if od.is_absolute() else (run_dir / od)
641
+ else:
642
+ final_dir = run_dir / "final_model"
643
+
644
+ # Merge-only
645
+ if args.merge_only:
646
+ if not _looks_like_model_dir(best_adapter_dir):
647
+ raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
648
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
649
+ return
650
+
651
+ # Initialize Wandb
652
+ wandb_run = setup_wandb(cfg, run_dir)
653
+
654
+ # Training
655
+ set_seed(int(cfg["run"].get("seed", 42)))
656
+
657
+ model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
658
+ model, _ = apply_peft(cfg, model)
659
+
660
+ train_ds, eval_ds = build_datasets(cfg, tokenizer)
661
+
662
+ tr_cfg = cfg["train"]
663
+
664
+ dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
665
+ use_fp16 = (dtype == torch.float16)
666
+ use_bf16 = (dtype == torch.bfloat16)
667
+
668
+ max_steps = int(tr_cfg.get("max_steps", 0))
669
+ num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
670
+
671
+ # --- Dynamic evaluation strategy parameter handling ---
672
+ ta_params = inspect.signature(TrainingArguments.__init__).parameters
673
+ eval_key = "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
674
+
675
+ # Setup reporting based on wandb availability
676
+ report_to = []
677
+ if wandb_run is not None:
678
+ report_to.append("wandb")
679
+
680
+ desired_ta_kwargs = dict(
681
+ output_dir=str(ckpt_dir),
682
+ max_steps=max_steps if max_steps > 0 else -1,
683
+ num_train_epochs=num_train_epochs,
684
+
685
+ per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
686
+ per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1))),
687
+ gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
688
+
689
+ learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
690
+ weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
691
+ warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
692
+ lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
693
+
694
+ optim=str(tr_cfg.get("optim", "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch")),
695
+ max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
696
+
697
+ logging_steps=int(tr_cfg.get("logging_steps", 10)),
698
+
699
+ save_strategy=str(tr_cfg.get("save_strategy", "steps")),
700
+ save_steps=int(tr_cfg.get("save_steps", 200)),
701
+ save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
702
+
703
+ eval_steps=int(tr_cfg.get("eval_steps", 200)),
704
+
705
+ load_best_model_at_end=bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False,
706
+ metric_for_best_model="eval_loss",
707
+ greater_is_better=False,
708
+
709
+ fp16=use_fp16,
710
+ bf16=use_bf16,
711
+
712
+ report_to=report_to,
713
+ remove_unused_columns=False,
714
+ save_safetensors=True,
715
+ overwrite_output_dir=False,
716
+ )
717
+
718
+ # Set the correct argument name for this transformers version
719
+ desired_ta_kwargs[eval_key] = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
720
+ ta_kwargs = {k: v for k, v in desired_ta_kwargs.items() if k in ta_params}
721
+
722
+ training_args = TrainingArguments(**ta_kwargs)
723
+
724
+ trainer_params = inspect.signature(Trainer.__init__).parameters
725
+ desired_trainer_kwargs = dict(
726
+ model=model,
727
+ args=training_args,
728
+ train_dataset=train_ds,
729
+ eval_dataset=eval_ds,
730
+ tokenizer=tokenizer,
731
+ processing_class=tokenizer,
732
+ data_collator=default_data_collator,
733
+ callbacks=[JsonlLoggerCallback(run_dir)],
734
+ )
735
+ trainer_kwargs = {k: v for k, v in desired_trainer_kwargs.items() if k in trainer_params}
736
+ trainer = Trainer(**trainer_kwargs)
737
+
738
+ # Resume
739
+ resume_from = tr_cfg.get("resume_from_checkpoint", None)
740
+ if resume_from == "auto":
741
+ last = get_last_checkpoint(str(ckpt_dir))
742
+ resume_from = last if last else None
743
+ if resume_from:
744
+ print(f"Resuming from {resume_from}")
745
+
746
+ print("Starting training...")
747
+ trainer.train(resume_from_checkpoint=resume_from)
748
+
749
+ trainer.save_model(str(best_adapter_dir))
750
+ print(f"Saved best adapter -> {best_adapter_dir}")
751
+
752
+ if eval_ds is not None:
753
+ metrics = trainer.evaluate()
754
+ eval_loss = metrics.get("eval_loss", None)
755
+ metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
756
+ with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
757
+ json.dump(metrics, f, indent=2)
758
+ print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
759
+
760
+ if bool(cfg.get("merge", {}).get("enabled", False)):
761
+ del trainer, model
762
+ torch.cuda.empty_cache()
763
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
764
+ else:
765
+ print("Merge disabled. Run with --merge-only later if needed.")
766
+
767
+ # Finish Wandb run
768
+ finish_wandb()
769
+
770
+
771
+ if __name__ == "__main__":
772
+ main()
trainer-kit/CPT/README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Trainer‑Kit : Config‑Driven CPT (LoRA / QLoRA) with Packing, Logging, Resume, and Merge
2
+
3
+ Trainer‑Kit is a small, config‑driven training runner for **continued pretraining (CPT)** on causal LMs.
4
+ It supports **LoRA** and **QLoRA**, data **packing** (strict or padding‑masked), **checkpointing + resume**, **JSONL logging**, periodic **eval with perplexity**, and an optional **merge** step to export a final merged model.
5
+
6
+ ---
7
+
8
+ ## What we built
9
+
10
+ ### ✅ Core goals implemented
11
+
12
+ * **CPT training loop** controlled entirely via a **YAML config**
13
+ * **Local model support** (load from filesystem) and optional **HF download** (if `repo_id` is a hub id)
14
+ * **JSONL datasets** for train (+ optional eval split)
15
+ * **CPT‑style token stream packing** into fixed‑length blocks
16
+ * **Two packing modes**
17
+
18
+ * `drop`: strict CPT, drop remainder tokens (preferred for real CPT)
19
+ * `pad`: pad the remainder to `block_size` and **mask loss** on padding (useful for small datasets / debugging)
20
+ * **Checkpointing + resume**
21
+
22
+ * `resume_from_checkpoint: "auto"` resumes from the latest checkpoint under `run_dir/checkpoints`
23
+ * **JSONL logs** written locally
24
+
25
+ * training logs: `run_dir/logs/train.jsonl`
26
+ * eval logs: `run_dir/logs/eval.jsonl`
27
+ * **Evaluation**
28
+
29
+ * logs `eval_loss` and computed `perplexity = exp(eval_loss)` (with safe overflow guard)
30
+ * **Adapter output**
31
+
32
+ * saves the final/best adapter to `run_dir/best_adapter`
33
+ * **Merge workflow**
34
+
35
+ * `--merge-only` merges an existing adapter later
36
+ * merge is done **on CPU** to avoid GPU OOM
37
+ * merged model is stored under the configured merge output directory (relative to `run_dir` if a relative path)
38
+
39
+ ---
40
+
41
+ ## Repository layout (outputs)
42
+
43
+ A run produces the following structure under `run.run_dir`:
44
+
45
+ ```
46
+ runs/<run_name>/
47
+ ├─ checkpoints/ # trainer checkpoints (for resume)
48
+ ├─ best_adapter/ # saved LoRA adapter
49
+ ├─ logs/
50
+ │ ├─ train.jsonl # step-wise training logs
51
+ │ └─ eval.jsonl # eval logs (eval_loss + perplexity)
52
+ ├─ eval_final.json # final eval metrics summary (if eval is enabled)
53
+ └─ config_resolved.yaml # exact config used for the run
54
+ ```
55
+
56
+ If merge is used, the merged model is written to:
57
+
58
+ * `run_dir/<merge.output_dir>` if `merge.output_dir` is relative (e.g. `./merged_model`)
59
+ * or the absolute path if it is absolute.
60
+
61
+ ---
62
+
63
+ ## Supported training modes
64
+
65
+ ### 1) LoRA vs QLoRA (same script)
66
+
67
+ * **QLoRA** happens when `model.use_4bit: true`
68
+
69
+ * base weights are loaded in 4‑bit using bitsandbytes
70
+ * training updates only LoRA parameters
71
+ * **LoRA** happens when `model.use_4bit: false`
72
+
73
+ * base weights are loaded in fp16/bf16 (as configured)
74
+ * training updates only LoRA parameters
75
+
76
+ No “full finetune” mode is enabled by default in this runner.
77
+
78
+ ---
79
+
80
+ ## Data pipeline (CPT behavior)
81
+
82
+ ### Input format
83
+
84
+ * JSONL file where each line contains a text field (default `"text"`).
85
+ * Example:
86
+
87
+ * `{"text": "some training text..."}`
88
+
89
+ ### Packing (token stream → fixed blocks)
90
+
91
+ * Each sample is tokenized without truncation.
92
+ * An **EOS token is appended** per document to preserve boundaries.
93
+ * Token lists are concatenated and converted into **fixed‑length blocks** of `data.block_size`.
94
+
95
+ Two modes:
96
+
97
+ * **`drop` (strict CPT):** remainder tokens that don’t fill a full block are discarded.
98
+ * **`pad` (debug/small data):** remainder is padded to block_size:
99
+
100
+ * `attention_mask = 0` for padded positions
101
+ * `labels = -100` for padded positions (loss masking)
102
+
103
+ This is what allowed training to proceed even with tiny dummy datasets at `block_size=1024`.
104
+
105
+ ---
106
+
107
+ ## Logging
108
+
109
+ Trainer‑Kit writes **machine‑readable logs** in JSONL.
110
+
111
+ ### Training logs (`logs/train.jsonl`)
112
+
113
+ Includes entries with:
114
+
115
+ * `step`
116
+ * `loss`
117
+ * `grad_norm`
118
+ * `learning_rate`
119
+ * `progress_pct` (step progress when `max_steps` is active)
120
+ * ETA estimation
121
+
122
+ ### Eval logs (`logs/eval.jsonl`)
123
+
124
+ Includes:
125
+
126
+ * `eval_loss`
127
+ * `perplexity`
128
+
129
+ Notes:
130
+
131
+ * When using `max_steps`, the Trainer’s internal `epoch` counter can grow unexpectedly on tiny datasets (because steps/epoch becomes ~1).
132
+ **Use `progress_pct` as the reliable indicator** for step‑based runs.
133
+
134
+ ---
135
+
136
+ ## Checkpointing and resume
137
+
138
+ The trainer saves checkpoints under:
139
+
140
+ * `run_dir/checkpoints/`
141
+
142
+ Resume options:
143
+
144
+ * `resume_from_checkpoint: "auto"` → picks the latest checkpoint automatically
145
+ * `resume_from_checkpoint: "/path/to/checkpoint"` → resumes from a specific checkpoint
146
+ * `resume_from_checkpoint: null` → fresh run
147
+
148
+ ---
149
+
150
+ ## Merging adapters into a final model
151
+
152
+ Trainer‑Kit supports exporting a merged model:
153
+
154
+ ### Merge after training
155
+
156
+ * Enable merge in config (`merge.enabled: true`)
157
+ * The script will:
158
+
159
+ 1. save the adapter
160
+ 2. free GPU memory
161
+ 3. reload base model on **CPU**
162
+ 4. load adapter
163
+ 5. `merge_and_unload()`
164
+ 6. save final merged model
165
+
166
+ ### Merge later
167
+
168
+ Run:
169
+
170
+ ```
171
+ python run_cpt.py --config config.yaml --merge-only
172
+ ```
173
+
174
+ This skips training and merges `run_dir/best_adapter` into the base model.
175
+
176
+ ---
177
+
178
+ ## How to run
179
+
180
+ ### Train
181
+
182
+ ```
183
+ python run_cpt.py --config config.yaml
184
+ ```
185
+
186
+ ### Merge only
187
+
188
+ ```
189
+ python run_cpt.py --config config.yaml --merge-only
trainer-kit/CPT/commands.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Commands
2
+
3
+ Train (no merge):
4
+
5
+ ```bash
6
+ python run_cpt.py --config config.yaml
7
+ ```
8
+
9
+ Merge later:
10
+
11
+ ```bash
12
+ python run_cpt.py --config config.yaml --merge-only
13
+ ```
14
+
15
+ ---
trainer-kit/CPT/config.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ run_dir: "./runs/cpt_run_v1"
3
+ seed: 42
4
+
5
+ model:
6
+ # Local model path (no download)
7
+ repo_id: "/workspace/Models/Devstral-Small-2-24B-Instruct-2512"
8
+ revision: null
9
+
10
+ # Used only when repo_id is a HF repo (not a local path)
11
+ base_local_dir: "base_model"
12
+
13
+ trust_remote_code: true
14
+ tokenizer_use_fast: true
15
+ device_map: "auto"
16
+
17
+ torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
18
+
19
+ # QLoRA
20
+ use_4bit: false
21
+ bnb_4bit_quant_type: "nf4"
22
+ bnb_4bit_use_double_quant: false
23
+ bnb_4bit_compute_dtype: "bfloat16"
24
+
25
+ # optional: "flash_attention_2" | "sdpa" | null
26
+ attn_implementation: null
27
+
28
+ data:
29
+ train_jsonl: "/workspace/all_data_with_descriptions.jsonl"
30
+ eval_jsonl: null
31
+ eval_split_ratio: 0.1
32
+ text_field: "text"
33
+ block_size: 4096
34
+ shuffle: true
35
+ num_proc: 4
36
+
37
+ # ✅ NEW: packing behavior
38
+ # "drop" = strict CPT (drop remainder)
39
+ # "pad" = pad remainder to block_size + loss mask (-100) + attention_mask=0
40
+ pack_mode: "pad"
41
+
42
+ peft:
43
+ enabled: true
44
+ r: 64
45
+ lora_alpha: 128
46
+ lora_dropout: 0.05
47
+ bias: "none"
48
+ target_modules: "auto"
49
+
50
+ train:
51
+ #max_steps: 1000
52
+ num_train_epochs: 2
53
+
54
+ per_device_train_batch_size: 1
55
+ per_device_eval_batch_size: 1
56
+ gradient_accumulation_steps: 16
57
+
58
+ learning_rate: 2e-5
59
+ weight_decay: 0.0
60
+ warmup_ratio: 0.1
61
+ lr_scheduler_type: "cosine"
62
+
63
+ optim: "paged_adamw_8bit"
64
+ max_grad_norm: 1.0
65
+ gradient_checkpointing: true
66
+
67
+ logging_steps: 1
68
+ save_strategy: "steps"
69
+ save_steps: 100
70
+ save_total_limit: 4
71
+
72
+ evaluation_strategy: "steps"
73
+ eval_steps: 50
74
+ load_best_model_at_end: true
75
+
76
+ resume_from_checkpoint: "auto"
77
+
78
+ merge:
79
+ enabled: true
80
+ merged_dtype: "float16"
81
+ max_shard_size: "2GB"
82
+ output_dir: "./merged_24b_cpt_lora"
trainer-kit/CPT/detailed_parameter_documentation.md ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CPT Configuration Parameters: Detailed Guide
2
+
3
+ This document provides a comprehensive explanation of all configuration parameters in `config.yaml` and how they're implemented in `run_cpt.py`.
4
+
5
+ ## Table of Contents
6
+ - [Run Parameters](#run-parameters)
7
+ - [Model Parameters](#model-parameters)
8
+ - [Data Parameters](#data-parameters)
9
+ - [PEFT Parameters](#peft-parameters)
10
+ - [Training Parameters](#training-parameters)
11
+ - [Merge Parameters](#merge-parameters)
12
+
13
+ ---
14
+
15
+ ## Run Parameters
16
+
17
+ ### `run.run_dir`
18
+ - **Type**: String (path)
19
+ - **Required**: Yes
20
+ - **Default**: No default
21
+ - **Description**: Directory where training outputs will be saved
22
+ - **Used in**: Line ~480 in `run_cpt.py`
23
+ - **Implementation**:
24
+ ```python
25
+ run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
26
+ ```
27
+ - **Example Values**:
28
+ - `./runs/cpt_run_v1`
29
+ - `/workspace/outputs/my_experiment`
30
+ - `./checkpoints/cpt_experiment`
31
+
32
+ ### `run.seed`
33
+ - **Type**: Integer
34
+ - **Required**: No
35
+ - **Default**: None
36
+ - **Description**: Random seed for reproducibility
37
+ - **Used in**: Lines ~460, ~240 in `run_cpt.py`
38
+ - **Implementation**:
39
+ ```python
40
+ set_seed(int(cfg["run"].get("seed", 42)))
41
+ # Used in data shuffling and train/test split
42
+ ```
43
+ - **Example Values**: `42`, `123`, `2023`
44
+
45
+ ---
46
+
47
+ ## Model Parameters
48
+
49
+ ### `model.repo_id`
50
+ - **Type**: String (path or HuggingFace repo)
51
+ - **Required**: Yes
52
+ - **Default**: No default
53
+ - **Description**: Model identifier - can be local path or HuggingFace repository
54
+ - **Used in**: Lines ~480-500 in `run_cpt.py`
55
+ - **Implementation**:
56
+ ```python
57
+ repo_id = str(model_cfg["repo_id"]).strip()
58
+ repo_path = Path(repo_id)
59
+ if repo_path.exists() and repo_path.is_dir():
60
+ base_dir = repo_path # Local path
61
+ else:
62
+ # Download from HuggingFace
63
+ snapshot_download(repo_id=repo_id, ...)
64
+ ```
65
+ - **Example Values**:
66
+ - Local: `/workspace/Models/Devstral-Small-2-24B-Instruct-2512`
67
+ - HF Repo: `meta-llama/Llama-2-7b-hf`
68
+
69
+ ### `model.revision`
70
+ - **Type**: String or null
71
+ - **Required**: No
72
+ - **Default**: null
73
+ - **Description**: Specific model revision/branch/tag from HuggingFace
74
+ - **Used in**: Line ~495 in `run_cpt.py`
75
+ - **Implementation**:
76
+ ```python
77
+ snapshot_download(
78
+ repo_id=repo_id,
79
+ revision=model_cfg.get("revision", None),
80
+ ...
81
+ )
82
+ ```
83
+ - **Example Values**:
84
+ - `"main"` - Main branch
85
+ - `"v1.0"` - Specific tag
86
+ - `"abc123def"` - Specific commit hash
87
+ - `null` - Latest version
88
+
89
+ ### `model.base_local_dir`
90
+ - **Type**: String (path)
91
+ - **Required**: No
92
+ - **Default**: `"base_model"`
93
+ - **Description**: Directory name for downloaded model when using HF repo
94
+ - **Used in**: Line ~495 in `run_cpt.py`
95
+ - **Implementation**:
96
+ ```python
97
+ base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
98
+ ```
99
+ - **Example Values**: `"base_model"`, `"downloaded_model"`, `"model_files"`
100
+
101
+ ### `model.trust_remote_code`
102
+ - **Type**: Boolean
103
+ - **Required**: No
104
+ - **Default**: `true`
105
+ - **Description**: Allow loading models with custom code
106
+ - **Used in**: Lines ~320, ~340, ~450 in `run_cpt.py`
107
+ - **Implementation**:
108
+ ```python
109
+ tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
110
+ model = AutoModelForCausalLM.from_pretrained(..., trust_remote_code=trust_remote_code, ...)
111
+ ```
112
+ - **Example Values**: `true`, `false`
113
+
114
+ ### `model.tokenizer_use_fast`
115
+ - **Type**: Boolean
116
+ - **Required**: No
117
+ - **Default**: `true`
118
+ - **Description**: Use fast tokenizer implementation
119
+ - **Used in**: Lines ~320, ~450 in `run_cpt.py`
120
+ - **Implementation**:
121
+ ```python
122
+ tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
123
+ ```
124
+ - **Example Values**: `true`, `false`
125
+
126
+ ### `model.device_map`
127
+ - **Type**: String
128
+ - **Required**: No
129
+ - **Default**: `"auto"`
130
+ - **Description**: How to distribute model across devices
131
+ - **Used in**: Lines ~350, ~370 in `run_cpt.py`
132
+ - **Implementation**:
133
+ ```python
134
+ model = AutoModelForCausalLM.from_pretrained(..., device_map=device_map, ...)
135
+ ```
136
+ - **Example Values**:
137
+ - `"auto"` - Automatic distribution
138
+ - `"cpu"` - CPU only
139
+ - `"cuda:0"` - Single GPU
140
+ - `{"": 0}` - Manual mapping
141
+
142
+ ### `model.torch_dtype`
143
+ - **Type**: String
144
+ - **Required**: No
145
+ - **Default**: `"bfloat16"`
146
+ - **Description**: Data type for model tensors
147
+ - **Used in**: Lines ~45, ~350 in `run_cpt.py`
148
+ - **Implementation**:
149
+ ```python
150
+ def _dtype_from_str(s: str) -> torch.dtype:
151
+ if s in ("float16", "fp16"): return torch.float16
152
+ if s in ("bfloat16", "bf16"): return torch.bfloat16
153
+ if s in ("float32", "fp32"): return torch.float32
154
+ ```
155
+ - **Example Values**:
156
+ - `"float16"` - 16-bit floats (faster, less memory, less stable)
157
+ - `"bfloat16"` - Brain float16 (stable, good for training)
158
+ - `"float32"` - 32-bit floats (slowest, most memory)
159
+
160
+ ### `model.use_4bit`
161
+ - **Type**: Boolean
162
+ - **Required**: No
163
+ - **Default**: `false`
164
+ - **Description**: Use 4-bit quantization for memory efficiency
165
+ - **Used in**: Lines ~325, ~395 in `run_cpt.py`
166
+ - **Implementation**:
167
+ ```python
168
+ use_4bit = bool(model_cfg.get("use_4bit", False))
169
+ if use_4bit:
170
+ quant_cfg = BitsAndBytesConfig(load_in_4bit=True, ...)
171
+ ```
172
+ - **Example Values**: `true`, `false`
173
+
174
+ ### `model.bnb_4bit_quant_type`
175
+ - **Type**: String
176
+ - **Required**: No
177
+ - **Default**: `"nf4"`
178
+ - **Description**: 4-bit quantization type
179
+ - **Used in**: Lines ~328 in `run_cpt.py`
180
+ - **Implementation**:
181
+ ```python
182
+ bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4"))
183
+ ```
184
+ - **Example Values**:
185
+ - `"nf4"` - NormalFloat4 (recommended)
186
+ - `"fp4"` - FloatingPoint4
187
+ - `"int4"` - Integer4
188
+
189
+ ### `model.bnb_4bit_use_double_quant`
190
+ - **Type**: Boolean
191
+ - **Required**: No
192
+ - **Default**: `false`
193
+ - **Description**: Use double quantization for memory efficiency
194
+ - **Used in**: Lines ~329 in `run_cpt.py`
195
+ - **Implementation**:
196
+ ```python
197
+ bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True))
198
+ ```
199
+ - **Example Values**: `true`, `false`
200
+
201
+ ### `model.bnb_4bit_compute_dtype`
202
+ - **Type**: String
203
+ - **Required**: No
204
+ - **Default**: `"bfloat16"`
205
+ - **Description**: Compute dtype for 4-bit quantization
206
+ - **Used in**: Lines ~330 in `run_cpt.py`
207
+ - **Implementation**:
208
+ ```python
209
+ bnb_4bit_compute_dtype=_dtype_from_str(model_cfg.get("bnb_4bit_compute_dtype", "bfloat16"))
210
+ ```
211
+ - **Example Values**: `"float16"`, `"bfloat16"`, `"float32"`
212
+
213
+ ### `model.attn_implementation`
214
+ - **Type**: String or null
215
+ - **Required**: No
216
+ - **Default**: `null`
217
+ - **Description**: Attention implementation to use
218
+ - **Used in**: Lines ~155, ~350 in `run_cpt.py`
219
+ - **Implementation**:
220
+ ```python
221
+ def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
222
+ return cfg.get("model", {}).get("attn_implementation", None)
223
+ # Used in model.from_pretrained(..., attn_implementation=attn_impl, ...)
224
+ ```
225
+ - **Example Values**:
226
+ - `"flash_attention_2"` - Flash Attention 2 (fastest)
227
+ - `"sdpa"` - Scaled Dot-Product Attention
228
+ - `null` - Default implementation
229
+
230
+ ---
231
+
232
+ ## Data Parameters
233
+
234
+ ### `data.train_jsonl`
235
+ - **Type**: String (path)
236
+ - **Required**: Yes
237
+ - **Default**: No default
238
+ - **Description**: Path to training data in JSONL format
239
+ - **Used in**: Lines ~170 in `run_cpt.py`
240
+ - **Implementation**:
241
+ ```python
242
+ train_path = data_cfg["train_jsonl"]
243
+ ds = load_dataset("json", data_files={"train": train_path})
244
+ ```
245
+ - **Example Values**: `"/workspace/all_data_with_descriptions.jsonl"`
246
+
247
+ ### `data.eval_jsonl`
248
+ - **Type**: String (path) or null
249
+ - **Required**: No
250
+ - **Default**: `null`
251
+ - **Description**: Path to evaluation data in JSONL format
252
+ - **Used in**: Lines ~175 in `run_cpt.py`
253
+ - **Implementation**:
254
+ ```python
255
+ eval_path = data_cfg.get("eval_jsonl", None)
256
+ if eval_path:
257
+ ds_eval = load_dataset("json", data_files={"eval": eval_path})
258
+ ```
259
+ - **Example Values**: `null` (no separate eval file), `"/workspace/eval_data.jsonl"`
260
+
261
+ ### `data.eval_split_ratio`
262
+ - **Type**: Float
263
+ - **Required**: No
264
+ - **Default**: `0.1`
265
+ - **Description**: Ratio of training data to use for evaluation split
266
+ - **Used in**: Lines ~177 in `run_cpt.py`
267
+ - **Implementation**:
268
+ ```python
269
+ split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
270
+ if 0.0 < split_ratio < 1.0:
271
+ split = ds["train"].train_test_split(test_size=split_ratio, seed=seed)
272
+ ```
273
+ - **Example Values**: `0.1` (10%), `0.2` (20%), `0.05` (5%)
274
+
275
+ ### `data.text_field`
276
+ - **Type**: String
277
+ - **Required**: No
278
+ - **Default**: `"text"`
279
+ - **Description**: Field name in JSONL containing the text data
280
+ - **Used in**: Lines ~185 in `run_cpt.py`
281
+ - **Implementation**:
282
+ ```python
283
+ text_field = data_cfg.get("text_field", "text")
284
+ # Used in tokenization
285
+ tokenized = dsd["train"].map(
286
+ tokenize_fn,
287
+ batched=True,
288
+ remove_columns=dsd["train"].column_names,
289
+ desc="Tokenizing train",
290
+ )
291
+ ```
292
+ - **Example Values**: `"text"`, `"content"`, `"prompt"`, `"input"`
293
+
294
+ ### `data.block_size`
295
+ - **Type**: Integer
296
+ - **Required**: No
297
+ - **Default**: `4096`
298
+ - **Description**: Maximum sequence length for training
299
+ - **Used in**: Lines ~180 in `run_cpt.py`
300
+ - **Implementation**:
301
+ ```python
302
+ block_size = int(data_cfg.get("block_size", 2048))
303
+ # Used in grouping texts into blocks
304
+ for i in range(0, full_len, block_size):
305
+ chunk = concatenated["input_ids"][i:i + block_size]
306
+ ```
307
+ - **Example Values**: `2048`, `4096`, `8192`
308
+
309
+ ### `data.shuffle`
310
+ - **Type**: Boolean
311
+ - **Required**: No
312
+ - **Default**: `true`
313
+ - **Description**: Whether to shuffle training data
314
+ - **Used in**: Lines ~235 in `run_cpt.py`
315
+ - **Implementation**:
316
+ ```python
317
+ if shuffle:
318
+ tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
319
+ ```
320
+ - **Example Values**: `true`, `false`
321
+
322
+ ### `data.num_proc`
323
+ - **Type**: Integer
324
+ - **Required**: No
325
+ - **Default**: `4`
326
+ - **Description**: Number of processes for data loading
327
+ - **Used in**: Lines ~200, ~210 in `run_cpt.py`
328
+ - **Implementation**:
329
+ ```python
330
+ num_proc = int(data_cfg.get("num_proc", 4))
331
+ tokenized_train = dsd["train"].map(
332
+ tokenize_fn,
333
+ batched=True,
334
+ num_proc=num_proc,
335
+ ...
336
+ )
337
+ ```
338
+ - **Example Values**: `1`, `4`, `8`, `16`
339
+
340
+ ### `data.pack_mode`
341
+ - **Type**: String
342
+ - **Required**: No
343
+ - **Default**: `"pad"`
344
+ - **Description**: How to handle remainder tokens in final block
345
+ - **Used in**: Lines ~150-230 in `run_cpt.py`
346
+ - **Implementation**:
347
+ ```python
348
+ pack_mode = str(data_cfg.get("pack_mode", "drop")).lower().strip()
349
+ if pack_mode == "pad":
350
+ # Pad remainder and mask loss
351
+ labels[-pad_len:] = [-100] * pad_len
352
+ # If "drop": ignore remainder entirely
353
+ ```
354
+ - **Example Values**:
355
+ - `"drop"` - Drop incomplete blocks (strict CPT)
356
+ - `"pad"` - Pad incomplete blocks with masked loss
357
+
358
+ ---
359
+
360
+ ## PEFT Parameters
361
+
362
+ ### `peft.enabled`
363
+ - **Type**: Boolean
364
+ - **Required**: No
365
+ - **Default**: `true`
366
+ - **Description**: Whether to use PEFT (Parameter-Efficient Fine-Tuning)
367
+ - **Used in**: Lines ~395 in `run_cpt.py`
368
+ - **Implementation**:
369
+ ```python
370
+ if not bool(peft_cfg.get("enabled", True)):
371
+ return model, None
372
+ # Otherwise proceed with LoRA configuration
373
+ ```
374
+ - **Example Values**: `true`, `false`
375
+
376
+ ### `peft.r`
377
+ - **Type**: Integer
378
+ - **Required**: No
379
+ - **Default**: `64`
380
+ - **Description**: LoRA rank - dimension of low-rank matrices
381
+ - **Used in**: Lines ~415 in `run_cpt.py`
382
+ - **Implementation**:
383
+ ```python
384
+ lora_config = LoraConfig(
385
+ r=int(peft_cfg.get("r", 16)),
386
+ ...
387
+ )
388
+ ```
389
+ - **Example Values**: `8`, `16`, `32`, `64`, `128`
390
+ - **Note**: Higher values = more parameters but potentially better performance
391
+
392
+ ### `peft.lora_alpha`
393
+ - **Type**: Integer
394
+ - **Required**: No
395
+ - **Default**: `128`
396
+ - **Description**: LoRA alpha scaling parameter
397
+ - **Used in**: Lines ~416 in `run_cpt.py`
398
+ - **Implementation**:
399
+ ```python
400
+ lora_config = LoraConfig(
401
+ lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
402
+ ...
403
+ )
404
+ ```
405
+ - **Example Values**: `16`, `32`, `64`, `128`, `256`
406
+
407
+ ### `peft.lora_dropout`
408
+ - **Type**: Float
409
+ - **Required**: No
410
+ - **Default**: `0.05`
411
+ - **Description**: Dropout rate for LoRA layers
412
+ - **Used in**: Lines ~417 in `run_cpt.py`
413
+ - **Implementation**:
414
+ ```python
415
+ lora_config = LoraConfig(
416
+ lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
417
+ ...
418
+ )
419
+ ```
420
+ - **Example Values**: `0.0`, `0.05`, `0.1`, `0.2`
421
+
422
+ ### `peft.bias`
423
+ - **Type**: String
424
+ - **Required**: No
425
+ - **Default**: `"none"`
426
+ - **Description**: Bias training strategy
427
+ - **Used in**: Lines ~418 in `run_cpt.py`
428
+ - **Implementation**:
429
+ ```python
430
+ lora_config = LoraConfig(
431
+ bias=str(peft_cfg.get("bias", "none")),
432
+ ...
433
+ )
434
+ ```
435
+ - **Example Values**:
436
+ - `"none"` - No bias training
437
+ - `"all"` - Train all biases
438
+ - `"lora_only"` - Only LoRA bias
439
+
440
+ ### `peft.target_modules`
441
+ - **Type**: String or List
442
+ - **Required**: No
443
+ - **Default**: `"auto"`
444
+ - **Description**: Which modules to apply LoRA to
445
+ - **Used in**: Lines ~405, ~140-170 in `run_cpt.py`
446
+ - **Implementation**:
447
+ ```python
448
+ target_modules = peft_cfg.get("target_modules", "auto")
449
+ if target_modules == "auto":
450
+ target_modules = _infer_target_modules(model)
451
+ ```
452
+ - **Example Values**:
453
+ - `"auto"` - Automatic detection
454
+ - `["q_proj", "k_proj", "v_proj", "o_proj"]` - Explicit list
455
+ - `["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]` - MLP only
456
+
457
+ ---
458
+
459
+ ## Training Parameters
460
+
461
+ ### `train.num_train_epochs`
462
+ - **Type**: Float
463
+ - **Required**: No
464
+ - **Default**: `2`
465
+ - **Description**: Number of epochs to train
466
+ - **Used in**: Lines ~470 in `run_cpt.py`
467
+ - **Implementation**:
468
+ ```python
469
+ num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
470
+ # Used in TrainingArguments
471
+ ```
472
+ - **Example Values**: `1.0`, `2.0`, `3.5`
473
+
474
+ ### `train.per_device_train_batch_size`
475
+ - **Type**: Integer
476
+ - **Required**: No
477
+ - **Default**: `1`
478
+ - **Description**: Training batch size per device
479
+ - **Used in**: Lines ~475 in `run_cpt.py`
480
+ - **Implementation**:
481
+ ```python
482
+ per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1))
483
+ ```
484
+ - **Example Values**: `1`, `2`, `4`, `8`
485
+
486
+ ### `train.per_device_eval_batch_size`
487
+ - **Type**: Integer
488
+ - **Required**: No
489
+ - **Default**: Same as train batch size
490
+ - **Description**: Evaluation batch size per device
491
+ - **Used in**: Lines ~476 in `run_cpt.py`
492
+ - **Implementation**:
493
+ ```python
494
+ per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1)))
495
+ ```
496
+ - **Example Values**: `1`, `2`, `4`, `8`
497
+
498
+ ### `train.gradient_accumulation_steps`
499
+ - **Type**: Integer
500
+ - **Required**: No
501
+ - **Default**: `16`
502
+ - **Description**: Number of steps to accumulate gradients
503
+ - **Used in**: Lines ~477 in `run_cpt.py`
504
+ - **Implementation**:
505
+ ```python
506
+ gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1))
507
+ ```
508
+ - **Example Values**: `1`, `4`, `8`, `16`, `32`
509
+
510
+ ### `train.learning_rate`
511
+ - **Type**: Float
512
+ - **Required**: No
513
+ - **Default**: `2e-5`
514
+ - **Description**: Learning rate for optimizer
515
+ - **Used in**: Lines ~478 in `run_cpt.py`
516
+ - **Implementation**:
517
+ ```python
518
+ learning_rate=float(tr_cfg.get("learning_rate", 2e-5))
519
+ ```
520
+ - **Example Values**: `1e-5`, `2e-5`, `5e-5`, `1e-4`
521
+
522
+ ### `train.weight_decay`
523
+ - **Type**: Float
524
+ - **Required**: No
525
+ - **Default**: `0.0`
526
+ - **Description**: Weight decay for regularization
527
+ - **Used in**: Lines ~479 in `run_cpt.py`
528
+ - **Implementation**:
529
+ ```python
530
+ weight_decay=float(tr_cfg.get("weight_decay", 0.0))
531
+ ```
532
+ - **Example Values**: `0.0`, `0.01`, `0.1`
533
+
534
+ ### `train.warmup_ratio`
535
+ - **Type**: Float
536
+ - **Required**: No
537
+ - **Default**: `0.1`
538
+ - **Description**: Ratio of steps for learning rate warmup
539
+ - **Used in**: Lines ~480 in `run_cpt.py`
540
+ - **Implementation**:
541
+ ```python
542
+ warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0))
543
+ ```
544
+ - **Example Values**: `0.0`, `0.1`, `0.2`
545
+
546
+ ### `train.lr_scheduler_type`
547
+ - **Type**: String
548
+ - **Required**: No
549
+ - **Default**: `"cosine"`
550
+ - **Description**: Learning rate scheduler type
551
+ - **Used in**: Lines ~481 in `run_cpt.py`
552
+ - **Implementation**:
553
+ ```python
554
+ lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine"))
555
+ ```
556
+ - **Example Values**:
557
+ - `"cosine"` - Cosine annealing
558
+ - `"linear"` - Linear decay
559
+ - `"constant"` - Constant rate
560
+ - `"polynomial"` - Polynomial decay
561
+
562
+ ### `train.optim`
563
+ - **Type**: String
564
+ - **Required**: No
565
+ - **Default**: `"paged_adamw_8bit"` (if 4-bit), `"adamw_torch"` (otherwise)
566
+ - **Description**: Optimizer type
567
+ - **Used in**: Lines ~482 in `run_cpt.py`
568
+ - **Implementation**:
569
+ ```python
570
+ optim=str(tr_cfg.get("optim", "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch"))
571
+ ```
572
+ - **Example Values**:
573
+ - `"adamw_torch"` - AdamW (standard)
574
+ - `"paged_adamw_8bit"` - Paged AdamW for 8-bit training
575
+ - `"sgd"` - SGD
576
+ - `"adafactor"` - Adafactor
577
+
578
+ ### `train.max_grad_norm`
579
+ - **Type**: Float
580
+ - **Required**: No
581
+ - **Default**: `1.0`
582
+ - **Description**: Maximum gradient norm for clipping
583
+ - **Used in**: Lines ~483 in `run_cpt.py`
584
+ - **Implementation**:
585
+ ```python
586
+ max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0))
587
+ ```
588
+ - **Example Values**: `0.5`, `1.0`, `2.0`
589
+
590
+ ### `train.gradient_checkpointing`
591
+ - **Type**: Boolean
592
+ - **Required**: No
593
+ - **Default**: `true`
594
+ - **Description**: Use gradient checkpointing to save memory
595
+ - **Used in**: Lines ~396-400 in `run_cpt.py`
596
+ - **Implementation**:
597
+ ```python
598
+ gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
599
+ if gradient_checkpointing:
600
+ model.gradient_checkpointing_enable()
601
+ ```
602
+ - **Example Values**: `true`, `false`
603
+
604
+ ### `train.logging_steps`
605
+ - **Type**: Integer
606
+ - **Required**: No
607
+ - **Default**: `1`
608
+ - **Description**: Log training progress every N steps
609
+ - **Used in**: Lines ~485 in `run_cpt.py`
610
+ - **Implementation**:
611
+ ```python
612
+ logging_steps=int(tr_cfg.get("logging_steps", 10))
613
+ ```
614
+ - **Example Values**: `1`, `10`, `50`, `100`
615
+
616
+ ### `train.save_strategy`
617
+ - **Type**: String
618
+ - **Required**: No
619
+ - **Default**: `"steps"`
620
+ - **Description**: When to save model checkpoints
621
+ - **Used in**: Lines ~487 in `run_cpt.py`
622
+ - **Implementation**:
623
+ ```python
624
+ save_strategy=str(tr_cfg.get("save_strategy", "steps"))
625
+ ```
626
+ - **Example Values**:
627
+ - `"steps"` - Save every N steps
628
+ - `"epochs"` - Save every epoch
629
+ - `"no"` - Don't save
630
+
631
+ ### `train.save_steps`
632
+ - **Type**: Integer
633
+ - **Required**: No
634
+ - **Default**: `100`
635
+ - **Description**: Save checkpoint every N steps
636
+ - **Used in**: Lines ~488 in `run_cpt.py`
637
+ - **Implementation**:
638
+ ```python
639
+ save_steps=int(tr_cfg.get("save_steps", 200))
640
+ ```
641
+ - **Example Values**: `50`, `100`, `200`, `500`
642
+
643
+ ### `train.save_total_limit`
644
+ - **Type**: Integer
645
+ - **Required**: No
646
+ - **Default**: `4`
647
+ - **Description**: Maximum number of checkpoints to keep
648
+ - **Used in**: Lines ~489 in `run_cpt.py`
649
+ - **Implementation**:
650
+ ```python
651
+ save_total_limit=int(tr_cfg.get("save_total_limit", 3))
652
+ ```
653
+ - **Example Values**: `1`, `2`, `3`, `5`
654
+
655
+ ### `train.evaluation_strategy`
656
+ - **Type**: String
657
+ - **Required**: No
658
+ - **Default**: `"steps"` (if eval data), `"no"` (otherwise)
659
+ - **Description**: When to evaluate model
660
+ - **Used in**: Lines ~494 in `run_cpt.py`
661
+ - **Implementation**:
662
+ ```python
663
+ evaluation_strategy=str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
664
+ ```
665
+ - **Example Values**:
666
+ - `"steps"` - Evaluate every N steps
667
+ - `"epochs"` - Evaluate every epoch
668
+ - `"no"` - Don't evaluate
669
+
670
+ ### `train.eval_steps`
671
+ - **Type**: Integer
672
+ - **Required**: No
673
+ - **Default**: `50`
674
+ - **Description**: Evaluate every N steps
675
+ - **Used in**: Lines ~491 in `run_cpt.py`
676
+ - **Implementation**:
677
+ ```python
678
+ eval_steps=int(tr_cfg.get("eval_steps", 200))
679
+ ```
680
+ - **Example Values**: `25`, `50`, `100`, `200`
681
+
682
+ ### `train.load_best_model_at_end`
683
+ - **Type**: Boolean
684
+ - **Required**: No
685
+ - **Default**: `true` (if eval data), `false` (otherwise)
686
+ - **Description**: Load best model at end of training
687
+ - **Used in**: Lines ~492-493 in `run_cpt.py`
688
+ - **Implementation**:
689
+ ```python
690
+ load_best_model_at_end=bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False
691
+ ```
692
+ - **Example Values**: `true`, `false`
693
+
694
+ ### `train.resume_from_checkpoint`
695
+ - **Type**: String
696
+ - **Required**: No
697
+ - **Default**: `"auto"`
698
+ - **Description**: Resume training from checkpoint
699
+ - **Used in**: Lines ~510-520 in `run_cpt.py`
700
+ - **Implementation**:
701
+ ```python
702
+ resume_from = tr_cfg.get("resume_from_checkpoint", None)
703
+ if resume_from == "auto":
704
+ last = get_last_checkpoint(str(ckpt_dir))
705
+ resume_from = last if last else None
706
+ ```
707
+ - **Example Values**:
708
+ - `"auto"` - Auto-detect latest checkpoint
709
+ - `"checkpoint-100"` - Specific checkpoint
710
+ - `null` - Start from scratch
711
+
712
+ ---
713
+
714
+ ## Merge Parameters
715
+
716
+ ### `merge.enabled`
717
+ - **Type**: Boolean
718
+ - **Required**: No
719
+ - **Default**: `false`
720
+ - **Description**: Whether to merge LoRA adapters with base model
721
+ - **Used in**: Lines ~545 in `run_cpt.py`
722
+ - **Implementation**:
723
+ ```python
724
+ if bool(cfg.get("merge", {}).get("enabled", False)):
725
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
726
+ ```
727
+ - **Example Values**: `true`, `false`
728
+
729
+ ### `merge.merged_dtype`
730
+ - **Type**: String
731
+ - **Required**: No
732
+ - **Default**: `"float16"`
733
+ - **Description**: Data type for merged model
734
+ - **Used in**: Lines ~430 in `run_cpt.py`
735
+ - **Implementation**:
736
+ ```python
737
+ merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
738
+ ```
739
+ - **Example Values**: `"float16"`, `"bfloat16"`, `"float32"`
740
+
741
+ ### `merge.max_shard_size`
742
+ - **Type**: String
743
+ - **Required**: No
744
+ - **Default**: `"2GB"`
745
+ - **Description**: Maximum size per shard when saving
746
+ - **Used in**: Lines ~445 in `run_cpt.py`
747
+ - **Implementation**:
748
+ ```python
749
+ merged.save_pretrained(str(final_dir), safe_serialization=True, max_shard_size=max_shard_size)
750
+ ```
751
+ - **Example Values**: `"1GB"`, `"2GB"`, `"5GB"`
752
+
753
+ ### `merge.output_dir`
754
+ - **Type**: String (path)
755
+ - **Required**: No
756
+ - **Default**: `"./merged_model"`
757
+ - **Description**: Directory for merged model output
758
+ - **Used in**: Lines ~505-510 in `run_cpt.py`
759
+ - **Implementation**:
760
+ ```python
761
+ if merge_cfg.get("output_dir"):
762
+ od = Path(str(merge_cfg["output_dir"]))
763
+ final_dir = od if od.is_absolute() else (run_dir / od)
764
+ else:
765
+ final_dir = run_dir / "final_model"
766
+ ```
767
+ - **Example Values**: `"./merged_model"`, `"/workspace/final_model"`, `"./models/merged"`
768
+
769
+ ---
770
+
771
+ ## Parameter Dependencies and Interactions
772
+
773
+ ### Memory-Related Dependencies
774
+ - `per_device_train_batch_size` + `gradient_accumulation_steps` = effective batch size
775
+ - `block_size` affects memory usage significantly
776
+ - `use_4bit` + `bnb_4bit_*` parameters work together for quantization
777
+ - `gradient_checkpointing` can enable larger `block_size` or `batch_size`
778
+
779
+ ### Training Strategy Dependencies
780
+ - `evaluation_strategy` requires either `eval_jsonl` or `eval_split_ratio > 0`
781
+ - `load_best_model_at_end` requires `evaluation_strategy` to be enabled
782
+ - `save_strategy` should be compatible with `evaluation_strategy`
783
+ - `lr_scheduler_type` affects warmup calculations
784
+
785
+ ### Model-Specific Dependencies
786
+ - `target_modules` must match the actual module names in your model
787
+ - `torch_dtype` should be compatible with your GPU hardware
788
+ - `device_map` affects whether you can use certain optimizations
789
+
790
+ ### Data Processing Dependencies
791
+ - `text_field` must exist in your JSONL data
792
+ - `pack_mode: "pad"` requires `block_size` to be set appropriately
793
+ - `eval_split_ratio` is ignored if `eval_jsonl` is provided
794
+
795
+ This comprehensive documentation should help you understand and configure all parameters in the CPT training system according to your specific needs and constraints.
trainer-kit/CPT/dummy_data.jsonl ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"text": "This is a test sentence for the dummy dataset."}
2
+ {"text": "Another sentence to check if training works."}
3
+ {"text": "We need enough data to form a batch."}
4
+ {"text": "FSDP and LoRA are cool technologies."}
5
+ {"text": "Fine-tuning LLMs is fun and useful."}
6
+ {"text": "This is the end of the dummy dataset."}
trainer-kit/CPT/requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ torch>=2.1.0
3
+ transformers>=4.41.0
4
+ datasets>=2.18.0
5
+ accelerate>=0.30.0
6
+
7
+ # PEFT / QLoRA
8
+ peft>=0.11.1
9
+ bitsandbytes>=0.43.1
10
+
11
+ # Hugging Face Hub (local + download support)
12
+ huggingface_hub>=0.23.0
13
+
14
+ # Config + utilities
15
+ pyyaml>=6.0
16
+ tqdm>=4.66.0
17
+
18
+ # Optional but recommended (tokenizers speed)
19
+ tokenizers>=0.15.0
20
+ safetensors>=0.4.2
21
+ # Optional (for eval)
22
+ rouge-score>=0.1.2
trainer-kit/CPT/run_cpt.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import inspect # Added for Transformers version compatibility
4
+ import math
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple, List
8
+
9
+ import torch
10
+ import yaml
11
+ from datasets import load_dataset, DatasetDict
12
+ from huggingface_hub import snapshot_download
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ PreTrainedTokenizerFast,
17
+ TrainingArguments,
18
+ Trainer,
19
+ TrainerCallback,
20
+ default_data_collator,
21
+ set_seed,
22
+ )
23
+ from transformers.trainer_utils import get_last_checkpoint
24
+ from peft import (
25
+ LoraConfig,
26
+ get_peft_model,
27
+ prepare_model_for_kbit_training,
28
+ PeftModel,
29
+ )
30
+
31
+ try:
32
+ from transformers import BitsAndBytesConfig
33
+ except ImportError: # older transformers
34
+ BitsAndBytesConfig = None
35
+
36
+
37
+ # --------------------------
38
+ # Helpers
39
+ # --------------------------
40
+
41
+ def _dtype_from_str(s: str) -> torch.dtype:
42
+ s = (s or "").lower()
43
+ if s in ("float16", "fp16"):
44
+ return torch.float16
45
+ if s in ("bfloat16", "bf16"):
46
+ return torch.bfloat16
47
+ if s in ("float32", "fp32"):
48
+ return torch.float32
49
+ raise ValueError(f"Unknown torch_dtype: {s}")
50
+
51
+ def _now_iso() -> str:
52
+ return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
53
+
54
+ def _safe_exp(x: float) -> float:
55
+ x = min(float(x), 50.0)
56
+ return float(math.exp(x))
57
+
58
+ def _ensure_dir(p: Path) -> Path:
59
+ p.mkdir(parents=True, exist_ok=True)
60
+ return p
61
+
62
+ def _looks_like_model_dir(p: Path) -> bool:
63
+ if not p.exists() or not p.is_dir():
64
+ return False
65
+ if (p / "config.json").exists():
66
+ return True
67
+ if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
68
+ return True
69
+ return False
70
+
71
+ def _detect_text_field(example: Dict[str, Any]) -> Optional[str]:
72
+ for k, v in example.items():
73
+ if isinstance(v, str) and v.strip():
74
+ return k
75
+ return None
76
+
77
+ def _load_tokenizer(base_dir: Path, use_fast: bool, trust_remote_code: bool):
78
+ try:
79
+ return AutoTokenizer.from_pretrained(
80
+ str(base_dir),
81
+ use_fast=use_fast,
82
+ trust_remote_code=trust_remote_code,
83
+ )
84
+ except ValueError as e:
85
+ if "TokenizersBackend" not in str(e):
86
+ raise
87
+ tok_file = base_dir / "tokenizer.json"
88
+ tok_cfg_path = base_dir / "tokenizer_config.json"
89
+ if not tok_file.exists():
90
+ raise
91
+
92
+ tok_kwargs: Dict[str, Any] = {}
93
+ if tok_cfg_path.exists():
94
+ with tok_cfg_path.open("r", encoding="utf-8") as f:
95
+ tok_cfg = json.load(f)
96
+ for key in ("bos_token", "eos_token", "pad_token", "unk_token", "model_max_length"):
97
+ if tok_cfg.get(key) is not None:
98
+ tok_kwargs[key] = tok_cfg[key]
99
+ extra = tok_cfg.get("additional_special_tokens") or tok_cfg.get("extra_special_tokens")
100
+ if extra:
101
+ tok_kwargs["additional_special_tokens"] = extra
102
+
103
+ return PreTrainedTokenizerFast(tokenizer_file=str(tok_file), **tok_kwargs)
104
+
105
+ def _infer_target_modules(model) -> List[str]:
106
+ names = set()
107
+ for n, _ in model.named_modules():
108
+ names.add(n.split(".")[-1])
109
+
110
+ for group in [
111
+ ["q_proj", "k_proj", "v_proj", "o_proj"],
112
+ ["Wqkv", "out_proj"],
113
+ ["query_key_value", "dense"],
114
+ ["c_attn", "c_proj"],
115
+ ]:
116
+ if all(x in names for x in group):
117
+ return group
118
+
119
+ fallback = [x for x in ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn", "c_proj", "out_proj", "dense"] if x in names]
120
+ if fallback:
121
+ return fallback
122
+
123
+ raise ValueError("Could not auto-infer target_modules. Set peft.target_modules explicitly.")
124
+
125
+ def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
126
+ return cfg.get("model", {}).get("attn_implementation", None)
127
+
128
+
129
+ # --------------------------
130
+ # JSONL Logger Callback
131
+ # --------------------------
132
+
133
+ class JsonlLoggerCallback(TrainerCallback):
134
+ def __init__(self, run_dir: Path):
135
+ self.run_dir = run_dir
136
+ self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
137
+ self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
138
+ self.start_time = None
139
+
140
+ def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
141
+ if self.start_time is None or global_step <= 0 or max_steps <= 0:
142
+ return None
143
+ elapsed = time.time() - self.start_time
144
+ sec_per_step = elapsed / global_step
145
+ remaining = max(0, max_steps - global_step) * sec_per_step
146
+ h = int(remaining // 3600)
147
+ m = int((remaining % 3600) // 60)
148
+ s = int(remaining % 60)
149
+ return f"{h:02d}:{m:02d}:{s:02d}"
150
+
151
+ def on_train_begin(self, args, state, control, **kwargs):
152
+ self.start_time = time.time()
153
+
154
+ def on_log(self, args, state, control, logs=None, **kwargs):
155
+ if not logs:
156
+ return
157
+
158
+ max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
159
+ progress_pct = (100.0 * state.global_step / max_steps) if max_steps > 0 else None
160
+ epoch_pct = None
161
+ if state.epoch is not None and args.num_train_epochs and args.num_train_epochs > 0:
162
+ epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
163
+
164
+ payload = {
165
+ "ts": _now_iso(),
166
+ "event": "train_log",
167
+ "step": int(state.global_step),
168
+ "epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
169
+ "progress_pct": round(progress_pct, 2) if progress_pct is not None else None,
170
+ "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
171
+ "eta": self._eta(int(state.global_step), max_steps),
172
+ "max_grad_norm": getattr(args, "max_grad_norm", None),
173
+ **logs,
174
+ }
175
+
176
+ with self.train_log_path.open("a", encoding="utf-8") as f:
177
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
178
+
179
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
180
+ if not metrics:
181
+ return
182
+ eval_loss = metrics.get("eval_loss", None)
183
+ ppl = _safe_exp(eval_loss) if eval_loss is not None else None
184
+
185
+ payload = {
186
+ "ts": _now_iso(),
187
+ "event": "eval",
188
+ "step": int(state.global_step),
189
+ "epoch": float(state.epoch) if state.epoch is not None else None,
190
+ **metrics,
191
+ "perplexity": ppl,
192
+ }
193
+ with self.eval_log_path.open("a", encoding="utf-8") as f:
194
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
195
+
196
+
197
+ # --------------------------
198
+ # Data Pipeline (EOS + Packing)
199
+ # --------------------------
200
+
201
+ def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
202
+ data_cfg = cfg["data"]
203
+ train_path = data_cfg["train_jsonl"]
204
+ eval_path = data_cfg.get("eval_jsonl", None)
205
+ split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
206
+ text_field = data_cfg.get("text_field", "text")
207
+ block_size = int(data_cfg.get("block_size", 2048))
208
+ shuffle = bool(data_cfg.get("shuffle", True))
209
+ num_proc = int(data_cfg.get("num_proc", 4))
210
+
211
+ pack_mode = str(data_cfg.get("pack_mode", "drop")).lower().strip()
212
+ if pack_mode not in ("drop", "pad"):
213
+ raise ValueError(f"data.pack_mode must be 'drop' or 'pad', got: {pack_mode}")
214
+
215
+ eos_id = tokenizer.eos_token_id
216
+ if eos_id is None:
217
+ raise ValueError("Tokenizer has no eos_token_id; CPT packing needs an EOS delimiter.")
218
+
219
+ if tokenizer.pad_token_id is None:
220
+ # safe default for many causal LMs
221
+ tokenizer.pad_token = tokenizer.eos_token
222
+ pad_id = tokenizer.pad_token_id
223
+
224
+ ds = load_dataset("json", data_files={"train": train_path})
225
+
226
+ if eval_path:
227
+ ds_eval = load_dataset("json", data_files={"eval": eval_path})
228
+ dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
229
+ else:
230
+ if 0.0 < split_ratio < 1.0:
231
+ split = ds["train"].train_test_split(test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)))
232
+ dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
233
+ else:
234
+ dsd = DatasetDict({"train": ds["train"], "eval": None})
235
+
236
+ if text_field not in dsd["train"].column_names:
237
+ auto_field = _detect_text_field(dsd["train"][0])
238
+ if not auto_field:
239
+ raise ValueError(f"Could not find text field. Columns: {dsd['train'].column_names}")
240
+ text_field = auto_field
241
+
242
+ def tokenize_fn(examples):
243
+ out = tokenizer(
244
+ examples[text_field],
245
+ add_special_tokens=False,
246
+ truncation=False,
247
+ padding=False,
248
+ )
249
+ if "token_type_ids" in out:
250
+ del out["token_type_ids"]
251
+ # Add EOS between docs
252
+ out["input_ids"] = [ids + [eos_id] for ids in out["input_ids"]]
253
+ out["attention_mask"] = [m + [1] for m in out["attention_mask"]]
254
+ return out
255
+
256
+ tokenized_train = dsd["train"].map(
257
+ tokenize_fn,
258
+ batched=True,
259
+ num_proc=num_proc,
260
+ remove_columns=dsd["train"].column_names,
261
+ desc="Tokenizing train",
262
+ )
263
+
264
+ tokenized_eval = None
265
+ if dsd["eval"] is not None:
266
+ tokenized_eval = dsd["eval"].map(
267
+ tokenize_fn,
268
+ batched=True,
269
+ num_proc=num_proc,
270
+ remove_columns=dsd["eval"].column_names,
271
+ desc="Tokenizing eval",
272
+ )
273
+
274
+ def group_texts(examples):
275
+ concatenated = {k: sum(examples[k], []) for k in examples.keys()}
276
+ total_length = len(concatenated["input_ids"])
277
+
278
+ if total_length == 0:
279
+ return {"input_ids": [], "attention_mask": [], "labels": []}
280
+
281
+ full_len = (total_length // block_size) * block_size
282
+ blocks_input, blocks_attn, blocks_labels = [], [], []
283
+
284
+ # full blocks
285
+ for i in range(0, full_len, block_size):
286
+ chunk = concatenated["input_ids"][i:i + block_size]
287
+ attn = concatenated["attention_mask"][i:i + block_size]
288
+ blocks_input.append(chunk)
289
+ blocks_attn.append(attn)
290
+ blocks_labels.append(chunk.copy())
291
+
292
+ # remainder
293
+ remainder = total_length - full_len
294
+ if remainder > 0 and pack_mode == "pad":
295
+ chunk = concatenated["input_ids"][full_len:full_len + remainder]
296
+ attn = concatenated["attention_mask"][full_len:full_len + remainder]
297
+
298
+ pad_len = block_size - remainder
299
+ chunk_padded = chunk + [pad_id] * pad_len
300
+ attn_padded = attn + [0] * pad_len
301
+
302
+ labels = chunk_padded.copy()
303
+ labels[-pad_len:] = [-100] * pad_len # loss mask
304
+
305
+ blocks_input.append(chunk_padded)
306
+ blocks_attn.append(attn_padded)
307
+ blocks_labels.append(labels)
308
+
309
+ return {
310
+ "input_ids": blocks_input,
311
+ "attention_mask": blocks_attn,
312
+ "labels": blocks_labels,
313
+ }
314
+
315
+ tokenized_train = tokenized_train.map(
316
+ group_texts,
317
+ batched=True,
318
+ num_proc=num_proc,
319
+ desc=f"Packing train blocks (mode={pack_mode})",
320
+ )
321
+ if tokenized_eval is not None:
322
+ tokenized_eval = tokenized_eval.map(
323
+ group_texts,
324
+ batched=True,
325
+ num_proc=num_proc,
326
+ desc=f"Packing eval blocks (mode={pack_mode})",
327
+ )
328
+
329
+ if len(tokenized_train) == 0:
330
+ raise ValueError(
331
+ "Train dataset is empty after packing. "
332
+ "Either increase data, reduce block_size, or set data.pack_mode='pad'."
333
+ )
334
+
335
+ if shuffle:
336
+ tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
337
+
338
+ return tokenized_train, tokenized_eval
339
+
340
+
341
+ # --------------------------
342
+ # Model Loading + PEFT
343
+ # --------------------------
344
+
345
+ def _select_model_loader(base_dir: Path):
346
+ cfg_path = base_dir / "config.json"
347
+ if not cfg_path.exists():
348
+ return {"kind": "causal", "arch": None}
349
+ with cfg_path.open("r", encoding="utf-8") as f:
350
+ cfg = json.load(f)
351
+ arch = cfg.get("architectures") or []
352
+ arch_name = arch[0] if arch else None
353
+ if any("ForConditionalGeneration" in a for a in arch):
354
+ return {"kind": "conditional", "arch": arch_name}
355
+ return {"kind": "causal", "arch": arch_name}
356
+
357
+ def _resolve_model_class(arch_name: str):
358
+ import transformers
359
+ cls = getattr(transformers, arch_name, None)
360
+ if cls is None:
361
+ raise ValueError(f"Model class '{arch_name}' is not available in installed transformers.")
362
+ return cls
363
+
364
+
365
+ def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
366
+ model_cfg = cfg["model"]
367
+ trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
368
+ use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
369
+ device_map = model_cfg.get("device_map", "auto")
370
+
371
+ tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
372
+ if tokenizer.pad_token is None:
373
+ tokenizer.pad_token = tokenizer.eos_token
374
+
375
+ torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
376
+ use_4bit = bool(model_cfg.get("use_4bit", False))
377
+
378
+ quant_cfg = None
379
+ if use_4bit:
380
+ if BitsAndBytesConfig is None:
381
+ raise ImportError("BitsAndBytesConfig is not available in this transformers version.")
382
+ quant_cfg = BitsAndBytesConfig(
383
+ load_in_4bit=True,
384
+ bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
385
+ bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)),
386
+ bnb_4bit_compute_dtype=_dtype_from_str(model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")),
387
+ )
388
+
389
+ attn_impl = _choose_attn_impl(cfg)
390
+ model_meta = _select_model_loader(base_dir)
391
+
392
+ try:
393
+ if model_meta["kind"] == "conditional":
394
+ model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
395
+ if model_cls is None:
396
+ raise ValueError("Conditional model architecture not specified in config.json.")
397
+ model = model_cls.from_pretrained(
398
+ str(base_dir),
399
+ device_map=device_map,
400
+ trust_remote_code=trust_remote_code,
401
+ low_cpu_mem_usage=True,
402
+ torch_dtype=(torch_dtype if not use_4bit else None),
403
+ quantization_config=quant_cfg,
404
+ attn_implementation=attn_impl,
405
+ )
406
+ else:
407
+ model = AutoModelForCausalLM.from_pretrained(
408
+ str(base_dir),
409
+ device_map=device_map,
410
+ trust_remote_code=trust_remote_code,
411
+ low_cpu_mem_usage=True,
412
+ torch_dtype=(torch_dtype if not use_4bit else None),
413
+ quantization_config=quant_cfg,
414
+ attn_implementation=attn_impl,
415
+ )
416
+ except Exception as e:
417
+ if attn_impl is not None:
418
+ print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
419
+ print("[warn] Falling back to default attention implementation.")
420
+ if model_meta["kind"] == "conditional":
421
+ model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
422
+ if model_cls is None:
423
+ raise ValueError("Conditional model architecture not specified in config.json.")
424
+ model = model_cls.from_pretrained(
425
+ str(base_dir),
426
+ device_map=device_map,
427
+ trust_remote_code=trust_remote_code,
428
+ low_cpu_mem_usage=True,
429
+ torch_dtype=(torch_dtype if not use_4bit else None),
430
+ quantization_config=quant_cfg,
431
+ )
432
+ else:
433
+ model = AutoModelForCausalLM.from_pretrained(
434
+ str(base_dir),
435
+ device_map=device_map,
436
+ trust_remote_code=trust_remote_code,
437
+ low_cpu_mem_usage=True,
438
+ torch_dtype=(torch_dtype if not use_4bit else None),
439
+ quantization_config=quant_cfg,
440
+ )
441
+
442
+ return model, tokenizer
443
+
444
+
445
+ def apply_peft(cfg: Dict[str, Any], model):
446
+ peft_cfg = cfg["peft"]
447
+ model_cfg = cfg["model"]
448
+ tr_cfg = cfg["train"]
449
+
450
+ if not bool(peft_cfg.get("enabled", True)):
451
+ return model, None
452
+
453
+ use_4bit = bool(model_cfg.get("use_4bit", False))
454
+ gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
455
+
456
+ if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
457
+ model.gradient_checkpointing_enable()
458
+ if hasattr(model, "config"):
459
+ model.config.use_cache = False
460
+
461
+ if use_4bit:
462
+ model = prepare_model_for_kbit_training(
463
+ model,
464
+ use_gradient_checkpointing=gradient_checkpointing,
465
+ )
466
+
467
+ target_modules = peft_cfg.get("target_modules", "auto")
468
+ if target_modules == "auto":
469
+ target_modules = _infer_target_modules(model)
470
+
471
+ lora_config = LoraConfig(
472
+ r=int(peft_cfg.get("r", 16)),
473
+ lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
474
+ lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
475
+ bias=str(peft_cfg.get("bias", "none")),
476
+ task_type="CAUSAL_LM",
477
+ target_modules=target_modules,
478
+ )
479
+ model = get_peft_model(model, lora_config)
480
+ return model, lora_config
481
+
482
+
483
+ # --------------------------
484
+ # Merge Logic
485
+ # --------------------------
486
+
487
+ def merge_adapter(cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path):
488
+ print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
489
+
490
+ model_cfg = cfg["model"]
491
+ merge_cfg = cfg.get("merge", {})
492
+ trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
493
+ use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
494
+
495
+ merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
496
+ max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
497
+
498
+ model_meta = _select_model_loader(base_dir)
499
+ if model_meta["kind"] == "conditional":
500
+ base_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
501
+ if base_cls is None:
502
+ raise ValueError("Conditional model architecture not specified in config.json.")
503
+ base = base_cls.from_pretrained(
504
+ str(base_dir),
505
+ torch_dtype=merged_dtype,
506
+ device_map="cpu",
507
+ low_cpu_mem_usage=True,
508
+ trust_remote_code=trust_remote_code,
509
+ )
510
+ else:
511
+ base = AutoModelForCausalLM.from_pretrained(
512
+ str(base_dir),
513
+ torch_dtype=merged_dtype,
514
+ device_map="cpu",
515
+ low_cpu_mem_usage=True,
516
+ trust_remote_code=trust_remote_code,
517
+ )
518
+
519
+ merged = PeftModel.from_pretrained(base, str(adapter_dir))
520
+ merged = merged.merge_and_unload()
521
+
522
+ _ensure_dir(final_dir)
523
+ # Fix for transformers weight conversion bug with quantized models
524
+ # Clear weight conversions to avoid NotImplementedError in reverse_transform
525
+ if hasattr(merged, '_weight_conversions'):
526
+ merged._weight_conversions = []
527
+ merged.save_pretrained(
528
+ str(final_dir),
529
+ safe_serialization=True,
530
+ max_shard_size=max_shard_size
531
+ )
532
+
533
+ tok = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
534
+ if tok.pad_token is None:
535
+ tok.pad_token = tok.eos_token
536
+ tok.save_pretrained(str(final_dir))
537
+
538
+ print("--- Merge complete ---")
539
+
540
+
541
+ # --------------------------
542
+ # Main
543
+ # --------------------------
544
+
545
+ def main():
546
+ ap = argparse.ArgumentParser()
547
+ ap.add_argument("--config", required=True, help="Path to YAML config")
548
+ ap.add_argument("--merge-only", action="store_true", help="Skip training, just merge adapter")
549
+ args = ap.parse_args()
550
+
551
+ with open(args.config, "r", encoding="utf-8") as f:
552
+ cfg = yaml.safe_load(f)
553
+
554
+ run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
555
+ _ensure_dir(run_dir / "logs")
556
+
557
+ with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
558
+ yaml.safe_dump(cfg, f, sort_keys=False)
559
+
560
+ model_cfg = cfg["model"]
561
+ repo_id = str(model_cfg["repo_id"]).strip()
562
+ repo_path = Path(repo_id)
563
+
564
+ # ✅ Local model path -> load directly; no download
565
+ if repo_path.exists() and repo_path.is_dir():
566
+ base_dir = repo_path
567
+ if not _looks_like_model_dir(base_dir):
568
+ raise ValueError(f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}")
569
+ else:
570
+ # HF repo_id -> download into run_dir/base_local_dir
571
+ base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
572
+ if not _looks_like_model_dir(base_dir):
573
+ print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
574
+ snapshot_download(
575
+ repo_id=repo_id,
576
+ revision=model_cfg.get("revision", None),
577
+ local_dir=str(base_dir),
578
+ local_dir_use_symlinks=False,
579
+ )
580
+
581
+ ckpt_dir = _ensure_dir(run_dir / "checkpoints")
582
+ best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
583
+
584
+ merge_cfg = cfg.get("merge", {}) or {}
585
+ if merge_cfg.get("output_dir"):
586
+ od = Path(str(merge_cfg["output_dir"]))
587
+ final_dir = od if od.is_absolute() else (run_dir / od)
588
+ else:
589
+ final_dir = run_dir / "final_model"
590
+
591
+ # Merge-only
592
+ if args.merge_only:
593
+ if not _looks_like_model_dir(best_adapter_dir):
594
+ raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
595
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
596
+ return
597
+
598
+ # Training
599
+ set_seed(int(cfg["run"].get("seed", 42)))
600
+
601
+ model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
602
+ model, _ = apply_peft(cfg, model)
603
+
604
+ train_ds, eval_ds = build_datasets(cfg, tokenizer)
605
+
606
+ tr_cfg = cfg["train"]
607
+
608
+ dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
609
+ use_fp16 = (dtype == torch.float16)
610
+ use_bf16 = (dtype == torch.bfloat16)
611
+
612
+ max_steps = int(tr_cfg.get("max_steps", 0))
613
+ num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
614
+
615
+ # --- Dynamic evaluation strategy parameter handling ---
616
+ ta_params = inspect.signature(TrainingArguments.__init__).parameters
617
+ eval_key = "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
618
+
619
+ desired_ta_kwargs = dict(
620
+ output_dir=str(ckpt_dir),
621
+ max_steps=max_steps if max_steps > 0 else -1,
622
+ num_train_epochs=num_train_epochs,
623
+
624
+ per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
625
+ per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1))),
626
+ gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
627
+
628
+ learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
629
+ weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
630
+ warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
631
+ lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
632
+
633
+ optim=str(tr_cfg.get("optim", "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch")),
634
+ max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
635
+
636
+ logging_steps=int(tr_cfg.get("logging_steps", 10)),
637
+
638
+ save_strategy=str(tr_cfg.get("save_strategy", "steps")),
639
+ save_steps=int(tr_cfg.get("save_steps", 200)),
640
+ save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
641
+
642
+ eval_steps=int(tr_cfg.get("eval_steps", 200)),
643
+
644
+ load_best_model_at_end=bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False,
645
+ metric_for_best_model="eval_loss",
646
+ greater_is_better=False,
647
+
648
+ fp16=use_fp16,
649
+ bf16=use_bf16,
650
+
651
+ report_to=[],
652
+ remove_unused_columns=False,
653
+ save_safetensors=True,
654
+ overwrite_output_dir=False,
655
+ )
656
+
657
+ # Set the correct argument name for this transformers version
658
+ desired_ta_kwargs[eval_key] = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
659
+ ta_kwargs = {k: v for k, v in desired_ta_kwargs.items() if k in ta_params}
660
+
661
+ training_args = TrainingArguments(**ta_kwargs)
662
+
663
+ trainer_params = inspect.signature(Trainer.__init__).parameters
664
+ desired_trainer_kwargs = dict(
665
+ model=model,
666
+ args=training_args,
667
+ train_dataset=train_ds,
668
+ eval_dataset=eval_ds,
669
+ tokenizer=tokenizer,
670
+ processing_class=tokenizer,
671
+ data_collator=default_data_collator,
672
+ callbacks=[JsonlLoggerCallback(run_dir)],
673
+ )
674
+ trainer_kwargs = {k: v for k, v in desired_trainer_kwargs.items() if k in trainer_params}
675
+ trainer = Trainer(**trainer_kwargs)
676
+
677
+ # Resume
678
+ resume_from = tr_cfg.get("resume_from_checkpoint", None)
679
+ if resume_from == "auto":
680
+ last = get_last_checkpoint(str(ckpt_dir))
681
+ resume_from = last if last else None
682
+ if resume_from:
683
+ print(f"Resuming from {resume_from}")
684
+
685
+ print("Starting training...")
686
+ trainer.train(resume_from_checkpoint=resume_from)
687
+
688
+ trainer.save_model(str(best_adapter_dir))
689
+ print(f"Saved best adapter -> {best_adapter_dir}")
690
+
691
+ if eval_ds is not None:
692
+ metrics = trainer.evaluate()
693
+ eval_loss = metrics.get("eval_loss", None)
694
+ metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
695
+ with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
696
+ json.dump(metrics, f, indent=2)
697
+ print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
698
+
699
+ if bool(cfg.get("merge", {}).get("enabled", False)):
700
+ del trainer, model
701
+ torch.cuda.empty_cache()
702
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
703
+ else:
704
+ print("Merge disabled. Run with --merge-only later if needed.")
705
+
706
+
707
+ if __name__ == "__main__":
708
+ main()
trainer-kit/SFT-14b/.DS_Store ADDED
Binary file (6.15 kB). View file
 
trainer-kit/SFT-14b/config_instruct.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ run_dir: "./runs/instruct_run_14b_v1"
3
+ seed: 42
4
+
5
+ # WandB integration for experiment tracking
6
+ wandb:
7
+ enabled: true # Set to true to enable wandb logging
8
+ project: "sft-training" # WandB project name
9
+ entity: null # WandB entity/team (optional)
10
+ name: null # Run name (optional, will auto-generate if null)
11
+ tags: ["sft-lora", "instruction-tuning"] # List of tags for the run (e.g., ["lora", "qlora", "experiment-1"])
12
+ notes: null # Run description/notes (optional)
13
+
14
+ model:
15
+ # Use local Qwen2.5-Coder-14B model
16
+ repo_id: "./runs/cpt_run_14b/merged_14b_cpt_lora"
17
+ revision: null
18
+
19
+ # Used only when repo_id is a HF repo (not a local path)
20
+ base_local_dir: "base_model"
21
+
22
+ trust_remote_code: true
23
+ tokenizer_use_fast: true
24
+ device_map: "auto"
25
+
26
+ torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
27
+
28
+ # QLoRA
29
+ use_4bit: false
30
+ bnb_4bit_quant_type: "nf4"
31
+ bnb_4bit_use_double_quant: false
32
+ bnb_4bit_compute_dtype: "bfloat16"
33
+
34
+ # optional: "flash_attention_2" | "sdpa" | null
35
+ attn_implementation: null
36
+
37
+ data:
38
+ train_jsonl: "sft_dataset.jsonl"
39
+ eval_jsonl: null
40
+ eval_split_ratio: 0.1
41
+
42
+ # Field names in your JSONL data
43
+ instruction_field: "instruction" # This will be the system prompt
44
+ input_field: "input" # This is the task description
45
+ output_field: "output" # This is the analysis + selection
46
+
47
+ # Formatting options
48
+ format_type: "custom" # "chatml" | "alpaca" | "custom"
49
+
50
+ # For chatml format
51
+ system_prompt: |
52
+ You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
53
+
54
+ ## Output Format
55
+
56
+ ##OUTPUT
57
+ Explain the data flow and why each component must change:
58
+ - Flow: [Input → Processing → Output with arrows]
59
+ - For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
60
+ - Explain coupling between components
61
+
62
+ ##SELECT
63
+ modify::crates/path/to/file.rs::impl::ComponentName
64
+ add::crates/another/file.rs::function::AnotherComponent
65
+ <EOS>
66
+
67
+ ## Rules
68
+
69
+ 1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
70
+ 2. Use `::` for nested items: `status::StructName::Type::Name`
71
+ 3. Always explain "must change because" and "without this"
72
+ 3. Types of components: function, struct, enum, impl, trait
73
+ 4. If there is extra information (e.g., enum variants), include that too.
74
+ 5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
75
+
76
+ ## Example
77
+
78
+ ##TASK
79
+ Add webhook subscription support
80
+
81
+ ##OUTPUT
82
+ The webhook system routes events via EventClass enum. Flow: webhook → EventClass → handler → processing. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must add Subscriptions variant because it defines event routing—without this, subscription events cannot be processed. The SubscriptionStatus impl (crates/common_enums/src/transformers.rs::SubscriptionStatus) must map to EventType because it converts status to events—without this, status changes don't trigger webhooks. These are coupled: EventClass routes to handlers that use SubscriptionStatus mappings.
83
+
84
+ ##SELECT
85
+ crates/common_enums/src/enums.rs::EventClass
86
+ crates/common_enums/src/transformers.rs::SubscriptionStatus
87
+ <EOS>
88
+
89
+ # For custom format (only used when format_type="custom")
90
+ custom_template: "##INSTRUCTION\n{instruction}<|im_end|>\n##TASK\n{input}<|im_end|>\n##OUTPUT\n{output}<|im_end|>"
91
+
92
+ max_length: 2048
93
+ shuffle: true
94
+ num_proc: 4
95
+
96
+ peft:
97
+ enabled: true
98
+ r: 16
99
+ lora_alpha: 32
100
+ lora_dropout: 0.05
101
+ bias: "none"
102
+ target_modules: "auto"
103
+
104
+ train:
105
+ # max_steps: 10
106
+ num_train_epochs: 6
107
+
108
+ per_device_train_batch_size: 1
109
+ per_device_eval_batch_size: 1
110
+ gradient_accumulation_steps: 8
111
+
112
+ learning_rate: 2e-4
113
+ weight_decay: 0.0
114
+ warmup_ratio: 0.08
115
+ lr_scheduler_type: "cosine"
116
+
117
+ optim: "adamw_torch" # ✅ Changed from paged_adamw_8bit (requires use_4bit=true)
118
+ max_grad_norm: 1.0
119
+ gradient_checkpointing: true
120
+
121
+ logging_steps: 2
122
+ save_strategy: "steps"
123
+ save_steps: 500
124
+ save_total_limit: 20
125
+
126
+ evaluation_strategy: "steps"
127
+ eval_steps: 100
128
+ load_best_model_at_end: true
129
+
130
+ # Early stopping
131
+ early_stopping:
132
+ enabled: true
133
+ patience: 3 # Number of evaluations with no improvement before stopping
134
+ min_delta: 0.001 # Minimum change to qualify as improvement
135
+ metric: "eval_loss" # Metric to monitor
136
+ mode: "min" # "min" for loss, "max" for accuracy/etc.
137
+
138
+ resume_from_checkpoint: "auto"
139
+
140
+ merge:
141
+ enabled: true
142
+ merged_dtype: "float16"
143
+ max_shard_size: "2GB"
144
+ output_dir: "./merged_14b_instruct_lora"
trainer-kit/SFT-14b/instruct_data.jsonl ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {"input": "##TASK\nAdd subscription support to webhook event metadata and outgoing webhook content systems to enable proper subscription lifecycle event handling and webhook delivery for subscription-related operations.", "output": "##OUTPUT\nThe webhook system requires subscription event support to handle subscription lifecycle events (creation, payment, cancellation) and deliver them to merchants via outgoing webhooks. This affects the event metadata enumeration and webhook content transformation layers. The complete flow is: subscription event occurs → EventMetadata::Subscription variant created with subscription_id, invoice_id, payment_id → stored in database → retrieved for webhook delivery → OutgoingWebhookEventContent::foreign_from converts EventMetadata to OutgoingWebhookEventContent::Subscription → StripeWebhookObject::From transforms to StripeWebhookObject::Subscriptions for Stripe compatibility. The OutgoingWebhookEventContent::foreign_from function (router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from) must add the Subscription variant match arm because it's the bridge from storage EventMetadata to webhook content representation—without this, subscription events cannot be converted to webhook content and would be silently dropped during webhook processing. The StripeWebhookObject::From implementation (router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From) must handle OutgoingWebhookContent::SubscriptionDetails because it provides Stripe-compatible webhook formatting—without this, merchants using Stripe webhook endpoints would receive malformed or missing subscription events. The EventMetadata enum (implicitly referenced in foreign_from) requires the Subscription variant to exist in the diesel_models layer, which enables persistent storage of subscription event references. These components form a tightly coupled webhook pipeline where each depends on the previous: EventMetadata → OutgoingWebhookEventContent → StripeWebhookObject. If foreign_from lacks the Subscription match arm, subscription events stored as EventMetadata::Subscription would cause panic during webhook delivery. If StripeWebhookObject::From lacks SubscriptionDetails handling, Stripe webhook endpoints would receive unparseable responses, breaking merchant integrations that rely on Stripe-compatible subscription webhooks for revenue recognition and customer management systems.\n\n##SELECT\ncrates/router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from\ncrates/router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From\n<EOS>"}
2
+ {"input": "##TASK\nUpdate subscription response mapping to use common_enums::SubscriptionStatus instead of subscription_types::SubscriptionStatus for consistent enum usage across the codebase.", "output": "##OUTPUT\nThe subscription system needs to standardize on a single SubscriptionStatus enum type to avoid type inconsistencies and import conflicts. The current flow is: subscription handler creates SubscriptionResponse → calls to_subscription_response method → converts string status to subscription_types::SubscriptionStatus → returns response. The to_subscription_response method (crates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response) must change because it's the conversion point where the status enum is instantiated. Using common_enums::SubscriptionStatus ensures consistency with other parts of the system that already use this common enum type. Without this change, the subscription module would maintain its own duplicate enum, creating potential mismatches and requiring conversions throughout the codebase when interacting with other modules that expect the common enum type.\n\n##SELECT\ncrates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response\n<EOS>"}
3
+ {"input": "##TASK\nAdd subscription support to webhook event system by including Subscriptions in EventClass enum and implementing SubscriptionStatus to EventType mapping for proper webhook routing.", "output": "##OUTPUT\nThe webhook system needs to handle subscription-related events for proper event classification and routing. The current flow is: webhook event received → EventClass determines routing category → EventType specifies the exact event type → handlers process based on event type. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must include Subscriptions because it's the primary classification that determines which subsystem handles the event. The SubscriptionStatus transformer (crates/common_enums/src/transformers.rs::impl::Option<EventType>) must map subscription status changes to appropriate EventType values because it converts domain state changes into webhook events. Without EventClass::Subscriptions, subscription events would be misclassified or unhandled. Without the SubscriptionStatus mapping, subscription status changes wouldn't trigger appropriate webhook notifications, breaking integration with external systems that depend on subscription lifecycle events.\n\n##SELECT\ncrates/common_enums/src/enums.rs::EventClass\ncrates/common_enums/src/transformers.rs::impl::Option<EventType>\n<EOS>"}
4
+ {"input": "##TASK\nAdd subscription ID support to webhook object reference system for proper webhook payload identification and routing.", "output": "##OUTPUT\nThe webhook system needs to reference subscription entities in event payloads for proper event correlation and processing. The current flow is: webhook event generated → ObjectReferenceId identifies the affected entity → webhook payload includes reference → consumers process based on entity type. The ObjectReferenceId enum (crates/api_models/src/webhooks.rs::ObjectReferenceId) must include SubscriptionId because it's the type-safe identifier used throughout the webhook payload structure to specify which subscription triggered the event. Without SubscriptionId, webhook events related to subscriptions couldn't properly reference the subscription entity, making it impossible for consumers to correlate events with specific subscriptions. This would break webhook consumers that need to update their local state or trigger business logic based on subscription events.\n\n##SELECT\ncrates/api_models/src/webhooks.rs::ObjectReferenceId\n<EOS>"}
trainer-kit/SFT-14b/requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ torch>=2.1.0
3
+ transformers>=4.41.0
4
+ datasets>=2.18.0
5
+ accelerate>=0.30.0
6
+
7
+ # PEFT / QLoRA
8
+ peft>=0.11.1
9
+ bitsandbytes>=0.43.1
10
+
11
+ # Hugging Face Hub (local + download support)
12
+ huggingface_hub>=0.23.0
13
+
14
+ # Config + utilities
15
+ pyyaml>=6.0
16
+ tqdm>=4.66.0
17
+
18
+ # Optional but recommended (tokenizers speed)
19
+ tokenizers>=0.15.0
20
+ safetensors>=0.4.2
21
+
22
+ # Experiment tracking
23
+ wandb>=0.16.0
trainer-kit/SFT-14b/run_instruct.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import inspect # Added for Transformers version compatibility
4
+ import math
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple, List
8
+
9
+ import torch
10
+ import yaml
11
+ from datasets import load_dataset, DatasetDict
12
+ from huggingface_hub import snapshot_download
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForCausalLM,
16
+ BitsAndBytesConfig,
17
+ TrainingArguments,
18
+ Trainer,
19
+ TrainerCallback,
20
+ EarlyStoppingCallback,
21
+ default_data_collator,
22
+ set_seed,
23
+ )
24
+ from transformers.trainer_utils import get_last_checkpoint
25
+ from peft import (
26
+ LoraConfig,
27
+ get_peft_model,
28
+ prepare_model_for_kbit_training,
29
+ PeftModel,
30
+ )
31
+
32
+ try:
33
+ import wandb
34
+ WANDB_AVAILABLE = True
35
+ except ImportError:
36
+ WANDB_AVAILABLE = False
37
+ wandb = None
38
+
39
+
40
+ # --------------------------
41
+ # Helpers
42
+ # --------------------------
43
+
44
+
45
+ def _dtype_from_str(s: str) -> torch.dtype:
46
+ s = (s or "").lower()
47
+ if s in ("float16", "fp16"):
48
+ return torch.float16
49
+ if s in ("bfloat16", "bf16"):
50
+ return torch.bfloat16
51
+ if s in ("float32", "fp32"):
52
+ return torch.float32
53
+ raise ValueError(f"Unknown torch_dtype: {s}")
54
+
55
+
56
+ def _now_iso() -> str:
57
+ return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
58
+
59
+
60
+ def _safe_exp(x: float) -> float:
61
+ x = min(float(x), 50.0)
62
+ return float(math.exp(x))
63
+
64
+
65
+ def _ensure_dir(p: Path) -> Path:
66
+ p.mkdir(parents=True, exist_ok=True)
67
+ return p
68
+
69
+
70
+ def _looks_like_model_dir(p: Path) -> bool:
71
+ if not p.exists() or not p.is_dir():
72
+ return False
73
+ if (p / "config.json").exists():
74
+ return True
75
+ if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
76
+ return True
77
+ return False
78
+
79
+
80
+ def _infer_target_modules(model) -> List[str]:
81
+ names = set()
82
+ for n, _ in model.named_modules():
83
+ names.add(n.split(".")[-1])
84
+
85
+ for group in [
86
+ ["q_proj", "k_proj", "v_proj", "o_proj"],
87
+ ["Wqkv", "out_proj"],
88
+ ["query_key_value", "dense"],
89
+ ["c_attn", "c_proj"],
90
+ ]:
91
+ if all(x in names for x in group):
92
+ return group
93
+
94
+ fallback = [
95
+ x
96
+ for x in [
97
+ "q_proj",
98
+ "k_proj",
99
+ "v_proj",
100
+ "o_proj",
101
+ "c_attn",
102
+ "c_proj",
103
+ "out_proj",
104
+ "dense",
105
+ ]
106
+ if x in names
107
+ ]
108
+ if fallback:
109
+ return fallback
110
+
111
+ raise ValueError(
112
+ "Could not auto-infer target_modules. Set peft.target_modules explicitly."
113
+ )
114
+
115
+
116
+ def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
117
+ return cfg.get("model", {}).get("attn_implementation", None)
118
+
119
+
120
+ # --------------------------
121
+ # Wandb Integration
122
+ # --------------------------
123
+
124
+ def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
125
+ """Initialize Wandb if enabled in configuration."""
126
+ wandb_cfg = cfg.get("wandb", {})
127
+
128
+ if not wandb_cfg.get("enabled", False):
129
+ print("Wandb logging disabled")
130
+ return None
131
+
132
+ if not WANDB_AVAILABLE:
133
+ print("Wandb not available. Install with: pip install wandb")
134
+ return None
135
+
136
+ # Extract wandb configuration
137
+ project = wandb_cfg.get("project", "sft-training")
138
+ entity = wandb_cfg.get("entity", None)
139
+ name = wandb_cfg.get("name", None)
140
+ tags = wandb_cfg.get("tags", [])
141
+ notes = wandb_cfg.get("notes", None)
142
+
143
+ # Initialize wandb
144
+ try:
145
+ wandb.init(
146
+ project=project,
147
+ entity=entity,
148
+ name=name,
149
+ tags=tags,
150
+ notes=notes,
151
+ dir=str(run_dir),
152
+ config={
153
+ "model": cfg.get("model", {}),
154
+ "data": cfg.get("data", {}),
155
+ "peft": cfg.get("peft", {}),
156
+ "train": cfg.get("train", {}),
157
+ "run_dir": str(run_dir),
158
+ }
159
+ )
160
+ print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
161
+ return wandb
162
+ except Exception as e:
163
+ print(f"Failed to initialize Wandb: {e}")
164
+ return None
165
+
166
+
167
+ def finish_wandb():
168
+ """Finish Wandb run if active."""
169
+ if WANDB_AVAILABLE and wandb.run is not None:
170
+ wandb.finish()
171
+ print("Wandb run finished")
172
+
173
+
174
+ # --------------------------
175
+ # JSONL Logger Callback
176
+ # --------------------------
177
+
178
+
179
+ class JsonlLoggerCallback(TrainerCallback):
180
+ def __init__(self, run_dir: Path):
181
+ self.run_dir = run_dir
182
+ self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
183
+ self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
184
+ self.start_time = None
185
+
186
+ def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
187
+ if self.start_time is None or global_step <= 0 or max_steps <= 0:
188
+ return None
189
+ elapsed = time.time() - self.start_time
190
+ sec_per_step = elapsed / global_step
191
+ remaining = max(0, max_steps - global_step) * sec_per_step
192
+ h = int(remaining // 3600)
193
+ m = int((remaining % 3600) // 60)
194
+ s = int(remaining % 60)
195
+ return f"{h:02d}:{m:02d}:{s:02d}"
196
+
197
+ def on_train_begin(self, args, state, control, **kwargs):
198
+ self.start_time = time.time()
199
+
200
+ def on_log(self, args, state, control, logs=None, **kwargs):
201
+ if not logs:
202
+ return
203
+
204
+ max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
205
+ progress_pct = (
206
+ (100.0 * state.global_step / max_steps) if max_steps > 0 else None
207
+ )
208
+ epoch_pct = None
209
+ if (
210
+ state.epoch is not None
211
+ and args.num_train_epochs
212
+ and args.num_train_epochs > 0
213
+ ):
214
+ epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
215
+
216
+ payload = {
217
+ "ts": _now_iso(),
218
+ "event": "train_log",
219
+ "step": int(state.global_step),
220
+ "epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
221
+ "progress_pct": (
222
+ round(progress_pct, 2) if progress_pct is not None else None
223
+ ),
224
+ "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
225
+ "eta": self._eta(int(state.global_step), max_steps),
226
+ "max_grad_norm": getattr(args, "max_grad_norm", None),
227
+ **logs,
228
+ }
229
+
230
+ with self.train_log_path.open("a", encoding="utf-8") as f:
231
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
232
+
233
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
234
+ if not metrics:
235
+ return
236
+ eval_loss = metrics.get("eval_loss", None)
237
+ ppl = _safe_exp(eval_loss) if eval_loss is not None else None
238
+
239
+ payload = {
240
+ "ts": _now_iso(),
241
+ "event": "eval",
242
+ "step": int(state.global_step),
243
+ "epoch": float(state.epoch) if state.epoch is not None else None,
244
+ **metrics,
245
+ "perplexity": ppl,
246
+ }
247
+ with self.eval_log_path.open("a", encoding="utf-8") as f:
248
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
249
+
250
+
251
+ # --------------------------
252
+ # Data Pipeline (Instruction Formatting)
253
+ # --------------------------
254
+
255
+
256
+ def format_instruction(
257
+ example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
258
+ ) -> Dict[str, Any]:
259
+ """
260
+ Format instruction data for training.
261
+ Supports multiple formats: chatml, alpaca, custom templates.
262
+ Returns both formatted text and the response start position for loss masking.
263
+ """
264
+ data_cfg = cfg["data"]
265
+ format_type = data_cfg.get("format_type", "chatml")
266
+
267
+ # Get field names from config
268
+ input_field = data_cfg.get("input_field", "input")
269
+ output_field = data_cfg.get("output_field", "output")
270
+ instruction_field = data_cfg.get("instruction_field", "instruction")
271
+
272
+ # Extract text from example
273
+ instruction = example.get(instruction_field, "")
274
+ input_text = example.get(input_field, "")
275
+ output_text = example.get(output_field, "")
276
+
277
+ if format_type == "chatml":
278
+ # ChatML format with special tokens
279
+ system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
280
+
281
+ messages = []
282
+ if system_prompt:
283
+ messages.append({"role": "system", "content": system_prompt})
284
+
285
+ user_content = instruction
286
+ if input_text:
287
+ user_content = f"{instruction}\n\n{input_text}"
288
+ messages.append({"role": "user", "content": user_content})
289
+ messages.append({"role": "assistant", "content": output_text})
290
+
291
+ # Apply chat template
292
+ formatted_text = tokenizer.apply_chat_template(
293
+ messages, tokenize=False, add_generation_prompt=False
294
+ )
295
+
296
+ # Add EOS token if not present
297
+ if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
298
+ formatted_text += tokenizer.eos_token
299
+
300
+ # Find where the assistant response starts for loss masking
301
+ # Try multiple possible markers for robustness
302
+ markers = ["<|im_start|>assistant", "<|assistant|>", "Assistant:", "assistant\n"]
303
+ response_start_pos = -1
304
+
305
+ for marker in markers:
306
+ idx = formatted_text.find(marker)
307
+ if idx != -1:
308
+ # Find the newline after the marker
309
+ newline_idx = formatted_text.find("\n", idx)
310
+ if newline_idx != -1:
311
+ response_start_pos = newline_idx + 1
312
+ break
313
+
314
+ # Fallback: find where the actual output starts
315
+ if response_start_pos == -1:
316
+ output_idx = formatted_text.find(output_text)
317
+ if output_idx != -1:
318
+ response_start_pos = output_idx
319
+ else:
320
+ # Last resort: split at last occurrence of newline before end
321
+ response_start_pos = formatted_text.rfind("\n", 0, len(formatted_text) - len(output_text)) + 1
322
+
323
+ elif format_type == "alpaca":
324
+ # Alpaca format
325
+ if input_text:
326
+ prefix = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
327
+ else:
328
+ prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
329
+
330
+ formatted_text = prefix + output_text
331
+
332
+ # Add EOS token
333
+ if tokenizer.eos_token:
334
+ formatted_text += tokenizer.eos_token
335
+
336
+ # Response starts after the prefix
337
+ response_start_pos = len(prefix)
338
+
339
+ elif format_type == "custom":
340
+ # Custom template from config
341
+ template = data_cfg.get("custom_template", "{instruction}\n{input}\n{output}")
342
+
343
+ # For custom format, use system_prompt as instruction if instruction field is empty
344
+ if not instruction:
345
+ instruction = data_cfg.get("system_prompt", "")
346
+
347
+ # For custom templates, we need to find where {output} starts
348
+ template_parts = template.split("{output}")
349
+ prefix = template_parts[0].format(instruction=instruction, input=input_text)
350
+ formatted_text = prefix + output_text
351
+
352
+ # Add EOS token if not already in template
353
+ if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
354
+ formatted_text += tokenizer.eos_token
355
+
356
+ # Response starts after the prefix
357
+ response_start_pos = len(prefix)
358
+ else:
359
+ raise ValueError(f"Unsupported format_type: {format_type}")
360
+
361
+ return {"text": formatted_text, "response_start_pos": response_start_pos}
362
+
363
+
364
+ def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
365
+ """
366
+ Build datasets for instruction fine-tuning.
367
+ """
368
+ data_cfg = cfg["data"]
369
+ train_path = data_cfg["train_jsonl"]
370
+ eval_path = data_cfg.get("eval_jsonl", None)
371
+ split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
372
+ max_length = int(data_cfg.get("max_length", 2048))
373
+ shuffle = bool(data_cfg.get("shuffle", True))
374
+ num_proc = int(data_cfg.get("num_proc", 4))
375
+
376
+ # Ensure tokenizer has pad token
377
+ if tokenizer.pad_token is None:
378
+ tokenizer.pad_token = tokenizer.eos_token
379
+
380
+ # Load datasets
381
+ ds = load_dataset("json", data_files={"train": train_path})
382
+
383
+ if eval_path:
384
+ ds_eval = load_dataset("json", data_files={"eval": eval_path})
385
+ dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
386
+ else:
387
+ if 0.0 < split_ratio < 1.0:
388
+ split = ds["train"].train_test_split(
389
+ test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
390
+ )
391
+ dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
392
+ else:
393
+ dsd = DatasetDict({"train": ds["train"], "eval": None})
394
+
395
+ # Format instructions and track response start positions
396
+ def format_fn(examples):
397
+ formatted_examples = []
398
+ response_start_positions = []
399
+ for i in range(len(examples[list(examples.keys())[0]])):
400
+ example = {k: examples[k][i] for k in examples.keys()}
401
+ formatted = format_instruction(example, cfg, tokenizer)
402
+ formatted_examples.append(formatted["text"])
403
+ response_start_positions.append(formatted["response_start_pos"])
404
+ return {
405
+ "text": formatted_examples,
406
+ "response_start_pos": response_start_positions
407
+ }
408
+
409
+ formatted_train = dsd["train"].map(
410
+ format_fn,
411
+ batched=True,
412
+ num_proc=num_proc,
413
+ remove_columns=dsd["train"].column_names,
414
+ desc="Formatting train instructions",
415
+ )
416
+
417
+ formatted_eval = None
418
+ if dsd["eval"] is not None:
419
+ formatted_eval = dsd["eval"].map(
420
+ format_fn,
421
+ batched=True,
422
+ num_proc=num_proc,
423
+ remove_columns=dsd["eval"].column_names,
424
+ desc="Formatting eval instructions",
425
+ )
426
+
427
+ # Tokenize and apply loss masking
428
+ def tokenize_and_mask_fn(examples):
429
+ tokenized = tokenizer(
430
+ examples["text"],
431
+ truncation=True,
432
+ padding=False,
433
+ max_length=max_length,
434
+ return_overflowing_tokens=False,
435
+ )
436
+
437
+ # Apply loss masking - CRITICAL for SFT
438
+ labels = []
439
+ attention_masks = []
440
+
441
+ for i in range(len(tokenized["input_ids"])):
442
+ input_ids = tokenized["input_ids"][i]
443
+ response_start_pos = examples["response_start_pos"][i]
444
+
445
+ # Get the instruction part (before response)
446
+ full_text = examples["text"][i]
447
+ instruction_text = full_text[:response_start_pos]
448
+
449
+ # Create labels masked by default
450
+ label_ids = [-100] * len(input_ids)
451
+
452
+ # Find where response starts using character-based ratio
453
+ # This is more reliable than tokenizing prefix separately
454
+ # because separate tokenization can add different special tokens
455
+ char_ratio = response_start_pos / max(len(full_text), 1)
456
+ response_start_idx = int(len(input_ids) * char_ratio)
457
+
458
+ # Ensure we have valid bounds (at least position 1, at most len-1)
459
+ response_start_idx = max(1, min(response_start_idx, len(input_ids) - 1))
460
+
461
+ # Unmask response tokens (including EOS)
462
+ for j in range(response_start_idx, len(input_ids)):
463
+ label_ids[j] = input_ids[j]
464
+
465
+ # Create attention mask (1 for real tokens, 0 for padding)
466
+ attention_mask = [1] * len(input_ids)
467
+
468
+ labels.append(label_ids)
469
+ attention_masks.append(attention_mask)
470
+
471
+ tokenized["labels"] = labels
472
+ tokenized["attention_mask"] = attention_masks
473
+ return tokenized
474
+
475
+ tokenized_train = formatted_train.map(
476
+ tokenize_and_mask_fn,
477
+ batched=True,
478
+ num_proc=num_proc,
479
+ desc="Tokenizing and masking train",
480
+ )
481
+
482
+ tokenized_eval = None
483
+ if formatted_eval is not None:
484
+ tokenized_eval = formatted_eval.map(
485
+ tokenize_and_mask_fn,
486
+ batched=True,
487
+ num_proc=num_proc,
488
+ desc="Tokenizing and masking eval",
489
+ )
490
+
491
+ if shuffle:
492
+ tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
493
+
494
+ return tokenized_train, tokenized_eval
495
+
496
+
497
+ # --------------------------
498
+ # Model Loading + PEFT
499
+ # --------------------------
500
+
501
+
502
+ def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
503
+ model_cfg = cfg["model"]
504
+ trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
505
+ use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
506
+ device_map = model_cfg.get("device_map", "auto")
507
+
508
+ tokenizer = AutoTokenizer.from_pretrained(
509
+ str(base_dir),
510
+ use_fast=use_fast,
511
+ trust_remote_code=trust_remote_code,
512
+ )
513
+ if tokenizer.pad_token is None:
514
+ tokenizer.pad_token = tokenizer.eos_token
515
+
516
+ torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
517
+ use_4bit = bool(model_cfg.get("use_4bit", False))
518
+
519
+ quant_cfg = None
520
+ if use_4bit:
521
+ quant_cfg = BitsAndBytesConfig(
522
+ load_in_4bit=True,
523
+ bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
524
+ bnb_4bit_use_double_quant=bool(
525
+ model_cfg.get("bnb_4bit_use_double_quant", True)
526
+ ),
527
+ bnb_4bit_compute_dtype=_dtype_from_str(
528
+ model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
529
+ ),
530
+ )
531
+
532
+ attn_impl = _choose_attn_impl(cfg)
533
+
534
+ try:
535
+ model = AutoModelForCausalLM.from_pretrained(
536
+ str(base_dir),
537
+ device_map=device_map,
538
+ trust_remote_code=trust_remote_code,
539
+ low_cpu_mem_usage=True,
540
+ torch_dtype=(torch_dtype if not use_4bit else None),
541
+ quantization_config=quant_cfg,
542
+ attn_implementation=attn_impl,
543
+ )
544
+ except Exception as e:
545
+ if attn_impl is not None:
546
+ print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
547
+ print("[warn] Falling back to default attention implementation.")
548
+ model = AutoModelForCausalLM.from_pretrained(
549
+ str(base_dir),
550
+ device_map=device_map,
551
+ trust_remote_code=trust_remote_code,
552
+ low_cpu_mem_usage=True,
553
+ torch_dtype=(torch_dtype if not use_4bit else None),
554
+ quantization_config=quant_cfg,
555
+ )
556
+
557
+ return model, tokenizer
558
+
559
+
560
+ def apply_peft(cfg: Dict[str, Any], model):
561
+ peft_cfg = cfg["peft"]
562
+ model_cfg = cfg["model"]
563
+ tr_cfg = cfg["train"]
564
+
565
+ if not bool(peft_cfg.get("enabled", True)):
566
+ return model, None
567
+
568
+ use_4bit = bool(model_cfg.get("use_4bit", False))
569
+ gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
570
+
571
+ if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
572
+ model.gradient_checkpointing_enable()
573
+ if hasattr(model, "config"):
574
+ model.config.use_cache = False
575
+
576
+ if use_4bit:
577
+ model = prepare_model_for_kbit_training(
578
+ model,
579
+ use_gradient_checkpointing=gradient_checkpointing,
580
+ )
581
+
582
+ target_modules = peft_cfg.get("target_modules", "auto")
583
+ if target_modules == "auto":
584
+ target_modules = _infer_target_modules(model)
585
+
586
+ lora_config = LoraConfig(
587
+ r=int(peft_cfg.get("r", 16)),
588
+ lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
589
+ lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
590
+ bias=str(peft_cfg.get("bias", "none")),
591
+ task_type="CAUSAL_LM",
592
+ target_modules=target_modules,
593
+ )
594
+ model = get_peft_model(model, lora_config)
595
+ return model, lora_config
596
+
597
+
598
+ # --------------------------
599
+ # Merge Logic
600
+ # --------------------------
601
+
602
+
603
+ def merge_adapter(
604
+ cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
605
+ ):
606
+ print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
607
+
608
+ model_cfg = cfg["model"]
609
+ merge_cfg = cfg.get("merge", {})
610
+ trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
611
+
612
+ merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
613
+ max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
614
+
615
+ base = AutoModelForCausalLM.from_pretrained(
616
+ str(base_dir),
617
+ torch_dtype=merged_dtype,
618
+ device_map="cpu",
619
+ low_cpu_mem_usage=True,
620
+ trust_remote_code=trust_remote_code,
621
+ )
622
+
623
+ merged = PeftModel.from_pretrained(base, str(adapter_dir))
624
+ merged = merged.merge_and_unload()
625
+
626
+ _ensure_dir(final_dir)
627
+ merged.save_pretrained(
628
+ str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
629
+ )
630
+
631
+ tok = AutoTokenizer.from_pretrained(
632
+ str(base_dir), trust_remote_code=trust_remote_code
633
+ )
634
+ if tok.pad_token is None:
635
+ tok.pad_token = tok.eos_token
636
+ tok.save_pretrained(str(final_dir))
637
+
638
+ print("--- Merge complete ---")
639
+
640
+
641
+ # --------------------------
642
+ # Main
643
+ # --------------------------
644
+
645
+
646
+ def main():
647
+ ap = argparse.ArgumentParser()
648
+ ap.add_argument("--config", required=True, help="Path to YAML config")
649
+ ap.add_argument(
650
+ "--merge-only", action="store_true", help="Skip training, just merge adapter"
651
+ )
652
+ args = ap.parse_args()
653
+
654
+ with open(args.config, "r", encoding="utf-8") as f:
655
+ cfg = yaml.safe_load(f)
656
+
657
+ run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
658
+ _ensure_dir(run_dir / "logs")
659
+
660
+ with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
661
+ yaml.safe_dump(cfg, f, sort_keys=False)
662
+
663
+ model_cfg = cfg["model"]
664
+ repo_id = str(model_cfg["repo_id"]).strip()
665
+ repo_path = Path(repo_id)
666
+
667
+ # ✅ Local model path -> load directly; no download
668
+ if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
669
+ base_dir = repo_path
670
+ print(f"Using local model at: {base_dir}")
671
+ elif repo_path.exists() and repo_path.is_dir():
672
+ raise ValueError(
673
+ f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
674
+ )
675
+ else:
676
+ # HF repo_id -> download into run_dir/base_local_dir
677
+ base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
678
+ if not _looks_like_model_dir(base_dir):
679
+ print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
680
+ snapshot_download(
681
+ repo_id=repo_id,
682
+ revision=model_cfg.get("revision", None),
683
+ local_dir=str(base_dir),
684
+ local_dir_use_symlinks=False,
685
+ )
686
+
687
+ ckpt_dir = _ensure_dir(run_dir / "checkpoints")
688
+ best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
689
+
690
+ merge_cfg = cfg.get("merge", {}) or {}
691
+ if merge_cfg.get("output_dir"):
692
+ od = Path(str(merge_cfg["output_dir"]))
693
+ final_dir = od if od.is_absolute() else (run_dir / od)
694
+ else:
695
+ final_dir = run_dir / "final_model"
696
+
697
+ # Merge-only
698
+ if args.merge_only:
699
+ if not _looks_like_model_dir(best_adapter_dir):
700
+ raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
701
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
702
+ return
703
+
704
+ # Initialize Wandb
705
+ wandb_run = setup_wandb(cfg, run_dir)
706
+
707
+ # Training
708
+ set_seed(int(cfg["run"].get("seed", 42)))
709
+
710
+ model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
711
+ model, _ = apply_peft(cfg, model)
712
+
713
+ train_ds, eval_ds = build_datasets(cfg, tokenizer)
714
+
715
+ tr_cfg = cfg["train"]
716
+
717
+ dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
718
+ use_fp16 = dtype == torch.float16
719
+ use_bf16 = dtype == torch.bfloat16
720
+
721
+ max_steps = int(tr_cfg.get("max_steps", 0))
722
+ num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
723
+
724
+ # --- Dynamic evaluation strategy parameter handling ---
725
+ ta_params = inspect.signature(TrainingArguments.__init__).parameters
726
+ eval_key = (
727
+ "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
728
+ )
729
+
730
+ # Setup reporting based on wandb availability
731
+ report_to = []
732
+ if wandb_run is not None:
733
+ report_to.append("wandb")
734
+
735
+ ta_kwargs = dict(
736
+ output_dir=str(ckpt_dir),
737
+ max_steps=max_steps if max_steps > 0 else -1,
738
+ num_train_epochs=num_train_epochs,
739
+ per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
740
+ per_device_eval_batch_size=int(
741
+ tr_cfg.get(
742
+ "per_device_eval_batch_size",
743
+ tr_cfg.get("per_device_train_batch_size", 1),
744
+ )
745
+ ),
746
+ gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
747
+ learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
748
+ weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
749
+ warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
750
+ lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
751
+ optim=str(
752
+ tr_cfg.get(
753
+ "optim",
754
+ (
755
+ "paged_adamw_8bit"
756
+ if bool(model_cfg.get("use_4bit", False))
757
+ else "adamw_torch"
758
+ ),
759
+ )
760
+ ),
761
+ max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
762
+ logging_steps=int(tr_cfg.get("logging_steps", 10)),
763
+ save_strategy=str(tr_cfg.get("save_strategy", "steps")),
764
+ save_steps=int(tr_cfg.get("save_steps", 200)),
765
+ save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
766
+ eval_steps=int(tr_cfg.get("eval_steps", 200)),
767
+ load_best_model_at_end=(
768
+ bool(tr_cfg.get("load_best_model_at_end", True))
769
+ if eval_ds is not None
770
+ else False
771
+ ),
772
+ metric_for_best_model="eval_loss",
773
+ greater_is_better=False,
774
+ fp16=use_fp16,
775
+ bf16=use_bf16,
776
+ report_to=report_to,
777
+ remove_unused_columns=False,
778
+ )
779
+
780
+ # Set the correct argument name for this transformers version
781
+ ta_kwargs[eval_key] = str(
782
+ tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
783
+ )
784
+
785
+ training_args = TrainingArguments(**ta_kwargs)
786
+
787
+ # Setup callbacks
788
+ callbacks = [JsonlLoggerCallback(run_dir)]
789
+
790
+ # Add early stopping callback if enabled
791
+ early_stopping_cfg = tr_cfg.get("early_stopping", {})
792
+ if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
793
+ early_stopping_callback = EarlyStoppingCallback(
794
+ early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
795
+ early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
796
+ )
797
+ callbacks.append(early_stopping_callback)
798
+ print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, "
799
+ f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}")
800
+
801
+ trainer = Trainer(
802
+ model=model,
803
+ args=training_args,
804
+ train_dataset=train_ds,
805
+ eval_dataset=eval_ds,
806
+ data_collator=default_data_collator,
807
+ callbacks=callbacks,
808
+ )
809
+
810
+ # Resume
811
+ resume_from = tr_cfg.get("resume_from_checkpoint", None)
812
+ if resume_from == "auto":
813
+ last = get_last_checkpoint(str(ckpt_dir))
814
+ resume_from = last if last else None
815
+ if resume_from:
816
+ print(f"Resuming from {resume_from}")
817
+
818
+ print("Starting instruction fine-tuning...")
819
+ trainer.train(resume_from_checkpoint=resume_from)
820
+
821
+ trainer.save_model(str(best_adapter_dir))
822
+ print(f"Saved best adapter -> {best_adapter_dir}")
823
+
824
+ if eval_ds is not None:
825
+ metrics = trainer.evaluate()
826
+ eval_loss = metrics.get("eval_loss", None)
827
+ metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
828
+ with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
829
+ json.dump(metrics, f, indent=2)
830
+ print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
831
+
832
+ if bool(cfg.get("merge", {}).get("enabled", False)):
833
+ del trainer, model
834
+ torch.cuda.empty_cache()
835
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
836
+ else:
837
+ print("Merge disabled. Run with --merge-only later if needed.")
838
+
839
+ # Finish Wandb run
840
+ finish_wandb()
841
+
842
+
843
+ if __name__ == "__main__":
844
+ main()
trainer-kit/SFT/.DS_Store ADDED
Binary file (6.15 kB). View file
 
trainer-kit/SFT/config_instruct.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ run_dir: "./runs/instruct_run_24b"
3
+ seed: 42
4
+
5
+ # WandB integration for experiment tracking
6
+ wandb:
7
+ enabled: true # Set to true to enable wandb logging
8
+ project: "sft-training" # WandB project name
9
+ entity: null # WandB entity/team (optional)
10
+ name: null # Run name (optional, will auto-generate if null)
11
+ tags: ["sft-lora", "24b-Devstral"] # List of tags for the run (e.g., ["lora", "qlora", "experiment-1"])
12
+ notes: null # Run description/notes (optional)
13
+
14
+ model:
15
+ # Use local Qwen2.5-Coder-14B model
16
+ repo_id: "./CPT/runs/cpt_run_v1/merged_24b_cpt_lora"
17
+ revision: null
18
+
19
+ # Used only when repo_id is a HF repo (not a local path)
20
+ base_local_dir: "base_model"
21
+
22
+ trust_remote_code: true
23
+ tokenizer_use_fast: true
24
+ device_map: "auto"
25
+
26
+ torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
27
+
28
+ # QLoRA
29
+ use_4bit: false
30
+ bnb_4bit_quant_type: "nf4"
31
+ bnb_4bit_use_double_quant: false
32
+ bnb_4bit_compute_dtype: "bfloat16"
33
+
34
+ # optional: "flash_attention_2" | "sdpa" | null
35
+ attn_implementation: null
36
+
37
+ data:
38
+ train_jsonl: "../sft_dataset.jsonl"
39
+ eval_jsonl: null
40
+ eval_split_ratio: 0.1
41
+
42
+ # Field names in your JSONL data
43
+ instruction_field: "instruction" # This will be the system prompt
44
+ input_field: "input" # This is the task description
45
+ output_field: "output" # This is the analysis + selection
46
+
47
+ # Formatting options
48
+ format_type: "custom" # "chatml" | "alpaca" | "custom"
49
+
50
+ # For chatml format
51
+ system_prompt: |
52
+ You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
53
+
54
+ ## Output Format
55
+
56
+ ##OUTPUT
57
+ Explain the data flow and why each component must change:
58
+ - Flow: [Input → Processing → Output with arrows]
59
+ - For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
60
+ - Explain coupling between components
61
+
62
+ ##SELECT
63
+ modify::crates/path/to/file.rs::impl::ComponentName
64
+ add::crates/another/file.rs::function::AnotherComponent
65
+ <EOS>
66
+
67
+ ## Rules
68
+
69
+ 1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
70
+ 2. Use `::` for nested items: `status::StructName::Type::Name`
71
+ 3. Always explain "must change because" and "without this"
72
+ 3. Types of components: function, struct, enum, impl, trait
73
+ 4. If there is extra information (e.g., enum variants), include that too.
74
+ 5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
75
+
76
+ ## Example
77
+
78
+ ##TASK
79
+ Add webhook subscription support
80
+
81
+ ##OUTPUT
82
+ The webhook system routes events via EventClass enum. Flow: webhook → EventClass → handler → processing. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must add Subscriptions variant because it defines event routing—without this, subscription events cannot be processed. The SubscriptionStatus impl (crates/common_enums/src/transformers.rs::SubscriptionStatus) must map to EventType because it converts status to events—without this, status changes don't trigger webhooks. These are coupled: EventClass routes to handlers that use SubscriptionStatus mappings.
83
+
84
+ ##SELECT
85
+ crates/common_enums/src/enums.rs::EventClass
86
+ crates/common_enums/src/transformers.rs::SubscriptionStatus
87
+ <EOS>
88
+
89
+ # For custom format (only used when format_type="custom")
90
+ custom_template: "##INSTRUCTION\n{instruction}<|im_end|>\n##TASK\n{input}<|im_end|>\n##OUTPUT\n{output}<|im_end|>"
91
+
92
+ max_length: 2048
93
+ shuffle: true
94
+ num_proc: 4
95
+
96
+ peft:
97
+ enabled: true
98
+ r: 8
99
+ lora_alpha: 16
100
+ lora_dropout: 0.05
101
+ bias: "none"
102
+ target_modules: "auto"
103
+
104
+ train:
105
+ # max_steps: 10
106
+ num_train_epochs: 6
107
+
108
+ per_device_train_batch_size: 1
109
+ per_device_eval_batch_size: 1
110
+ gradient_accumulation_steps: 8
111
+
112
+ learning_rate: 1e-4
113
+ weight_decay: 0.0
114
+ warmup_ratio: 0.08
115
+ lr_scheduler_type: "cosine"
116
+
117
+ optim: "adamw_torch" # ✅ Changed from paged_adamw_8bit (requires use_4bit=true)
118
+ max_grad_norm: 0.8
119
+ gradient_checkpointing: true
120
+
121
+ logging_steps: 2
122
+ save_strategy: "steps"
123
+ save_steps: 500
124
+ save_total_limit: 20
125
+
126
+ evaluation_strategy: "steps"
127
+ eval_steps: 100
128
+ load_best_model_at_end: true
129
+
130
+ # Early stopping
131
+ early_stopping:
132
+ enabled: true
133
+ patience: 3 # Number of evaluations with no improvement before stopping
134
+ min_delta: 0.001 # Minimum change to qualify as improvement
135
+ metric: "eval_loss" # Metric to monitor
136
+ mode: "min" # "min" for loss, "max" for accuracy/etc.
137
+
138
+ resume_from_checkpoint: "auto"
139
+
140
+ merge:
141
+ enabled: true
142
+ merged_dtype: "float16"
143
+ max_shard_size: "2GB"
144
+ output_dir: "./merged_24b_instruct_lora"
trainer-kit/SFT/instruct_data.jsonl ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {"input": "##TASK\nAdd subscription support to webhook event metadata and outgoing webhook content systems to enable proper subscription lifecycle event handling and webhook delivery for subscription-related operations.", "output": "##OUTPUT\nThe webhook system requires subscription event support to handle subscription lifecycle events (creation, payment, cancellation) and deliver them to merchants via outgoing webhooks. This affects the event metadata enumeration and webhook content transformation layers. The complete flow is: subscription event occurs → EventMetadata::Subscription variant created with subscription_id, invoice_id, payment_id → stored in database → retrieved for webhook delivery → OutgoingWebhookEventContent::foreign_from converts EventMetadata to OutgoingWebhookEventContent::Subscription → StripeWebhookObject::From transforms to StripeWebhookObject::Subscriptions for Stripe compatibility. The OutgoingWebhookEventContent::foreign_from function (router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from) must add the Subscription variant match arm because it's the bridge from storage EventMetadata to webhook content representation—without this, subscription events cannot be converted to webhook content and would be silently dropped during webhook processing. The StripeWebhookObject::From implementation (router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From) must handle OutgoingWebhookContent::SubscriptionDetails because it provides Stripe-compatible webhook formatting—without this, merchants using Stripe webhook endpoints would receive malformed or missing subscription events. The EventMetadata enum (implicitly referenced in foreign_from) requires the Subscription variant to exist in the diesel_models layer, which enables persistent storage of subscription event references. These components form a tightly coupled webhook pipeline where each depends on the previous: EventMetadata → OutgoingWebhookEventContent → StripeWebhookObject. If foreign_from lacks the Subscription match arm, subscription events stored as EventMetadata::Subscription would cause panic during webhook delivery. If StripeWebhookObject::From lacks SubscriptionDetails handling, Stripe webhook endpoints would receive unparseable responses, breaking merchant integrations that rely on Stripe-compatible subscription webhooks for revenue recognition and customer management systems.\n\n##SELECT\ncrates/router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from\ncrates/router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From\n<EOS>"}
2
+ {"input": "##TASK\nUpdate subscription response mapping to use common_enums::SubscriptionStatus instead of subscription_types::SubscriptionStatus for consistent enum usage across the codebase.", "output": "##OUTPUT\nThe subscription system needs to standardize on a single SubscriptionStatus enum type to avoid type inconsistencies and import conflicts. The current flow is: subscription handler creates SubscriptionResponse → calls to_subscription_response method → converts string status to subscription_types::SubscriptionStatus → returns response. The to_subscription_response method (crates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response) must change because it's the conversion point where the status enum is instantiated. Using common_enums::SubscriptionStatus ensures consistency with other parts of the system that already use this common enum type. Without this change, the subscription module would maintain its own duplicate enum, creating potential mismatches and requiring conversions throughout the codebase when interacting with other modules that expect the common enum type.\n\n##SELECT\ncrates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response\n<EOS>"}
3
+ {"input": "##TASK\nAdd subscription support to webhook event system by including Subscriptions in EventClass enum and implementing SubscriptionStatus to EventType mapping for proper webhook routing.", "output": "##OUTPUT\nThe webhook system needs to handle subscription-related events for proper event classification and routing. The current flow is: webhook event received → EventClass determines routing category → EventType specifies the exact event type → handlers process based on event type. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must include Subscriptions because it's the primary classification that determines which subsystem handles the event. The SubscriptionStatus transformer (crates/common_enums/src/transformers.rs::impl::Option<EventType>) must map subscription status changes to appropriate EventType values because it converts domain state changes into webhook events. Without EventClass::Subscriptions, subscription events would be misclassified or unhandled. Without the SubscriptionStatus mapping, subscription status changes wouldn't trigger appropriate webhook notifications, breaking integration with external systems that depend on subscription lifecycle events.\n\n##SELECT\ncrates/common_enums/src/enums.rs::EventClass\ncrates/common_enums/src/transformers.rs::impl::Option<EventType>\n<EOS>"}
4
+ {"input": "##TASK\nAdd subscription ID support to webhook object reference system for proper webhook payload identification and routing.", "output": "##OUTPUT\nThe webhook system needs to reference subscription entities in event payloads for proper event correlation and processing. The current flow is: webhook event generated → ObjectReferenceId identifies the affected entity → webhook payload includes reference → consumers process based on entity type. The ObjectReferenceId enum (crates/api_models/src/webhooks.rs::ObjectReferenceId) must include SubscriptionId because it's the type-safe identifier used throughout the webhook payload structure to specify which subscription triggered the event. Without SubscriptionId, webhook events related to subscriptions couldn't properly reference the subscription entity, making it impossible for consumers to correlate events with specific subscriptions. This would break webhook consumers that need to update their local state or trigger business logic based on subscription events.\n\n##SELECT\ncrates/api_models/src/webhooks.rs::ObjectReferenceId\n<EOS>"}
trainer-kit/SFT/requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ torch>=2.1.0
3
+ transformers>=4.41.0
4
+ datasets>=2.18.0
5
+ accelerate>=0.30.0
6
+
7
+ # PEFT / QLoRA
8
+ peft>=0.11.1
9
+ bitsandbytes>=0.43.1
10
+
11
+ # Hugging Face Hub (local + download support)
12
+ huggingface_hub>=0.23.0
13
+
14
+ # Config + utilities
15
+ pyyaml>=6.0
16
+ tqdm>=4.66.0
17
+
18
+ # Optional but recommended (tokenizers speed)
19
+ tokenizers>=0.15.0
20
+ safetensors>=0.4.2
21
+
22
+ # Experiment tracking
23
+ wandb>=0.16.0
trainer-kit/SFT/run_instruct.py ADDED
@@ -0,0 +1,921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import inspect # Added for Transformers version compatibility
4
+ import math
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple, List
8
+
9
+ import torch
10
+ import yaml
11
+ from datasets import load_dataset, DatasetDict
12
+ from huggingface_hub import snapshot_download
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForCausalLM,
16
+ AutoModel,
17
+ AutoConfig,
18
+ BitsAndBytesConfig,
19
+ TrainingArguments,
20
+ Trainer,
21
+ TrainerCallback,
22
+ EarlyStoppingCallback,
23
+ default_data_collator,
24
+ set_seed,
25
+ )
26
+ from transformers.trainer_utils import get_last_checkpoint
27
+ from peft import (
28
+ LoraConfig,
29
+ get_peft_model,
30
+ prepare_model_for_kbit_training,
31
+ PeftModel,
32
+ )
33
+
34
+ try:
35
+ import wandb
36
+ WANDB_AVAILABLE = True
37
+ except ImportError:
38
+ WANDB_AVAILABLE = False
39
+ wandb = None
40
+
41
+
42
+ # --------------------------
43
+ # Helpers
44
+ # --------------------------
45
+
46
+
47
+ def _dtype_from_str(s: str) -> torch.dtype:
48
+ s = (s or "").lower()
49
+ if s in ("float16", "fp16"):
50
+ return torch.float16
51
+ if s in ("bfloat16", "bf16"):
52
+ return torch.bfloat16
53
+ if s in ("float32", "fp32"):
54
+ return torch.float32
55
+ raise ValueError(f"Unknown torch_dtype: {s}")
56
+
57
+
58
+ def _now_iso() -> str:
59
+ return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
60
+
61
+
62
+ def _safe_exp(x: float) -> float:
63
+ x = min(float(x), 50.0)
64
+ return float(math.exp(x))
65
+
66
+
67
+ def _ensure_dir(p: Path) -> Path:
68
+ p.mkdir(parents=True, exist_ok=True)
69
+ return p
70
+
71
+
72
+ def _looks_like_model_dir(p: Path) -> bool:
73
+ if not p.exists() or not p.is_dir():
74
+ return False
75
+ if (p / "config.json").exists():
76
+ return True
77
+ if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
78
+ return True
79
+ return False
80
+
81
+
82
+ def _infer_target_modules(model) -> List[str]:
83
+ names = set()
84
+ for n, _ in model.named_modules():
85
+ names.add(n.split(".")[-1])
86
+
87
+ for group in [
88
+ ["q_proj", "k_proj", "v_proj", "o_proj"],
89
+ ["Wqkv", "out_proj"],
90
+ ["query_key_value", "dense"],
91
+ ["c_attn", "c_proj"],
92
+ ]:
93
+ if all(x in names for x in group):
94
+ return group
95
+
96
+ fallback = [
97
+ x
98
+ for x in [
99
+ "q_proj",
100
+ "k_proj",
101
+ "v_proj",
102
+ "o_proj",
103
+ "c_attn",
104
+ "c_proj",
105
+ "out_proj",
106
+ "dense",
107
+ ]
108
+ if x in names
109
+ ]
110
+ if fallback:
111
+ return fallback
112
+
113
+ raise ValueError(
114
+ "Could not auto-infer target_modules. Set peft.target_modules explicitly."
115
+ )
116
+
117
+
118
+ def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
119
+ return cfg.get("model", {}).get("attn_implementation", None)
120
+
121
+
122
+ # --------------------------
123
+ # Wandb Integration
124
+ # --------------------------
125
+
126
+ def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
127
+ """Initialize Wandb if enabled in configuration."""
128
+ wandb_cfg = cfg.get("wandb", {})
129
+
130
+ if not wandb_cfg.get("enabled", False):
131
+ print("Wandb logging disabled")
132
+ return None
133
+
134
+ if not WANDB_AVAILABLE:
135
+ print("Wandb not available. Install with: pip install wandb")
136
+ return None
137
+
138
+ # Extract wandb configuration
139
+ project = wandb_cfg.get("project", "sft-training")
140
+ entity = wandb_cfg.get("entity", None)
141
+ name = wandb_cfg.get("name", None)
142
+ tags = wandb_cfg.get("tags", [])
143
+ notes = wandb_cfg.get("notes", None)
144
+
145
+ # Initialize wandb
146
+ try:
147
+ wandb.init(
148
+ project=project,
149
+ entity=entity,
150
+ name=name,
151
+ tags=tags,
152
+ notes=notes,
153
+ dir=str(run_dir),
154
+ config={
155
+ "model": cfg.get("model", {}),
156
+ "data": cfg.get("data", {}),
157
+ "peft": cfg.get("peft", {}),
158
+ "train": cfg.get("train", {}),
159
+ "run_dir": str(run_dir),
160
+ }
161
+ )
162
+ print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
163
+ return wandb
164
+ except Exception as e:
165
+ print(f"Failed to initialize Wandb: {e}")
166
+ return None
167
+
168
+
169
+ def finish_wandb():
170
+ """Finish Wandb run if active."""
171
+ if WANDB_AVAILABLE and wandb.run is not None:
172
+ wandb.finish()
173
+ print("Wandb run finished")
174
+
175
+
176
+ # --------------------------
177
+ # JSONL Logger Callback
178
+ # --------------------------
179
+
180
+
181
+ class JsonlLoggerCallback(TrainerCallback):
182
+ def __init__(self, run_dir: Path):
183
+ self.run_dir = run_dir
184
+ self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
185
+ self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
186
+ self.start_time = None
187
+
188
+ def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
189
+ if self.start_time is None or global_step <= 0 or max_steps <= 0:
190
+ return None
191
+ elapsed = time.time() - self.start_time
192
+ sec_per_step = elapsed / global_step
193
+ remaining = max(0, max_steps - global_step) * sec_per_step
194
+ h = int(remaining // 3600)
195
+ m = int((remaining % 3600) // 60)
196
+ s = int(remaining % 60)
197
+ return f"{h:02d}:{m:02d}:{s:02d}"
198
+
199
+ def on_train_begin(self, args, state, control, **kwargs):
200
+ self.start_time = time.time()
201
+
202
+ def on_log(self, args, state, control, logs=None, **kwargs):
203
+ if not logs:
204
+ return
205
+
206
+ max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
207
+ progress_pct = (
208
+ (100.0 * state.global_step / max_steps) if max_steps > 0 else None
209
+ )
210
+ epoch_pct = None
211
+ if (
212
+ state.epoch is not None
213
+ and args.num_train_epochs
214
+ and args.num_train_epochs > 0
215
+ ):
216
+ epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
217
+
218
+ payload = {
219
+ "ts": _now_iso(),
220
+ "event": "train_log",
221
+ "step": int(state.global_step),
222
+ "epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
223
+ "progress_pct": (
224
+ round(progress_pct, 2) if progress_pct is not None else None
225
+ ),
226
+ "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
227
+ "eta": self._eta(int(state.global_step), max_steps),
228
+ "max_grad_norm": getattr(args, "max_grad_norm", None),
229
+ **logs,
230
+ }
231
+
232
+ with self.train_log_path.open("a", encoding="utf-8") as f:
233
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
234
+
235
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
236
+ if not metrics:
237
+ return
238
+ eval_loss = metrics.get("eval_loss", None)
239
+ ppl = _safe_exp(eval_loss) if eval_loss is not None else None
240
+
241
+ payload = {
242
+ "ts": _now_iso(),
243
+ "event": "eval",
244
+ "step": int(state.global_step),
245
+ "epoch": float(state.epoch) if state.epoch is not None else None,
246
+ **metrics,
247
+ "perplexity": ppl,
248
+ }
249
+ with self.eval_log_path.open("a", encoding="utf-8") as f:
250
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
251
+
252
+
253
+ # --------------------------
254
+ # Data Pipeline (Instruction Formatting)
255
+ # --------------------------
256
+
257
+
258
+ def format_instruction(
259
+ example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
260
+ ) -> Dict[str, Any]:
261
+ """
262
+ Format instruction data for training.
263
+ Supports multiple formats: chatml, alpaca, custom templates.
264
+ Returns both formatted text and the response start position for loss masking.
265
+ """
266
+ data_cfg = cfg["data"]
267
+ format_type = data_cfg.get("format_type", "chatml")
268
+
269
+ # Get field names from config
270
+ input_field = data_cfg.get("input_field", "input")
271
+ output_field = data_cfg.get("output_field", "output")
272
+ instruction_field = data_cfg.get("instruction_field", "instruction")
273
+
274
+ # Extract text from example
275
+ instruction = example.get(instruction_field, "")
276
+ input_text = example.get(input_field, "")
277
+ output_text = example.get(output_field, "")
278
+
279
+ if format_type == "chatml":
280
+ # ChatML format with special tokens
281
+ system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
282
+
283
+ messages = []
284
+ if system_prompt:
285
+ messages.append({"role": "system", "content": system_prompt})
286
+
287
+ user_content = instruction
288
+ if input_text:
289
+ user_content = f"{instruction}\n\n{input_text}"
290
+ messages.append({"role": "user", "content": user_content})
291
+ messages.append({"role": "assistant", "content": output_text})
292
+
293
+ # Apply chat template
294
+ formatted_text = tokenizer.apply_chat_template(
295
+ messages, tokenize=False, add_generation_prompt=False
296
+ )
297
+
298
+ # Add EOS token if not present
299
+ if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
300
+ formatted_text += tokenizer.eos_token
301
+
302
+ # Find where the assistant response starts for loss masking
303
+ # Try multiple possible markers for robustness
304
+ markers = ["<|im_start|>assistant", "<|assistant|>", "Assistant:", "assistant\n"]
305
+ response_start_pos = -1
306
+
307
+ for marker in markers:
308
+ idx = formatted_text.find(marker)
309
+ if idx != -1:
310
+ # Find the newline after the marker
311
+ newline_idx = formatted_text.find("\n", idx)
312
+ if newline_idx != -1:
313
+ response_start_pos = newline_idx + 1
314
+ break
315
+
316
+ # Fallback: find where the actual output starts
317
+ if response_start_pos == -1:
318
+ output_idx = formatted_text.find(output_text)
319
+ if output_idx != -1:
320
+ response_start_pos = output_idx
321
+ else:
322
+ # Last resort: split at last occurrence of newline before end
323
+ response_start_pos = formatted_text.rfind("\n", 0, len(formatted_text) - len(output_text)) + 1
324
+
325
+ elif format_type == "alpaca":
326
+ # Alpaca format
327
+ if input_text:
328
+ prefix = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
329
+ else:
330
+ prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
331
+
332
+ formatted_text = prefix + output_text
333
+
334
+ # Add EOS token
335
+ if tokenizer.eos_token:
336
+ formatted_text += tokenizer.eos_token
337
+
338
+ # Response starts after the prefix
339
+ response_start_pos = len(prefix)
340
+
341
+ elif format_type == "custom":
342
+ # Custom template from config
343
+ template = data_cfg.get("custom_template", "{instruction}\n{input}\n{output}")
344
+
345
+ # For custom format, use system_prompt as instruction if instruction field is empty
346
+ if not instruction:
347
+ instruction = data_cfg.get("system_prompt", "")
348
+
349
+ # For custom templates, we need to find where {output} starts
350
+ template_parts = template.split("{output}")
351
+ prefix = template_parts[0].format(instruction=instruction, input=input_text)
352
+ formatted_text = prefix + output_text
353
+
354
+ # Add EOS token if not already in template
355
+ if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
356
+ formatted_text += tokenizer.eos_token
357
+
358
+ # Response starts after the prefix
359
+ response_start_pos = len(prefix)
360
+ else:
361
+ raise ValueError(f"Unsupported format_type: {format_type}")
362
+
363
+ return {"text": formatted_text, "response_start_pos": response_start_pos}
364
+
365
+
366
+ def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
367
+ """
368
+ Build datasets for instruction fine-tuning.
369
+ """
370
+ data_cfg = cfg["data"]
371
+ train_path = data_cfg["train_jsonl"]
372
+ eval_path = data_cfg.get("eval_jsonl", None)
373
+ split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
374
+ max_length = int(data_cfg.get("max_length", 2048))
375
+ shuffle = bool(data_cfg.get("shuffle", True))
376
+ num_proc = int(data_cfg.get("num_proc", 4))
377
+
378
+ # Ensure tokenizer has pad token
379
+ if tokenizer.pad_token is None:
380
+ tokenizer.pad_token = tokenizer.eos_token
381
+
382
+ # Load datasets
383
+ ds = load_dataset("json", data_files={"train": train_path})
384
+
385
+ if eval_path:
386
+ ds_eval = load_dataset("json", data_files={"eval": eval_path})
387
+ dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
388
+ else:
389
+ if 0.0 < split_ratio < 1.0:
390
+ split = ds["train"].train_test_split(
391
+ test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
392
+ )
393
+ dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
394
+ else:
395
+ dsd = DatasetDict({"train": ds["train"], "eval": None})
396
+
397
+ # Format instructions and track response start positions
398
+ def format_fn(examples):
399
+ formatted_examples = []
400
+ response_start_positions = []
401
+ for i in range(len(examples[list(examples.keys())[0]])):
402
+ example = {k: examples[k][i] for k in examples.keys()}
403
+ formatted = format_instruction(example, cfg, tokenizer)
404
+ formatted_examples.append(formatted["text"])
405
+ response_start_positions.append(formatted["response_start_pos"])
406
+ return {
407
+ "text": formatted_examples,
408
+ "response_start_pos": response_start_positions
409
+ }
410
+
411
+ formatted_train = dsd["train"].map(
412
+ format_fn,
413
+ batched=True,
414
+ num_proc=num_proc,
415
+ remove_columns=dsd["train"].column_names,
416
+ desc="Formatting train instructions",
417
+ )
418
+
419
+ formatted_eval = None
420
+ if dsd["eval"] is not None:
421
+ formatted_eval = dsd["eval"].map(
422
+ format_fn,
423
+ batched=True,
424
+ num_proc=num_proc,
425
+ remove_columns=dsd["eval"].column_names,
426
+ desc="Formatting eval instructions",
427
+ )
428
+
429
+ # Tokenize and apply loss masking
430
+ def tokenize_and_mask_fn(examples):
431
+ tokenized = tokenizer(
432
+ examples["text"],
433
+ truncation=True,
434
+ padding=False,
435
+ max_length=max_length,
436
+ return_overflowing_tokens=False,
437
+ )
438
+
439
+ # Apply loss masking - CRITICAL for SFT
440
+ labels = []
441
+ attention_masks = []
442
+
443
+ for i in range(len(tokenized["input_ids"])):
444
+ input_ids = tokenized["input_ids"][i]
445
+ response_start_pos = examples["response_start_pos"][i]
446
+
447
+ # Get the instruction part (before response)
448
+ full_text = examples["text"][i]
449
+ instruction_text = full_text[:response_start_pos]
450
+
451
+ # Create labels masked by default
452
+ label_ids = [-100] * len(input_ids)
453
+
454
+ # Find where response starts using character-based ratio
455
+ # This is more reliable than tokenizing prefix separately
456
+ # because separate tokenization can add different special tokens
457
+ char_ratio = response_start_pos / max(len(full_text), 1)
458
+ response_start_idx = int(len(input_ids) * char_ratio)
459
+
460
+ # Ensure we have valid bounds (at least position 1, at most len-1)
461
+ response_start_idx = max(1, min(response_start_idx, len(input_ids) - 1))
462
+
463
+ # Unmask response tokens (including EOS)
464
+ for j in range(response_start_idx, len(input_ids)):
465
+ label_ids[j] = input_ids[j]
466
+
467
+ # Create attention mask (1 for real tokens, 0 for padding)
468
+ attention_mask = [1] * len(input_ids)
469
+
470
+ labels.append(label_ids)
471
+ attention_masks.append(attention_mask)
472
+
473
+ tokenized["labels"] = labels
474
+ tokenized["attention_mask"] = attention_masks
475
+ return tokenized
476
+
477
+ tokenized_train = formatted_train.map(
478
+ tokenize_and_mask_fn,
479
+ batched=True,
480
+ num_proc=num_proc,
481
+ desc="Tokenizing and masking train",
482
+ )
483
+
484
+ tokenized_eval = None
485
+ if formatted_eval is not None:
486
+ tokenized_eval = formatted_eval.map(
487
+ tokenize_and_mask_fn,
488
+ batched=True,
489
+ num_proc=num_proc,
490
+ desc="Tokenizing and masking eval",
491
+ )
492
+
493
+ if shuffle:
494
+ tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
495
+
496
+ return tokenized_train, tokenized_eval
497
+
498
+
499
+ # --------------------------
500
+ # Model Loading + PEFT
501
+ # --------------------------
502
+
503
+
504
+ def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
505
+ model_cfg = cfg["model"]
506
+ trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
507
+ use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
508
+ device_map = model_cfg.get("device_map", "auto")
509
+
510
+ tokenizer = AutoTokenizer.from_pretrained(
511
+ str(base_dir),
512
+ use_fast=use_fast,
513
+ trust_remote_code=trust_remote_code,
514
+ )
515
+ if tokenizer.pad_token is None:
516
+ tokenizer.pad_token = tokenizer.eos_token
517
+
518
+ torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
519
+ use_4bit = bool(model_cfg.get("use_4bit", False))
520
+
521
+ quant_cfg = None
522
+ if use_4bit:
523
+ quant_cfg = BitsAndBytesConfig(
524
+ load_in_4bit=True,
525
+ bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
526
+ bnb_4bit_use_double_quant=bool(
527
+ model_cfg.get("bnb_4bit_use_double_quant", True)
528
+ ),
529
+ bnb_4bit_compute_dtype=_dtype_from_str(
530
+ model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
531
+ ),
532
+ )
533
+
534
+ attn_impl = _choose_attn_impl(cfg)
535
+
536
+ # First check the model type to determine loading strategy
537
+ try:
538
+ config = AutoConfig.from_pretrained(str(base_dir), trust_remote_code=True)
539
+ model_type = config.model_type
540
+ architectures = getattr(config, 'architectures', [])
541
+
542
+ # Handle Mistral3 (multimodal) models
543
+ if model_type == "mistral3" or (architectures and "Mistral3" in architectures[0]):
544
+ print(f"[info] Detected Mistral3 model architecture, loading with specific class")
545
+ from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
546
+
547
+ try:
548
+ model = Mistral3ForConditionalGeneration.from_pretrained(
549
+ str(base_dir),
550
+ config=config,
551
+ device_map=device_map,
552
+ low_cpu_mem_usage=True,
553
+ torch_dtype=(torch_dtype if not use_4bit else None),
554
+ quantization_config=quant_cfg,
555
+ attn_implementation=attn_impl,
556
+ )
557
+ except Exception as e:
558
+ if attn_impl is not None:
559
+ print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
560
+ print("[warn] Falling back to default attention implementation.")
561
+ model = Mistral3ForConditionalGeneration.from_pretrained(
562
+ str(base_dir),
563
+ config=config,
564
+ device_map=device_map,
565
+ low_cpu_mem_usage=True,
566
+ torch_dtype=(torch_dtype if not use_4bit else None),
567
+ quantization_config=quant_cfg,
568
+ )
569
+ else:
570
+ raise e
571
+ else:
572
+ # Standard AutoModelForCausalLM loading for other models
573
+ try:
574
+ model = AutoModelForCausalLM.from_pretrained(
575
+ str(base_dir),
576
+ device_map=device_map,
577
+ trust_remote_code=True,
578
+ low_cpu_mem_usage=True,
579
+ torch_dtype=(torch_dtype if not use_4bit else None),
580
+ quantization_config=quant_cfg,
581
+ attn_implementation=attn_impl,
582
+ )
583
+ except Exception as e:
584
+ if attn_impl is not None:
585
+ print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
586
+ print("[warn] Falling back to default attention implementation.")
587
+ model = AutoModelForCausalLM.from_pretrained(
588
+ str(base_dir),
589
+ device_map=device_map,
590
+ trust_remote_code=True,
591
+ low_cpu_mem_usage=True,
592
+ torch_dtype=(torch_dtype if not use_4bit else None),
593
+ quantization_config=quant_cfg,
594
+ )
595
+ else:
596
+ raise e
597
+ except Exception as e:
598
+ print(f"[error] Failed to load model: {e}")
599
+ raise e
600
+
601
+ # Ensure all parameters are off meta device
602
+ print("[info] Ensuring all parameters are materialized...")
603
+ meta_params = []
604
+ for name, param in model.named_parameters():
605
+ if param.device.type == 'meta':
606
+ meta_params.append(name)
607
+
608
+ if meta_params:
609
+ print(f"[warn] Found {len(meta_params)} parameters on meta device")
610
+ # For multimodal models, freeze vision components if doing text-only training
611
+ if hasattr(model, 'vision_tower'):
612
+ print("[info] Freezing vision tower for text-only training")
613
+ for param in model.vision_tower.parameters():
614
+ param.requires_grad = False
615
+
616
+ return model, tokenizer
617
+
618
+
619
+ def apply_peft(cfg: Dict[str, Any], model):
620
+ peft_cfg = cfg["peft"]
621
+ model_cfg = cfg["model"]
622
+ tr_cfg = cfg["train"]
623
+
624
+ if not bool(peft_cfg.get("enabled", True)):
625
+ return model, None
626
+
627
+ use_4bit = bool(model_cfg.get("use_4bit", False))
628
+ gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
629
+
630
+ # For multimodal models, ensure vision tower doesn't use gradient checkpointing
631
+ if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
632
+ if hasattr(model, 'vision_tower'):
633
+ print("[info] Disabling gradient checkpointing for vision tower")
634
+ # Only enable gradient checkpointing on language model
635
+ if hasattr(model, 'language_model'):
636
+ model.language_model.gradient_checkpointing_enable()
637
+ elif hasattr(model, 'lm_head'):
638
+ model.gradient_checkpointing_enable()
639
+ else:
640
+ model.gradient_checkpointing_enable()
641
+
642
+ if hasattr(model, "config"):
643
+ model.config.use_cache = False
644
+
645
+ if use_4bit:
646
+ model = prepare_model_for_kbit_training(
647
+ model,
648
+ use_gradient_checkpointing=gradient_checkpointing,
649
+ )
650
+
651
+ target_modules = peft_cfg.get("target_modules", "auto")
652
+ if target_modules == "auto":
653
+ target_modules = _infer_target_modules(model)
654
+
655
+ # For multimodal models, ensure we only target language model modules
656
+ if hasattr(model, 'vision_tower') and isinstance(target_modules, list):
657
+ print(f"[info] Filtering target modules to exclude vision tower")
658
+ # Filter out any vision tower modules
659
+ target_modules = [m for m in target_modules if 'vision' not in m.lower()]
660
+ print(f"[info] LoRA target modules: {target_modules}")
661
+
662
+ lora_config = LoraConfig(
663
+ r=int(peft_cfg.get("r", 16)),
664
+ lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
665
+ lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
666
+ bias=str(peft_cfg.get("bias", "none")),
667
+ task_type="CAUSAL_LM",
668
+ target_modules=target_modules,
669
+ modules_to_save=None, # Don't update any additional modules
670
+ )
671
+ model = get_peft_model(model, lora_config)
672
+ return model, lora_config
673
+
674
+
675
+ # --------------------------
676
+ # Merge Logic
677
+ # --------------------------
678
+
679
+
680
+ def merge_adapter(
681
+ cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
682
+ ):
683
+ print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
684
+
685
+ model_cfg = cfg["model"]
686
+ merge_cfg = cfg.get("merge", {})
687
+ trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
688
+
689
+ merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
690
+ max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
691
+
692
+ base = AutoModelForCausalLM.from_pretrained(
693
+ str(base_dir),
694
+ torch_dtype=merged_dtype,
695
+ device_map="cpu",
696
+ low_cpu_mem_usage=True,
697
+ trust_remote_code=trust_remote_code,
698
+ )
699
+
700
+ merged = PeftModel.from_pretrained(base, str(adapter_dir))
701
+ merged = merged.merge_and_unload()
702
+
703
+ _ensure_dir(final_dir)
704
+ merged.save_pretrained(
705
+ str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
706
+ )
707
+
708
+ tok = AutoTokenizer.from_pretrained(
709
+ str(base_dir), trust_remote_code=trust_remote_code
710
+ )
711
+ if tok.pad_token is None:
712
+ tok.pad_token = tok.eos_token
713
+ tok.save_pretrained(str(final_dir))
714
+
715
+ print("--- Merge complete ---")
716
+
717
+
718
+ # --------------------------
719
+ # Main
720
+ # --------------------------
721
+
722
+
723
+ def main():
724
+ ap = argparse.ArgumentParser()
725
+ ap.add_argument("--config", required=True, help="Path to YAML config")
726
+ ap.add_argument(
727
+ "--merge-only", action="store_true", help="Skip training, just merge adapter"
728
+ )
729
+ args = ap.parse_args()
730
+
731
+ with open(args.config, "r", encoding="utf-8") as f:
732
+ cfg = yaml.safe_load(f)
733
+
734
+ run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
735
+ _ensure_dir(run_dir / "logs")
736
+
737
+ with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
738
+ yaml.safe_dump(cfg, f, sort_keys=False)
739
+
740
+ model_cfg = cfg["model"]
741
+ repo_id = str(model_cfg["repo_id"]).strip()
742
+ repo_path = Path(repo_id)
743
+
744
+ # ✅ Local model path -> load directly; no download
745
+ if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
746
+ base_dir = repo_path
747
+ print(f"Using local model at: {base_dir}")
748
+ elif repo_path.exists() and repo_path.is_dir():
749
+ raise ValueError(
750
+ f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
751
+ )
752
+ else:
753
+ # HF repo_id -> download into run_dir/base_local_dir
754
+ base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
755
+ if not _looks_like_model_dir(base_dir):
756
+ print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
757
+ snapshot_download(
758
+ repo_id=repo_id,
759
+ revision=model_cfg.get("revision", None),
760
+ local_dir=str(base_dir),
761
+ local_dir_use_symlinks=False,
762
+ )
763
+
764
+ ckpt_dir = _ensure_dir(run_dir / "checkpoints")
765
+ best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
766
+
767
+ merge_cfg = cfg.get("merge", {}) or {}
768
+ if merge_cfg.get("output_dir"):
769
+ od = Path(str(merge_cfg["output_dir"]))
770
+ final_dir = od if od.is_absolute() else (run_dir / od)
771
+ else:
772
+ final_dir = run_dir / "final_model"
773
+
774
+ # Merge-only
775
+ if args.merge_only:
776
+ if not _looks_like_model_dir(best_adapter_dir):
777
+ raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
778
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
779
+ return
780
+
781
+ # Initialize Wandb
782
+ wandb_run = setup_wandb(cfg, run_dir)
783
+
784
+ # Training
785
+ set_seed(int(cfg["run"].get("seed", 42)))
786
+
787
+ model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
788
+ model, _ = apply_peft(cfg, model)
789
+
790
+ train_ds, eval_ds = build_datasets(cfg, tokenizer)
791
+
792
+ tr_cfg = cfg["train"]
793
+
794
+ dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
795
+ use_fp16 = dtype == torch.float16
796
+ use_bf16 = dtype == torch.bfloat16
797
+
798
+ max_steps = int(tr_cfg.get("max_steps", 0))
799
+ num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
800
+
801
+ # --- Dynamic evaluation strategy parameter handling ---
802
+ ta_params = inspect.signature(TrainingArguments.__init__).parameters
803
+ eval_key = (
804
+ "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
805
+ )
806
+
807
+ # Setup reporting based on wandb availability
808
+ report_to = []
809
+ if wandb_run is not None:
810
+ report_to.append("wandb")
811
+
812
+ ta_kwargs = dict(
813
+ output_dir=str(ckpt_dir),
814
+ max_steps=max_steps if max_steps > 0 else -1,
815
+ num_train_epochs=num_train_epochs,
816
+ per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
817
+ per_device_eval_batch_size=int(
818
+ tr_cfg.get(
819
+ "per_device_eval_batch_size",
820
+ tr_cfg.get("per_device_train_batch_size", 1),
821
+ )
822
+ ),
823
+ gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
824
+ learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
825
+ weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
826
+ warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
827
+ lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
828
+ optim=str(
829
+ tr_cfg.get(
830
+ "optim",
831
+ (
832
+ "paged_adamw_8bit"
833
+ if bool(model_cfg.get("use_4bit", False))
834
+ else "adamw_torch"
835
+ ),
836
+ )
837
+ ),
838
+ max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
839
+ logging_steps=int(tr_cfg.get("logging_steps", 10)),
840
+ save_strategy=str(tr_cfg.get("save_strategy", "steps")),
841
+ save_steps=int(tr_cfg.get("save_steps", 200)),
842
+ save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
843
+ eval_steps=int(tr_cfg.get("eval_steps", 200)),
844
+ load_best_model_at_end=(
845
+ bool(tr_cfg.get("load_best_model_at_end", True))
846
+ if eval_ds is not None
847
+ else False
848
+ ),
849
+ metric_for_best_model="eval_loss",
850
+ greater_is_better=False,
851
+ fp16=use_fp16,
852
+ bf16=use_bf16,
853
+ report_to=report_to,
854
+ remove_unused_columns=False,
855
+ )
856
+
857
+ # Set the correct argument name for this transformers version
858
+ ta_kwargs[eval_key] = str(
859
+ tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
860
+ )
861
+
862
+ training_args = TrainingArguments(**ta_kwargs)
863
+
864
+ # Setup callbacks
865
+ callbacks = [JsonlLoggerCallback(run_dir)]
866
+
867
+ # Add early stopping callback if enabled
868
+ early_stopping_cfg = tr_cfg.get("early_stopping", {})
869
+ if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
870
+ early_stopping_callback = EarlyStoppingCallback(
871
+ early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
872
+ early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
873
+ )
874
+ callbacks.append(early_stopping_callback)
875
+ print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, "
876
+ f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}")
877
+
878
+ trainer = Trainer(
879
+ model=model,
880
+ args=training_args,
881
+ train_dataset=train_ds,
882
+ eval_dataset=eval_ds,
883
+ data_collator=default_data_collator,
884
+ callbacks=callbacks,
885
+ )
886
+
887
+ # Resume
888
+ resume_from = tr_cfg.get("resume_from_checkpoint", None)
889
+ if resume_from == "auto":
890
+ last = get_last_checkpoint(str(ckpt_dir))
891
+ resume_from = last if last else None
892
+ if resume_from:
893
+ print(f"Resuming from {resume_from}")
894
+
895
+ print("Starting instruction fine-tuning...")
896
+ trainer.train(resume_from_checkpoint=resume_from)
897
+
898
+ trainer.save_model(str(best_adapter_dir))
899
+ print(f"Saved best adapter -> {best_adapter_dir}")
900
+
901
+ if eval_ds is not None:
902
+ metrics = trainer.evaluate()
903
+ eval_loss = metrics.get("eval_loss", None)
904
+ metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
905
+ with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
906
+ json.dump(metrics, f, indent=2)
907
+ print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
908
+
909
+ if bool(cfg.get("merge", {}).get("enabled", False)):
910
+ del trainer, model
911
+ torch.cuda.empty_cache()
912
+ merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
913
+ else:
914
+ print("Merge disabled. Run with --merge-only later if needed.")
915
+
916
+ # Finish Wandb run
917
+ finish_wandb()
918
+
919
+
920
+ if __name__ == "__main__":
921
+ main()
trainer-kit/documentation.md ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CPT Training Different Modules Guide
2
+
3
+ ## Overview
4
+
5
+ By default, the CPT (Continual Pre-Training) configuration in `/workspace/Trainer-kit/CPT/config.yaml` trains only **attention projection layers** using LoRA adapters. This guide explains how to modify the configuration to train other modules.
6
+
7
+ ## Current Default Configuration
8
+
9
+ ```yaml
10
+ peft:
11
+ enabled: true
12
+ target_modules: "auto"
13
+ ```
14
+
15
+ When `target_modules: "auto"` is set, the script automatically detects and trains these attention layers:
16
+ - `q_proj` - Query projection
17
+ - `k_proj` - Key projection
18
+ - `v_proj` - Value projection
19
+ - `o_proj` - Output projection
20
+
21
+ ## How to Train Other Modules
22
+
23
+ ### Method 1: Explicit Target Modules
24
+
25
+ Replace `"auto"` with a list of specific module names you want to train:
26
+
27
+ ```yaml
28
+ peft:
29
+ enabled: true
30
+ target_modules:
31
+ - "q_proj"
32
+ - "k_proj"
33
+ - "v_proj"
34
+ - "o_proj"
35
+ - "mlp.down_proj" # Add MLP down projection
36
+ - "mlp.gate_proj" # Add MLP gate projection
37
+ - "mlp.up_proj" # Add MLP up projection
38
+ ```
39
+
40
+ ### Method 2: Custom Module Lists
41
+
42
+ For different model architectures, here are common modules you can train:
43
+
44
+ #### LLaMA/Llama-style Models
45
+ ```yaml
46
+ peft:
47
+ enabled: true
48
+ target_modules:
49
+ - "q_proj"
50
+ - "k_proj"
51
+ - "v_proj"
52
+ - "o_proj"
53
+ - "mlp.gate_proj"
54
+ - "mlp.up_proj"
55
+ - "mlp.down_proj"
56
+ ```
57
+
58
+ #### Qwen-style Models
59
+ ```yaml
60
+ peft:
61
+ enabled: true
62
+ target_modules:
63
+ - "q_proj"
64
+ - "k_proj"
65
+ - "v_proj"
66
+ - "o_proj"
67
+ - "mlp.gate_proj"
68
+ - "mlp.up_proj"
69
+ - "mlp.down_proj"
70
+ ```
71
+
72
+ #### Mixtral/Gemma-style Models
73
+ ```yaml
74
+ peft:
75
+ enabled: true
76
+ target_modules:
77
+ - "q_proj"
78
+ - "k_proj"
79
+ - "v_proj"
80
+ - "o_proj"
81
+ - "mlp.experts.*.w1" # Expert layer 1
82
+ - "mlp.experts.*.w2" # Expert layer 2
83
+ - "mlp.experts.*.w3" # Expert layer 3
84
+ ```
85
+
86
+ ## Module Types You Can Train
87
+
88
+ ### 1. Attention Layers
89
+ - `q_proj` - Query projections
90
+ - `k_proj` - Key projections
91
+ - `v_proj` - Value projections
92
+ - `o_proj` - Output projections
93
+ - `qkv_proj` - Combined QKV (in some models)
94
+ - `c_attn` - Attention in older models
95
+
96
+ ### 2. MLP/Feed-Forward Layers
97
+ - `mlp.gate_proj` - Gate projection
98
+ - `mlp.up_proj` - Up projection
99
+ - `mlp.down_proj` - Down projection
100
+ - `mlp.fc1` - First layer
101
+ - `mlp.fc2` - Second layer
102
+ - `w1`, `w2`, `w3` - Alternative naming
103
+
104
+ ### 3. Embedding Layers
105
+ ```yaml
106
+ peft:
107
+ enabled: true
108
+ target_modules:
109
+ - "model.embed_tokens" # Token embeddings
110
+ - "lm_head" # Language model head
111
+ ```
112
+
113
+ ### 4. Normalization Layers
114
+ ```yaml
115
+ peft:
116
+ enabled: true
117
+ target_modules:
118
+ - "input_layernorm" # Input normalization
119
+ - "post_attention_layernorm" # Post-attention norm
120
+ - "final_layernorm" # Final normalization
121
+ ```
122
+
123
+ ### 5. MoE (Mixture of Experts) Layers
124
+ ```yaml
125
+ peft:
126
+ enabled: true
127
+ target_modules:
128
+ - "mlp.experts.*.w1" # Expert layer 1
129
+ - "mlp.experts.*.w2" # Expert layer 2
130
+ - "mlp.experts.*.w3" # Expert layer 3
131
+ - "mlp.gate" # Expert routing gate
132
+ ```
133
+
134
+ ## Advanced Configuration Examples
135
+
136
+ ### Train Multiple Layer Types
137
+ ```yaml
138
+ peft:
139
+ enabled: true
140
+ target_modules:
141
+ - "q_proj"
142
+ - "k_proj"
143
+ - "v_proj"
144
+ - "o_proj"
145
+ - "mlp.gate_proj"
146
+ - "mlp.up_proj"
147
+ - "mlp.down_proj"
148
+ - "input_layernorm"
149
+ - "post_attention_layernorm"
150
+ ```
151
+
152
+ ### Conservative Approach (Only MLPs)
153
+ ```yaml
154
+ peft:
155
+ enabled: true
156
+ target_modules:
157
+ - "mlp.gate_proj"
158
+ - "mlp.up_proj"
159
+ - "mlp.down_proj"
160
+ ```
161
+
162
+ ### Comprehensive Approach (All Main Layers)
163
+ ```yaml
164
+ peft:
165
+ enabled: true
166
+ target_modules:
167
+ - "q_proj"
168
+ - "k_proj"
169
+ - "v_proj"
170
+ - "o_proj"
171
+ - "mlp.gate_proj"
172
+ - "mlp.up_proj"
173
+ - "mlp.down_proj"
174
+ - "input_layernorm"
175
+ - "post_attention_layernorm"
176
+ ```
177
+
178
+ ## How to Find Module Names for Your Model
179
+
180
+ ### Method 1: Automatic Detection
181
+ Run the script once with `target_modules: "auto"` - it will log which modules it found:
182
+
183
+ ```
184
+ Using auto-inferred target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj']
185
+ ```
186
+
187
+ ### Method 2: Manual Inspection
188
+ Inspect your model structure:
189
+
190
+ ```python
191
+ import torch
192
+ from transformers import AutoModel
193
+
194
+ model = AutoModel.from_pretrained("/workspace/Models/YourModel")
195
+
196
+ # Print all module names
197
+ for name, module in model.named_modules():
198
+ print(name)
199
+ ```
200
+
201
+ ### Method 3: Use PEFT's Built-in Function
202
+ The script includes `_infer_target_modules()` function that can help identify available modules.
203
+
204
+ ## Considerations
205
+
206
+ ### 1. Memory Usage
207
+ - **More modules = More memory**: Training additional layers requires more GPU memory
208
+ - **Monitor VRAM usage**: Use `nvidia-smi` to monitor memory consumption
209
+ - **Adjust batch size**: You may need to reduce `per_device_train_batch_size`
210
+
211
+ ### 2. Training Time
212
+ - **More modules = Longer training**: Each additional layer increases computation time
213
+ - **Learning rate adjustments**: You might need to reduce `learning_rate` when training more layers
214
+
215
+ ### 3. Performance Trade-offs
216
+ - **Attention only**: Fast training, good for language understanding
217
+ - **MLP only**: Fast training, good for knowledge storage
218
+ - **Both attention + MLP**: Slower but potentially better performance
219
+ - **All layers**: Slowest but most comprehensive adaptation
220
+
221
+ ### 4. Model Architecture Differences
222
+ Different model families use different module naming conventions:
223
+ - **LLaMA**: `mlp.gate_proj`, `mlp.up_proj`, `mlp.down_proj`
224
+ - **Qwen**: `mlp.gate_proj`, `mlp.up_proj`, `mlp.down_proj`
225
+ - **Gemma**: `mlp.gate_proj`, `mlp.up_proj`, `mlp.down_proj`
226
+ - **Mixtral**: `mlp.experts.*.w1`, etc.
227
+
228
+ ## Best Practices
229
+
230
+ ### 1. Start Conservative
231
+ Begin with just attention layers, then gradually add more modules if needed:
232
+ ```yaml
233
+ # Phase 1: Start here
234
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
235
+
236
+ # Phase 2: Add MLPs
237
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "mlp.down_proj"]
238
+
239
+ # Phase 3: Add more if needed
240
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]
241
+ ```
242
+
243
+ ### 2. Monitor Overfitting
244
+ - Use evaluation split to monitor performance
245
+ - Adjust `learning_rate` if overfitting occurs
246
+ - Consider `lora_dropout` to reduce overfitting
247
+
248
+ ### 3. Resource Management
249
+ - Start with small LoRA rank (`r: 16`) if training many modules
250
+ - Increase `gradient_accumulation_steps` if reducing batch size
251
+ - Monitor GPU memory usage throughout training
252
+
253
+ ### 4. Model-Specific Tuning
254
+ Different models may benefit from different module combinations:
255
+ - **Code models**: Focus on attention + MLP layers
256
+ - **Chat models**: Attention layers are most important
257
+ - **Reasoning models**: All layers might be beneficial
258
+
259
+ ## Example: Training Custom Modules
260
+
261
+ ### Complete Configuration Example
262
+ ```yaml
263
+ model:
264
+ repo_id: "/workspace/Models/Devstral-Small-2-24B-Instruct-2512"
265
+ torch_dtype: "bfloat16"
266
+
267
+ peft:
268
+ enabled: true
269
+ r: 64
270
+ lora_alpha: 128
271
+ lora_dropout: 0.05
272
+ bias: "none"
273
+ target_modules:
274
+ - "q_proj"
275
+ - "k_proj"
276
+ - "v_proj"
277
+ - "o_proj"
278
+ - "mlp.gate_proj"
279
+ - "mlp.up_proj"
280
+ - "mlp.down_proj"
281
+ - "input_layernorm"
282
+
283
+ train:
284
+ num_train_epochs: 2
285
+ learning_rate: 1e-5 # Reduced due to more modules
286
+ per_device_train_batch_size: 1
287
+ gradient_accumulation_steps: 16
288
+ ```
289
+
290
+ This configuration will train:
291
+ - All attention projection layers
292
+ - All MLP projection layers
293
+ - Input normalization layers
294
+ - Using a reduced learning rate to accommodate the additional trainable parameters.
295
+
296
+ Remember to always test with a small number of steps first to ensure your configuration works correctly before running full training.