Upload scripts/train_alizee_v2_stage1_sft.py with huggingface_hub
Browse files
scripts/train_alizee_v2_stage1_sft.py
CHANGED
|
@@ -50,7 +50,9 @@ LEARNING_RATE = 5e-5
|
|
| 50 |
EFFECTIVE_BATCH_SIZE = 256
|
| 51 |
PER_DEVICE_BATCH = 1
|
| 52 |
GRADIENT_ACCUMULATION = EFFECTIVE_BATCH_SIZE // PER_DEVICE_BATCH
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
NUM_EPOCHS = 2
|
| 55 |
WARMUP_RATIO = 0.05
|
| 56 |
|
|
@@ -160,8 +162,8 @@ for i, sample in enumerate(coding_ds):
|
|
| 160 |
coding_ds_final = Dataset.from_list(coding_samples)
|
| 161 |
print(f" Collected {len(coding_ds_final)} coding samples")
|
| 162 |
|
| 163 |
-
# Approximate character limit for
|
| 164 |
-
MAX_CHARS = MAX_SEQ_LENGTH * 3 # ~
|
| 165 |
|
| 166 |
# Format functions for different data sources
|
| 167 |
def format_reasoning_sample(example):
|
|
@@ -172,10 +174,10 @@ def format_reasoning_sample(example):
|
|
| 172 |
- output: reasoning trace / expected output explanation
|
| 173 |
- solution: the actual code
|
| 174 |
"""
|
| 175 |
-
# Truncate
|
| 176 |
-
input_text = str(example.get('input', ''))[:
|
| 177 |
-
output_text = str(example.get('output', ''))[:
|
| 178 |
-
solution_text = str(example.get('solution', ''))[:
|
| 179 |
|
| 180 |
# Create a reasoning-enhanced prompt
|
| 181 |
messages = [
|
|
@@ -193,8 +195,8 @@ def format_reasoning_sample(example):
|
|
| 193 |
|
| 194 |
def format_coding_sample(example):
|
| 195 |
"""Format starcoderdata sample for capability preservation."""
|
| 196 |
-
# Extract code content and truncate to
|
| 197 |
-
content = str(example.get("content", ""))[:
|
| 198 |
|
| 199 |
# Create a simple code completion task
|
| 200 |
lines = content.split("\n")
|
|
|
|
| 50 |
EFFECTIVE_BATCH_SIZE = 256
|
| 51 |
PER_DEVICE_BATCH = 1
|
| 52 |
GRADIENT_ACCUMULATION = EFFECTIVE_BATCH_SIZE // PER_DEVICE_BATCH
|
| 53 |
+
# Reduced from 32K to 8K to avoid disk storage overflow during tokenization
|
| 54 |
+
# (32K × 860K samples creates ~100GB+ tokenized dataset that exceeds pod storage)
|
| 55 |
+
MAX_SEQ_LENGTH = 8192
|
| 56 |
NUM_EPOCHS = 2
|
| 57 |
WARMUP_RATIO = 0.05
|
| 58 |
|
|
|
|
| 162 |
coding_ds_final = Dataset.from_list(coding_samples)
|
| 163 |
print(f" Collected {len(coding_ds_final)} coding samples")
|
| 164 |
|
| 165 |
+
# Approximate character limit for 8K tokens (assuming ~4 chars per token average)
|
| 166 |
+
MAX_CHARS = MAX_SEQ_LENGTH * 3 # ~24K chars for 8K tokens
|
| 167 |
|
| 168 |
# Format functions for different data sources
|
| 169 |
def format_reasoning_sample(example):
|
|
|
|
| 174 |
- output: reasoning trace / expected output explanation
|
| 175 |
- solution: the actual code
|
| 176 |
"""
|
| 177 |
+
# Truncate fields to fit within 8K tokens (~24K chars total)
|
| 178 |
+
input_text = str(example.get('input', ''))[:8000]
|
| 179 |
+
output_text = str(example.get('output', ''))[:12000]
|
| 180 |
+
solution_text = str(example.get('solution', ''))[:4000]
|
| 181 |
|
| 182 |
# Create a reasoning-enhanced prompt
|
| 183 |
messages = [
|
|
|
|
| 195 |
|
| 196 |
def format_coding_sample(example):
|
| 197 |
"""Format starcoderdata sample for capability preservation."""
|
| 198 |
+
# Extract code content and truncate to fit 8K context (~20K chars)
|
| 199 |
+
content = str(example.get("content", ""))[:20000]
|
| 200 |
|
| 201 |
# Create a simple code completion task
|
| 202 |
lines = content.split("\n")
|