Spaces:
Paused
Paused
Upload folder using huggingface_hub
Browse files- app.py +7 -4
- requirements.txt +1 -0
- run_transformers_training.py +54 -311
app.py
CHANGED
|
@@ -84,15 +84,17 @@ def display_config():
|
|
| 84 |
<ul>
|
| 85 |
<li><b>Model:</b> {model_name}</li>
|
| 86 |
<li><b>Learning Rate:</b> {training_config.get('learning_rate', '2e-5')}</li>
|
| 87 |
-
<li><b>Batch Size:</b> {
|
| 88 |
-
<li><b>
|
|
|
|
|
|
|
| 89 |
<li><b>Precision:</b> {'BF16' if transformers_config.get('bf16', True) else 'FP16' if transformers_config.get('fp16', False) else 'FP32'}</li>
|
| 90 |
<li><b>Max Sequence Length:</b> {transformers_config.get('tokenizer', {}).get('max_seq_length', 2048)}</li>
|
| 91 |
</ul>
|
| 92 |
|
| 93 |
<h3>Hardware</h3>
|
| 94 |
<ul>
|
| 95 |
-
<li><b>GPU:</b> {gpu_count}× {gpu_type} ({vram} GB)</li>
|
| 96 |
<li><b>Multi-GPU Strategy:</b> {hardware_config.get('training_optimizations', {}).get('multi_gpu_strategy', 'data_parallel')}</li>
|
| 97 |
<li><b>Memory Optimizations:</b> {'Gradient Checkpointing' if hardware_config.get('training_optimizations', {}).get('memory_optimizations', {}).get('use_gradient_checkpointing', True) else 'None'}</li>
|
| 98 |
</ul>
|
|
@@ -154,9 +156,10 @@ def create_interface():
|
|
| 154 |
gr.Markdown("## Training Information")
|
| 155 |
gr.Markdown("""
|
| 156 |
### Hardware:
|
| 157 |
-
- 4× NVIDIA L4 GPUs (24GB VRAM
|
| 158 |
- Training with BF16 precision
|
| 159 |
- Using Data Parallel for multi-GPU
|
|
|
|
| 160 |
|
| 161 |
### Notes:
|
| 162 |
- Training may take several hours depending on dataset size
|
|
|
|
| 84 |
<ul>
|
| 85 |
<li><b>Model:</b> {model_name}</li>
|
| 86 |
<li><b>Learning Rate:</b> {training_config.get('learning_rate', '2e-5')}</li>
|
| 87 |
+
<li><b>Per-Device Batch Size:</b> {batch_size}</li>
|
| 88 |
+
<li><b>Gradient Accumulation:</b> {grad_accum}</li>
|
| 89 |
+
<li><b>Total Effective Batch Size:</b> {batch_size} × {gpu_count} × {grad_accum} = {batch_size * gpu_count * grad_accum}</li>
|
| 90 |
+
<li><b>Epochs:</b> {epochs}</li>
|
| 91 |
<li><b>Precision:</b> {'BF16' if transformers_config.get('bf16', True) else 'FP16' if transformers_config.get('fp16', False) else 'FP32'}</li>
|
| 92 |
<li><b>Max Sequence Length:</b> {transformers_config.get('tokenizer', {}).get('max_seq_length', 2048)}</li>
|
| 93 |
</ul>
|
| 94 |
|
| 95 |
<h3>Hardware</h3>
|
| 96 |
<ul>
|
| 97 |
+
<li><b>GPU:</b> {gpu_count}× {gpu_type} ({vram} GB VRAM per GPU, total: {vram * gpu_count} GB)</li>
|
| 98 |
<li><b>Multi-GPU Strategy:</b> {hardware_config.get('training_optimizations', {}).get('multi_gpu_strategy', 'data_parallel')}</li>
|
| 99 |
<li><b>Memory Optimizations:</b> {'Gradient Checkpointing' if hardware_config.get('training_optimizations', {}).get('memory_optimizations', {}).get('use_gradient_checkpointing', True) else 'None'}</li>
|
| 100 |
</ul>
|
|
|
|
| 156 |
gr.Markdown("## Training Information")
|
| 157 |
gr.Markdown("""
|
| 158 |
### Hardware:
|
| 159 |
+
- 4× NVIDIA L4 GPUs (24GB VRAM per GPU, 96GB total)
|
| 160 |
- Training with BF16 precision
|
| 161 |
- Using Data Parallel for multi-GPU
|
| 162 |
+
- Effective batch size: 16 (per device) × 4 (GPUs) × 3 (gradient accumulation) = 192
|
| 163 |
|
| 164 |
### Notes:
|
| 165 |
- Training may take several hours depending on dataset size
|
requirements.txt
CHANGED
|
@@ -3,6 +3,7 @@ bitsandbytes>=0.41.0
|
|
| 3 |
datasets>=2.15.0
|
| 4 |
einops>=0.7.0
|
| 5 |
filelock>=3.13.1
|
|
|
|
| 6 |
gradio>=5.17.0
|
| 7 |
huggingface-hub>=0.19.0
|
| 8 |
matplotlib>=3.7.0
|
|
|
|
| 3 |
datasets>=2.15.0
|
| 4 |
einops>=0.7.0
|
| 5 |
filelock>=3.13.1
|
| 6 |
+
flash-attn>=2.5.1
|
| 7 |
gradio>=5.17.0
|
| 8 |
huggingface-hub>=0.19.0
|
| 9 |
matplotlib>=3.7.0
|
run_transformers_training.py
CHANGED
|
@@ -309,315 +309,58 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 309 |
if source != target: # Only rename if names are different
|
| 310 |
dataset = dataset.rename_column(source, target)
|
| 311 |
|
| 312 |
-
# Add prompt_number field that increments based on original order
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
return examples
|
| 333 |
-
|
| 334 |
-
# Add prompt numbers to the dataset based on original order
|
| 335 |
-
logger.info("Adding prompt numbers based on original dataset order (starting at 1)")
|
| 336 |
-
try:
|
| 337 |
-
dataset = dataset.map(
|
| 338 |
-
add_prompt_numbers,
|
| 339 |
-
with_indices=True,
|
| 340 |
-
desc="Adding prompt numbers"
|
| 341 |
-
)
|
| 342 |
-
logger.info(f"Successfully added prompt_number field to dataset")
|
| 343 |
-
except Exception as e:
|
| 344 |
-
logger.error(f"Error adding prompt numbers: {e}")
|
| 345 |
-
# Create a fallback implementation that doesn't rely on with_indices
|
| 346 |
-
logger.info("Attempting fallback method for adding prompt numbers")
|
| 347 |
-
|
| 348 |
-
def add_prompt_numbers_fallback(example, idx):
|
| 349 |
-
example["prompt_number"] = idx + 1
|
| 350 |
return example
|
| 351 |
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
for i, example in enumerate(dataset):
|
| 355 |
-
updated_examples.append(add_prompt_numbers_fallback(dict(example), i))
|
| 356 |
-
|
| 357 |
-
# Create a new dataset with the updated examples
|
| 358 |
-
from datasets import Dataset
|
| 359 |
-
dataset = Dataset.from_list(updated_examples)
|
| 360 |
-
logger.info(f"Successfully added prompt_number field using fallback method")
|
| 361 |
-
|
| 362 |
-
# Rename 'id' to 'article_id' if it exists
|
| 363 |
-
if 'id' in dataset.column_names and 'article_id' not in dataset.column_names:
|
| 364 |
-
logger.info("Renaming 'id' column to 'article_id'")
|
| 365 |
-
dataset = dataset.rename_column('id', 'article_id')
|
| 366 |
-
|
| 367 |
-
# Reorder columns to make prompt_number first if it exists
|
| 368 |
-
if 'prompt_number' in dataset.column_names:
|
| 369 |
-
logger.info("Reordering columns to place prompt_number first")
|
| 370 |
-
# Get current column names
|
| 371 |
-
current_columns = dataset.column_names
|
| 372 |
-
# Create new column order with prompt_number first
|
| 373 |
-
new_column_order = ['prompt_number'] + [col for col in current_columns if col != 'prompt_number']
|
| 374 |
-
# Reorder columns
|
| 375 |
-
dataset = dataset.select_columns(new_column_order)
|
| 376 |
-
|
| 377 |
-
# Verify all new column names for logging
|
| 378 |
-
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
| 379 |
-
logger.info(f"Dataset columns: {dataset.column_names}")
|
| 380 |
-
|
| 381 |
-
# Verify dataset is not empty
|
| 382 |
-
if len(dataset) == 0:
|
| 383 |
-
logger.error("Dataset is empty! This will cause errors during training.")
|
| 384 |
-
raise ValueError("Empty dataset loaded")
|
| 385 |
-
|
| 386 |
-
# Check for required columns
|
| 387 |
-
required_columns = ['conversations']
|
| 388 |
-
for col in required_columns:
|
| 389 |
-
if col not in dataset.column_names:
|
| 390 |
-
logger.error(f"Required column '{col}' not found in dataset!")
|
| 391 |
-
raise ValueError(f"Required column '{col}' missing from dataset")
|
| 392 |
-
|
| 393 |
-
# Verify expected columns exist
|
| 394 |
-
expected_columns = {"article_id", "conversations", "prompt_number"}
|
| 395 |
-
missing_columns = expected_columns - set(dataset.column_names)
|
| 396 |
-
if missing_columns:
|
| 397 |
-
logger.warning(f"Some expected columns are missing: {missing_columns}")
|
| 398 |
-
|
| 399 |
-
# If "conversations" is missing but "text" exists, attempt conversion
|
| 400 |
-
if "conversations" not in dataset.column_names and "text" in dataset.column_names:
|
| 401 |
-
logger.info("Converting 'text' field to 'conversations' format")
|
| 402 |
-
|
| 403 |
-
def convert_text_to_conversations(example):
|
| 404 |
-
# Check if text is already a list of conversation turns
|
| 405 |
-
if isinstance(example.get("text"), list):
|
| 406 |
-
example["conversations"] = example["text"]
|
| 407 |
-
# Otherwise, create a simple conversation with the text as user message
|
| 408 |
-
else:
|
| 409 |
-
example["conversations"] = [
|
| 410 |
-
{"role": "user", "content": str(example.get("text", ""))}
|
| 411 |
-
]
|
| 412 |
-
return example
|
| 413 |
-
|
| 414 |
-
dataset = dataset.map(convert_text_to_conversations)
|
| 415 |
-
logger.info("Successfully converted 'text' to 'conversations'")
|
| 416 |
-
|
| 417 |
-
# Verify data ordering requirements
|
| 418 |
-
processing_config = dataset_config.get("dataset", {}).get("processing", {})
|
| 419 |
-
data_loading_config = dataset_config.get("data_loading", {})
|
| 420 |
-
|
| 421 |
-
# Check if sorting is required
|
| 422 |
-
sort_by_article_id = processing_config.get("sort_by_article_id", False)
|
| 423 |
-
if sort_by_article_id and 'article_id' in dataset.column_names:
|
| 424 |
-
logger.info("Sorting dataset by article_id as specified in config")
|
| 425 |
-
dataset = dataset.sort("article_id")
|
| 426 |
-
sorted_ids = [example['article_id'] for example in dataset.select(range(min(5, len(dataset))))]
|
| 427 |
-
logger.info(f"First few article_ids after sorting: {sorted_ids}")
|
| 428 |
-
|
| 429 |
-
# Flag consolidation - we only need one flag to control sequence preservation
|
| 430 |
-
# Default to True to ensure safety
|
| 431 |
-
preserve_sequence = processing_config.get("preserve_entry_sequence", True)
|
| 432 |
-
shuffle_disabled = not data_loading_config.get("shuffle", False)
|
| 433 |
-
|
| 434 |
-
if not preserve_sequence:
|
| 435 |
-
logger.warning("CRITICAL: preserve_entry_sequence is set to False. This is NOT RECOMMENDED!")
|
| 436 |
-
logger.warning("Data sequence integrity is essential for proper model training.")
|
| 437 |
-
|
| 438 |
-
if not shuffle_disabled:
|
| 439 |
-
logger.error("CRITICAL: shuffle is enabled in the dataset config!")
|
| 440 |
-
logger.error("This will RANDOMIZE your dataset and break sequential order.")
|
| 441 |
-
logger.error("Please set shuffle: false in your data_loading configuration.")
|
| 442 |
-
# Actually enforce sequence preservation by raising an error
|
| 443 |
-
raise ValueError("Dataset shuffling is enabled but preserve_entry_sequence is required. " +
|
| 444 |
-
"Please disable shuffling in your configuration.")
|
| 445 |
-
|
| 446 |
-
# Verify the IDs are in sequential order if they're numeric
|
| 447 |
-
try:
|
| 448 |
-
if len(dataset) > 1:
|
| 449 |
-
# Check prompt numbers are sequential
|
| 450 |
-
sample_indices = range(min(10, len(dataset)))
|
| 451 |
-
sample_prompt_numbers = []
|
| 452 |
-
|
| 453 |
-
# Defensive collection of prompt numbers
|
| 454 |
-
for i in sample_indices:
|
| 455 |
-
try:
|
| 456 |
-
if i < len(dataset) and "prompt_number" in dataset[i]:
|
| 457 |
-
sample_prompt_numbers.append(dataset[i]["prompt_number"])
|
| 458 |
-
else:
|
| 459 |
-
# If prompt_number doesn't exist, use index+1 as fallback
|
| 460 |
-
sample_prompt_numbers.append(i + 1)
|
| 461 |
-
logger.warning(f"Sample at index {i} missing prompt_number, using {i+1} as fallback")
|
| 462 |
-
except Exception as e:
|
| 463 |
-
logger.warning(f"Error accessing sample at index {i}: {e}")
|
| 464 |
-
sample_prompt_numbers.append(i + 1) # Use fallback
|
| 465 |
-
|
| 466 |
-
logger.info(f"Verifying sequential integrity with prompt numbers: {sample_prompt_numbers}")
|
| 467 |
-
|
| 468 |
-
# Check if prompt numbers are sequential (1-indexed)
|
| 469 |
-
if sample_prompt_numbers:
|
| 470 |
-
is_sequential = all(sample_prompt_numbers[i] == i + 1 for i in range(len(sample_prompt_numbers)))
|
| 471 |
-
if not is_sequential:
|
| 472 |
-
logger.warning("WARNING: Prompt numbers are not in sequential order!")
|
| 473 |
-
logger.warning("This may indicate that data sequence is not preserved.")
|
| 474 |
-
else:
|
| 475 |
-
logger.info("Prompt numbers verify that samples are in sequential order.")
|
| 476 |
-
else:
|
| 477 |
-
logger.warning("Could not verify sequential integrity: no prompt numbers collected")
|
| 478 |
-
|
| 479 |
-
# Also check original IDs as a backup if numeric
|
| 480 |
-
try:
|
| 481 |
-
sample_examples = []
|
| 482 |
-
for i in sample_indices:
|
| 483 |
-
try:
|
| 484 |
-
if i < len(dataset):
|
| 485 |
-
sample_examples.append(dataset[i])
|
| 486 |
-
except Exception as e:
|
| 487 |
-
logger.warning(f"Error accessing dataset at index {i}: {e}")
|
| 488 |
-
|
| 489 |
-
if sample_examples:
|
| 490 |
-
id_field = 'article_id' if 'article_id' in dataset.column_names else 'id'
|
| 491 |
-
if all(isinstance(example.get(id_field, ''), (int, str)) for example in sample_examples):
|
| 492 |
-
sample_ids = [example.get(id_field, '') for example in sample_examples if id_field in example]
|
| 493 |
-
|
| 494 |
-
if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
|
| 495 |
-
numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
|
| 496 |
-
if len(numeric_ids) > 1:
|
| 497 |
-
is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
|
| 498 |
-
if not is_ordered:
|
| 499 |
-
logger.warning(f"WARNING: Sample {id_field}s are not in sequential order.")
|
| 500 |
-
else:
|
| 501 |
-
logger.info(f"Sample {id_field}s appear to be in sequential order.")
|
| 502 |
-
except Exception as e:
|
| 503 |
-
logger.warning(f"Error checking ID sequence: {e}")
|
| 504 |
-
except Exception as e:
|
| 505 |
-
logger.warning(f"Could not verify sequential integrity: {e}")
|
| 506 |
|
| 507 |
-
#
|
| 508 |
-
if "conversations" in dataset.column_names:
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
first_few_indices = range(min(5, len(dataset)))
|
| 512 |
-
sample_prompt_numbers = []
|
| 513 |
-
sample_article_ids = []
|
| 514 |
-
|
| 515 |
-
for i in first_few_indices:
|
| 516 |
-
try:
|
| 517 |
-
example = dataset[i]
|
| 518 |
-
if 'prompt_number' in example:
|
| 519 |
-
sample_prompt_numbers.append(example['prompt_number'])
|
| 520 |
-
if 'article_id' in example:
|
| 521 |
-
sample_article_ids.append(example['article_id'])
|
| 522 |
-
except Exception as e:
|
| 523 |
-
logger.warning(f"Error accessing sample at index {i}: {e}")
|
| 524 |
-
|
| 525 |
-
logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, Article IDs: {sample_article_ids}")
|
| 526 |
-
|
| 527 |
-
# Log conversation structure without full content
|
| 528 |
-
if len(dataset) > 0:
|
| 529 |
-
try:
|
| 530 |
-
sample_conv_structure = []
|
| 531 |
-
first_example = dataset[0]
|
| 532 |
-
|
| 533 |
-
if 'conversations' in first_example and first_example['conversations'] is not None:
|
| 534 |
-
for msg in first_example['conversations']:
|
| 535 |
-
if isinstance(msg, dict):
|
| 536 |
-
content = msg.get('content', '')
|
| 537 |
-
preview = content[:50] + "..." if len(content) > 50 else content
|
| 538 |
-
sample_conv_structure.append({
|
| 539 |
-
"role": msg.get('role', ''),
|
| 540 |
-
"content_length": len(content),
|
| 541 |
-
"preview": preview
|
| 542 |
-
})
|
| 543 |
-
logger.info(f"Conversation structure: {sample_conv_structure}")
|
| 544 |
-
except Exception as e:
|
| 545 |
-
logger.warning(f"Error logging conversation structure: {e}")
|
| 546 |
-
except Exception as e:
|
| 547 |
-
logger.warning(f"Error logging sample examples: {e}")
|
| 548 |
|
|
|
|
| 549 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
| 550 |
logger.info(f"Dataset columns: {dataset.column_names}")
|
| 551 |
|
| 552 |
-
#
|
| 553 |
-
if len(dataset)
|
| 554 |
-
|
| 555 |
-
|
|
|
|
|
|
|
| 556 |
|
| 557 |
-
# Check for required columns
|
| 558 |
-
required_cols = ['conversations', 'prompt_number']
|
| 559 |
-
for col in required_cols:
|
| 560 |
-
if col not in dataset.column_names:
|
| 561 |
-
logger.error(f"Required column '{col}' missing from dataset. Cannot proceed with training.")
|
| 562 |
-
return dataset
|
| 563 |
-
|
| 564 |
-
# Validate at least one sample can be processed
|
| 565 |
-
try:
|
| 566 |
-
if len(dataset) > 0:
|
| 567 |
-
sample = dataset[0]
|
| 568 |
-
if 'conversations' not in sample or not sample['conversations']:
|
| 569 |
-
logger.error("First sample has no conversations! Data format may be incorrect.")
|
| 570 |
-
return dataset
|
| 571 |
-
if not isinstance(sample['conversations'], list):
|
| 572 |
-
logger.error(f"Conversations field should be a list but got {type(sample['conversations'])}")
|
| 573 |
-
return dataset
|
| 574 |
-
except Exception as e:
|
| 575 |
-
logger.error(f"Error validating first sample: {e}")
|
| 576 |
-
return dataset
|
| 577 |
-
|
| 578 |
-
# Add metadata if specified
|
| 579 |
-
metadata_config = dataset_config.get("data_formatting", {}).get("metadata_handling", {})
|
| 580 |
-
if metadata_config:
|
| 581 |
-
include_article_id = metadata_config.get("include_article_id", False)
|
| 582 |
-
include_prompt_number = metadata_config.get("include_prompt_number", False)
|
| 583 |
-
metadata_format = metadata_config.get("metadata_format", "")
|
| 584 |
-
|
| 585 |
-
if (include_article_id or include_prompt_number) and metadata_format:
|
| 586 |
-
logger.info("Adding metadata to conversations")
|
| 587 |
-
|
| 588 |
-
def add_metadata(example):
|
| 589 |
-
if not example.get("conversations"):
|
| 590 |
-
return example
|
| 591 |
-
|
| 592 |
-
# Prepare metadata
|
| 593 |
-
metadata = metadata_format
|
| 594 |
-
if include_article_id and "article_id" in example:
|
| 595 |
-
metadata = metadata.replace("{article_id}", str(example.get("article_id", "")))
|
| 596 |
-
if include_prompt_number and "prompt_number" in example:
|
| 597 |
-
metadata = metadata.replace("{prompt_number}", str(example.get("prompt_number", "")))
|
| 598 |
-
|
| 599 |
-
# Add system message with metadata if not empty
|
| 600 |
-
if metadata.strip():
|
| 601 |
-
if example["conversations"] and isinstance(example["conversations"], list):
|
| 602 |
-
# Check if first message is already a system message
|
| 603 |
-
if (isinstance(example["conversations"][0], dict) and
|
| 604 |
-
example["conversations"][0].get("role") == "system"):
|
| 605 |
-
# Append to existing system message
|
| 606 |
-
example["conversations"][0]["content"] += f"\n\nMetadata: {metadata}"
|
| 607 |
-
else:
|
| 608 |
-
# Add new system message at the beginning
|
| 609 |
-
example["conversations"].insert(0, {
|
| 610 |
-
"role": "system",
|
| 611 |
-
"content": f"Metadata: {metadata}"
|
| 612 |
-
})
|
| 613 |
-
|
| 614 |
-
return example
|
| 615 |
-
|
| 616 |
-
dataset = dataset.map(add_metadata)
|
| 617 |
-
logger.info("Metadata added to conversations")
|
| 618 |
-
|
| 619 |
return dataset
|
| 620 |
-
|
| 621 |
except Exception as e:
|
| 622 |
logger.error(f"Error loading dataset: {str(e)}")
|
| 623 |
raise
|
|
@@ -1112,6 +855,10 @@ def main():
|
|
| 1112 |
per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16)
|
| 1113 |
gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3)
|
| 1114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1115 |
# For multi-GPU setup, adjust for better balance
|
| 1116 |
if CUDA_AVAILABLE and NUM_GPUS > 1:
|
| 1117 |
log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs")
|
|
@@ -1213,21 +960,17 @@ def main():
|
|
| 1213 |
"""Custom dataloader that preserves original dataset order"""
|
| 1214 |
log_info("Creating sequential dataloader to maintain original dataset order")
|
| 1215 |
|
| 1216 |
-
#
|
| 1217 |
-
|
| 1218 |
-
sequential_processing = data_loading_config.get("sequential_processing", True)
|
| 1219 |
-
shuffle_disabled = not data_loading_config.get("shuffle", False)
|
| 1220 |
|
| 1221 |
-
|
| 1222 |
-
|
| 1223 |
-
|
| 1224 |
-
# Force sequential processing regardless of flag
|
| 1225 |
|
| 1226 |
-
if
|
| 1227 |
-
log_info("CRITICAL ERROR: Shuffle is
|
| 1228 |
-
# Actually handle the error rather than just logging it
|
| 1229 |
raise ValueError("Dataset shuffling is enabled but sequential processing is required. " +
|
| 1230 |
-
|
| 1231 |
|
| 1232 |
# Calculate batch size based on device availability
|
| 1233 |
if getattr(training_args, "no_cuda", False):
|
|
|
|
| 309 |
if source != target: # Only rename if names are different
|
| 310 |
dataset = dataset.rename_column(source, target)
|
| 311 |
|
| 312 |
+
# Add prompt_number field that increments based on original order - simple approach
|
| 313 |
+
logger.info("Adding prompt_number based on original dataset order (starting at 1)")
|
| 314 |
+
|
| 315 |
+
# Simple approach 1: Add index as a column during dataset creation
|
| 316 |
+
# Create a list of dicts with indices
|
| 317 |
+
examples_with_idx = []
|
| 318 |
+
for i, example in enumerate(dataset):
|
| 319 |
+
example = dict(example) # Make a copy to avoid modifying the original
|
| 320 |
+
example['prompt_number'] = i + 1 # 1-indexed
|
| 321 |
+
examples_with_idx.append(example)
|
| 322 |
+
|
| 323 |
+
# Recreate dataset with prompt_number included
|
| 324 |
+
from datasets import Dataset
|
| 325 |
+
dataset = Dataset.from_list(examples_with_idx)
|
| 326 |
+
logger.info("Successfully added prompt_number to dataset")
|
| 327 |
+
|
| 328 |
+
# If conversations is missing but text exists, attempt conversion
|
| 329 |
+
if "conversations" not in dataset.column_names and "text" in dataset.column_names:
|
| 330 |
+
logger.info("Converting 'text' field to 'conversations' format")
|
| 331 |
|
| 332 |
+
def convert_text_to_conversations(example):
|
| 333 |
+
# Check if text is already a list of conversation turns
|
| 334 |
+
if isinstance(example.get("text"), list):
|
| 335 |
+
example["conversations"] = example["text"]
|
| 336 |
+
# Otherwise, create a simple conversation with the text as user message
|
| 337 |
+
else:
|
| 338 |
+
example["conversations"] = [
|
| 339 |
+
{"role": "user", "content": str(example.get("text", ""))}
|
| 340 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
return example
|
| 342 |
|
| 343 |
+
dataset = dataset.map(convert_text_to_conversations)
|
| 344 |
+
logger.info("Successfully converted 'text' to 'conversations'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
+
# Verify we have the required columns
|
| 347 |
+
if "conversations" not in dataset.column_names:
|
| 348 |
+
logger.error("Required 'conversations' column not found in dataset!")
|
| 349 |
+
raise ValueError("Required 'conversations' column missing from dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
# Log column names and a sample
|
| 352 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
| 353 |
logger.info(f"Dataset columns: {dataset.column_names}")
|
| 354 |
|
| 355 |
+
# Log a sample for inspection
|
| 356 |
+
if len(dataset) > 0:
|
| 357 |
+
sample = dataset[0]
|
| 358 |
+
prompt_num = sample.get("prompt_number", "N/A")
|
| 359 |
+
article_id = sample.get("article_id", sample.get("id", "N/A"))
|
| 360 |
+
logger.info(f"First sample - Prompt number: {prompt_num}, ID: {article_id}")
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
return dataset
|
| 363 |
+
|
| 364 |
except Exception as e:
|
| 365 |
logger.error(f"Error loading dataset: {str(e)}")
|
| 366 |
raise
|
|
|
|
| 855 |
per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16)
|
| 856 |
gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3)
|
| 857 |
|
| 858 |
+
# Get multi-GPU strategy from hardware config (default to data_parallel)
|
| 859 |
+
multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel")
|
| 860 |
+
logger.info(f"Multi-GPU strategy: {multi_gpu_strategy}")
|
| 861 |
+
|
| 862 |
# For multi-GPU setup, adjust for better balance
|
| 863 |
if CUDA_AVAILABLE and NUM_GPUS > 1:
|
| 864 |
log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs")
|
|
|
|
| 960 |
"""Custom dataloader that preserves original dataset order"""
|
| 961 |
log_info("Creating sequential dataloader to maintain original dataset order")
|
| 962 |
|
| 963 |
+
# Create a simple sequential sampler
|
| 964 |
+
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
|
|
|
|
|
|
|
| 965 |
|
| 966 |
+
# Verification of sequence preservation flags - simplified
|
| 967 |
+
data_loading_config = dataset_config.get("data_loading", {})
|
| 968 |
+
shuffle_enabled = data_loading_config.get("shuffle", False)
|
|
|
|
| 969 |
|
| 970 |
+
if shuffle_enabled:
|
| 971 |
+
log_info("CRITICAL ERROR: Shuffle is enabled! This will randomize data entry order!")
|
|
|
|
| 972 |
raise ValueError("Dataset shuffling is enabled but sequential processing is required. " +
|
| 973 |
+
"Please disable shuffling in your configuration.")
|
| 974 |
|
| 975 |
# Calculate batch size based on device availability
|
| 976 |
if getattr(training_args, "no_cuda", False):
|