File size: 16,050 Bytes
d39cef0 7bbc799 d39cef0 7bbc799 d39cef0 e522a44 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 a2506f2 0c7d0f3 a2506f2 0c7d0f3 7bbc799 9d017cf 7bbc799 6a114fd 7bbc799 0c7d0f3 7bbc799 e522a44 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 7bbc799 0c7d0f3 6a114fd 7bbc799 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 | import os
import pandas as pd
import shutil
import argparse
from datetime import datetime
from pathlib import Path
def _normalize_timestamp(ts: str) -> str:
"""Normalize timestamp strings with underscores instead of colons (cross-platform filenames)."""
if 'T' in ts:
date_part, time_part = ts.split('T', 1)
return f"{date_part}T{time_part.replace('_', ':')}"
return ts
def split_data(input_folder, output_dir, data_type, flare_events_csv=None, repartition=False,
train_start=None, train_end=None, val_start=None, val_end=None,
test_start=None, test_end=None, use_buffer_strategy=False, copy_files=False):
"""
Split data from input folder into train/val/test based on custom date ranges or month.
Optionally use flare events for additional classification.
Args:
input_folder (str): Path to the input folder containing data files
output_dir (str): Path to the output directory where split data will be saved
data_type (str): Type of data ('aia' or 'sxr')
flare_events_csv (str, optional): Path to the flare events CSV file
repartition (bool): If True, treat input_folder as already partitioned (has train/val/test subdirs)
train_start (str, optional): Start date for training data (format: 'YYYY-MM-DD')
train_end (str, optional): End date for training data (format: 'YYYY-MM-DD')
val_start (str, optional): Start date for validation data (format: 'YYYY-MM-DD')
val_end (str, optional): End date for validation data (format: 'YYYY-MM-DD')
test_start (str, optional): Start date for test data (format: 'YYYY-MM-DD')
test_end (str, optional): End date for test data (format: 'YYYY-MM-DD')
use_buffer_strategy (bool): If True, use predefined buffer strategy with specific date ranges
copy_files (bool): If True, copy files instead of moving them (default: False)
"""
# Validate input folder
if not os.path.exists(input_folder):
raise ValueError(f"Input folder does not exist: {input_folder}")
# Validate data type
if data_type.lower() not in ['aia', 'sxr']:
raise ValueError("data_type must be 'aia' or 'sxr'")
# Parse and validate date ranges
date_ranges = {}
custom_dates = False
buffer_strategy_ranges = None
if use_buffer_strategy:
print("Using predefined buffer strategy...")
# Define the buffer strategy date ranges (multiple periods per split)
buffer_strategy_ranges = {
'train': [
(pd.to_datetime("2012-01-01").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2022-12-31").replace(hour=23, minute=59, second=59, microsecond=999999)),
(pd.to_datetime("2023-07-01").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2023-07-20").replace(hour=23, minute=59, second=59, microsecond=999999))
],
'val': [
(pd.to_datetime("2023-01-01").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2023-06-30").replace(hour=23, minute=59, second=59, microsecond=999999)),
(pd.to_datetime("2023-07-22").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2023-07-30").replace(hour=23, minute=59, second=59, microsecond=999999))
],
'test': [
(pd.to_datetime("2023-08-01").replace(hour=0, minute=0, second=0, microsecond=0),
pd.to_datetime("2025-09-30").replace(hour=23, minute=59, second=59, microsecond=999999))
]
}
#print the buffer strategy date ranges
print(buffer_strategy_ranges)
elif any([train_start, train_end, val_start, val_end, test_start, test_end]):
custom_dates = True
print("Using custom date ranges for data splitting...")
# Parse date strings to datetime objects
try:
if train_start and train_end:
# Start date at beginning of day, end date at end of day
train_start_dt = pd.to_datetime(train_start).replace(hour=0, minute=0, second=0, microsecond=0)
train_end_dt = pd.to_datetime(train_end).replace(hour=23, minute=59, second=59, microsecond=999999)
date_ranges['train'] = (train_start_dt, train_end_dt)
print(f"Training data: {train_start} 00:00:00 to {train_end} 23:59:59")
if val_start and val_end:
val_start_dt = pd.to_datetime(val_start).replace(hour=0, minute=0, second=0, microsecond=0)
val_end_dt = pd.to_datetime(val_end).replace(hour=23, minute=59, second=59, microsecond=999999)
date_ranges['val'] = (val_start_dt, val_end_dt)
print(f"Validation data: {val_start} 00:00:00 to {val_end} 23:59:59")
if test_start and test_end:
test_start_dt = pd.to_datetime(test_start).replace(hour=0, minute=0, second=0, microsecond=0)
test_end_dt = pd.to_datetime(test_end).replace(hour=23, minute=59, second=59, microsecond=999999)
date_ranges['test'] = (test_start_dt, test_end_dt)
print(f"Test data: {test_start} 00:00:00 to {test_end} 23:59:59")
except Exception as e:
raise ValueError(f"Invalid date format. Please use 'YYYY-MM-DD' format. Error: {e}")
# Validate that date ranges don't overlap
ranges = list(date_ranges.values())
for i in range(len(ranges)):
for j in range(i + 1, len(ranges)):
range1, range2 = ranges[i], ranges[j]
if not (range1[1] < range2[0] or range2[1] < range1[0]):
raise ValueError(f"Date ranges cannot overlap. Found overlap between ranges: {range1[0].strftime('%Y-%m-%d')} to {range1[1].strftime('%Y-%m-%d')} and {range2[0].strftime('%Y-%m-%d')} to {range2[1].strftime('%Y-%m-%d')}")
else:
print("Using default month-based splitting...")
# Create output directory structure
os.makedirs(output_dir, exist_ok=True)
for split in ["train", "val", "test"]:
os.makedirs(os.path.join(output_dir, split), exist_ok=True)
print(f"Processing {data_type.upper()} data from: {input_folder}")
print(f"Output directory: {output_dir}")
# Get list of files in input folder
if repartition:
# For repartitioning, collect files from train/val/test subdirectories
data_list = []
for split in ["train", "val", "test"]:
split_dir = os.path.join(input_folder, split)
if os.path.exists(split_dir):
split_files = os.listdir(split_dir)
# Add split information to each file for tracking
for file in split_files:
data_list.append((file, split))
print(f"Found {len(split_files)} files in {split}/ directory")
print(f"Total files to repartition: {len(data_list)}")
else:
# For normal splitting, get files directly from input folder
data_list = os.listdir(input_folder)
print(f"Found {len(data_list)} files to process")
moved_count = 0
skipped_count = 0
for file_info in data_list:
if repartition:
file, original_split = file_info
else:
file = file_info
original_split = None
try:
# Extract timestamp from filename (assuming format like "2012-01-01T00:00:00.npy")
file_time = pd.to_datetime(_normalize_timestamp(file.split(".")[0]))
except ValueError:
print(f"Skipping file {file}: Invalid timestamp format")
skipped_count += 1
continue
# Determine split based on custom date ranges or month
new_split_dir = None
if buffer_strategy_ranges:
# Use buffer strategy with multiple periods per split
for split_name, periods in buffer_strategy_ranges.items():
for start_date, end_date in periods:
if start_date <= file_time <= end_date:
new_split_dir = split_name
break
if new_split_dir:
break
elif custom_dates:
for split_name, (start_date, end_date) in date_ranges.items():
if start_date <= file_time <= end_date:
new_split_dir = split_name
break
print(f"File {file} assigned to split {new_split_dir}")
# Check if file was assigned to a split
if new_split_dir is None:
if custom_dates or buffer_strategy_ranges:
print(f"Skipping file {file}: No matching date range (file time: {file_time.strftime('%Y-%m-%d')})")
else:
# Use default month-based splitting
month = file_time.month
if month in [4, 5, 6, 7, 9, 10, 11, 12]:
new_split_dir = "train"
elif month in [1,2,3]:
new_split_dir = "val"
elif month == 8:
new_split_dir = "test"
else:
print(f"Skipping file {file}: Unexpected month {month}")
skipped_count += 1
continue
# If still no split assigned, skip the file
if new_split_dir is None:
skipped_count += 1
continue
# Determine source and destination paths
if repartition:
src_path = os.path.join(input_folder, original_split, file)
else:
src_path = os.path.join(input_folder, file)
dst_path = os.path.join(output_dir, new_split_dir, file)
# Skip if file is already in the correct split and we're repartitioning
if repartition and original_split == new_split_dir and os.path.exists(dst_path):
print(f"File {file} already in correct split ({new_split_dir}), skipping.")
skipped_count += 1
continue
if not os.path.exists(dst_path):
try:
if copy_files:
shutil.copy2(src_path, dst_path)
action = "Copied"
else:
shutil.move(src_path, dst_path)
action = "Moved"
moved_count += 1
except Exception as e:
print(f"Error {action.lower()}ing {file}: {e}")
skipped_count += 1
else:
action = "copy" if copy_files else "move"
print(f"File {dst_path} already exists, skipping {action}.")
skipped_count += 1
action = "copied" if copy_files else "moved"
print(f"\nProcessing complete!")
print(f"Files {action}: {moved_count}")
print(f"Files skipped: {skipped_count}")
print(f"Total files processed: {moved_count + skipped_count}")
# Check for overlapping files between splits
print(f"\nChecking for overlapping files between splits...")
overlap_found = False
# Get all files in each split directory
split_files = {}
for split in ["train", "val", "test"]:
split_dir = os.path.join(output_dir, split)
if os.path.exists(split_dir):
split_files[split] = set(os.listdir(split_dir))
else:
split_files[split] = set()
# Check for overlaps between each pair of splits
splits = list(split_files.keys())
for i in range(len(splits)):
for j in range(i + 1, len(splits)):
split1, split2 = splits[i], splits[j]
overlap = split_files[split1] & split_files[split2]
if overlap:
overlap_found = True
print(f"WARNING: Found {len(overlap)} overlapping files between {split1} and {split2}:")
for file in sorted(overlap):
print(f" - {file}")
else:
print(f"✓ No overlap between {split1} and {split2}")
if not overlap_found:
print("✓ No overlapping files found between any splits - data integrity verified!")
else:
print(f"\n⚠️ WARNING: Overlapping files detected! Please review the splitting logic.")
# Summary of files per split
print(f"\nFinal split summary:")
for split in ["train", "val", "test"]:
file_count = len(split_files[split])
print(f" {split}: {file_count} files")
def main():
parser = argparse.ArgumentParser(description='Split AIA or SXR data into train/val/test sets based on custom date ranges or month')
parser.add_argument('--input_folder', type=str, required=True,
help='Path to the input folder containing data files (or partitioned folder for repartition)')
parser.add_argument('--output_dir', type=str, required=True,
help='Path to the output directory where split data will be saved')
parser.add_argument('--data_type', type=str, choices=['aia', 'sxr'], required=True,
help='Type of data: "aia" or "sxr"')
parser.add_argument('--flare_events_csv', type=str, default=None,
help='Path to the flare events CSV file (optional)')
parser.add_argument('--repartition', action='store_true',
help='Repartition an already partitioned folder (input_folder should have train/val/test subdirs)')
# Custom date range arguments
parser.add_argument('--train_start', type=str, default=None,
help='Start date for training data (format: YYYY-MM-DD)')
parser.add_argument('--train_end', type=str, default=None,
help='End date for training data (format: YYYY-MM-DD)')
parser.add_argument('--val_start', type=str, default=None,
help='Start date for validation data (format: YYYY-MM-DD)')
parser.add_argument('--val_end', type=str, default=None,
help='End date for validation data (format: YYYY-MM-DD)')
parser.add_argument('--test_start', type=str, default=None,
help='Start date for test data (format: YYYY-MM-DD)')
parser.add_argument('--test_end', type=str, default=None,
help='End date for test data (format: YYYY-MM-DD)')
parser.add_argument('--use_buffer_strategy', action='store_true',
help='Use predefined buffer strategy with specific date ranges and buffer zones')
parser.add_argument('--copy_files', action='store_true',
help='Copy files instead of moving them (keeps original files intact)')
args = parser.parse_args()
# Convert to absolute paths
input_folder = os.path.abspath(args.input_folder)
output_dir = os.path.abspath(args.output_dir)
flare_events_csv = os.path.abspath(args.flare_events_csv) if args.flare_events_csv else None
# Validate repartition mode
if args.repartition:
# Check if input folder has train/val/test subdirectories
expected_dirs = ['train', 'val', 'test']
missing_dirs = []
for dir_name in expected_dirs:
if not os.path.exists(os.path.join(input_folder, dir_name)):
missing_dirs.append(dir_name)
if missing_dirs:
print(f"Warning: Input folder is missing expected subdirectories: {missing_dirs}")
print("Continuing with available directories...")
split_data(input_folder, output_dir, args.data_type, flare_events_csv, args.repartition,
args.train_start, args.train_end, args.val_start, args.val_end,
args.test_start, args.test_end, args.use_buffer_strategy, args.copy_files)
if __name__ == "__main__":
main()
|