Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +199 -169
- transformers_config.json +6 -5
run_transformers_training.py
CHANGED
|
@@ -123,30 +123,22 @@ def load_env_variables():
|
|
| 123 |
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
|
| 124 |
|
| 125 |
def load_configs(base_path):
|
| 126 |
-
"""Load
|
| 127 |
-
configs = {}
|
| 128 |
-
|
| 129 |
# Using a single consolidated config file
|
| 130 |
-
config_file =
|
| 131 |
|
| 132 |
-
file_path = os.path.join(base_path, config_file)
|
| 133 |
try:
|
| 134 |
-
with open(
|
| 135 |
config = json.load(f)
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
configs["hardware"] = config.get("hardware", {})
|
| 139 |
-
configs["dataset"] = config.get("dataset", {})
|
| 140 |
-
logger.info(f"Loaded consolidated configuration from {file_path}")
|
| 141 |
except Exception as e:
|
| 142 |
logger.error(f"Error loading {config_file}: {e}")
|
| 143 |
raise
|
| 144 |
-
|
| 145 |
-
return configs
|
| 146 |
|
| 147 |
def parse_args():
|
| 148 |
parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
|
| 149 |
-
parser.add_argument("--
|
| 150 |
return parser.parse_args()
|
| 151 |
|
| 152 |
def load_model_and_tokenizer(config):
|
|
@@ -157,8 +149,8 @@ def load_model_and_tokenizer(config):
|
|
| 157 |
logger.error("Please ensure unsloth is in requirements.txt")
|
| 158 |
raise ImportError("Unsloth is required for this training setup")
|
| 159 |
|
| 160 |
-
# Get model name correctly from
|
| 161 |
-
model_name = config.get("
|
| 162 |
logger.info(f"Loading model: {model_name}")
|
| 163 |
|
| 164 |
if not model_name:
|
|
@@ -166,14 +158,12 @@ def load_model_and_tokenizer(config):
|
|
| 166 |
|
| 167 |
logger.info("Using Unsloth optimizations with pre-quantized model")
|
| 168 |
|
| 169 |
-
# Check for flash attention
|
| 170 |
use_flash_attention = config.get("use_flash_attention", True)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
logger.
|
| 174 |
-
except ImportError:
|
| 175 |
use_flash_attention = False
|
| 176 |
-
logger.warning("Flash attention not available, falling back to standard attention")
|
| 177 |
|
| 178 |
# First detect if we have a GPU
|
| 179 |
if torch.cuda.is_available():
|
|
@@ -321,13 +311,24 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 321 |
|
| 322 |
# Add prompt_number field that increments based on original order
|
| 323 |
def add_prompt_numbers(examples, indices):
|
| 324 |
-
# Defensive check to ensure indices is not None
|
| 325 |
if indices is None:
|
| 326 |
logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
|
| 327 |
indices = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
-
#
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
return examples
|
| 332 |
|
| 333 |
# Add prompt numbers to the dataset based on original order
|
|
@@ -358,37 +359,73 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 358 |
dataset = Dataset.from_list(updated_examples)
|
| 359 |
logger.info(f"Successfully added prompt_number field using fallback method")
|
| 360 |
|
| 361 |
-
#
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
if col not in dataset.column_names:
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
logger.info("Converting 'text' field to 'conversations' format")
|
| 368 |
-
|
| 369 |
-
def convert_text_to_conversations(example):
|
| 370 |
-
# Check if text is already a list of conversation turns
|
| 371 |
-
if isinstance(example.get("text"), list):
|
| 372 |
-
return {"conversations": example["text"]}
|
| 373 |
-
# Otherwise, create a simple conversation with the text as user message
|
| 374 |
-
else:
|
| 375 |
-
return {
|
| 376 |
-
"conversations": [
|
| 377 |
-
{"role": "user", "content": str(example.get("text", ""))}
|
| 378 |
-
]
|
| 379 |
-
}
|
| 380 |
-
|
| 381 |
-
dataset = dataset.map(convert_text_to_conversations)
|
| 382 |
-
else:
|
| 383 |
-
logger.warning(f"Expected column '{col}' not found in dataset")
|
| 384 |
|
| 385 |
-
#
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
-
#
|
| 389 |
processing_config = dataset_config.get("dataset", {}).get("processing", {})
|
| 390 |
data_loading_config = dataset_config.get("data_loading", {})
|
| 391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
# Flag consolidation - we only need one flag to control sequence preservation
|
| 393 |
# Default to True to ensure safety
|
| 394 |
preserve_sequence = processing_config.get("preserve_entry_sequence", True)
|
|
@@ -450,17 +487,18 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 450 |
logger.warning(f"Error accessing dataset at index {i}: {e}")
|
| 451 |
|
| 452 |
if sample_examples:
|
| 453 |
-
|
| 454 |
-
|
|
|
|
| 455 |
|
| 456 |
if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
|
| 457 |
numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
|
| 458 |
if len(numeric_ids) > 1:
|
| 459 |
is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
|
| 460 |
if not is_ordered:
|
| 461 |
-
logger.warning("WARNING: Sample
|
| 462 |
else:
|
| 463 |
-
logger.info("Sample
|
| 464 |
except Exception as e:
|
| 465 |
logger.warning(f"Error checking ID sequence: {e}")
|
| 466 |
except Exception as e:
|
|
@@ -472,19 +510,19 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 472 |
# Safely get first few samples
|
| 473 |
first_few_indices = range(min(5, len(dataset)))
|
| 474 |
sample_prompt_numbers = []
|
| 475 |
-
|
| 476 |
|
| 477 |
for i in first_few_indices:
|
| 478 |
try:
|
| 479 |
example = dataset[i]
|
| 480 |
if 'prompt_number' in example:
|
| 481 |
sample_prompt_numbers.append(example['prompt_number'])
|
| 482 |
-
if '
|
| 483 |
-
|
| 484 |
except Exception as e:
|
| 485 |
logger.warning(f"Error accessing sample at index {i}: {e}")
|
| 486 |
|
| 487 |
-
logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, IDs: {
|
| 488 |
|
| 489 |
# Log conversation structure without full content
|
| 490 |
if len(dataset) > 0:
|
|
@@ -510,6 +548,74 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 510 |
|
| 511 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
| 512 |
logger.info(f"Dataset columns: {dataset.column_names}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
return dataset
|
| 514 |
|
| 515 |
except Exception as e:
|
|
@@ -752,13 +858,13 @@ class LoggingCallback(TrainerCallback):
|
|
| 752 |
is_sequence_maintained = False
|
| 753 |
|
| 754 |
# Also compare IDs as a backup check
|
| 755 |
-
elif ('
|
| 756 |
-
'
|
| 757 |
-
orig_sample['
|
| 758 |
-
current_sample['
|
| 759 |
|
| 760 |
-
if orig_sample['
|
| 761 |
-
log_info(f"WARNING: Sequence integrity compromised! Sample {i}
|
| 762 |
is_sequence_maintained = False
|
| 763 |
|
| 764 |
# Compare input fingerprints
|
|
@@ -899,12 +1005,11 @@ def check_dependencies():
|
|
| 899 |
missing_packages.append("peft>=0.9.0")
|
| 900 |
|
| 901 |
# Optional packages - don't add to missing list, just log
|
| 902 |
-
|
| 903 |
-
import flash_attn
|
| 904 |
logger.info("flash-attn found. Flash attention will be used for faster training.")
|
| 905 |
-
|
| 906 |
logger.warning("flash-attn not found. Training will work but may be slower.")
|
| 907 |
-
|
| 908 |
|
| 909 |
# If critical packages are missing, exit with instructions
|
| 910 |
if missing_packages:
|
|
@@ -918,115 +1023,44 @@ def check_dependencies():
|
|
| 918 |
|
| 919 |
def main():
|
| 920 |
# Set up logging
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
# Log hardware information
|
| 924 |
-
log_info(f"Hardware detection: CUDA {'available' if CUDA_AVAILABLE else 'not available'}")
|
| 925 |
-
if CUDA_AVAILABLE:
|
| 926 |
-
log_info(f"Found {NUM_GPUS} GPUs")
|
| 927 |
-
for i in range(NUM_GPUS):
|
| 928 |
-
log_info(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 929 |
-
else:
|
| 930 |
-
log_info("Running on CPU (training will be very slow)")
|
| 931 |
|
| 932 |
# Parse arguments
|
| 933 |
args = parse_args()
|
| 934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
# Check dependencies
|
| 936 |
if not check_dependencies():
|
| 937 |
logger.error("Aborting due to missing critical dependencies")
|
| 938 |
return 1
|
| 939 |
|
| 940 |
-
# Load environment variables
|
| 941 |
-
load_env_variables()
|
| 942 |
-
|
| 943 |
# Check if we're in distributed mode
|
| 944 |
is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
|
| 945 |
if is_distributed:
|
| 946 |
-
|
|
|
|
| 947 |
else:
|
| 948 |
log_info("Running in non-distributed mode (single process)")
|
| 949 |
|
| 950 |
-
# Load all configurations - do this once
|
| 951 |
-
try:
|
| 952 |
-
configs = load_configs(args.config_dir)
|
| 953 |
-
|
| 954 |
-
# Extract specific configs immediately after loading
|
| 955 |
-
if not configs:
|
| 956 |
-
logger.error("Failed to load configuration")
|
| 957 |
-
return 1
|
| 958 |
-
|
| 959 |
-
# Store configurations in clear variables
|
| 960 |
-
transformers_config = configs.get("transformers", {})
|
| 961 |
-
hardware_config = configs.get("hardware", {})
|
| 962 |
-
dataset_config = configs.get("dataset", {})
|
| 963 |
-
|
| 964 |
-
# Verify configuration sections exist
|
| 965 |
-
if not transformers_config:
|
| 966 |
-
logger.error("transformers_config.json not found or invalid")
|
| 967 |
-
return 1
|
| 968 |
-
|
| 969 |
-
if not hardware_config:
|
| 970 |
-
logger.warning("Hardware configuration section not found in transformers_config.json. Using default hardware configuration.")
|
| 971 |
-
|
| 972 |
-
if not dataset_config:
|
| 973 |
-
logger.error("Dataset configuration section not found in transformers_config.json")
|
| 974 |
-
return 1
|
| 975 |
-
|
| 976 |
-
# Validate model configuration
|
| 977 |
-
model_name = (transformers_config.get("model", {}).get("name") or
|
| 978 |
-
transformers_config.get("model_name_or_path") or
|
| 979 |
-
transformers_config.get("model_name"))
|
| 980 |
-
|
| 981 |
-
if not model_name:
|
| 982 |
-
logger.error("Model name not specified in configuration")
|
| 983 |
-
logger.error("Please ensure 'name' is specified under 'model' in transformers_config.json")
|
| 984 |
-
return 1
|
| 985 |
-
|
| 986 |
-
log_info(f"Using model: {model_name}")
|
| 987 |
-
log_info("All configurations loaded successfully")
|
| 988 |
-
|
| 989 |
-
# Apply hardware-specific settings if available
|
| 990 |
-
if hardware_config:
|
| 991 |
-
# Get training optimizations from hardware config
|
| 992 |
-
training_opts = hardware_config.get("training_optimizations", {})
|
| 993 |
-
|
| 994 |
-
# Apply batch size and gradient accumulation settings
|
| 995 |
-
if training_opts.get("per_device_batch_size") and transformers_config.get("training"):
|
| 996 |
-
batch_size = training_opts.get("per_device_batch_size")
|
| 997 |
-
transformers_config["training"]["per_device_train_batch_size"] = batch_size
|
| 998 |
-
log_info(f"Applied hardware-optimized batch size: {batch_size}")
|
| 999 |
-
|
| 1000 |
-
if training_opts.get("gradient_accumulation_steps") and transformers_config.get("training"):
|
| 1001 |
-
grad_steps = training_opts.get("gradient_accumulation_steps")
|
| 1002 |
-
transformers_config["training"]["gradient_accumulation_steps"] = grad_steps
|
| 1003 |
-
log_info(f"Applied hardware-optimized gradient accumulation: {grad_steps}")
|
| 1004 |
-
|
| 1005 |
-
# Apply memory optimizations
|
| 1006 |
-
memory_opts = training_opts.get("memory_optimizations", {})
|
| 1007 |
-
if memory_opts.get("use_gradient_checkpointing") is not None and transformers_config.get("training"):
|
| 1008 |
-
grad_ckpt = memory_opts.get("use_gradient_checkpointing")
|
| 1009 |
-
transformers_config["training"]["gradient_checkpointing"] = grad_ckpt
|
| 1010 |
-
log_info(f"Applied hardware-optimized gradient checkpointing: {grad_ckpt}")
|
| 1011 |
-
|
| 1012 |
-
# Apply system settings
|
| 1013 |
-
system_settings = hardware_config.get("system_settings", {})
|
| 1014 |
-
if system_settings.get("dataloader_num_workers") is not None:
|
| 1015 |
-
workers = system_settings.get("dataloader_num_workers")
|
| 1016 |
-
log_info(f"Using {workers} dataloader workers from hardware config")
|
| 1017 |
-
|
| 1018 |
-
# Get distribution strategy
|
| 1019 |
-
multi_gpu_strategy = training_opts.get("multi_gpu_strategy", "data_parallel")
|
| 1020 |
-
log_info(f"Hardware config specifies {multi_gpu_strategy} for multi-GPU training")
|
| 1021 |
-
|
| 1022 |
-
except Exception as e:
|
| 1023 |
-
logger.error(f"Error loading configurations: {e}")
|
| 1024 |
-
return 1
|
| 1025 |
-
|
| 1026 |
# Set random seed for reproducibility
|
| 1027 |
seed = transformers_config.get("seed", 42)
|
| 1028 |
set_seed(seed)
|
| 1029 |
-
|
|
|
|
|
|
|
|
|
|
| 1030 |
|
| 1031 |
# Empty CUDA cache to ensure clean state
|
| 1032 |
if CUDA_AVAILABLE:
|
|
@@ -1043,17 +1077,13 @@ def main():
|
|
| 1043 |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
|
| 1044 |
|
| 1045 |
try:
|
| 1046 |
-
log_info("Loading
|
| 1047 |
-
|
| 1048 |
-
log_info("
|
| 1049 |
|
| 1050 |
-
#
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
dataset = load_dataset_with_mapping(dataset_config)
|
| 1054 |
-
log_info(f"Dataset loaded with {len(dataset)} examples")
|
| 1055 |
-
except Exception as e:
|
| 1056 |
-
logger.error(f"Error loading dataset: {e}")
|
| 1057 |
return 1
|
| 1058 |
|
| 1059 |
# Create data collator
|
|
|
|
| 123 |
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
|
| 124 |
|
| 125 |
def load_configs(base_path):
|
| 126 |
+
"""Load configuration from transformers_config.json file."""
|
|
|
|
|
|
|
| 127 |
# Using a single consolidated config file
|
| 128 |
+
config_file = base_path
|
| 129 |
|
|
|
|
| 130 |
try:
|
| 131 |
+
with open(config_file, "r") as f:
|
| 132 |
config = json.load(f)
|
| 133 |
+
logger.info(f"Loaded configuration from {config_file}")
|
| 134 |
+
return config
|
|
|
|
|
|
|
|
|
|
| 135 |
except Exception as e:
|
| 136 |
logger.error(f"Error loading {config_file}: {e}")
|
| 137 |
raise
|
|
|
|
|
|
|
| 138 |
|
| 139 |
def parse_args():
|
| 140 |
parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
|
| 141 |
+
parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file")
|
| 142 |
return parser.parse_args()
|
| 143 |
|
| 144 |
def load_model_and_tokenizer(config):
|
|
|
|
| 149 |
logger.error("Please ensure unsloth is in requirements.txt")
|
| 150 |
raise ImportError("Unsloth is required for this training setup")
|
| 151 |
|
| 152 |
+
# Get model name correctly from config
|
| 153 |
+
model_name = config.get("model_name") or config.get("model", {}).get("name")
|
| 154 |
logger.info(f"Loading model: {model_name}")
|
| 155 |
|
| 156 |
if not model_name:
|
|
|
|
| 158 |
|
| 159 |
logger.info("Using Unsloth optimizations with pre-quantized model")
|
| 160 |
|
| 161 |
+
# Check for flash attention
|
| 162 |
use_flash_attention = config.get("use_flash_attention", True)
|
| 163 |
+
if use_flash_attention and not find_spec("flash_attn"):
|
| 164 |
+
logger.warning("flash-attn not found. Will continue without flash attention.")
|
| 165 |
+
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
|
|
|
| 166 |
use_flash_attention = False
|
|
|
|
| 167 |
|
| 168 |
# First detect if we have a GPU
|
| 169 |
if torch.cuda.is_available():
|
|
|
|
| 311 |
|
| 312 |
# Add prompt_number field that increments based on original order
|
| 313 |
def add_prompt_numbers(examples, indices):
|
| 314 |
+
# Defensive check to ensure indices is not None and is iterable
|
| 315 |
if indices is None:
|
| 316 |
logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
|
| 317 |
indices = []
|
| 318 |
+
elif isinstance(indices, int):
|
| 319 |
+
# Handle case where indices is a single integer
|
| 320 |
+
logger.warning(f"Warning: indices is an integer ({indices}) in add_prompt_numbers, converting to list")
|
| 321 |
+
indices = [indices]
|
| 322 |
|
| 323 |
+
# Ensure indices is always a list/iterable
|
| 324 |
+
try:
|
| 325 |
+
# Create a new field with the dataset index as the prompt number, starting at 1
|
| 326 |
+
examples["prompt_number"] = [idx + 1 for idx in indices] # Adding 1 to make it 1-indexed
|
| 327 |
+
except TypeError:
|
| 328 |
+
# Fallback for non-iterable types
|
| 329 |
+
logger.warning(f"Warning: non-iterable indices in add_prompt_numbers: {type(indices)}, using default")
|
| 330 |
+
examples["prompt_number"] = [1] * len(next(iter(examples.values())))
|
| 331 |
+
|
| 332 |
return examples
|
| 333 |
|
| 334 |
# Add prompt numbers to the dataset based on original order
|
|
|
|
| 359 |
dataset = Dataset.from_list(updated_examples)
|
| 360 |
logger.info(f"Successfully added prompt_number field using fallback method")
|
| 361 |
|
| 362 |
+
# Rename 'id' to 'article_id' if it exists
|
| 363 |
+
if 'id' in dataset.column_names and 'article_id' not in dataset.column_names:
|
| 364 |
+
logger.info("Renaming 'id' column to 'article_id'")
|
| 365 |
+
dataset = dataset.rename_column('id', 'article_id')
|
| 366 |
+
|
| 367 |
+
# Reorder columns to make prompt_number first if it exists
|
| 368 |
+
if 'prompt_number' in dataset.column_names:
|
| 369 |
+
logger.info("Reordering columns to place prompt_number first")
|
| 370 |
+
# Get current column names
|
| 371 |
+
current_columns = dataset.column_names
|
| 372 |
+
# Create new column order with prompt_number first
|
| 373 |
+
new_column_order = ['prompt_number'] + [col for col in current_columns if col != 'prompt_number']
|
| 374 |
+
# Reorder columns
|
| 375 |
+
dataset = dataset.select_columns(new_column_order)
|
| 376 |
+
|
| 377 |
+
# Verify all new column names for logging
|
| 378 |
+
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
| 379 |
+
logger.info(f"Dataset columns: {dataset.column_names}")
|
| 380 |
+
|
| 381 |
+
# Verify dataset is not empty
|
| 382 |
+
if len(dataset) == 0:
|
| 383 |
+
logger.error("Dataset is empty! This will cause errors during training.")
|
| 384 |
+
raise ValueError("Empty dataset loaded")
|
| 385 |
+
|
| 386 |
+
# Check for required columns
|
| 387 |
+
required_columns = ['conversations']
|
| 388 |
+
for col in required_columns:
|
| 389 |
if col not in dataset.column_names:
|
| 390 |
+
logger.error(f"Required column '{col}' not found in dataset!")
|
| 391 |
+
raise ValueError(f"Required column '{col}' missing from dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
+
# Verify expected columns exist
|
| 394 |
+
expected_columns = {"article_id", "conversations", "prompt_number"}
|
| 395 |
+
missing_columns = expected_columns - set(dataset.column_names)
|
| 396 |
+
if missing_columns:
|
| 397 |
+
logger.warning(f"Some expected columns are missing: {missing_columns}")
|
| 398 |
+
|
| 399 |
+
# If "conversations" is missing but "text" exists, attempt conversion
|
| 400 |
+
if "conversations" not in dataset.column_names and "text" in dataset.column_names:
|
| 401 |
+
logger.info("Converting 'text' field to 'conversations' format")
|
| 402 |
+
|
| 403 |
+
def convert_text_to_conversations(example):
|
| 404 |
+
# Check if text is already a list of conversation turns
|
| 405 |
+
if isinstance(example.get("text"), list):
|
| 406 |
+
example["conversations"] = example["text"]
|
| 407 |
+
# Otherwise, create a simple conversation with the text as user message
|
| 408 |
+
else:
|
| 409 |
+
example["conversations"] = [
|
| 410 |
+
{"role": "user", "content": str(example.get("text", ""))}
|
| 411 |
+
]
|
| 412 |
+
return example
|
| 413 |
+
|
| 414 |
+
dataset = dataset.map(convert_text_to_conversations)
|
| 415 |
+
logger.info("Successfully converted 'text' to 'conversations'")
|
| 416 |
|
| 417 |
+
# Verify data ordering requirements
|
| 418 |
processing_config = dataset_config.get("dataset", {}).get("processing", {})
|
| 419 |
data_loading_config = dataset_config.get("data_loading", {})
|
| 420 |
|
| 421 |
+
# Check if sorting is required
|
| 422 |
+
sort_by_article_id = processing_config.get("sort_by_article_id", False)
|
| 423 |
+
if sort_by_article_id and 'article_id' in dataset.column_names:
|
| 424 |
+
logger.info("Sorting dataset by article_id as specified in config")
|
| 425 |
+
dataset = dataset.sort("article_id")
|
| 426 |
+
sorted_ids = [example['article_id'] for example in dataset.select(range(min(5, len(dataset))))]
|
| 427 |
+
logger.info(f"First few article_ids after sorting: {sorted_ids}")
|
| 428 |
+
|
| 429 |
# Flag consolidation - we only need one flag to control sequence preservation
|
| 430 |
# Default to True to ensure safety
|
| 431 |
preserve_sequence = processing_config.get("preserve_entry_sequence", True)
|
|
|
|
| 487 |
logger.warning(f"Error accessing dataset at index {i}: {e}")
|
| 488 |
|
| 489 |
if sample_examples:
|
| 490 |
+
id_field = 'article_id' if 'article_id' in dataset.column_names else 'id'
|
| 491 |
+
if all(isinstance(example.get(id_field, ''), (int, str)) for example in sample_examples):
|
| 492 |
+
sample_ids = [example.get(id_field, '') for example in sample_examples if id_field in example]
|
| 493 |
|
| 494 |
if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
|
| 495 |
numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
|
| 496 |
if len(numeric_ids) > 1:
|
| 497 |
is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
|
| 498 |
if not is_ordered:
|
| 499 |
+
logger.warning(f"WARNING: Sample {id_field}s are not in sequential order.")
|
| 500 |
else:
|
| 501 |
+
logger.info(f"Sample {id_field}s appear to be in sequential order.")
|
| 502 |
except Exception as e:
|
| 503 |
logger.warning(f"Error checking ID sequence: {e}")
|
| 504 |
except Exception as e:
|
|
|
|
| 510 |
# Safely get first few samples
|
| 511 |
first_few_indices = range(min(5, len(dataset)))
|
| 512 |
sample_prompt_numbers = []
|
| 513 |
+
sample_article_ids = []
|
| 514 |
|
| 515 |
for i in first_few_indices:
|
| 516 |
try:
|
| 517 |
example = dataset[i]
|
| 518 |
if 'prompt_number' in example:
|
| 519 |
sample_prompt_numbers.append(example['prompt_number'])
|
| 520 |
+
if 'article_id' in example:
|
| 521 |
+
sample_article_ids.append(example['article_id'])
|
| 522 |
except Exception as e:
|
| 523 |
logger.warning(f"Error accessing sample at index {i}: {e}")
|
| 524 |
|
| 525 |
+
logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, Article IDs: {sample_article_ids}")
|
| 526 |
|
| 527 |
# Log conversation structure without full content
|
| 528 |
if len(dataset) > 0:
|
|
|
|
| 548 |
|
| 549 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
| 550 |
logger.info(f"Dataset columns: {dataset.column_names}")
|
| 551 |
+
|
| 552 |
+
# Verify dataset is not empty
|
| 553 |
+
if len(dataset) == 0:
|
| 554 |
+
logger.error("Dataset is empty! Cannot proceed with training.")
|
| 555 |
+
return dataset
|
| 556 |
+
|
| 557 |
+
# Check for required columns
|
| 558 |
+
required_cols = ['conversations', 'prompt_number']
|
| 559 |
+
for col in required_cols:
|
| 560 |
+
if col not in dataset.column_names:
|
| 561 |
+
logger.error(f"Required column '{col}' missing from dataset. Cannot proceed with training.")
|
| 562 |
+
return dataset
|
| 563 |
+
|
| 564 |
+
# Validate at least one sample can be processed
|
| 565 |
+
try:
|
| 566 |
+
if len(dataset) > 0:
|
| 567 |
+
sample = dataset[0]
|
| 568 |
+
if 'conversations' not in sample or not sample['conversations']:
|
| 569 |
+
logger.error("First sample has no conversations! Data format may be incorrect.")
|
| 570 |
+
return dataset
|
| 571 |
+
if not isinstance(sample['conversations'], list):
|
| 572 |
+
logger.error(f"Conversations field should be a list but got {type(sample['conversations'])}")
|
| 573 |
+
return dataset
|
| 574 |
+
except Exception as e:
|
| 575 |
+
logger.error(f"Error validating first sample: {e}")
|
| 576 |
+
return dataset
|
| 577 |
+
|
| 578 |
+
# Add metadata if specified
|
| 579 |
+
metadata_config = dataset_config.get("data_formatting", {}).get("metadata_handling", {})
|
| 580 |
+
if metadata_config:
|
| 581 |
+
include_article_id = metadata_config.get("include_article_id", False)
|
| 582 |
+
include_prompt_number = metadata_config.get("include_prompt_number", False)
|
| 583 |
+
metadata_format = metadata_config.get("metadata_format", "")
|
| 584 |
+
|
| 585 |
+
if (include_article_id or include_prompt_number) and metadata_format:
|
| 586 |
+
logger.info("Adding metadata to conversations")
|
| 587 |
+
|
| 588 |
+
def add_metadata(example):
|
| 589 |
+
if not example.get("conversations"):
|
| 590 |
+
return example
|
| 591 |
+
|
| 592 |
+
# Prepare metadata
|
| 593 |
+
metadata = metadata_format
|
| 594 |
+
if include_article_id and "article_id" in example:
|
| 595 |
+
metadata = metadata.replace("{article_id}", str(example.get("article_id", "")))
|
| 596 |
+
if include_prompt_number and "prompt_number" in example:
|
| 597 |
+
metadata = metadata.replace("{prompt_number}", str(example.get("prompt_number", "")))
|
| 598 |
+
|
| 599 |
+
# Add system message with metadata if not empty
|
| 600 |
+
if metadata.strip():
|
| 601 |
+
if example["conversations"] and isinstance(example["conversations"], list):
|
| 602 |
+
# Check if first message is already a system message
|
| 603 |
+
if (isinstance(example["conversations"][0], dict) and
|
| 604 |
+
example["conversations"][0].get("role") == "system"):
|
| 605 |
+
# Append to existing system message
|
| 606 |
+
example["conversations"][0]["content"] += f"\n\nMetadata: {metadata}"
|
| 607 |
+
else:
|
| 608 |
+
# Add new system message at the beginning
|
| 609 |
+
example["conversations"].insert(0, {
|
| 610 |
+
"role": "system",
|
| 611 |
+
"content": f"Metadata: {metadata}"
|
| 612 |
+
})
|
| 613 |
+
|
| 614 |
+
return example
|
| 615 |
+
|
| 616 |
+
dataset = dataset.map(add_metadata)
|
| 617 |
+
logger.info("Metadata added to conversations")
|
| 618 |
+
|
| 619 |
return dataset
|
| 620 |
|
| 621 |
except Exception as e:
|
|
|
|
| 858 |
is_sequence_maintained = False
|
| 859 |
|
| 860 |
# Also compare IDs as a backup check
|
| 861 |
+
elif ('article_id' in orig_sample and
|
| 862 |
+
'article_id' in current_sample and
|
| 863 |
+
orig_sample['article_id'] is not None and
|
| 864 |
+
current_sample['article_id'] is not None):
|
| 865 |
|
| 866 |
+
if orig_sample['article_id'] != current_sample['article_id']:
|
| 867 |
+
log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
|
| 868 |
is_sequence_maintained = False
|
| 869 |
|
| 870 |
# Compare input fingerprints
|
|
|
|
| 1005 |
missing_packages.append("peft>=0.9.0")
|
| 1006 |
|
| 1007 |
# Optional packages - don't add to missing list, just log
|
| 1008 |
+
if find_spec("flash_attn"):
|
|
|
|
| 1009 |
logger.info("flash-attn found. Flash attention will be used for faster training.")
|
| 1010 |
+
else:
|
| 1011 |
logger.warning("flash-attn not found. Training will work but may be slower.")
|
| 1012 |
+
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
| 1013 |
|
| 1014 |
# If critical packages are missing, exit with instructions
|
| 1015 |
if missing_packages:
|
|
|
|
| 1023 |
|
| 1024 |
def main():
|
| 1025 |
# Set up logging
|
| 1026 |
+
logger.info("Starting training process")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1027 |
|
| 1028 |
# Parse arguments
|
| 1029 |
args = parse_args()
|
| 1030 |
|
| 1031 |
+
# Load environment variables
|
| 1032 |
+
load_env_variables()
|
| 1033 |
+
|
| 1034 |
+
# Load configuration
|
| 1035 |
+
try:
|
| 1036 |
+
transformers_config = load_configs(args.config)
|
| 1037 |
+
hardware_config = transformers_config.get("hardware", {})
|
| 1038 |
+
dataset_config = transformers_config.get("dataset", {})
|
| 1039 |
+
logger.info("Configuration loaded successfully")
|
| 1040 |
+
except Exception as e:
|
| 1041 |
+
logger.error(f"Error loading configuration: {e}")
|
| 1042 |
+
return 1
|
| 1043 |
+
|
| 1044 |
# Check dependencies
|
| 1045 |
if not check_dependencies():
|
| 1046 |
logger.error("Aborting due to missing critical dependencies")
|
| 1047 |
return 1
|
| 1048 |
|
|
|
|
|
|
|
|
|
|
| 1049 |
# Check if we're in distributed mode
|
| 1050 |
is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
|
| 1051 |
if is_distributed:
|
| 1052 |
+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
| 1053 |
+
log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}")
|
| 1054 |
else:
|
| 1055 |
log_info("Running in non-distributed mode (single process)")
|
| 1056 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1057 |
# Set random seed for reproducibility
|
| 1058 |
seed = transformers_config.get("seed", 42)
|
| 1059 |
set_seed(seed)
|
| 1060 |
+
logger.info(f"Set random seed to {seed}")
|
| 1061 |
+
|
| 1062 |
+
# Load model and tokenizer using the consolidated config
|
| 1063 |
+
model, tokenizer = load_model_and_tokenizer(transformers_config)
|
| 1064 |
|
| 1065 |
# Empty CUDA cache to ensure clean state
|
| 1066 |
if CUDA_AVAILABLE:
|
|
|
|
| 1077 |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
|
| 1078 |
|
| 1079 |
try:
|
| 1080 |
+
log_info("Loading dataset...")
|
| 1081 |
+
dataset = load_dataset_with_mapping(dataset_config)
|
| 1082 |
+
log_info(f"Dataset loaded with {len(dataset)} examples")
|
| 1083 |
|
| 1084 |
+
# Minimal validation before proceeding
|
| 1085 |
+
if dataset is None or len(dataset) == 0:
|
| 1086 |
+
logger.error("Dataset is empty or None! Cannot proceed with training.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1087 |
return 1
|
| 1088 |
|
| 1089 |
# Create data collator
|
transformers_config.json
CHANGED
|
@@ -134,10 +134,11 @@
|
|
| 134 |
"name": "George-API/cognitive-data",
|
| 135 |
"split": "train",
|
| 136 |
"column_mapping": {
|
| 137 |
-
"conversations": "text"
|
|
|
|
| 138 |
},
|
| 139 |
"processing": {
|
| 140 |
-
"
|
| 141 |
"maintain_paper_order": true,
|
| 142 |
"preserve_entry_sequence": true,
|
| 143 |
"max_seq_length": 2048
|
|
@@ -152,9 +153,9 @@
|
|
| 152 |
"user": "Human: {content}\n\n"
|
| 153 |
},
|
| 154 |
"metadata_handling": {
|
| 155 |
-
"
|
| 156 |
-
"
|
| 157 |
-
"metadata_format": "
|
| 158 |
}
|
| 159 |
},
|
| 160 |
"data_loading": {
|
|
|
|
| 134 |
"name": "George-API/cognitive-data",
|
| 135 |
"split": "train",
|
| 136 |
"column_mapping": {
|
| 137 |
+
"conversations": "text",
|
| 138 |
+
"article_id": "id"
|
| 139 |
},
|
| 140 |
"processing": {
|
| 141 |
+
"sort_by_article_id": true,
|
| 142 |
"maintain_paper_order": true,
|
| 143 |
"preserve_entry_sequence": true,
|
| 144 |
"max_seq_length": 2048
|
|
|
|
| 153 |
"user": "Human: {content}\n\n"
|
| 154 |
},
|
| 155 |
"metadata_handling": {
|
| 156 |
+
"include_article_id": true,
|
| 157 |
+
"include_prompt_number": true,
|
| 158 |
+
"metadata_format": "Article ID: {article_id} | Prompt: {prompt_number}"
|
| 159 |
}
|
| 160 |
},
|
| 161 |
"data_loading": {
|