FOXES / data /split_data.py
griffingoodwin04's picture
Fixed normalization of timestamp names
e522a44
import os
import pandas as pd
import shutil
import argparse
from datetime import datetime
from pathlib import Path
def _normalize_timestamp(ts: str) -> str:
"""Normalize timestamp strings with underscores instead of colons (cross-platform filenames)."""
if 'T' in ts:
date_part, time_part = ts.split('T', 1)
return f"{date_part}T{time_part.replace('_', ':')}"
return ts
def split_data(input_folder, output_dir, data_type, flare_events_csv=None, repartition=False,
train_start=None, train_end=None, val_start=None, val_end=None,
test_start=None, test_end=None, use_buffer_strategy=False, copy_files=False):
"""
Split data from input folder into train/val/test based on custom date ranges or month.
Optionally use flare events for additional classification.
Args:
input_folder (str): Path to the input folder containing data files
output_dir (str): Path to the output directory where split data will be saved
data_type (str): Type of data ('aia' or 'sxr')
flare_events_csv (str, optional): Path to the flare events CSV file
repartition (bool): If True, treat input_folder as already partitioned (has train/val/test subdirs)
train_start (str, optional): Start date for training data (format: 'YYYY-MM-DD')
train_end (str, optional): End date for training data (format: 'YYYY-MM-DD')
val_start (str, optional): Start date for validation data (format: 'YYYY-MM-DD')
val_end (str, optional): End date for validation data (format: 'YYYY-MM-DD')
test_start (str, optional): Start date for test data (format: 'YYYY-MM-DD')
test_end (str, optional): End date for test data (format: 'YYYY-MM-DD')
use_buffer_strategy (bool): If True, use predefined buffer strategy with specific date ranges
copy_files (bool): If True, copy files instead of moving them (default: False)
"""
# Validate input folder
if not os.path.exists(input_folder):
raise ValueError(f"Input folder does not exist: {input_folder}")
# Validate data type
if data_type.lower() not in ['aia', 'sxr']:
raise ValueError("data_type must be 'aia' or 'sxr'")
# Parse and validate date ranges
date_ranges = {}
custom_dates = False
buffer_strategy_ranges = None
if use_buffer_strategy:
print("Using predefined buffer strategy...")
# Define the buffer strategy date ranges (multiple periods per split)
buffer_strategy_ranges = {
'train': [
(pd.to_datetime("2012-01-01").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2022-12-31").replace(hour=23, minute=59, second=59, microsecond=999999)),
(pd.to_datetime("2023-07-01").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2023-07-20").replace(hour=23, minute=59, second=59, microsecond=999999))
],
'val': [
(pd.to_datetime("2023-01-01").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2023-06-30").replace(hour=23, minute=59, second=59, microsecond=999999)),
(pd.to_datetime("2023-07-22").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2023-07-30").replace(hour=23, minute=59, second=59, microsecond=999999))
],
'test': [
(pd.to_datetime("2023-08-01").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2025-09-30").replace(hour=23, minute=59, second=59, microsecond=999999))
]
}
#print the buffer strategy date ranges
print(buffer_strategy_ranges)
elif any([train_start, train_end, val_start, val_end, test_start, test_end]):
custom_dates = True
print("Using custom date ranges for data splitting...")
# Parse date strings to datetime objects
try:
if train_start and train_end:
# Start date at beginning of day, end date at end of day
train_start_dt = pd.to_datetime(train_start).replace(hour=0, minute=0, second=0, microsecond=0)
train_end_dt = pd.to_datetime(train_end).replace(hour=23, minute=59, second=59, microsecond=999999)
date_ranges['train'] = (train_start_dt, train_end_dt)
print(f"Training data: {train_start} 00:00:00 to {train_end} 23:59:59")
if val_start and val_end:
val_start_dt = pd.to_datetime(val_start).replace(hour=0, minute=0, second=0, microsecond=0)
val_end_dt = pd.to_datetime(val_end).replace(hour=23, minute=59, second=59, microsecond=999999)
date_ranges['val'] = (val_start_dt, val_end_dt)
print(f"Validation data: {val_start} 00:00:00 to {val_end} 23:59:59")
if test_start and test_end:
test_start_dt = pd.to_datetime(test_start).replace(hour=0, minute=0, second=0, microsecond=0)
test_end_dt = pd.to_datetime(test_end).replace(hour=23, minute=59, second=59, microsecond=999999)
date_ranges['test'] = (test_start_dt, test_end_dt)
print(f"Test data: {test_start} 00:00:00 to {test_end} 23:59:59")
except Exception as e:
raise ValueError(f"Invalid date format. Please use 'YYYY-MM-DD' format. Error: {e}")
# Validate that date ranges don't overlap
ranges = list(date_ranges.values())
for i in range(len(ranges)):
for j in range(i + 1, len(ranges)):
range1, range2 = ranges[i], ranges[j]
if not (range1[1] < range2[0] or range2[1] < range1[0]):
raise ValueError(f"Date ranges cannot overlap. Found overlap between ranges: {range1[0].strftime('%Y-%m-%d')} to {range1[1].strftime('%Y-%m-%d')} and {range2[0].strftime('%Y-%m-%d')} to {range2[1].strftime('%Y-%m-%d')}")
else:
print("Using default month-based splitting...")
# Create output directory structure
os.makedirs(output_dir, exist_ok=True)
for split in ["train", "val", "test"]:
os.makedirs(os.path.join(output_dir, split), exist_ok=True)
print(f"Processing {data_type.upper()} data from: {input_folder}")
print(f"Output directory: {output_dir}")
# Get list of files in input folder
if repartition:
# For repartitioning, collect files from train/val/test subdirectories
data_list = []
for split in ["train", "val", "test"]:
split_dir = os.path.join(input_folder, split)
if os.path.exists(split_dir):
split_files = os.listdir(split_dir)
# Add split information to each file for tracking
for file in split_files:
data_list.append((file, split))
print(f"Found {len(split_files)} files in {split}/ directory")
print(f"Total files to repartition: {len(data_list)}")
else:
# For normal splitting, get files directly from input folder
data_list = os.listdir(input_folder)
print(f"Found {len(data_list)} files to process")
moved_count = 0
skipped_count = 0
for file_info in data_list:
if repartition:
file, original_split = file_info
else:
file = file_info
original_split = None
try:
# Extract timestamp from filename (assuming format like "2012-01-01T00:00:00.npy")
file_time = pd.to_datetime(_normalize_timestamp(file.split(".")[0]))
except ValueError:
print(f"Skipping file {file}: Invalid timestamp format")
skipped_count += 1
continue
# Determine split based on custom date ranges or month
new_split_dir = None
if buffer_strategy_ranges:
# Use buffer strategy with multiple periods per split
for split_name, periods in buffer_strategy_ranges.items():
for start_date, end_date in periods:
if start_date <= file_time <= end_date:
new_split_dir = split_name
break
if new_split_dir:
break
elif custom_dates:
for split_name, (start_date, end_date) in date_ranges.items():
if start_date <= file_time <= end_date:
new_split_dir = split_name
break
print(f"File {file} assigned to split {new_split_dir}")
# Check if file was assigned to a split
if new_split_dir is None:
if custom_dates or buffer_strategy_ranges:
print(f"Skipping file {file}: No matching date range (file time: {file_time.strftime('%Y-%m-%d')})")
else:
# Use default month-based splitting
month = file_time.month
if month in [4, 5, 6, 7, 9, 10, 11, 12]:
new_split_dir = "train"
elif month in [1,2,3]:
new_split_dir = "val"
elif month == 8:
new_split_dir = "test"
else:
print(f"Skipping file {file}: Unexpected month {month}")
skipped_count += 1
continue
# If still no split assigned, skip the file
if new_split_dir is None:
skipped_count += 1
continue
# Determine source and destination paths
if repartition:
src_path = os.path.join(input_folder, original_split, file)
else:
src_path = os.path.join(input_folder, file)
dst_path = os.path.join(output_dir, new_split_dir, file)
# Skip if file is already in the correct split and we're repartitioning
if repartition and original_split == new_split_dir and os.path.exists(dst_path):
print(f"File {file} already in correct split ({new_split_dir}), skipping.")
skipped_count += 1
continue
if not os.path.exists(dst_path):
try:
if copy_files:
shutil.copy2(src_path, dst_path)
action = "Copied"
else:
shutil.move(src_path, dst_path)
action = "Moved"
moved_count += 1
except Exception as e:
print(f"Error {action.lower()}ing {file}: {e}")
skipped_count += 1
else:
action = "copy" if copy_files else "move"
print(f"File {dst_path} already exists, skipping {action}.")
skipped_count += 1
action = "copied" if copy_files else "moved"
print(f"\nProcessing complete!")
print(f"Files {action}: {moved_count}")
print(f"Files skipped: {skipped_count}")
print(f"Total files processed: {moved_count + skipped_count}")
# Check for overlapping files between splits
print(f"\nChecking for overlapping files between splits...")
overlap_found = False
# Get all files in each split directory
split_files = {}
for split in ["train", "val", "test"]:
split_dir = os.path.join(output_dir, split)
if os.path.exists(split_dir):
split_files[split] = set(os.listdir(split_dir))
else:
split_files[split] = set()
# Check for overlaps between each pair of splits
splits = list(split_files.keys())
for i in range(len(splits)):
for j in range(i + 1, len(splits)):
split1, split2 = splits[i], splits[j]
overlap = split_files[split1] & split_files[split2]
if overlap:
overlap_found = True
print(f"WARNING: Found {len(overlap)} overlapping files between {split1} and {split2}:")
for file in sorted(overlap):
print(f" - {file}")
else:
print(f"✓ No overlap between {split1} and {split2}")
if not overlap_found:
print("✓ No overlapping files found between any splits - data integrity verified!")
else:
print(f"\n⚠️ WARNING: Overlapping files detected! Please review the splitting logic.")
# Summary of files per split
print(f"\nFinal split summary:")
for split in ["train", "val", "test"]:
file_count = len(split_files[split])
print(f" {split}: {file_count} files")
def main():
parser = argparse.ArgumentParser(description='Split AIA or SXR data into train/val/test sets based on custom date ranges or month')
parser.add_argument('--input_folder', type=str, required=True,
help='Path to the input folder containing data files (or partitioned folder for repartition)')
parser.add_argument('--output_dir', type=str, required=True,
help='Path to the output directory where split data will be saved')
parser.add_argument('--data_type', type=str, choices=['aia', 'sxr'], required=True,
help='Type of data: "aia" or "sxr"')
parser.add_argument('--flare_events_csv', type=str, default=None,
help='Path to the flare events CSV file (optional)')
parser.add_argument('--repartition', action='store_true',
help='Repartition an already partitioned folder (input_folder should have train/val/test subdirs)')
# Custom date range arguments
parser.add_argument('--train_start', type=str, default=None,
help='Start date for training data (format: YYYY-MM-DD)')
parser.add_argument('--train_end', type=str, default=None,
help='End date for training data (format: YYYY-MM-DD)')
parser.add_argument('--val_start', type=str, default=None,
help='Start date for validation data (format: YYYY-MM-DD)')
parser.add_argument('--val_end', type=str, default=None,
help='End date for validation data (format: YYYY-MM-DD)')
parser.add_argument('--test_start', type=str, default=None,
help='Start date for test data (format: YYYY-MM-DD)')
parser.add_argument('--test_end', type=str, default=None,
help='End date for test data (format: YYYY-MM-DD)')
parser.add_argument('--use_buffer_strategy', action='store_true',
help='Use predefined buffer strategy with specific date ranges and buffer zones')
parser.add_argument('--copy_files', action='store_true',
help='Copy files instead of moving them (keeps original files intact)')
args = parser.parse_args()
# Convert to absolute paths
input_folder = os.path.abspath(args.input_folder)
output_dir = os.path.abspath(args.output_dir)
flare_events_csv = os.path.abspath(args.flare_events_csv) if args.flare_events_csv else None
# Validate repartition mode
if args.repartition:
# Check if input folder has train/val/test subdirectories
expected_dirs = ['train', 'val', 'test']
missing_dirs = []
for dir_name in expected_dirs:
if not os.path.exists(os.path.join(input_folder, dir_name)):
missing_dirs.append(dir_name)
if missing_dirs:
print(f"Warning: Input folder is missing expected subdirectories: {missing_dirs}")
print("Continuing with available directories...")
split_data(input_folder, output_dir, args.data_type, flare_events_csv, args.repartition,
args.train_start, args.train_end, args.val_start, args.val_end,
args.test_start, args.test_end, args.use_buffer_strategy, args.copy_files)
if __name__ == "__main__":
main()