File size: 16,050 Bytes
d39cef0
 
 
7bbc799
d39cef0
7bbc799
d39cef0
e522a44
 
 
 
 
 
 
 
0c7d0f3
 
 
7bbc799
0c7d0f3
7bbc799
 
 
 
 
 
 
 
0c7d0f3
 
 
 
 
 
 
 
7bbc799
 
 
 
 
 
 
 
 
 
0c7d0f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2506f2
0c7d0f3
 
 
 
a2506f2
0c7d0f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bbc799
 
9d017cf
7bbc799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a114fd
7bbc799
 
 
 
 
 
0c7d0f3
7bbc799
 
 
 
 
 
 
 
 
e522a44
7bbc799
 
 
 
 
0c7d0f3
 
7bbc799
0c7d0f3
 
 
 
 
 
 
 
 
 
 
 
 
 
7bbc799
0c7d0f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bbc799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c7d0f3
 
 
7bbc799
0c7d0f3
 
7bbc799
 
0c7d0f3
7bbc799
 
0c7d0f3
 
7bbc799
 
0c7d0f3
7bbc799
0c7d0f3
7bbc799
 
0c7d0f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bbc799
 
0c7d0f3
7bbc799
 
 
 
 
 
 
 
 
 
 
0c7d0f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bbc799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c7d0f3
 
 
6a114fd
7bbc799
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import os
import pandas as pd
import shutil
import argparse
from datetime import datetime
from pathlib import Path


def _normalize_timestamp(ts: str) -> str:
    """Normalize timestamp strings with underscores instead of colons (cross-platform filenames)."""
    if 'T' in ts:
        date_part, time_part = ts.split('T', 1)
        return f"{date_part}T{time_part.replace('_', ':')}"
    return ts

def split_data(input_folder, output_dir, data_type, flare_events_csv=None, repartition=False, 
               train_start=None, train_end=None, val_start=None, val_end=None, 
               test_start=None, test_end=None, use_buffer_strategy=False, copy_files=False):
    """
    Split data from input folder into train/val/test based on custom date ranges or month.
    Optionally use flare events for additional classification.
    
    Args:
        input_folder (str): Path to the input folder containing data files
        output_dir (str): Path to the output directory where split data will be saved
        data_type (str): Type of data ('aia' or 'sxr')
        flare_events_csv (str, optional): Path to the flare events CSV file
        repartition (bool): If True, treat input_folder as already partitioned (has train/val/test subdirs)
        train_start (str, optional): Start date for training data (format: 'YYYY-MM-DD')
        train_end (str, optional): End date for training data (format: 'YYYY-MM-DD')
        val_start (str, optional): Start date for validation data (format: 'YYYY-MM-DD')
        val_end (str, optional): End date for validation data (format: 'YYYY-MM-DD')
        test_start (str, optional): Start date for test data (format: 'YYYY-MM-DD')
        test_end (str, optional): End date for test data (format: 'YYYY-MM-DD')
        use_buffer_strategy (bool): If True, use predefined buffer strategy with specific date ranges
        copy_files (bool): If True, copy files instead of moving them (default: False)
    """
    
    # Validate input folder
    if not os.path.exists(input_folder):
        raise ValueError(f"Input folder does not exist: {input_folder}")
    
    # Validate data type
    if data_type.lower() not in ['aia', 'sxr']:
        raise ValueError("data_type must be 'aia' or 'sxr'")
    
    # Parse and validate date ranges
    date_ranges = {}
    custom_dates = False
    buffer_strategy_ranges = None
    
    if use_buffer_strategy:
        print("Using predefined buffer strategy...")
        
        # Define the buffer strategy date ranges (multiple periods per split)
        buffer_strategy_ranges = {
            'train': [
                (pd.to_datetime("2012-01-01").replace(hour=0, minute=0, second=0, microsecond=0),
                 pd.to_datetime("2022-12-31").replace(hour=23, minute=59, second=59, microsecond=999999)),
                (pd.to_datetime("2023-07-01").replace(hour=0, minute=0, second=0, microsecond=0),
                 pd.to_datetime("2023-07-20").replace(hour=23, minute=59, second=59, microsecond=999999))
            ],
            'val': [
                (pd.to_datetime("2023-01-01").replace(hour=0, minute=0, second=0, microsecond=0),
                 pd.to_datetime("2023-06-30").replace(hour=23, minute=59, second=59, microsecond=999999)),
                (pd.to_datetime("2023-07-22").replace(hour=0, minute=0, second=0, microsecond=0),
                 pd.to_datetime("2023-07-30").replace(hour=23, minute=59, second=59, microsecond=999999))
            ],
            'test': [
                (pd.to_datetime("2023-08-01").replace(hour=0, minute=0, second=0, microsecond=0),
                 pd.to_datetime("2025-09-30").replace(hour=23, minute=59, second=59, microsecond=999999))
            ]
        }
        #print the buffer strategy date ranges
        print(buffer_strategy_ranges)
    elif any([train_start, train_end, val_start, val_end, test_start, test_end]):
        custom_dates = True
        print("Using custom date ranges for data splitting...")
        
        # Parse date strings to datetime objects
        try:
            if train_start and train_end:
                # Start date at beginning of day, end date at end of day
                train_start_dt = pd.to_datetime(train_start).replace(hour=0, minute=0, second=0, microsecond=0)
                train_end_dt = pd.to_datetime(train_end).replace(hour=23, minute=59, second=59, microsecond=999999)
                date_ranges['train'] = (train_start_dt, train_end_dt)
                print(f"Training data: {train_start} 00:00:00 to {train_end} 23:59:59")
            if val_start and val_end:
                val_start_dt = pd.to_datetime(val_start).replace(hour=0, minute=0, second=0, microsecond=0)
                val_end_dt = pd.to_datetime(val_end).replace(hour=23, minute=59, second=59, microsecond=999999)
                date_ranges['val'] = (val_start_dt, val_end_dt)
                print(f"Validation data: {val_start} 00:00:00 to {val_end} 23:59:59")
            if test_start and test_end:
                test_start_dt = pd.to_datetime(test_start).replace(hour=0, minute=0, second=0, microsecond=0)
                test_end_dt = pd.to_datetime(test_end).replace(hour=23, minute=59, second=59, microsecond=999999)
                date_ranges['test'] = (test_start_dt, test_end_dt)
                print(f"Test data: {test_start} 00:00:00 to {test_end} 23:59:59")
        except Exception as e:
            raise ValueError(f"Invalid date format. Please use 'YYYY-MM-DD' format. Error: {e}")
        
        # Validate that date ranges don't overlap
        ranges = list(date_ranges.values())
        for i in range(len(ranges)):
            for j in range(i + 1, len(ranges)):
                range1, range2 = ranges[i], ranges[j]
                if not (range1[1] < range2[0] or range2[1] < range1[0]):
                    raise ValueError(f"Date ranges cannot overlap. Found overlap between ranges: {range1[0].strftime('%Y-%m-%d')} to {range1[1].strftime('%Y-%m-%d')} and {range2[0].strftime('%Y-%m-%d')} to {range2[1].strftime('%Y-%m-%d')}")
    else:
        print("Using default month-based splitting...")
    
    # Create output directory structure
    os.makedirs(output_dir, exist_ok=True)
    for split in ["train", "val", "test"]:
        os.makedirs(os.path.join(output_dir, split), exist_ok=True)
    
    print(f"Processing {data_type.upper()} data from: {input_folder}")
    print(f"Output directory: {output_dir}")
    
    # Get list of files in input folder
    if repartition:
        # For repartitioning, collect files from train/val/test subdirectories
        data_list = []
        for split in ["train", "val", "test"]:
            split_dir = os.path.join(input_folder, split)
            if os.path.exists(split_dir):
                split_files = os.listdir(split_dir)
                # Add split information to each file for tracking
                for file in split_files:
                    data_list.append((file, split))
                print(f"Found {len(split_files)} files in {split}/ directory")
        print(f"Total files to repartition: {len(data_list)}")
    else:
        # For normal splitting, get files directly from input folder
        data_list = os.listdir(input_folder)
        print(f"Found {len(data_list)} files to process")
    
    moved_count = 0
    skipped_count = 0

    for file_info in data_list:
        if repartition:
            file, original_split = file_info
        else:
            file = file_info
            original_split = None
            
        try:
            # Extract timestamp from filename (assuming format like "2012-01-01T00:00:00.npy")
            file_time = pd.to_datetime(_normalize_timestamp(file.split(".")[0]))
        except ValueError:
            print(f"Skipping file {file}: Invalid timestamp format")
            skipped_count += 1
            continue
        
        # Determine split based on custom date ranges or month
        new_split_dir = None
        
        if buffer_strategy_ranges:
            # Use buffer strategy with multiple periods per split
            for split_name, periods in buffer_strategy_ranges.items():
                for start_date, end_date in periods:
                    if start_date <= file_time <= end_date:
                        new_split_dir = split_name
                        break
                if new_split_dir:
                    break
        elif custom_dates:
            for split_name, (start_date, end_date) in date_ranges.items():
                if start_date <= file_time <= end_date:
                    new_split_dir = split_name
                    break
        
        print(f"File {file} assigned to split {new_split_dir}")

        # Check if file was assigned to a split
        if new_split_dir is None:
            if custom_dates or buffer_strategy_ranges:
                print(f"Skipping file {file}: No matching date range (file time: {file_time.strftime('%Y-%m-%d')})")
            else:
                # Use default month-based splitting
                month = file_time.month
                
                if month in [4, 5, 6, 7, 9, 10, 11, 12]:
                    new_split_dir = "train"
                elif month in [1,2,3]:
                    new_split_dir = "val"
                elif month == 8:
                    new_split_dir = "test"
                else:
                    print(f"Skipping file {file}: Unexpected month {month}")
                    skipped_count += 1
                    continue
            
            # If still no split assigned, skip the file
            if new_split_dir is None:
                skipped_count += 1
                continue
        
        # Determine source and destination paths
        if repartition:
            src_path = os.path.join(input_folder, original_split, file)
        else:
            src_path = os.path.join(input_folder, file)
        
        dst_path = os.path.join(output_dir, new_split_dir, file)
        
        # Skip if file is already in the correct split and we're repartitioning
        if repartition and original_split == new_split_dir and os.path.exists(dst_path):
            print(f"File {file} already in correct split ({new_split_dir}), skipping.")
            skipped_count += 1
            continue
        
        if not os.path.exists(dst_path):
            try:
                if copy_files:
                    shutil.copy2(src_path, dst_path)
                    action = "Copied"
                else:
                    shutil.move(src_path, dst_path)
                    action = "Moved"
                moved_count += 1
            except Exception as e:
                print(f"Error {action.lower()}ing {file}: {e}")
                skipped_count += 1
        else:
            action = "copy" if copy_files else "move"
            print(f"File {dst_path} already exists, skipping {action}.")
            skipped_count += 1
    
    action = "copied" if copy_files else "moved"
    print(f"\nProcessing complete!")
    print(f"Files {action}: {moved_count}")
    print(f"Files skipped: {skipped_count}")
    print(f"Total files processed: {moved_count + skipped_count}")
    
    # Check for overlapping files between splits
    print(f"\nChecking for overlapping files between splits...")
    overlap_found = False
    
    # Get all files in each split directory
    split_files = {}
    for split in ["train", "val", "test"]:
        split_dir = os.path.join(output_dir, split)
        if os.path.exists(split_dir):
            split_files[split] = set(os.listdir(split_dir))
        else:
            split_files[split] = set()
    
    # Check for overlaps between each pair of splits
    splits = list(split_files.keys())
    for i in range(len(splits)):
        for j in range(i + 1, len(splits)):
            split1, split2 = splits[i], splits[j]
            overlap = split_files[split1] & split_files[split2]
            
            if overlap:
                overlap_found = True
                print(f"WARNING: Found {len(overlap)} overlapping files between {split1} and {split2}:")
                for file in sorted(overlap):
                    print(f"  - {file}")
            else:
                print(f"✓ No overlap between {split1} and {split2}")
    
    if not overlap_found:
        print("✓ No overlapping files found between any splits - data integrity verified!")
    else:
        print(f"\n⚠️  WARNING: Overlapping files detected! Please review the splitting logic.")
    
    # Summary of files per split
    print(f"\nFinal split summary:")
    for split in ["train", "val", "test"]:
        file_count = len(split_files[split])
        print(f"  {split}: {file_count} files")

def main():
    parser = argparse.ArgumentParser(description='Split AIA or SXR data into train/val/test sets based on custom date ranges or month')
    parser.add_argument('--input_folder', type=str, required=True,
                       help='Path to the input folder containing data files (or partitioned folder for repartition)')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Path to the output directory where split data will be saved')
    parser.add_argument('--data_type', type=str, choices=['aia', 'sxr'], required=True,
                       help='Type of data: "aia" or "sxr"')
    parser.add_argument('--flare_events_csv', type=str, default=None,
                       help='Path to the flare events CSV file (optional)')
    parser.add_argument('--repartition', action='store_true',
                       help='Repartition an already partitioned folder (input_folder should have train/val/test subdirs)')
    
    # Custom date range arguments
    parser.add_argument('--train_start', type=str, default=None,
                       help='Start date for training data (format: YYYY-MM-DD)')
    parser.add_argument('--train_end', type=str, default=None,
                       help='End date for training data (format: YYYY-MM-DD)')
    parser.add_argument('--val_start', type=str, default=None,
                       help='Start date for validation data (format: YYYY-MM-DD)')
    parser.add_argument('--val_end', type=str, default=None,
                       help='End date for validation data (format: YYYY-MM-DD)')
    parser.add_argument('--test_start', type=str, default=None,
                       help='Start date for test data (format: YYYY-MM-DD)')
    parser.add_argument('--test_end', type=str, default=None,
                       help='End date for test data (format: YYYY-MM-DD)')
    parser.add_argument('--use_buffer_strategy', action='store_true',
                       help='Use predefined buffer strategy with specific date ranges and buffer zones')
    parser.add_argument('--copy_files', action='store_true',
                       help='Copy files instead of moving them (keeps original files intact)')
    
    args = parser.parse_args()
    
    # Convert to absolute paths
    input_folder = os.path.abspath(args.input_folder)
    output_dir = os.path.abspath(args.output_dir)
    flare_events_csv = os.path.abspath(args.flare_events_csv) if args.flare_events_csv else None
    
    # Validate repartition mode
    if args.repartition:
        # Check if input folder has train/val/test subdirectories
        expected_dirs = ['train', 'val', 'test']
        missing_dirs = []
        for dir_name in expected_dirs:
            if not os.path.exists(os.path.join(input_folder, dir_name)):
                missing_dirs.append(dir_name)
        
        if missing_dirs:
            print(f"Warning: Input folder is missing expected subdirectories: {missing_dirs}")
            print("Continuing with available directories...")
    
    split_data(input_folder, output_dir, args.data_type, flare_events_csv, args.repartition,
               args.train_start, args.train_end, args.val_start, args.val_end, 
               args.test_start, args.test_end, args.use_buffer_strategy, args.copy_files)

if __name__ == "__main__":
    main()