| import os |
| import pandas as pd |
| import shutil |
| import argparse |
| from datetime import datetime |
| from pathlib import Path |
|
|
|
|
| def _normalize_timestamp(ts: str) -> str: |
| """Normalize timestamp strings with underscores instead of colons (cross-platform filenames).""" |
| if 'T' in ts: |
| date_part, time_part = ts.split('T', 1) |
| return f"{date_part}T{time_part.replace('_', ':')}" |
| return ts |
|
|
| def split_data(input_folder, output_dir, data_type, flare_events_csv=None, repartition=False, |
| train_start=None, train_end=None, val_start=None, val_end=None, |
| test_start=None, test_end=None, use_buffer_strategy=False, copy_files=False): |
| """ |
| Split data from input folder into train/val/test based on custom date ranges or month. |
| Optionally use flare events for additional classification. |
| |
| Args: |
| input_folder (str): Path to the input folder containing data files |
| output_dir (str): Path to the output directory where split data will be saved |
| data_type (str): Type of data ('aia' or 'sxr') |
| flare_events_csv (str, optional): Path to the flare events CSV file |
| repartition (bool): If True, treat input_folder as already partitioned (has train/val/test subdirs) |
| train_start (str, optional): Start date for training data (format: 'YYYY-MM-DD') |
| train_end (str, optional): End date for training data (format: 'YYYY-MM-DD') |
| val_start (str, optional): Start date for validation data (format: 'YYYY-MM-DD') |
| val_end (str, optional): End date for validation data (format: 'YYYY-MM-DD') |
| test_start (str, optional): Start date for test data (format: 'YYYY-MM-DD') |
| test_end (str, optional): End date for test data (format: 'YYYY-MM-DD') |
| use_buffer_strategy (bool): If True, use predefined buffer strategy with specific date ranges |
| copy_files (bool): If True, copy files instead of moving them (default: False) |
| """ |
| |
| |
| if not os.path.exists(input_folder): |
| raise ValueError(f"Input folder does not exist: {input_folder}") |
| |
| |
| if data_type.lower() not in ['aia', 'sxr']: |
| raise ValueError("data_type must be 'aia' or 'sxr'") |
| |
| |
| date_ranges = {} |
| custom_dates = False |
| buffer_strategy_ranges = None |
| |
| if use_buffer_strategy: |
| print("Using predefined buffer strategy...") |
| |
| |
| buffer_strategy_ranges = { |
| 'train': [ |
| (pd.to_datetime("2012-01-01").replace(hour=0, minute=0, second=0, microsecond=0), |
| pd.to_datetime("2022-12-31").replace(hour=23, minute=59, second=59, microsecond=999999)), |
| (pd.to_datetime("2023-07-01").replace(hour=0, minute=0, second=0, microsecond=0), |
| pd.to_datetime("2023-07-20").replace(hour=23, minute=59, second=59, microsecond=999999)) |
| ], |
| 'val': [ |
| (pd.to_datetime("2023-01-01").replace(hour=0, minute=0, second=0, microsecond=0), |
| pd.to_datetime("2023-06-30").replace(hour=23, minute=59, second=59, microsecond=999999)), |
| (pd.to_datetime("2023-07-22").replace(hour=0, minute=0, second=0, microsecond=0), |
| pd.to_datetime("2023-07-30").replace(hour=23, minute=59, second=59, microsecond=999999)) |
| ], |
| 'test': [ |
| (pd.to_datetime("2023-08-01").replace(hour=0, minute=0, second=0, microsecond=0), |
| pd.to_datetime("2025-09-30").replace(hour=23, minute=59, second=59, microsecond=999999)) |
| ] |
| } |
| |
| print(buffer_strategy_ranges) |
| elif any([train_start, train_end, val_start, val_end, test_start, test_end]): |
| custom_dates = True |
| print("Using custom date ranges for data splitting...") |
| |
| |
| try: |
| if train_start and train_end: |
| |
| train_start_dt = pd.to_datetime(train_start).replace(hour=0, minute=0, second=0, microsecond=0) |
| train_end_dt = pd.to_datetime(train_end).replace(hour=23, minute=59, second=59, microsecond=999999) |
| date_ranges['train'] = (train_start_dt, train_end_dt) |
| print(f"Training data: {train_start} 00:00:00 to {train_end} 23:59:59") |
| if val_start and val_end: |
| val_start_dt = pd.to_datetime(val_start).replace(hour=0, minute=0, second=0, microsecond=0) |
| val_end_dt = pd.to_datetime(val_end).replace(hour=23, minute=59, second=59, microsecond=999999) |
| date_ranges['val'] = (val_start_dt, val_end_dt) |
| print(f"Validation data: {val_start} 00:00:00 to {val_end} 23:59:59") |
| if test_start and test_end: |
| test_start_dt = pd.to_datetime(test_start).replace(hour=0, minute=0, second=0, microsecond=0) |
| test_end_dt = pd.to_datetime(test_end).replace(hour=23, minute=59, second=59, microsecond=999999) |
| date_ranges['test'] = (test_start_dt, test_end_dt) |
| print(f"Test data: {test_start} 00:00:00 to {test_end} 23:59:59") |
| except Exception as e: |
| raise ValueError(f"Invalid date format. Please use 'YYYY-MM-DD' format. Error: {e}") |
| |
| |
| ranges = list(date_ranges.values()) |
| for i in range(len(ranges)): |
| for j in range(i + 1, len(ranges)): |
| range1, range2 = ranges[i], ranges[j] |
| if not (range1[1] < range2[0] or range2[1] < range1[0]): |
| raise ValueError(f"Date ranges cannot overlap. Found overlap between ranges: {range1[0].strftime('%Y-%m-%d')} to {range1[1].strftime('%Y-%m-%d')} and {range2[0].strftime('%Y-%m-%d')} to {range2[1].strftime('%Y-%m-%d')}") |
| else: |
| print("Using default month-based splitting...") |
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
| for split in ["train", "val", "test"]: |
| os.makedirs(os.path.join(output_dir, split), exist_ok=True) |
| |
| print(f"Processing {data_type.upper()} data from: {input_folder}") |
| print(f"Output directory: {output_dir}") |
| |
| |
| if repartition: |
| |
| data_list = [] |
| for split in ["train", "val", "test"]: |
| split_dir = os.path.join(input_folder, split) |
| if os.path.exists(split_dir): |
| split_files = os.listdir(split_dir) |
| |
| for file in split_files: |
| data_list.append((file, split)) |
| print(f"Found {len(split_files)} files in {split}/ directory") |
| print(f"Total files to repartition: {len(data_list)}") |
| else: |
| |
| data_list = os.listdir(input_folder) |
| print(f"Found {len(data_list)} files to process") |
| |
| moved_count = 0 |
| skipped_count = 0 |
|
|
| for file_info in data_list: |
| if repartition: |
| file, original_split = file_info |
| else: |
| file = file_info |
| original_split = None |
| |
| try: |
| |
| file_time = pd.to_datetime(_normalize_timestamp(file.split(".")[0])) |
| except ValueError: |
| print(f"Skipping file {file}: Invalid timestamp format") |
| skipped_count += 1 |
| continue |
| |
| |
| new_split_dir = None |
| |
| if buffer_strategy_ranges: |
| |
| for split_name, periods in buffer_strategy_ranges.items(): |
| for start_date, end_date in periods: |
| if start_date <= file_time <= end_date: |
| new_split_dir = split_name |
| break |
| if new_split_dir: |
| break |
| elif custom_dates: |
| for split_name, (start_date, end_date) in date_ranges.items(): |
| if start_date <= file_time <= end_date: |
| new_split_dir = split_name |
| break |
| |
| print(f"File {file} assigned to split {new_split_dir}") |
|
|
| |
| if new_split_dir is None: |
| if custom_dates or buffer_strategy_ranges: |
| print(f"Skipping file {file}: No matching date range (file time: {file_time.strftime('%Y-%m-%d')})") |
| else: |
| |
| month = file_time.month |
| |
| if month in [4, 5, 6, 7, 9, 10, 11, 12]: |
| new_split_dir = "train" |
| elif month in [1,2,3]: |
| new_split_dir = "val" |
| elif month == 8: |
| new_split_dir = "test" |
| else: |
| print(f"Skipping file {file}: Unexpected month {month}") |
| skipped_count += 1 |
| continue |
| |
| |
| if new_split_dir is None: |
| skipped_count += 1 |
| continue |
| |
| |
| if repartition: |
| src_path = os.path.join(input_folder, original_split, file) |
| else: |
| src_path = os.path.join(input_folder, file) |
| |
| dst_path = os.path.join(output_dir, new_split_dir, file) |
| |
| |
| if repartition and original_split == new_split_dir and os.path.exists(dst_path): |
| print(f"File {file} already in correct split ({new_split_dir}), skipping.") |
| skipped_count += 1 |
| continue |
| |
| if not os.path.exists(dst_path): |
| try: |
| if copy_files: |
| shutil.copy2(src_path, dst_path) |
| action = "Copied" |
| else: |
| shutil.move(src_path, dst_path) |
| action = "Moved" |
| moved_count += 1 |
| except Exception as e: |
| print(f"Error {action.lower()}ing {file}: {e}") |
| skipped_count += 1 |
| else: |
| action = "copy" if copy_files else "move" |
| print(f"File {dst_path} already exists, skipping {action}.") |
| skipped_count += 1 |
| |
| action = "copied" if copy_files else "moved" |
| print(f"\nProcessing complete!") |
| print(f"Files {action}: {moved_count}") |
| print(f"Files skipped: {skipped_count}") |
| print(f"Total files processed: {moved_count + skipped_count}") |
| |
| |
| print(f"\nChecking for overlapping files between splits...") |
| overlap_found = False |
| |
| |
| split_files = {} |
| for split in ["train", "val", "test"]: |
| split_dir = os.path.join(output_dir, split) |
| if os.path.exists(split_dir): |
| split_files[split] = set(os.listdir(split_dir)) |
| else: |
| split_files[split] = set() |
| |
| |
| splits = list(split_files.keys()) |
| for i in range(len(splits)): |
| for j in range(i + 1, len(splits)): |
| split1, split2 = splits[i], splits[j] |
| overlap = split_files[split1] & split_files[split2] |
| |
| if overlap: |
| overlap_found = True |
| print(f"WARNING: Found {len(overlap)} overlapping files between {split1} and {split2}:") |
| for file in sorted(overlap): |
| print(f" - {file}") |
| else: |
| print(f"✓ No overlap between {split1} and {split2}") |
| |
| if not overlap_found: |
| print("✓ No overlapping files found between any splits - data integrity verified!") |
| else: |
| print(f"\n⚠️ WARNING: Overlapping files detected! Please review the splitting logic.") |
| |
| |
| print(f"\nFinal split summary:") |
| for split in ["train", "val", "test"]: |
| file_count = len(split_files[split]) |
| print(f" {split}: {file_count} files") |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Split AIA or SXR data into train/val/test sets based on custom date ranges or month') |
| parser.add_argument('--input_folder', type=str, required=True, |
| help='Path to the input folder containing data files (or partitioned folder for repartition)') |
| parser.add_argument('--output_dir', type=str, required=True, |
| help='Path to the output directory where split data will be saved') |
| parser.add_argument('--data_type', type=str, choices=['aia', 'sxr'], required=True, |
| help='Type of data: "aia" or "sxr"') |
| parser.add_argument('--flare_events_csv', type=str, default=None, |
| help='Path to the flare events CSV file (optional)') |
| parser.add_argument('--repartition', action='store_true', |
| help='Repartition an already partitioned folder (input_folder should have train/val/test subdirs)') |
| |
| |
| parser.add_argument('--train_start', type=str, default=None, |
| help='Start date for training data (format: YYYY-MM-DD)') |
| parser.add_argument('--train_end', type=str, default=None, |
| help='End date for training data (format: YYYY-MM-DD)') |
| parser.add_argument('--val_start', type=str, default=None, |
| help='Start date for validation data (format: YYYY-MM-DD)') |
| parser.add_argument('--val_end', type=str, default=None, |
| help='End date for validation data (format: YYYY-MM-DD)') |
| parser.add_argument('--test_start', type=str, default=None, |
| help='Start date for test data (format: YYYY-MM-DD)') |
| parser.add_argument('--test_end', type=str, default=None, |
| help='End date for test data (format: YYYY-MM-DD)') |
| parser.add_argument('--use_buffer_strategy', action='store_true', |
| help='Use predefined buffer strategy with specific date ranges and buffer zones') |
| parser.add_argument('--copy_files', action='store_true', |
| help='Copy files instead of moving them (keeps original files intact)') |
| |
| args = parser.parse_args() |
| |
| |
| input_folder = os.path.abspath(args.input_folder) |
| output_dir = os.path.abspath(args.output_dir) |
| flare_events_csv = os.path.abspath(args.flare_events_csv) if args.flare_events_csv else None |
| |
| |
| if args.repartition: |
| |
| expected_dirs = ['train', 'val', 'test'] |
| missing_dirs = [] |
| for dir_name in expected_dirs: |
| if not os.path.exists(os.path.join(input_folder, dir_name)): |
| missing_dirs.append(dir_name) |
| |
| if missing_dirs: |
| print(f"Warning: Input folder is missing expected subdirectories: {missing_dirs}") |
| print("Continuing with available directories...") |
| |
| split_data(input_folder, output_dir, args.data_type, flare_events_csv, args.repartition, |
| args.train_start, args.train_end, args.val_start, args.val_end, |
| args.test_start, args.test_end, args.use_buffer_strategy, args.copy_files) |
|
|
| if __name__ == "__main__": |
| main() |
|
|