stmasson commited on
Commit
ab24506
·
verified ·
1 Parent(s): ab21097

Upload scripts/train_alizee_v2_stage1_sft.py with huggingface_hub

Browse files
scripts/train_alizee_v2_stage1_sft.py CHANGED
@@ -50,7 +50,9 @@ LEARNING_RATE = 5e-5
50
  EFFECTIVE_BATCH_SIZE = 256
51
  PER_DEVICE_BATCH = 1
52
  GRADIENT_ACCUMULATION = EFFECTIVE_BATCH_SIZE // PER_DEVICE_BATCH
53
- MAX_SEQ_LENGTH = 32768
 
 
54
  NUM_EPOCHS = 2
55
  WARMUP_RATIO = 0.05
56
 
@@ -160,8 +162,8 @@ for i, sample in enumerate(coding_ds):
160
  coding_ds_final = Dataset.from_list(coding_samples)
161
  print(f" Collected {len(coding_ds_final)} coding samples")
162
 
163
- # Approximate character limit for 32K tokens (assuming ~4 chars per token average)
164
- MAX_CHARS = MAX_SEQ_LENGTH * 3 # ~98K chars, slightly conservative
165
 
166
  # Format functions for different data sources
167
  def format_reasoning_sample(example):
@@ -172,10 +174,10 @@ def format_reasoning_sample(example):
172
  - output: reasoning trace / expected output explanation
173
  - solution: the actual code
174
  """
175
- # Truncate very long fields to prevent memory issues
176
- input_text = str(example.get('input', ''))[:30000]
177
- output_text = str(example.get('output', ''))[:50000]
178
- solution_text = str(example.get('solution', ''))[:15000]
179
 
180
  # Create a reasoning-enhanced prompt
181
  messages = [
@@ -193,8 +195,8 @@ def format_reasoning_sample(example):
193
 
194
  def format_coding_sample(example):
195
  """Format starcoderdata sample for capability preservation."""
196
- # Extract code content and truncate to prevent memory issues
197
- content = str(example.get("content", ""))[:40000]
198
 
199
  # Create a simple code completion task
200
  lines = content.split("\n")
 
50
  EFFECTIVE_BATCH_SIZE = 256
51
  PER_DEVICE_BATCH = 1
52
  GRADIENT_ACCUMULATION = EFFECTIVE_BATCH_SIZE // PER_DEVICE_BATCH
53
+ # Reduced from 32K to 8K to avoid disk storage overflow during tokenization
54
+ # (32K × 860K samples creates ~100GB+ tokenized dataset that exceeds pod storage)
55
+ MAX_SEQ_LENGTH = 8192
56
  NUM_EPOCHS = 2
57
  WARMUP_RATIO = 0.05
58
 
 
162
  coding_ds_final = Dataset.from_list(coding_samples)
163
  print(f" Collected {len(coding_ds_final)} coding samples")
164
 
165
+ # Approximate character limit for 8K tokens (assuming ~4 chars per token average)
166
+ MAX_CHARS = MAX_SEQ_LENGTH * 3 # ~24K chars for 8K tokens
167
 
168
  # Format functions for different data sources
169
  def format_reasoning_sample(example):
 
174
  - output: reasoning trace / expected output explanation
175
  - solution: the actual code
176
  """
177
+ # Truncate fields to fit within 8K tokens (~24K chars total)
178
+ input_text = str(example.get('input', ''))[:8000]
179
+ output_text = str(example.get('output', ''))[:12000]
180
+ solution_text = str(example.get('solution', ''))[:4000]
181
 
182
  # Create a reasoning-enhanced prompt
183
  messages = [
 
195
 
196
  def format_coding_sample(example):
197
  """Format starcoderdata sample for capability preservation."""
198
+ # Extract code content and truncate to fit 8K context (~20K chars)
199
+ content = str(example.get("content", ""))[:20000]
200
 
201
  # Create a simple code completion task
202
  lines = content.split("\n")