Spaces:
Paused
Paused
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +113 -24
run_transformers_training.py
CHANGED
|
@@ -337,6 +337,31 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 337 |
if len(dataset) == 0:
|
| 338 |
raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
except Exception as dataset_error:
|
| 341 |
logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}")
|
| 342 |
logger.error("Make sure the dataset exists and you have proper access permissions")
|
|
@@ -478,32 +503,59 @@ class SimpleDataCollator:
|
|
| 478 |
for example in features:
|
| 479 |
try:
|
| 480 |
# Get ID
|
| 481 |
-
paper_id = example.get("id", "")
|
| 482 |
|
| 483 |
-
# Get conversations
|
| 484 |
-
|
| 485 |
-
if not
|
|
|
|
| 486 |
self.stats["skipped"] += 1
|
| 487 |
continue
|
| 488 |
|
| 489 |
-
#
|
| 490 |
-
# This
|
| 491 |
try:
|
| 492 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
inputs = self.tokenizer.apply_chat_template(
|
| 494 |
-
|
| 495 |
return_tensors=None,
|
| 496 |
add_generation_prompt=False
|
| 497 |
)
|
| 498 |
except Exception as chat_error:
|
| 499 |
# Fallback if apply_chat_template fails
|
| 500 |
-
logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)
|
| 501 |
|
| 502 |
-
# Create a basic representation of the
|
| 503 |
conversation_text = ""
|
| 504 |
-
for msg in
|
| 505 |
if isinstance(msg, dict) and 'content' in msg:
|
| 506 |
-
conversation_text += msg
|
|
|
|
|
|
|
| 507 |
|
| 508 |
# Basic tokenization
|
| 509 |
inputs = self.tokenizer(
|
|
@@ -537,7 +589,7 @@ class SimpleDataCollator:
|
|
| 537 |
logger.info(f"Example {self.stats['processed']}:")
|
| 538 |
logger.info(f"Paper ID: {paper_id}")
|
| 539 |
logger.info(f"Token count: {len(inputs)}")
|
| 540 |
-
logger.info(f"Conversation entries: {len(
|
| 541 |
else:
|
| 542 |
self.stats["skipped"] += 1
|
| 543 |
except Exception as e:
|
|
@@ -1004,6 +1056,14 @@ def main():
|
|
| 1004 |
"""Custom dataloader that preserves original dataset order"""
|
| 1005 |
log_info("Creating sequential dataloader to maintain original dataset order")
|
| 1006 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1007 |
# Create a simple sequential sampler
|
| 1008 |
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
|
| 1009 |
|
|
@@ -1018,10 +1078,16 @@ def main():
|
|
| 1018 |
# Log our approach clearly
|
| 1019 |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
|
| 1020 |
|
| 1021 |
-
# Verify column order
|
| 1022 |
expected_order = ["prompt_number", "article_id", "conversations"]
|
| 1023 |
if hasattr(dataset, 'column_names'):
|
| 1024 |
actual_order = dataset.column_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1025 |
if actual_order == expected_order:
|
| 1026 |
log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
|
| 1027 |
else:
|
|
@@ -1030,6 +1096,16 @@ def main():
|
|
| 1030 |
|
| 1031 |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
|
| 1032 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1033 |
# Calculate batch size based on device availability
|
| 1034 |
if getattr(training_args, "no_cuda", False):
|
| 1035 |
batch_size = training_args.per_device_train_batch_size
|
|
@@ -1038,16 +1114,29 @@ def main():
|
|
| 1038 |
|
| 1039 |
log_info(f"Using sequential sampler with batch size {batch_size}")
|
| 1040 |
|
| 1041 |
-
# Return DataLoader with sequential sampler
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1051 |
|
| 1052 |
# Override the get_train_dataloader method
|
| 1053 |
trainer.get_train_dataloader = custom_get_train_dataloader
|
|
|
|
| 337 |
if len(dataset) == 0:
|
| 338 |
raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
|
| 339 |
|
| 340 |
+
# Verify conversations field specifically - this is critical for training
|
| 341 |
+
if "conversations" not in dataset.column_names:
|
| 342 |
+
raise ValueError(f"Dataset {dataset_name} missing required 'conversations' column")
|
| 343 |
+
|
| 344 |
+
# Check a sample of conversation entries to validate structure
|
| 345 |
+
logger.info("Validating conversation structure...")
|
| 346 |
+
for i in range(min(5, len(dataset))):
|
| 347 |
+
conv = dataset[i].get("conversations")
|
| 348 |
+
if conv is None:
|
| 349 |
+
logger.warning(f"Example {i} has None as 'conversations' value")
|
| 350 |
+
elif not isinstance(conv, list):
|
| 351 |
+
logger.warning(f"Example {i} has non-list 'conversations': {type(conv)}")
|
| 352 |
+
elif len(conv) == 0:
|
| 353 |
+
logger.warning(f"Example {i} has empty conversations list")
|
| 354 |
+
else:
|
| 355 |
+
# Look at the first conversation entry
|
| 356 |
+
first_entry = conv[0]
|
| 357 |
+
logger.info(f"Sample conversation: {str(first_entry)[:100]}...")
|
| 358 |
+
|
| 359 |
+
# Make sure content field exists
|
| 360 |
+
if isinstance(first_entry, dict) and "content" in first_entry:
|
| 361 |
+
logger.info(f"Content field example: {str(first_entry['content'])[:50]}...")
|
| 362 |
+
else:
|
| 363 |
+
logger.warning(f"Example {i} missing 'content' key in conversation")
|
| 364 |
+
|
| 365 |
except Exception as dataset_error:
|
| 366 |
logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}")
|
| 367 |
logger.error("Make sure the dataset exists and you have proper access permissions")
|
|
|
|
| 503 |
for example in features:
|
| 504 |
try:
|
| 505 |
# Get ID
|
| 506 |
+
paper_id = example.get("article_id", example.get("id", ""))
|
| 507 |
|
| 508 |
+
# Get conversations
|
| 509 |
+
raw_conversations = example.get("conversations", [])
|
| 510 |
+
if not raw_conversations:
|
| 511 |
+
logger.warning(f"Empty conversations for example {paper_id}")
|
| 512 |
self.stats["skipped"] += 1
|
| 513 |
continue
|
| 514 |
|
| 515 |
+
# Extract only the 'content' field from each conversation item
|
| 516 |
+
# This simplifies the structure and avoids potential NoneType errors
|
| 517 |
try:
|
| 518 |
+
# Convert conversations to the simple format with only content
|
| 519 |
+
simplified_conversations = []
|
| 520 |
+
for item in raw_conversations:
|
| 521 |
+
if isinstance(item, dict) and "content" in item:
|
| 522 |
+
# Keep only the content field
|
| 523 |
+
content = item["content"]
|
| 524 |
+
simplified_conversations.append({"role": "user", "content": content})
|
| 525 |
+
elif isinstance(item, str):
|
| 526 |
+
# If it's just a string, treat it as content
|
| 527 |
+
simplified_conversations.append({"role": "user", "content": item})
|
| 528 |
+
else:
|
| 529 |
+
logger.warning(f"Skipping invalid conversation item: {item}")
|
| 530 |
+
|
| 531 |
+
# Skip if no valid conversations after filtering
|
| 532 |
+
if not simplified_conversations:
|
| 533 |
+
logger.warning(f"No valid conversations after filtering for example {paper_id}")
|
| 534 |
+
self.stats["skipped"] += 1
|
| 535 |
+
continue
|
| 536 |
+
|
| 537 |
+
# Log the simplified content for debugging
|
| 538 |
+
if len(simplified_conversations) > 0:
|
| 539 |
+
first_content = simplified_conversations[0]["content"]
|
| 540 |
+
logger.debug(f"First content: {first_content[:50]}...")
|
| 541 |
+
|
| 542 |
+
# Let tokenizer handle the simplified conversations
|
| 543 |
inputs = self.tokenizer.apply_chat_template(
|
| 544 |
+
simplified_conversations,
|
| 545 |
return_tensors=None,
|
| 546 |
add_generation_prompt=False
|
| 547 |
)
|
| 548 |
except Exception as chat_error:
|
| 549 |
# Fallback if apply_chat_template fails
|
| 550 |
+
logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)}")
|
| 551 |
|
| 552 |
+
# Create a basic representation of just the content
|
| 553 |
conversation_text = ""
|
| 554 |
+
for msg in raw_conversations:
|
| 555 |
if isinstance(msg, dict) and 'content' in msg:
|
| 556 |
+
conversation_text += msg['content'] + "\n\n"
|
| 557 |
+
elif isinstance(msg, str):
|
| 558 |
+
conversation_text += msg + "\n\n"
|
| 559 |
|
| 560 |
# Basic tokenization
|
| 561 |
inputs = self.tokenizer(
|
|
|
|
| 589 |
logger.info(f"Example {self.stats['processed']}:")
|
| 590 |
logger.info(f"Paper ID: {paper_id}")
|
| 591 |
logger.info(f"Token count: {len(inputs)}")
|
| 592 |
+
logger.info(f"Conversation entries: {len(raw_conversations)}")
|
| 593 |
else:
|
| 594 |
self.stats["skipped"] += 1
|
| 595 |
except Exception as e:
|
|
|
|
| 1056 |
"""Custom dataloader that preserves original dataset order"""
|
| 1057 |
log_info("Creating sequential dataloader to maintain original dataset order")
|
| 1058 |
|
| 1059 |
+
# Safety check - make sure dataset exists and is not None
|
| 1060 |
+
if dataset is None:
|
| 1061 |
+
raise ValueError("Dataset is None - cannot create dataloader")
|
| 1062 |
+
|
| 1063 |
+
# Make sure dataset is not empty
|
| 1064 |
+
if len(dataset) == 0:
|
| 1065 |
+
raise ValueError("Dataset is empty - cannot create dataloader")
|
| 1066 |
+
|
| 1067 |
# Create a simple sequential sampler
|
| 1068 |
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
|
| 1069 |
|
|
|
|
| 1078 |
# Log our approach clearly
|
| 1079 |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
|
| 1080 |
|
| 1081 |
+
# Verify column order and check for 'conversations' field
|
| 1082 |
expected_order = ["prompt_number", "article_id", "conversations"]
|
| 1083 |
if hasattr(dataset, 'column_names'):
|
| 1084 |
actual_order = dataset.column_names
|
| 1085 |
+
|
| 1086 |
+
# Verify all required fields exist
|
| 1087 |
+
missing_fields = [field for field in ["conversations"] if field not in actual_order]
|
| 1088 |
+
if missing_fields:
|
| 1089 |
+
raise ValueError(f"Dataset missing critical fields: {missing_fields}")
|
| 1090 |
+
|
| 1091 |
if actual_order == expected_order:
|
| 1092 |
log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
|
| 1093 |
else:
|
|
|
|
| 1096 |
|
| 1097 |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
|
| 1098 |
|
| 1099 |
+
# Validate a few samples before proceeding
|
| 1100 |
+
for i in range(min(3, len(dataset))):
|
| 1101 |
+
sample = dataset[i]
|
| 1102 |
+
if "conversations" not in sample:
|
| 1103 |
+
log_info(f"WARNING: Sample {i} missing 'conversations' field")
|
| 1104 |
+
elif sample["conversations"] is None:
|
| 1105 |
+
log_info(f"WARNING: Sample {i} has None 'conversations' field")
|
| 1106 |
+
elif not isinstance(sample["conversations"], list):
|
| 1107 |
+
log_info(f"WARNING: Sample {i} has non-list 'conversations' field: {type(sample['conversations'])}")
|
| 1108 |
+
|
| 1109 |
# Calculate batch size based on device availability
|
| 1110 |
if getattr(training_args, "no_cuda", False):
|
| 1111 |
batch_size = training_args.per_device_train_batch_size
|
|
|
|
| 1114 |
|
| 1115 |
log_info(f"Using sequential sampler with batch size {batch_size}")
|
| 1116 |
|
| 1117 |
+
# Return DataLoader with sequential sampler and extra error handling
|
| 1118 |
+
try:
|
| 1119 |
+
return torch.utils.data.DataLoader(
|
| 1120 |
+
dataset,
|
| 1121 |
+
batch_size=batch_size,
|
| 1122 |
+
sampler=sequential_sampler, # Always use sequential sampler
|
| 1123 |
+
collate_fn=data_collator,
|
| 1124 |
+
drop_last=training_args.dataloader_drop_last,
|
| 1125 |
+
num_workers=training_args.dataloader_num_workers,
|
| 1126 |
+
pin_memory=training_args.dataloader_pin_memory,
|
| 1127 |
+
)
|
| 1128 |
+
except Exception as e:
|
| 1129 |
+
log_info(f"Error creating DataLoader: {str(e)}")
|
| 1130 |
+
# Try again with minimal settings
|
| 1131 |
+
log_info("Attempting to create DataLoader with minimal settings")
|
| 1132 |
+
return torch.utils.data.DataLoader(
|
| 1133 |
+
dataset,
|
| 1134 |
+
batch_size=1, # Minimal batch size
|
| 1135 |
+
sampler=sequential_sampler,
|
| 1136 |
+
collate_fn=data_collator,
|
| 1137 |
+
num_workers=0, # No parallel workers
|
| 1138 |
+
pin_memory=False,
|
| 1139 |
+
)
|
| 1140 |
|
| 1141 |
# Override the get_train_dataloader method
|
| 1142 |
trainer.get_train_dataloader = custom_get_train_dataloader
|