Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +142 -99
run_transformers_training.py
CHANGED
|
@@ -151,13 +151,55 @@ def load_model_and_tokenizer(config):
|
|
| 151 |
use_flash_attention = False
|
| 152 |
logger.warning("Flash attention not available, falling back to standard attention")
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# Apply Unsloth's training optimizations with config parameters
|
| 163 |
unsloth_config = config.get("unsloth", {})
|
|
@@ -332,14 +374,16 @@ class SimpleDataCollator:
|
|
| 332 |
self.dataset_config = dataset_config
|
| 333 |
self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
|
| 334 |
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
| 335 |
-
self.prompt_counter = 0
|
| 336 |
self.paper_counters = {}
|
| 337 |
self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
|
| 338 |
-
self.include_metadata =
|
| 339 |
-
self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
|
| 340 |
-
self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
|
| 341 |
self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
|
| 342 |
logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
def normalize_conversation(self, conversation):
|
| 345 |
"""Normalize conversation format to ensure consistent structure."""
|
|
@@ -353,6 +397,23 @@ class SimpleDataCollator:
|
|
| 353 |
else:
|
| 354 |
return []
|
| 355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
for turn in conversation:
|
| 357 |
# Skip empty or None entries
|
| 358 |
if not turn:
|
|
@@ -406,23 +467,6 @@ class SimpleDataCollator:
|
|
| 406 |
self.stats["skipped"] += 1
|
| 407 |
continue
|
| 408 |
|
| 409 |
-
# Track paper chunks
|
| 410 |
-
if paper_id not in self.paper_counters:
|
| 411 |
-
self.paper_counters[paper_id] = 0
|
| 412 |
-
self.paper_counters[paper_id] += 1
|
| 413 |
-
|
| 414 |
-
# Add metadata if configured
|
| 415 |
-
if self.include_metadata:
|
| 416 |
-
# Format metadata according to configured format
|
| 417 |
-
metadata_content = self.metadata_format.format(
|
| 418 |
-
paper_id=paper_id,
|
| 419 |
-
chunk_number=self.paper_counters[paper_id]
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
# Add as system message if not already in conversation
|
| 423 |
-
if not any(msg.get("role") == "system" for msg in conversation):
|
| 424 |
-
conversation = [{"role": "system", "content": metadata_content}] + conversation
|
| 425 |
-
|
| 426 |
# Format conversation with research introduction and chunk info
|
| 427 |
formatted_content = format_phi_chat(conversation, self.dataset_config)
|
| 428 |
|
|
@@ -433,6 +477,7 @@ class SimpleDataCollator:
|
|
| 433 |
truncation=True,
|
| 434 |
max_length=self.max_seq_length,
|
| 435 |
return_tensors=None,
|
|
|
|
| 436 |
)
|
| 437 |
|
| 438 |
if len(inputs["input_ids"]) > 0:
|
|
@@ -450,7 +495,7 @@ class SimpleDataCollator:
|
|
| 450 |
log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
|
| 451 |
if self.stats["processed"] <= log_samples:
|
| 452 |
logger.info(f"Example {self.stats['processed']} format:")
|
| 453 |
-
logger.info(f"Paper ID: {paper_id}
|
| 454 |
logger.info(f"Token count: {len(inputs['input_ids'])}")
|
| 455 |
logger.info(f"Content preview:\n{formatted_content[:500]}...")
|
| 456 |
logger.info(f"Conversation structure: {conversation[:2]}...")
|
|
@@ -464,6 +509,7 @@ class SimpleDataCollator:
|
|
| 464 |
|
| 465 |
if not batch["input_ids"]:
|
| 466 |
logger.warning("Empty batch, returning dummy tensors")
|
|
|
|
| 467 |
return {
|
| 468 |
"input_ids": torch.zeros((1, 1), dtype=torch.long),
|
| 469 |
"attention_mask": torch.zeros((1, 1), dtype=torch.long),
|
|
@@ -480,8 +526,8 @@ class SimpleDataCollator:
|
|
| 480 |
batch["attention_mask"][i].extend([0] * padding_length)
|
| 481 |
batch["labels"][i].extend([-100] * padding_length)
|
| 482 |
|
| 483 |
-
# Convert to tensors
|
| 484 |
-
batch = {k: torch.tensor(v) for k, v in batch.items()}
|
| 485 |
|
| 486 |
# Log stats periodically
|
| 487 |
log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
|
|
@@ -601,6 +647,18 @@ def main():
|
|
| 601 |
set_seed(seed)
|
| 602 |
logger.info(f"Set random seed to {seed}")
|
| 603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
try:
|
| 605 |
model, tokenizer = load_model_and_tokenizer(model_config)
|
| 606 |
logger.info("Model and tokenizer loaded successfully")
|
|
@@ -612,7 +670,7 @@ def main():
|
|
| 612 |
except Exception as e:
|
| 613 |
logger.error(f"Error loading dataset: {e}")
|
| 614 |
return 1
|
| 615 |
-
|
| 616 |
# Create data collator
|
| 617 |
data_collator = SimpleDataCollator(tokenizer, dataset_config)
|
| 618 |
|
|
@@ -627,6 +685,13 @@ def main():
|
|
| 627 |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
|
| 628 |
logger.info(f"Step {state.global_step}: Loss {state.log_history[-1]['loss'] if state.log_history else 'N/A'}")
|
| 629 |
self.last_log_time = current_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
|
| 631 |
# Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
|
| 632 |
use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16"
|
|
@@ -658,95 +723,73 @@ def main():
|
|
| 658 |
report_to="tensorboard",
|
| 659 |
remove_unused_columns=False, # Keep all columns
|
| 660 |
gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True),
|
| 661 |
-
dataloader_pin_memory=
|
| 662 |
optim=model_config.get("training", {}).get("optim", "adamw_torch"),
|
| 663 |
ddp_find_unused_parameters=False, # Improve distributed training efficiency
|
| 664 |
dataloader_drop_last=False, # Process all examples
|
| 665 |
dataloader_num_workers=4, # Sequential data loading
|
|
|
|
| 666 |
)
|
| 667 |
|
| 668 |
-
#
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
|
|
|
|
|
|
| 684 |
trainer = Trainer(
|
| 685 |
model=model,
|
| 686 |
args=training_args,
|
| 687 |
-
|
|
|
|
| 688 |
data_collator=data_collator,
|
| 689 |
callbacks=[LoggingCallback()]
|
| 690 |
)
|
| 691 |
|
| 692 |
-
# Override the default data loader to disable shuffling
|
| 693 |
-
# This is necessary because TrainingArguments doesn't have a direct shuffle parameter
|
| 694 |
-
def get_train_dataloader_no_shuffle():
|
| 695 |
-
"""Create a train DataLoader with shuffling disabled."""
|
| 696 |
-
logger.info("Creating train dataloader with sequential sampler (no shuffling)")
|
| 697 |
-
|
| 698 |
-
# Create a sequential sampler to ensure dataset is processed in order
|
| 699 |
-
train_sampler = torch.utils.data.SequentialSampler(dataset)
|
| 700 |
-
|
| 701 |
-
return torch.utils.data.DataLoader(
|
| 702 |
-
dataset,
|
| 703 |
-
batch_size=training_args.per_device_train_batch_size,
|
| 704 |
-
sampler=train_sampler, # Use sequential sampler instead of shuffle parameter
|
| 705 |
-
collate_fn=data_collator,
|
| 706 |
-
drop_last=False,
|
| 707 |
-
num_workers=0,
|
| 708 |
-
pin_memory=False
|
| 709 |
-
)
|
| 710 |
-
|
| 711 |
-
# Replace the default data loader with our non-shuffling version
|
| 712 |
-
trainer.get_train_dataloader = get_train_dataloader_no_shuffle
|
| 713 |
-
|
| 714 |
# Start training
|
| 715 |
-
logger.info("Starting training")
|
| 716 |
-
logger.info(f"Processing with batch size = {training_args.per_device_train_batch_size}, each entry processed independently")
|
| 717 |
-
|
| 718 |
-
# Create a lock file to indicate training is in progress
|
| 719 |
-
lock_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TRAINING_IN_PROGRESS.lock")
|
| 720 |
-
with open(lock_file, "w") as f:
|
| 721 |
-
f.write(f"Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 722 |
-
f.write(f"Expected completion: After {training_args.num_train_epochs} epochs\n")
|
| 723 |
-
f.write("DO NOT UPDATE OR RESTART THIS SPACE UNTIL TRAINING COMPLETES\n")
|
| 724 |
-
logger.info(f"Created lock file: {lock_file}")
|
| 725 |
-
|
| 726 |
try:
|
| 727 |
-
trainer.train(
|
| 728 |
logger.info("Training completed successfully")
|
| 729 |
|
| 730 |
-
# Save model
|
| 731 |
-
|
| 732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
trainer.push_to_hub()
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
logger.info(f"Saving model to {model_config.get('output_dir', './results')}")
|
| 737 |
-
trainer.save_model()
|
| 738 |
-
logger.info("Model saved successfully")
|
| 739 |
except Exception as e:
|
| 740 |
logger.error(f"Training failed with error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
raise
|
| 742 |
-
|
| 743 |
-
# Remove the lock file when training completes or fails
|
| 744 |
-
if os.path.exists(lock_file):
|
| 745 |
-
os.remove(lock_file)
|
| 746 |
-
logger.info(f"Removed lock file: {lock_file}")
|
| 747 |
-
|
| 748 |
-
return 0
|
| 749 |
-
|
| 750 |
except Exception as e:
|
| 751 |
logger.error(f"Error in main training loop: {str(e)}")
|
| 752 |
return 1
|
|
|
|
| 151 |
use_flash_attention = False
|
| 152 |
logger.warning("Flash attention not available, falling back to standard attention")
|
| 153 |
|
| 154 |
+
# First detect if we have a GPU
|
| 155 |
+
if torch.cuda.is_available():
|
| 156 |
+
logger.info(f"CUDA available, found {torch.cuda.device_count()} GPU(s)")
|
| 157 |
+
device_map = "auto"
|
| 158 |
+
else:
|
| 159 |
+
logger.warning("No CUDA available, falling back to CPU")
|
| 160 |
+
device_map = {"": "cpu"} # Force CPU placement
|
| 161 |
+
|
| 162 |
+
# Set default dtype for better numerics
|
| 163 |
+
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| 164 |
+
# Use bfloat16 for Ampere or newer
|
| 165 |
+
dtype = torch.bfloat16
|
| 166 |
+
logger.info("Using bfloat16 precision (Ampere+ GPU)")
|
| 167 |
+
elif torch.cuda.is_available():
|
| 168 |
+
# Use float16 for older GPUs
|
| 169 |
+
dtype = torch.float16
|
| 170 |
+
logger.info("Using float16 precision (pre-Ampere GPU)")
|
| 171 |
+
else:
|
| 172 |
+
# CPU, use default dtype
|
| 173 |
+
dtype = None
|
| 174 |
+
logger.info("Using default precision (CPU)")
|
| 175 |
+
|
| 176 |
+
# Load model with proper error handling for out-of-memory
|
| 177 |
+
try:
|
| 178 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 179 |
+
model_name=model_name,
|
| 180 |
+
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
|
| 181 |
+
dtype=dtype,
|
| 182 |
+
device_map=device_map,
|
| 183 |
+
# Don't explicitly use flash attention config here, let Unsloth handle it
|
| 184 |
+
)
|
| 185 |
+
except RuntimeError as e:
|
| 186 |
+
if "CUDA out of memory" in str(e):
|
| 187 |
+
logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.")
|
| 188 |
+
raise
|
| 189 |
+
else:
|
| 190 |
+
# Try again with CPU placement to see if it's a memory issue
|
| 191 |
+
logger.warning(f"Error loading model on default device: {str(e)}")
|
| 192 |
+
logger.warning("Attempting to load with device_map='cpu' and no specific dtype")
|
| 193 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 194 |
+
model_name=model_name,
|
| 195 |
+
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
|
| 196 |
+
dtype=None,
|
| 197 |
+
device_map={"": "cpu"},
|
| 198 |
+
)
|
| 199 |
+
logger.warning("Model loaded on CPU. Training will be very slow.")
|
| 200 |
+
|
| 201 |
+
# Ensure model and optimizer init is on the same device
|
| 202 |
+
logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
|
| 203 |
|
| 204 |
# Apply Unsloth's training optimizations with config parameters
|
| 205 |
unsloth_config = config.get("unsloth", {})
|
|
|
|
| 374 |
self.dataset_config = dataset_config
|
| 375 |
self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
|
| 376 |
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
|
|
|
| 377 |
self.paper_counters = {}
|
| 378 |
self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
|
| 379 |
+
self.include_metadata = False # Disable automatic metadata inclusion as it's already in content
|
|
|
|
|
|
|
| 380 |
self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
|
| 381 |
logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
|
| 382 |
+
logger.info("Metadata handling disabled - using metadata from content field")
|
| 383 |
+
|
| 384 |
+
# Check if we're on GPU
|
| 385 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 386 |
+
logger.info(f"SimpleDataCollator using device: {self.device}")
|
| 387 |
|
| 388 |
def normalize_conversation(self, conversation):
|
| 389 |
"""Normalize conversation format to ensure consistent structure."""
|
|
|
|
| 397 |
else:
|
| 398 |
return []
|
| 399 |
|
| 400 |
+
# Get introductory message if present (should be first and without chunk number)
|
| 401 |
+
intro_msg = None
|
| 402 |
+
for i, turn in enumerate(conversation):
|
| 403 |
+
if isinstance(turn, dict) and turn.get('content') and "[RESEARCH INTRODUCTION]" in turn.get('content', ''):
|
| 404 |
+
intro_msg = turn
|
| 405 |
+
break
|
| 406 |
+
|
| 407 |
+
# Process introduction message first if found
|
| 408 |
+
if intro_msg:
|
| 409 |
+
normalized.append({
|
| 410 |
+
"role": "system",
|
| 411 |
+
"content": intro_msg.get('content', '')
|
| 412 |
+
})
|
| 413 |
+
# Remove intro from further processing
|
| 414 |
+
conversation = [t for t in conversation if t != intro_msg]
|
| 415 |
+
|
| 416 |
+
# Process remaining messages
|
| 417 |
for turn in conversation:
|
| 418 |
# Skip empty or None entries
|
| 419 |
if not turn:
|
|
|
|
| 467 |
self.stats["skipped"] += 1
|
| 468 |
continue
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
# Format conversation with research introduction and chunk info
|
| 471 |
formatted_content = format_phi_chat(conversation, self.dataset_config)
|
| 472 |
|
|
|
|
| 477 |
truncation=True,
|
| 478 |
max_length=self.max_seq_length,
|
| 479 |
return_tensors=None,
|
| 480 |
+
padding=False, # Don't pad here, we'll pad the batch later
|
| 481 |
)
|
| 482 |
|
| 483 |
if len(inputs["input_ids"]) > 0:
|
|
|
|
| 495 |
log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
|
| 496 |
if self.stats["processed"] <= log_samples:
|
| 497 |
logger.info(f"Example {self.stats['processed']} format:")
|
| 498 |
+
logger.info(f"Paper ID: {paper_id}")
|
| 499 |
logger.info(f"Token count: {len(inputs['input_ids'])}")
|
| 500 |
logger.info(f"Content preview:\n{formatted_content[:500]}...")
|
| 501 |
logger.info(f"Conversation structure: {conversation[:2]}...")
|
|
|
|
| 509 |
|
| 510 |
if not batch["input_ids"]:
|
| 511 |
logger.warning("Empty batch, returning dummy tensors")
|
| 512 |
+
# Return tensors on the right device
|
| 513 |
return {
|
| 514 |
"input_ids": torch.zeros((1, 1), dtype=torch.long),
|
| 515 |
"attention_mask": torch.zeros((1, 1), dtype=torch.long),
|
|
|
|
| 526 |
batch["attention_mask"][i].extend([0] * padding_length)
|
| 527 |
batch["labels"][i].extend([-100] * padding_length)
|
| 528 |
|
| 529 |
+
# Convert to tensors on CPU first
|
| 530 |
+
batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
|
| 531 |
|
| 532 |
# Log stats periodically
|
| 533 |
log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
|
|
|
|
| 647 |
set_seed(seed)
|
| 648 |
logger.info(f"Set random seed to {seed}")
|
| 649 |
|
| 650 |
+
# Check CUDA and set environment variables for better memory management
|
| 651 |
+
if torch.cuda.is_available():
|
| 652 |
+
# Empty CUDA cache
|
| 653 |
+
torch.cuda.empty_cache()
|
| 654 |
+
# Set memory management env vars (optional)
|
| 655 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
| 656 |
+
# Log memory information
|
| 657 |
+
for i in range(torch.cuda.device_count()):
|
| 658 |
+
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 659 |
+
logger.info(f"Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
|
| 660 |
+
logger.info(f"Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
|
| 661 |
+
|
| 662 |
try:
|
| 663 |
model, tokenizer = load_model_and_tokenizer(model_config)
|
| 664 |
logger.info("Model and tokenizer loaded successfully")
|
|
|
|
| 670 |
except Exception as e:
|
| 671 |
logger.error(f"Error loading dataset: {e}")
|
| 672 |
return 1
|
| 673 |
+
|
| 674 |
# Create data collator
|
| 675 |
data_collator = SimpleDataCollator(tokenizer, dataset_config)
|
| 676 |
|
|
|
|
| 685 |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
|
| 686 |
logger.info(f"Step {state.global_step}: Loss {state.log_history[-1]['loss'] if state.log_history else 'N/A'}")
|
| 687 |
self.last_log_time = current_time
|
| 688 |
+
|
| 689 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 690 |
+
logger.info("Training is starting...")
|
| 691 |
+
# Log memory information
|
| 692 |
+
if torch.cuda.is_available():
|
| 693 |
+
for i in range(torch.cuda.device_count()):
|
| 694 |
+
logger.info(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
|
| 695 |
|
| 696 |
# Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
|
| 697 |
use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16"
|
|
|
|
| 723 |
report_to="tensorboard",
|
| 724 |
remove_unused_columns=False, # Keep all columns
|
| 725 |
gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True),
|
| 726 |
+
dataloader_pin_memory=True, # Keep data in pinned memory for faster transfer
|
| 727 |
optim=model_config.get("training", {}).get("optim", "adamw_torch"),
|
| 728 |
ddp_find_unused_parameters=False, # Improve distributed training efficiency
|
| 729 |
dataloader_drop_last=False, # Process all examples
|
| 730 |
dataloader_num_workers=4, # Sequential data loading
|
| 731 |
+
no_cuda=False if torch.cuda.is_available() else True, # Use CUDA if available
|
| 732 |
)
|
| 733 |
|
| 734 |
+
# Custom dataloader to ensure no shuffling of dataset
|
| 735 |
+
# This preserves the order of chunks in papers
|
| 736 |
+
def get_train_dataloader_no_shuffle():
|
| 737 |
+
logger.info("Creating data loader with sequential sampler to maintain paper order")
|
| 738 |
+
if getattr(training_args, "no_cuda", False):
|
| 739 |
+
batch_size = training_args.per_device_train_batch_size
|
| 740 |
+
else:
|
| 741 |
+
batch_size = max(training_args.per_device_train_batch_size * torch.cuda.device_count(), 1)
|
| 742 |
+
|
| 743 |
+
# Use sequential sampler to preserve order
|
| 744 |
+
sequential_sampler = torch.utils.data.SequentialSampler(dataset["train"])
|
| 745 |
+
logger.info(f"Using sequential sampler for batch size {batch_size}")
|
| 746 |
+
|
| 747 |
+
return torch.utils.data.DataLoader(
|
| 748 |
+
dataset["train"],
|
| 749 |
+
batch_size=batch_size,
|
| 750 |
+
sampler=sequential_sampler,
|
| 751 |
+
collate_fn=data_collator,
|
| 752 |
+
drop_last=training_args.dataloader_drop_last,
|
| 753 |
+
num_workers=training_args.dataloader_num_workers,
|
| 754 |
+
pin_memory=training_args.dataloader_pin_memory,
|
| 755 |
+
)
|
| 756 |
|
| 757 |
+
# Set up trainer with custom dataloader
|
| 758 |
+
logger.info("Initializing Trainer")
|
| 759 |
trainer = Trainer(
|
| 760 |
model=model,
|
| 761 |
args=training_args,
|
| 762 |
+
get_train_dataloader=get_train_dataloader_no_shuffle,
|
| 763 |
+
tokenizer=tokenizer,
|
| 764 |
data_collator=data_collator,
|
| 765 |
callbacks=[LoggingCallback()]
|
| 766 |
)
|
| 767 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
# Start training
|
| 769 |
+
logger.info("Starting training process")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
try:
|
| 771 |
+
trainer.train()
|
| 772 |
logger.info("Training completed successfully")
|
| 773 |
|
| 774 |
+
# Save the final model
|
| 775 |
+
logger.info("Saving final model")
|
| 776 |
+
trainer.save_model()
|
| 777 |
+
|
| 778 |
+
# Push to hub if enabled
|
| 779 |
+
if model_config.get("huggingface_hub", {}).get("push_to_hub", False):
|
| 780 |
+
logger.info("Pushing model to Hugging Face Hub")
|
| 781 |
trainer.push_to_hub()
|
| 782 |
+
|
| 783 |
+
return 0
|
|
|
|
|
|
|
|
|
|
| 784 |
except Exception as e:
|
| 785 |
logger.error(f"Training failed with error: {str(e)}")
|
| 786 |
+
# Log CUDA memory info if available
|
| 787 |
+
if torch.cuda.is_available():
|
| 788 |
+
for i in range(torch.cuda.device_count()):
|
| 789 |
+
logger.info(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
|
| 790 |
+
logger.info(f"GPU {i} Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
|
| 791 |
raise
|
| 792 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
except Exception as e:
|
| 794 |
logger.error(f"Error in main training loop: {str(e)}")
|
| 795 |
return 1
|