griffingoodwin04 commited on
Commit
ec2b4e7
·
1 Parent(s): 10ad6fc

Refactor pipeline configuration and update data processing scripts

Browse files
.gitignore CHANGED
@@ -154,4 +154,4 @@ wandb/
154
  *.code-workspace
155
 
156
  .claude/
157
- misc/
 
154
  *.code-workspace
155
 
156
  .claude/
157
+ Untracked/
.dockerignore → Untracked/.dockerignore RENAMED
@@ -24,7 +24,7 @@ venv.bak/
24
 
25
  # IDE
26
  .vscode/
27
- .idea/
28
  *.swp
29
  *.swo
30
  *~
 
24
 
25
  # IDE
26
  .vscode/
27
+ ../.idea/
28
  *.swp
29
  *.swo
30
  *~
QUICKSTART_DOCKER.md → Untracked/QUICKSTART_DOCKER.md RENAMED
File without changes
data/align_data.py CHANGED
@@ -38,9 +38,8 @@ def load_config():
38
  'alignment': {
39
  'goes_data_dir': "/mnt/data/PAPER/GOES-timespan/combined",
40
  'aia_processed_dir': "/mnt/data/PAPER/SDOITI",
41
- 'output_sxr_a_dir': "/mnt/data/PAPER/GOES-SXR-A",
42
- 'output_sxr_b_dir': "/mnt/data/PAPER/GOES-SXR-B",
43
- 'aia_missing_dir': "/mnt/data/PAPER/AIA_ITI_MISSING"
44
  },
45
  'processing': {
46
  'batch_size_multiplier': 4,
@@ -56,8 +55,7 @@ GOES_DATA_DIR = config['alignment']['goes_data_dir']
56
  AIA_PROCESSED_DIR = config['alignment']['aia_processed_dir']
57
 
58
  # Output directories
59
- OUTPUT_SXR_A_DIR = config['alignment']['output_sxr_a_dir']
60
- OUTPUT_SXR_B_DIR = config['alignment']['output_sxr_b_dir']
61
  AIA_MISSING_DIR = config['alignment']['aia_missing_dir']
62
 
63
  # Processing configuration
@@ -136,7 +134,6 @@ def create_combined_lookup_table(goes_data_dict, target_timestamps):
136
 
137
  # For each target timestamp, average over all available instruments at that time
138
  for target_time in tqdm(target_times, desc="Building lookup table"):
139
- sxr_a_values = []
140
  sxr_b_values = []
141
  available_instruments = []
142
 
@@ -144,22 +141,15 @@ def create_combined_lookup_table(goes_data_dict, target_timestamps):
144
  goes_data = goes_data_dict[g_number]
145
  if target_time in goes_data.index:
146
  row = goes_data.loc[target_time]
147
- sxr_a = row['xrsa_flux']
148
  sxr_b = row['xrsb_flux']
149
- # Only care about xrsb_flux for validity
150
  if not pd.isna(sxr_b):
151
  sxr_b_values.append(float(sxr_b))
152
- if not pd.isna(sxr_a):
153
- sxr_a_values.append(float(sxr_a))
154
  available_instruments.append(f"GOES-{g_number}")
155
 
156
  if sxr_b_values:
157
- avg_sxr_b = float(np.mean(sxr_b_values))
158
- avg_sxr_a = float(np.mean(sxr_a_values)) if sxr_a_values else float('nan')
159
  lookup_data.append({
160
  'timestamp': target_time.strftime('%Y-%m-%dT%H:%M:%S'),
161
- 'sxr_a': avg_sxr_a,
162
- 'sxr_b': avg_sxr_b,
163
  'instrument': ",".join(available_instruments)
164
  })
165
 
@@ -179,21 +169,14 @@ def process_batch(batch_data):
179
  for data in batch_data:
180
  try:
181
  timestamp = data['timestamp']
182
- sxr_a = data['sxr_a']
183
  sxr_b = data['sxr_b']
184
  instrument = data['instrument']
185
-
186
- # Create arrays
187
- sxr_a_data = np.array([sxr_a], dtype=np.float32)
188
- sxr_b_data = np.array([sxr_b], dtype=np.float32)
189
-
190
- # Save data to disk using configured directories
191
- np.save(f"{OUTPUT_SXR_A_DIR}/{timestamp}.npy", sxr_a_data)
192
- np.save(f"{OUTPUT_SXR_B_DIR}/{timestamp}.npy", sxr_b_data)
193
-
194
  successful_count += 1
195
  results.append((timestamp, True, f"Success using {instrument}"))
196
-
197
  except Exception as e:
198
  failed_count += 1
199
  results.append((timestamp, False, f"Error processing timestamp {timestamp}: {e}"))
@@ -213,14 +196,12 @@ def main():
213
  print("=" * 60)
214
  print(f"GOES data directory: {GOES_DATA_DIR}")
215
  print(f"AIA processed directory: {AIA_PROCESSED_DIR}")
216
- print(f"Output SXR-A directory: {OUTPUT_SXR_A_DIR}")
217
- print(f"Output SXR-B directory: {OUTPUT_SXR_B_DIR}")
218
  print(f"AIA missing directory: {AIA_MISSING_DIR}")
219
  print("=" * 60)
220
 
221
  # Make output directories if they don't exist
222
- os.makedirs(OUTPUT_SXR_A_DIR, exist_ok=True)
223
- os.makedirs(OUTPUT_SXR_B_DIR, exist_ok=True)
224
  os.makedirs(AIA_MISSING_DIR, exist_ok=True)
225
 
226
  # Load and prepare GOES data with optimizations
 
38
  'alignment': {
39
  'goes_data_dir': "/mnt/data/PAPER/GOES-timespan/combined",
40
  'aia_processed_dir': "/mnt/data/PAPER/SDOITI",
41
+ 'output_sxr_dir': "/Volumes/T9/Data_FOXES/SXR_processed",
42
+ 'aia_missing_dir': "/Volumes/T9/Data_FOXES/AIA_missing"
 
43
  },
44
  'processing': {
45
  'batch_size_multiplier': 4,
 
55
  AIA_PROCESSED_DIR = config['alignment']['aia_processed_dir']
56
 
57
  # Output directories
58
+ OUTPUT_SXR_DIR = config['alignment']['output_sxr_dir']
 
59
  AIA_MISSING_DIR = config['alignment']['aia_missing_dir']
60
 
61
  # Processing configuration
 
134
 
135
  # For each target timestamp, average over all available instruments at that time
136
  for target_time in tqdm(target_times, desc="Building lookup table"):
 
137
  sxr_b_values = []
138
  available_instruments = []
139
 
 
141
  goes_data = goes_data_dict[g_number]
142
  if target_time in goes_data.index:
143
  row = goes_data.loc[target_time]
 
144
  sxr_b = row['xrsb_flux']
 
145
  if not pd.isna(sxr_b):
146
  sxr_b_values.append(float(sxr_b))
 
 
147
  available_instruments.append(f"GOES-{g_number}")
148
 
149
  if sxr_b_values:
 
 
150
  lookup_data.append({
151
  'timestamp': target_time.strftime('%Y-%m-%dT%H:%M:%S'),
152
+ 'sxr_b': float(np.mean(sxr_b_values)),
 
153
  'instrument': ",".join(available_instruments)
154
  })
155
 
 
169
  for data in batch_data:
170
  try:
171
  timestamp = data['timestamp']
 
172
  sxr_b = data['sxr_b']
173
  instrument = data['instrument']
174
+
175
+ np.save(f"{OUTPUT_SXR_DIR}/{timestamp}.npy", np.array([sxr_b], dtype=np.float32))
176
+
 
 
 
 
 
 
177
  successful_count += 1
178
  results.append((timestamp, True, f"Success using {instrument}"))
179
+
180
  except Exception as e:
181
  failed_count += 1
182
  results.append((timestamp, False, f"Error processing timestamp {timestamp}: {e}"))
 
196
  print("=" * 60)
197
  print(f"GOES data directory: {GOES_DATA_DIR}")
198
  print(f"AIA processed directory: {AIA_PROCESSED_DIR}")
199
+ print(f"Output SXR directory: {OUTPUT_SXR_DIR}")
 
200
  print(f"AIA missing directory: {AIA_MISSING_DIR}")
201
  print("=" * 60)
202
 
203
  # Make output directories if they don't exist
204
+ os.makedirs(OUTPUT_SXR_DIR, exist_ok=True)
 
205
  os.makedirs(AIA_MISSING_DIR, exist_ok=True)
206
 
207
  # Load and prepare GOES data with optimizations
data/euv_data_cleaning.py CHANGED
@@ -14,36 +14,18 @@ collections.MutableMapping = collections.abc.MutableMapping
14
  from itipy.data.dataset import get_intersecting_files
15
  from astropy.io import fits
16
 
17
- # Configuration for all wavelengths to process
18
- # Load configuration from environment or use defaults
19
- import os
20
  import json
21
 
 
22
  def load_config():
23
  """Load configuration from environment or use defaults."""
24
- if 'PIPELINE_CONFIG' in os.environ:
25
- try:
26
- config = json.loads(os.environ['PIPELINE_CONFIG'])
27
- return config
28
- except:
29
- pass
30
-
31
- # Default configuration
32
- return {
33
- 'euv': {
34
- 'wavelengths': [94, 131, 171, 193, 211, 304],
35
- 'input_folder': '/mnt/data/PAPER/SDOData',
36
- 'bad_files_dir': '/mnt/data/PAPER/SDO-AIA_bad'
37
- }
38
- }
39
-
40
- config = load_config()
41
- wavelengths = config['euv']['wavelengths']
42
- base_input_folder = config['euv']['input_folder']
43
-
44
- aia_files = get_intersecting_files(base_input_folder, wavelengths)
45
-
46
- # Function to process a single file
47
  def process_fits_file(file_path):
48
  try:
49
  with fits.open(file_path) as hdu:
@@ -59,39 +41,43 @@ def process_fits_file(file_path):
59
  print(f"Error processing {file_path}: {e}")
60
  return None
61
 
62
- file_list = aia_files[0] # List of FITS file paths
63
-
64
- with Pool(processes=os.cpu_count()) as pool:
65
- results = list(tqdm(pool.imap(process_fits_file, file_list), total=len(file_list)))
66
-
67
- # Filter out None results (in case of failed files)
68
- results = [r for r in results if r is not None]
69
-
70
- # Convert to DataFrame
71
- aia_header = pd.DataFrame(results)
72
- # Ensure DATE-OBS is datetime (already timezone-naive from processing)
73
- aia_header['DATE-OBS'] = pd.to_datetime(aia_header['DATE-OBS'])
74
-
75
- # add a column for date difference between DATE-OBS and FILENAME
76
- aia_header['DATE_DIFF'] = (
77
- pd.to_datetime(aia_header['FILENAME']) - pd.to_datetime(aia_header['DATE-OBS'])).dt.total_seconds()
78
-
79
- # remove rows where DATE_DIFF is greater than plus or minus 60 seconds in a list
80
- files_to_remove = aia_header[(aia_header['DATE_DIFF'] <= -60) | (aia_header['DATE_DIFF'] >= 60)]
81
- print(len(files_to_remove))
82
- # Loop through each wavelength
83
- for wavelength in wavelengths:
84
- #print(f"\nProcessing wavelength: {wavelength}")
85
- for names in files_to_remove['FILENAME'].to_numpy():
86
- # Construct file path
87
- filename = pd.to_datetime(names).strftime('%Y-%m-%dT%H:%M:%S') + ".fits"
88
- file_path = os.path.join(base_input_folder, f"{wavelength}/{filename}")
89
- # Destination path
90
- destination_folder = os.path.join(config['euv']['bad_files_dir'], str(wavelength))
91
- os.makedirs(destination_folder, exist_ok=True)
92
- # Move or report missing
93
- if os.path.exists(file_path):
94
- shutil.move(file_path, destination_folder)
95
- print(f"Removed file: {file_path}")
96
- else:
97
- print(f"File not found: {file_path}")
 
 
 
 
 
14
  from itipy.data.dataset import get_intersecting_files
15
  from astropy.io import fits
16
 
 
 
 
17
  import json
18
 
19
+
20
  def load_config():
21
  """Load configuration from environment or use defaults."""
22
+ try:
23
+ config = json.loads(os.environ['PIPELINE_CONFIG'])
24
+ return config
25
+ except:
26
+ pass
27
+
28
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def process_fits_file(file_path):
30
  try:
31
  with fits.open(file_path) as hdu:
 
41
  print(f"Error processing {file_path}: {e}")
42
  return None
43
 
44
+
45
+ if __name__ == '__main__':
46
+ config = load_config()
47
+ wavelengths = config['euv']['wavelengths']
48
+ base_input_folder = config['euv']['input_folder']
49
+
50
+ aia_files = get_intersecting_files(base_input_folder, wavelengths)
51
+ file_list = aia_files[0] # List of FITS file paths
52
+
53
+ with Pool(processes=os.cpu_count()) as pool:
54
+ results = list(tqdm(pool.imap(process_fits_file, file_list), total=len(file_list)))
55
+
56
+ # Filter out None results (in case of failed files)
57
+ results = [r for r in results if r is not None]
58
+
59
+ # Convert to DataFrame
60
+ aia_header = pd.DataFrame(results)
61
+ aia_header['DATE-OBS'] = pd.to_datetime(aia_header['DATE-OBS'])
62
+
63
+ # Add a column for date difference between DATE-OBS and FILENAME
64
+ aia_header['DATE_DIFF'] = (
65
+ pd.to_datetime(aia_header['FILENAME']) - pd.to_datetime(aia_header['DATE-OBS'])
66
+ ).dt.total_seconds()
67
+
68
+ # Remove rows where DATE_DIFF is greater than ±60 seconds
69
+ files_to_remove = aia_header[(aia_header['DATE_DIFF'] <= -60) | (aia_header['DATE_DIFF'] >= 60)]
70
+ print(f"{len(files_to_remove)} bad files found")
71
+
72
+ for wavelength in wavelengths:
73
+ print(f"\nProcessing wavelength: {wavelength}")
74
+ for names in files_to_remove['FILENAME'].to_numpy():
75
+ filename = pd.to_datetime(names).strftime('%Y-%m-%dT%H:%M:%S') + ".fits"
76
+ file_path = os.path.join(base_input_folder, f"{wavelength}/{filename}")
77
+ destination_folder = os.path.join(config['euv']['bad_files_dir'], str(wavelength))
78
+ os.makedirs(destination_folder, exist_ok=True)
79
+ if os.path.exists(file_path):
80
+ shutil.move(file_path, destination_folder)
81
+ print(f"Moved: {file_path}")
82
+ else:
83
+ print(f"Not found: {file_path}")
data/iti_data_processing.py CHANGED
@@ -10,38 +10,18 @@ from astropy.visualization import ImageNormalize, AsinhStretch
10
  from itipy.data.dataset import StackDataset, get_intersecting_files, AIADataset
11
  from itipy.data.editor import BrightestPixelPatchEditor, sdo_norms
12
  import os
 
13
  from multiprocessing import Pool
14
  from tqdm import tqdm
15
 
16
- # Configuration for all wavelengths to process
17
- # Load configuration from environment or use defaults
18
- import os
19
- import json
20
 
21
  def load_config():
22
  """Load configuration from environment or use defaults."""
23
- if 'PIPELINE_CONFIG' in os.environ:
24
- try:
25
- config = json.loads(os.environ['PIPELINE_CONFIG'])
26
- return config
27
- except:
28
- pass
29
-
30
- # Default configuration
31
- return {
32
- 'iti': {
33
- 'wavelengths': [94, 131, 171, 193, 211, 304],
34
- 'input_folder': '/mnt/data/PAPER/SDOData',
35
- 'output_folder': '/mnt/data/PAPER/SDOITI'
36
- }
37
- }
38
-
39
- config = load_config()
40
- wavelengths = config['iti']['wavelengths']
41
- base_input_folder = config['iti']['input_folder']
42
- output_folder = config['iti']['output_folder']
43
- os.makedirs(output_folder, exist_ok=True)
44
-
45
 
46
 
47
 
@@ -72,72 +52,72 @@ class SDODataset_flaring(StackDataset):
72
  self.addEditor(BrightestPixelPatchEditor(patch_shape))
73
 
74
 
75
- # Check if we need to process anything before loading the dataset
76
- def check_existing_files():
77
- """Check how many files already exist without loading the full dataset"""
78
- # Get file list from the base folder to estimate total samples
79
- from itipy.data.dataset import get_intersecting_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  files = get_intersecting_files(base_input_folder, wavelengths, ext='.fits')
81
  if not files or len(files) == 0:
82
  return 0, 0
83
-
84
- # Count existing output files - need to check for each wavelength combination
85
  existing_count = 0
86
- total_expected = len(files[0]) # All wavelength lists should have same length
87
-
88
- # Check each time step (index across all wavelengths)
89
  for i in range(total_expected):
90
- # Check if output file exists for this time step
91
- # The output filename should be based on the first wavelength's filename
92
- first_wl_file = files[0][i] # Use first wavelength as reference
93
  base_name = os.path.splitext(os.path.basename(first_wl_file))[0]
94
- # Remove wavelength suffix if present (e.g., "_171" from filename)
95
  if '_' in base_name:
96
  base_name = '_'.join(base_name.split('_')[:-1])
97
  output_path = os.path.join(output_folder, base_name) + '.npy'
98
-
99
  if os.path.exists(output_path):
100
  existing_count += 1
101
-
102
  return existing_count, total_expected
103
 
104
- # Check existing files first
105
- existing_files, total_expected = check_existing_files()
106
- print(f"Found {existing_files} existing files out of {total_expected} expected files")
107
-
108
- if existing_files >= total_expected:
109
- print("All files already processed. Nothing to do.")
110
- else:
111
- print(f"Need to process {total_expected - existing_files} remaining files")
112
-
113
- # Only load the dataset if we need to process files
114
- aia_dataset = SDODataset_flaring(data=base_input_folder, wavelengths=wavelengths, resolution=512, allow_errors=True)
115
-
116
- # Filter out indices that already have processed files
117
- def get_unprocessed_indices():
118
- unprocessed = []
119
- for i in range(len(aia_dataset)):
120
- file_path = os.path.join(output_folder, aia_dataset.getId(i)) + '.npy'
121
- if not os.path.exists(file_path):
122
- unprocessed.append(i)
123
- return unprocessed
124
-
125
- def save_sample(i):
126
- try:
127
- data = aia_dataset[i]
128
- file_path = os.path.join(output_folder, aia_dataset.getId(i)) + '.npy'
129
- np.save(file_path, data)
130
- except Exception as e:
131
- print(f"Warning: Could not process sample {i} (ID: {aia_dataset.getId(i)}): {e}")
132
- return # Skip this sample and continue with the next one
133
-
134
- # Get only unprocessed indices
135
- unprocessed_indices = get_unprocessed_indices()
136
- print(f"Processing {len(unprocessed_indices)} unprocessed samples")
137
-
138
- if unprocessed_indices:
139
- with Pool(processes=os.cpu_count()) as pool:
140
- list(tqdm(pool.imap(save_sample, unprocessed_indices), total=len(unprocessed_indices)))
141
- print("AIA data processing completed.")
142
  else:
143
- print("All samples already processed. Nothing to do.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from itipy.data.dataset import StackDataset, get_intersecting_files, AIADataset
11
  from itipy.data.editor import BrightestPixelPatchEditor, sdo_norms
12
  import os
13
+ import json
14
  from multiprocessing import Pool
15
  from tqdm import tqdm
16
 
 
 
 
 
17
 
18
  def load_config():
19
  """Load configuration from environment or use defaults."""
20
+ try:
21
+ config = json.loads(os.environ['PIPELINE_CONFIG'])
22
+ return config
23
+ except:
24
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
 
 
52
  self.addEditor(BrightestPixelPatchEditor(patch_shape))
53
 
54
 
55
+ _aia_dataset = None
56
+ _output_folder = None
57
+
58
+
59
+ def _init_worker(dataset, out_folder):
60
+ global _aia_dataset, _output_folder
61
+ _aia_dataset = dataset
62
+ _output_folder = out_folder
63
+
64
+
65
+ def save_sample(i):
66
+ try:
67
+ data = _aia_dataset[i]
68
+ file_path = os.path.join(_output_folder, _aia_dataset.getId(i)) + '.npy'
69
+ np.save(file_path, data)
70
+ except Exception as e:
71
+ print(f"Warning: Could not process sample {i} (ID: {_aia_dataset.getId(i)}): {e}")
72
+
73
+
74
+ def check_existing_files(base_input_folder, wavelengths, output_folder):
75
+ """Check how many files already exist without loading the full dataset."""
76
  files = get_intersecting_files(base_input_folder, wavelengths, ext='.fits')
77
  if not files or len(files) == 0:
78
  return 0, 0
79
+
 
80
  existing_count = 0
81
+ total_expected = len(files[0])
82
+
 
83
  for i in range(total_expected):
84
+ first_wl_file = files[0][i]
 
 
85
  base_name = os.path.splitext(os.path.basename(first_wl_file))[0]
 
86
  if '_' in base_name:
87
  base_name = '_'.join(base_name.split('_')[:-1])
88
  output_path = os.path.join(output_folder, base_name) + '.npy'
 
89
  if os.path.exists(output_path):
90
  existing_count += 1
91
+
92
  return existing_count, total_expected
93
 
94
+
95
+ if __name__ == '__main__':
96
+ config = load_config()
97
+ wavelengths = config['iti']['wavelengths']
98
+ base_input_folder = config['iti']['input_folder']
99
+ output_folder = config['iti']['output_folder']
100
+ os.makedirs(output_folder, exist_ok=True)
101
+
102
+ existing_files, total_expected = check_existing_files(base_input_folder, wavelengths, output_folder)
103
+ print(f"Found {existing_files} existing files out of {total_expected} expected files")
104
+
105
+ if existing_files >= total_expected:
106
+ print("All files already processed. Nothing to do.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  else:
108
+ print(f"Need to process {total_expected - existing_files} remaining files")
109
+
110
+ aia_dataset = SDODataset_flaring(data=base_input_folder, wavelengths=wavelengths, resolution=512, allow_errors=True)
111
+
112
+ unprocessed_indices = [
113
+ i for i in range(len(aia_dataset))
114
+ if not os.path.exists(os.path.join(output_folder, aia_dataset.getId(i)) + '.npy')
115
+ ]
116
+ print(f"Processing {len(unprocessed_indices)} unprocessed samples")
117
+
118
+ if unprocessed_indices:
119
+ with Pool(processes=os.cpu_count(), initializer=_init_worker, initargs=(aia_dataset, output_folder)) as pool:
120
+ list(tqdm(pool.imap(save_sample, unprocessed_indices), total=len(unprocessed_indices)))
121
+ print("AIA data processing completed.")
122
+ else:
123
+ print("All samples already processed. Nothing to do.")
data/pipeline_config.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PipelineConfig: loads and validates the data processing pipeline configuration.
3
+
4
+ Used by process_data_pipeline.py. Config is read from a YAML file and passed
5
+ to each sub-script as a JSON string via the PIPELINE_CONFIG environment variable.
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+
11
+ import yaml
12
+
13
+
14
+ TEMPLATE_PATH = Path(__file__).parent / "pipeline_config.yaml"
15
+
16
+ # Paths that must exist before the pipeline runs
17
+ REQUIRED_INPUT_PATHS = [
18
+ ("euv", "input_folder"),
19
+ ("iti", "input_folder"),
20
+ ]
21
+
22
+ # Paths that should be created before the pipeline runs
23
+ OUTPUT_PATHS = [
24
+ ("euv", "bad_files_dir"),
25
+ ("iti", "output_folder"),
26
+ ("alignment", "output_sxr_dir"),
27
+ ("alignment", "aia_missing_dir"),
28
+ ]
29
+
30
+
31
+ class PipelineConfig:
32
+ def __init__(self, config_path: str = None):
33
+ if config_path:
34
+ with open(config_path, "r") as f:
35
+ self.config = yaml.safe_load(f)
36
+ else:
37
+ self.config = self._defaults()
38
+
39
+ # ------------------------------------------------------------------
40
+
41
+ def get_path(self, section: str, key: str) -> str:
42
+ """Return config[section][key], or config[section] if key == section."""
43
+ section_data = self.config.get(section, {})
44
+ if isinstance(section_data, dict):
45
+ return section_data.get(key, "")
46
+ return section_data # scalar value (e.g. base_data_dir)
47
+
48
+ def to_json(self) -> str:
49
+ """Serialize config to JSON string for passing via environment variable."""
50
+ return json.dumps(self.config)
51
+
52
+ # ------------------------------------------------------------------
53
+
54
+ def validate_paths(self) -> tuple[bool, list[str]]:
55
+ """Check that all required input paths exist. Returns (valid, missing)."""
56
+ missing = []
57
+ for section, key in REQUIRED_INPUT_PATHS:
58
+ p = self.get_path(section, key)
59
+ if p and not Path(p).exists():
60
+ missing.append(f"{section}.{key}: {p}")
61
+ return (len(missing) == 0, missing)
62
+
63
+ def create_directories(self):
64
+ """Create all output directories."""
65
+ for section, key in OUTPUT_PATHS:
66
+ p = self.get_path(section, key)
67
+ if p:
68
+ Path(p).mkdir(parents=True, exist_ok=True)
69
+
70
+ def print_config(self):
71
+ print(yaml.dump(self.config, default_flow_style=False))
72
+
73
+ def save_config_template(self, path: str = None):
74
+ dest = Path(path) if path else TEMPLATE_PATH
75
+ with open(dest, "w") as f:
76
+ yaml.dump(self._defaults(), f, default_flow_style=False)
77
+ print(f"Template saved to {dest}")
78
+
79
+ # ------------------------------------------------------------------
80
+
81
+ @staticmethod
82
+ def _defaults() -> dict:
83
+ return {
84
+ "base_data_dir": "/Volumes/T9/Data_FOXES",
85
+ "euv": {
86
+ "input_folder": "/Volumes/T9/Data_FOXES/AIA_raw",
87
+ "bad_files_dir": "/Volumes/T9/Data_FOXES/AIA_bad",
88
+ "wavelengths": [94, 131, 171, 193, 211, 304, 335],
89
+ },
90
+ "iti": {
91
+ "input_folder": "/Volumes/T9/Data_FOXES/AIA_raw",
92
+ "output_folder": "/Volumes/T9/Data_FOXES/AIA_processed",
93
+ "wavelengths": [94, 131, 171, 193, 211, 304, 335],
94
+ },
95
+ "alignment": {
96
+ "goes_data_dir": "/Volumes/T9/Data_FOXES/SXR_raw/combined",
97
+ "aia_processed_dir": "/Volumes/T9/Data_FOXES/AIA_processed",
98
+ "output_sxr_dir": "/Volumes/T9/Data_FOXES/SXR_processed",
99
+ "aia_missing_dir": "/Volumes/T9/Data_FOXES/AIA_missing",
100
+ },
101
+ "processing": {
102
+ "max_processes": None,
103
+ "batch_size_multiplier": 4,
104
+ "min_batch_size": 1,
105
+ },
106
+ }
data/pipeline_config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Processing Pipeline Configuration
2
+ #
3
+ # Usage: python process_data_pipeline.py --config pipeline_config.yaml
4
+ #
5
+ # Directory flow:
6
+ # AIA_raw → euv_data_cleaning → bad files moved to AIA_bad
7
+ # AIA_raw → iti_data_processing → AIA_processed (512x512 .npy)
8
+ # AIA_processed ┐
9
+ # ├─ align_data ──────→ SXR_processed (xrsb_flux .npy per timestamp)
10
+ # SXR_raw/combined ┘ AIA_missing (AIA files with no SXR match)
11
+
12
+ base_data_dir: /Volumes/T9/Data_FOXES
13
+
14
+ euv:
15
+ input_folder: /Volumes/T9/Data_FOXES/AIA_raw
16
+ bad_files_dir: /Volumes/T9/Data_FOXES/AIA_bad
17
+ wavelengths:
18
+ - 94
19
+ - 131
20
+ - 171
21
+ - 193
22
+ - 211
23
+ - 304
24
+ - 335
25
+
26
+ iti:
27
+ input_folder: /Volumes/T9/Data_FOXES/AIA_raw # same as euv (bad files already moved out)
28
+ output_folder: /Volumes/T9/Data_FOXES/AIA_processed
29
+ wavelengths:
30
+ - 94
31
+ - 131
32
+ - 171
33
+ - 193
34
+ - 211
35
+ - 304
36
+ - 335
37
+
38
+ alignment:
39
+ goes_data_dir: /Volumes/T9/Data_FOXES/SXR_raw/combined # output of sxr_downloader concat
40
+ aia_processed_dir: /Volumes/T9/Data_FOXES/AIA_processed # must match iti.output_folder
41
+ output_sxr_dir: /Volumes/T9/Data_FOXES/SXR_processed
42
+ aia_missing_dir: /Volumes/T9/Data_FOXES/AIA_missing
43
+
44
+ processing:
45
+ max_processes: null # null = use all available cores
46
+ batch_size_multiplier: 4
47
+ min_batch_size: 1
data/process_data_pipeline.py CHANGED
@@ -105,11 +105,8 @@ class DataProcessingPipeline:
105
  """
106
  Check if data alignment is complete by looking for output directories.
107
  """
108
- output_dirs = [
109
- Path(self.config.get_path('alignment', 'output_sxr_a_dir')),
110
- Path(self.config.get_path('alignment', 'output_sxr_b_dir'))
111
- ]
112
- return all(d.exists() and any(d.iterdir()) for d in output_dirs)
113
 
114
  def run_script(self, script_name, step_info):
115
  """
@@ -135,7 +132,7 @@ class DataProcessingPipeline:
135
  # Create environment variables for configuration
136
  env = os.environ.copy()
137
  env.update({
138
- 'PIPELINE_CONFIG': str(self.config.config),
139
  'BASE_DATA_DIR': self.config.get_path('base_data_dir', 'base_data_dir')
140
  })
141
 
@@ -145,23 +142,18 @@ class DataProcessingPipeline:
145
  # Run the script
146
  result = subprocess.run(
147
  [sys.executable, str(script_path)],
148
- capture_output=True,
149
- text=True,
150
  cwd=self.base_dir,
151
  env=env
152
  )
153
-
154
  end_time = time.time()
155
  duration = end_time - start_time
156
-
157
  if result.returncode == 0:
158
  logger.info(f"✓ {step_info['name']} completed successfully in {duration:.2f} seconds")
159
- if result.stdout:
160
- logger.debug(f"Output: {result.stdout}")
161
  return True
162
  else:
163
  logger.error(f"✗ {step_info['name']} failed with return code {result.returncode}")
164
- logger.error(f"Error output: {result.stderr}")
165
  return False
166
 
167
  except Exception as e:
 
105
  """
106
  Check if data alignment is complete by looking for output directories.
107
  """
108
+ output_dir = Path(self.config.get_path('alignment', 'output_sxr_dir'))
109
+ return output_dir.exists() and any(output_dir.iterdir())
 
 
 
110
 
111
  def run_script(self, script_name, step_info):
112
  """
 
132
  # Create environment variables for configuration
133
  env = os.environ.copy()
134
  env.update({
135
+ 'PIPELINE_CONFIG': self.config.to_json(),
136
  'BASE_DATA_DIR': self.config.get_path('base_data_dir', 'base_data_dir')
137
  })
138
 
 
142
  # Run the script
143
  result = subprocess.run(
144
  [sys.executable, str(script_path)],
 
 
145
  cwd=self.base_dir,
146
  env=env
147
  )
148
+
149
  end_time = time.time()
150
  duration = end_time - start_time
151
+
152
  if result.returncode == 0:
153
  logger.info(f"✓ {step_info['name']} completed successfully in {duration:.2f} seconds")
 
 
154
  return True
155
  else:
156
  logger.error(f"✗ {step_info['name']} failed with return code {result.returncode}")
 
157
  return False
158
 
159
  except Exception as e:
data/sxr_data_processing.py CHANGED
@@ -50,10 +50,10 @@ class SXRDataProcessor:
50
  total_files = len(g13_files) + len(g14_files) + len(g15_files) + len(g16_files) + len(g17_files) + len(g18_files)
51
  logging.info(
52
  f"Found {len(g13_files)} GOES-13 files, {len(g14_files)} GOES-14 files, {len(g15_files)} GOES-15 files, {len(g16_files)} GOES-16 files, {len(g17_files)} GOES-17 files, and {len(g18_files)} GOES-18 files.")
53
- print(f"📊 Total files found: {total_files}")
54
 
55
  if total_files == 0:
56
- print("⚠️ No GOES data files found in the specified directory.")
57
  return
58
 
59
  def process_files(files, satellite_name, output_file, used_file_list):
@@ -84,37 +84,37 @@ class SXRDataProcessor:
84
  ds.close()
85
 
86
  if not datasets:
87
- print(f"No valid datasets for {satellite_name}")
88
  logging.warning(f"No valid datasets for {satellite_name}")
89
  return
90
 
91
- print(f"📊 Processing {len(datasets)} datasets for {satellite_name}...")
92
 
93
  try:
94
- print(f"🔗 Concatenating datasets...")
95
  combined_ds = xr.concat(datasets, dim='time').sortby('time')
96
 
97
  # Scaling factors for GOES-13, GOES-14, and GOES-15
98
  if satellite_name in ['GOES-13', 'GOES-14', 'GOES-15']:
99
- print(f"⚖️ Applying scaling factors for {satellite_name}...")
100
  combined_ds['xrsa_flux'] = combined_ds['xrsa_flux'] / .85
101
  combined_ds['xrsb_flux'] = combined_ds['xrsb_flux'] / .7
102
 
103
- print(f"🔄 Converting to DataFrame...")
104
  df = combined_ds.to_dataframe().reset_index()
105
 
106
  if 'quad_diode' in df.columns:
107
- print(f"🔍 Filtering quad diode data...")
108
  df = df[df['quad_diode'] == 0] # Filter out quad diode data
109
 
110
  #Filter out data where xrsb_flux has a quality flag of >0
111
- print(f"🔍 Filtering out data where xrsb_flux has a quality flag of >0...")
112
- df = df[df['xrsb_quality'] == 0]
113
 
114
  df['time'] = pd.to_datetime(df['time'])
115
  df.set_index('time', inplace=True)
116
 
117
- print(f"📈 Applying log interpolation...")
118
  df_log = np.log10(df[columns_to_interp].replace(0, np.nan))
119
 
120
  # Step 3: Interpolate in log space
@@ -128,14 +128,14 @@ class SXRDataProcessor:
128
  max_date = df.index.max().strftime('%Y%m%d')
129
  filename = f"{str(output_file)}_{min_date}_{max_date}.csv"
130
 
131
- print(f"💾 Saving to {filename}...")
132
  df.to_csv(filename, index=True)
133
 
134
- print(f"Successfully processed {satellite_name}: {successful_files} files loaded, {failed_files} failed")
135
  logging.info(f"Saved combined file: {output_file}")
136
 
137
  except Exception as e:
138
- print(f"Failed to process {satellite_name}: {e}")
139
  logging.error(f"Failed to write {output_file}: {e}")
140
  finally:
141
  for ds in datasets:
@@ -156,7 +156,7 @@ class SXRDataProcessor:
156
  if len(g18_files) != 0:
157
  satellites_to_process.append((g18_files, "GOES-18", self.output_dir / "combined_g18_avg1m", self.used_g18_files))
158
 
159
- print(f"\n🚀 Starting processing of {len(satellites_to_process)} satellites...")
160
 
161
  # Process each satellite with overall progress tracking
162
  successful_satellites = 0
@@ -164,36 +164,36 @@ class SXRDataProcessor:
164
 
165
  for i, (files, satellite_name, output_file, used_file_list) in enumerate(satellites_to_process, 1):
166
  print(f"\n{'='*60}")
167
- print(f"📡 Processing satellite {i}/{len(satellites_to_process)}: {satellite_name}")
168
  print(f"{'='*60}")
169
 
170
  try:
171
  process_files(files, satellite_name, output_file, used_file_list)
172
  successful_satellites += 1
173
  except Exception as e:
174
- print(f"Failed to process {satellite_name}: {e}")
175
  failed_satellites += 1
176
  logging.error(f"Failed to process {satellite_name}: {e}")
177
 
178
  # Print final summary
179
  print(f"\n{'='*60}")
180
- print(f"📊 PROCESSING COMPLETE")
181
  print(f"{'='*60}")
182
- print(f"Successfully processed: {successful_satellites} satellites")
183
- print(f"Failed: {failed_satellites} satellites")
184
- print(f"📁 Total files processed: {total_files}")
185
- print(f"📂 Output directory: {self.output_dir}")
186
 
187
  # Print file usage statistics
188
  total_used_files = (len(self.used_g13_files) + len(self.used_g14_files) +
189
  len(self.used_g15_files) + len(self.used_g16_files) +
190
  len(self.used_g17_files) + len(self.used_g18_files))
191
- print(f"📋 Files used in processing: {total_used_files}")
192
 
193
  if successful_satellites > 0:
194
- print(f"\n🎉 SXR data processing completed successfully!")
195
  else:
196
- print(f"\n⚠No satellites were processed successfully.")
197
 
198
 
199
  if __name__ == '__main__':
@@ -204,13 +204,13 @@ if __name__ == '__main__':
204
  help='Directory where combined GOES data will be saved.')
205
  args = parser.parse_args()
206
 
207
- print("🌞 GOES SXR Data Processing Tool")
208
  print("=" * 50)
209
- print(f"📂 Data directory: {args.data_dir}")
210
- print(f"📁 Output directory: {args.output_dir}")
211
  print("=" * 50)
212
 
213
  processor = SXRDataProcessor(data_dir=args.data_dir, output_dir=args.output_dir)
214
  processor.combine_goes_data()
215
 
216
- print("\n🏁 All processing tasks completed.")
 
50
  total_files = len(g13_files) + len(g14_files) + len(g15_files) + len(g16_files) + len(g17_files) + len(g18_files)
51
  logging.info(
52
  f"Found {len(g13_files)} GOES-13 files, {len(g14_files)} GOES-14 files, {len(g15_files)} GOES-15 files, {len(g16_files)} GOES-16 files, {len(g17_files)} GOES-17 files, and {len(g18_files)} GOES-18 files.")
53
+ print(f"Total files found: {total_files}")
54
 
55
  if total_files == 0:
56
+ print("No GOES data files found in the specified directory.")
57
  return
58
 
59
  def process_files(files, satellite_name, output_file, used_file_list):
 
84
  ds.close()
85
 
86
  if not datasets:
87
+ print(f"No valid datasets for {satellite_name}")
88
  logging.warning(f"No valid datasets for {satellite_name}")
89
  return
90
 
91
+ print(f"Processing {len(datasets)} datasets for {satellite_name}...")
92
 
93
  try:
94
+ print(f"Concatenating datasets...")
95
  combined_ds = xr.concat(datasets, dim='time').sortby('time')
96
 
97
  # Scaling factors for GOES-13, GOES-14, and GOES-15
98
  if satellite_name in ['GOES-13', 'GOES-14', 'GOES-15']:
99
+ print(f"Applying scaling factors for {satellite_name}...")
100
  combined_ds['xrsa_flux'] = combined_ds['xrsa_flux'] / .85
101
  combined_ds['xrsb_flux'] = combined_ds['xrsb_flux'] / .7
102
 
103
+ print(f"Converting to DataFrame...")
104
  df = combined_ds.to_dataframe().reset_index()
105
 
106
  if 'quad_diode' in df.columns:
107
+ print(f"Filtering quad diode data...")
108
  df = df[df['quad_diode'] == 0] # Filter out quad diode data
109
 
110
  #Filter out data where xrsb_flux has a quality flag of >0
111
+ print(f"Filtering out data where xrsb_flux has a quality flag of >0...")
112
+ df = df[df['xrsb_flag'] == 0]
113
 
114
  df['time'] = pd.to_datetime(df['time'])
115
  df.set_index('time', inplace=True)
116
 
117
+ print(f"Applying log interpolation...")
118
  df_log = np.log10(df[columns_to_interp].replace(0, np.nan))
119
 
120
  # Step 3: Interpolate in log space
 
128
  max_date = df.index.max().strftime('%Y%m%d')
129
  filename = f"{str(output_file)}_{min_date}_{max_date}.csv"
130
 
131
+ print(f"Saving to {filename}...")
132
  df.to_csv(filename, index=True)
133
 
134
+ print(f"Successfully processed {satellite_name}: {successful_files} files loaded, {failed_files} failed")
135
  logging.info(f"Saved combined file: {output_file}")
136
 
137
  except Exception as e:
138
+ print(f"Failed to process {satellite_name}: {e}")
139
  logging.error(f"Failed to write {output_file}: {e}")
140
  finally:
141
  for ds in datasets:
 
156
  if len(g18_files) != 0:
157
  satellites_to_process.append((g18_files, "GOES-18", self.output_dir / "combined_g18_avg1m", self.used_g18_files))
158
 
159
+ print(f"\nStarting processing of {len(satellites_to_process)} satellites...")
160
 
161
  # Process each satellite with overall progress tracking
162
  successful_satellites = 0
 
164
 
165
  for i, (files, satellite_name, output_file, used_file_list) in enumerate(satellites_to_process, 1):
166
  print(f"\n{'='*60}")
167
+ print(f"Processing satellite {i}/{len(satellites_to_process)}: {satellite_name}")
168
  print(f"{'='*60}")
169
 
170
  try:
171
  process_files(files, satellite_name, output_file, used_file_list)
172
  successful_satellites += 1
173
  except Exception as e:
174
+ print(f"Failed to process {satellite_name}: {e}")
175
  failed_satellites += 1
176
  logging.error(f"Failed to process {satellite_name}: {e}")
177
 
178
  # Print final summary
179
  print(f"\n{'='*60}")
180
+ print(f"PROCESSING COMPLETE")
181
  print(f"{'='*60}")
182
+ print(f"Successfully processed: {successful_satellites} satellites")
183
+ print(f"Failed: {failed_satellites} satellites")
184
+ print(f"Total files processed: {total_files}")
185
+ print(f"Output directory: {self.output_dir}")
186
 
187
  # Print file usage statistics
188
  total_used_files = (len(self.used_g13_files) + len(self.used_g14_files) +
189
  len(self.used_g15_files) + len(self.used_g16_files) +
190
  len(self.used_g17_files) + len(self.used_g18_files))
191
+ print(f"Files used in processing: {total_used_files}")
192
 
193
  if successful_satellites > 0:
194
+ print(f"\nSXR data processing completed successfully!")
195
  else:
196
+ print(f"\n⚠No satellites were processed successfully.")
197
 
198
 
199
  if __name__ == '__main__':
 
204
  help='Directory where combined GOES data will be saved.')
205
  args = parser.parse_args()
206
 
207
+ print("GOES SXR Data Processing Tool")
208
  print("=" * 50)
209
+ print(f"Data directory: {args.data_dir}")
210
+ print(f"Output directory: {args.output_dir}")
211
  print("=" * 50)
212
 
213
  processor = SXRDataProcessor(data_dir=args.data_dir, output_dir=args.output_dir)
214
  processor.combine_goes_data()
215
 
216
+ print("\nAll processing tasks completed.")
download/download_sdo.py CHANGED
@@ -27,7 +27,7 @@ class SDODownloader:
27
  wavelengths (list): List of wavelengths to download.
28
  n_workers (int): Number of worker threads for parallel download.
29
  """
30
- def __init__(self, base_path='/mnt/data/PAPER/SDOData', email=None, wavelengths=['94', '131', '171', '193', '211', '304'], n_workers=4, cadence=60):
31
  self.ds_path = base_path
32
  self.wavelengths = [str(wl) for wl in wavelengths]
33
  self.n_workers = n_workers
@@ -53,7 +53,11 @@ class SDODownloader:
53
  if os.path.exists(map_path):
54
  return map_path
55
  # load map
 
 
 
56
  url = 'http://jsoc.stanford.edu' + segment
 
57
 
58
  # Retry download with exponential backoff
59
  max_retries = 3
@@ -111,27 +115,20 @@ class SDODownloader:
111
  id = date.isoformat()
112
 
113
  logging.info('Start download: %s' % id)
114
- # query Magnetogram
115
- #time_param = '%sZ' % date.isoformat('_', timespec='seconds')
116
- #ds_hmi = 'hmi.M_720s[%s]{magnetogram}' % time_param
117
- #keys_hmi = self.drms_client.keys(ds_hmi)
118
- #header_hmi, segment_hmi = self.drms_client.query(ds_hmi, key=','.join(keys_hmi), seg='magnetogram')
119
- #if len(header_hmi) != 1 or np.any(header_hmi.QUALITY != 0):
120
- # self.fetchDataFallback(date)
121
- # return
122
 
123
  # query EUV
124
  time_param = '%sZ' % date.isoformat('_', timespec='seconds')
125
  ds_euv = 'aia.lev1_euv_12s[%s][%s]{image}' % (time_param, ','.join(self.wavelengths))
126
  keys_euv = self.drms_client.keys(ds_euv)
127
  header_euv, segment_euv = self.drms_client.query(ds_euv, key=','.join(keys_euv), seg='image')
128
- if len(header_euv) != len(self.wavelengths) or np.any(header_euv.QUALITY != 0):
 
 
 
129
  self.fetchDataFallback(date)
130
  return
131
 
132
  queue = []
133
- #for (idx, h), s in zip(header_hmi.iterrows(), segment_hmi.magnetogram):
134
- # queue += [(h.to_dict(), s, date)]
135
  for (idx, h), s in zip(header_euv.iterrows(), segment_euv.image):
136
  queue += [(h.to_dict(), s, date)]
137
 
@@ -155,28 +152,6 @@ class SDODownloader:
155
  id = date.isoformat()
156
 
157
  logging.info('Fallback download: %s' % id)
158
- # query Magnetogram
159
- t = date - timedelta(hours=24)
160
- ds_hmi = 'hmi.M_720s[%sZ/12h@720s]{magnetogram}' % t.replace(tzinfo=None).isoformat('_', timespec='seconds')
161
- keys_hmi = self.drms_client.keys(ds_hmi)
162
- header_tmp, segment_tmp = self.drms_client.query(ds_hmi, key=','.join(keys_hmi), seg='magnetogram')
163
- assert len(header_tmp) != 0, 'No data found!'
164
- date_str = header_tmp['DATE__OBS'].replace('MISSING', '').str.replace('60', '59') # fix date format
165
- date_diff = np.abs(pd.to_datetime(date_str).dt.tz_localize(None) - date)
166
- # sort and filter
167
- header_tmp['date_diff'] = date_diff
168
- header_tmp.sort_values('date_diff')
169
- segment_tmp['date_diff'] = date_diff
170
- segment_tmp.sort_values('date_diff')
171
- cond_tmp = header_tmp.QUALITY == 0
172
- header_tmp = header_tmp[cond_tmp]
173
- segment_tmp = segment_tmp[cond_tmp]
174
- assert len(header_tmp) > 0, 'No valid quality flag found'
175
- # replace invalid
176
- header_hmi = header_tmp.iloc[0].drop('date_diff')
177
- segment_hmi = segment_tmp.iloc[0].drop('date_diff')
178
- ############################################################
179
- # query EUV
180
  header_euv, segment_euv = [], []
181
  t = date - timedelta(hours=6)
182
  for wl in self.wavelengths:
@@ -184,21 +159,23 @@ class SDODownloader:
184
  t.replace(tzinfo=None).isoformat('_', timespec='seconds'), wl)
185
  keys_euv = self.drms_client.keys(euv_ds)
186
  header_tmp, segment_tmp = self.drms_client.query(euv_ds, key=','.join(keys_euv), seg='image')
187
- assert len(header_tmp) != 0, 'No data found!'
 
188
  date_str = header_tmp['DATE__OBS'].replace('MISSING', '').str.replace('60', '59') # fix date format
189
  date_diff = (pd.to_datetime(date_str).dt.tz_localize(None) - date).abs()
190
  # sort and filter
191
  header_tmp['date_diff'] = date_diff
192
- header_tmp.sort_values('date_diff')
193
  segment_tmp['date_diff'] = date_diff
194
- segment_tmp.sort_values('date_diff')
195
- cond_tmp = header_tmp.QUALITY == 0
196
- header_tmp = header_tmp[cond_tmp]
197
- segment_tmp = segment_tmp[cond_tmp]
198
- assert len(header_tmp) > 0, 'No valid quality flag found'
199
- # replace invalid
200
- header_euv.append(header_tmp.iloc[0].drop('date_diff'))
201
- segment_euv.append(segment_tmp.iloc[0].drop('date_diff'))
 
 
202
 
203
  queue = []
204
  #queue += [(header_hmi.to_dict(), segment_hmi.magnetogram, date)]
@@ -218,8 +195,8 @@ if __name__ == '__main__':
218
  parser.add_argument('--download_dir', type=str, help='path to the download directory.')
219
  parser.add_argument('--email', type=str, help='registered email address for JSOC.')
220
  parser.add_argument('--start_date', type=str, help='start date in format YYYY-MM-DD.')
221
- parser.add_argument('--end_date', type=str, help='end date in format YYYY-MM-DD.', required=False,
222
- default=str(datetime.now()).split(' ')[0])
223
  parser.add_argument('--cadence', type=int, help='cadence in minutes.', required=False, default=60)
224
 
225
  args = parser.parse_args()
@@ -228,7 +205,7 @@ if __name__ == '__main__':
228
  end_date = args.end_date
229
  cadence = args.cadence
230
 
231
- [os.makedirs(os.path.join(download_dir, str(c)), exist_ok=True) for c in [94, 131, 171, 193, 211, 304]]
232
  downloader = SDODownloader(base_path=download_dir, email=args.email)
233
  start_date_datetime = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
234
  #end_date = datetime.now()
@@ -236,10 +213,10 @@ if __name__ == '__main__':
236
 
237
 
238
  #Skip over dates that already exist in the download directory
239
- for d in [start_date_datetime + i * timedelta(minutes=1) for i in
240
- range((end_date_datetime - start_date_datetime) // timedelta(minutes=1))]:
241
  #make sure the file exists in all wavelengths directories
242
- for wl in [94, 131, 171, 193, 211, 304]:
243
  if not os.path.exists(os.path.join(
244
  download_dir,
245
  str(wl),
 
27
  wavelengths (list): List of wavelengths to download.
28
  n_workers (int): Number of worker threads for parallel download.
29
  """
30
+ def __init__(self, base_path='/mnt/data/PAPER/SDOData', email=None, wavelengths=['94', '131', '171', '193', '211', '304', '335'], n_workers=4, cadence=60):
31
  self.ds_path = base_path
32
  self.wavelengths = [str(wl) for wl in wavelengths]
33
  self.n_workers = n_workers
 
53
  if os.path.exists(map_path):
54
  return map_path
55
  # load map
56
+ if not segment or pd.isna(segment):
57
+ logging.error('Segment path is null for %s — data may not be in JSOC cache' % header.get('DATE__OBS'))
58
+ raise ValueError('Null segment path for %s' % header.get('DATE__OBS'))
59
  url = 'http://jsoc.stanford.edu' + segment
60
+ logging.info('Downloading: %s' % url)
61
 
62
  # Retry download with exponential backoff
63
  max_retries = 3
 
115
  id = date.isoformat()
116
 
117
  logging.info('Start download: %s' % id)
 
 
 
 
 
 
 
 
118
 
119
  # query EUV
120
  time_param = '%sZ' % date.isoformat('_', timespec='seconds')
121
  ds_euv = 'aia.lev1_euv_12s[%s][%s]{image}' % (time_param, ','.join(self.wavelengths))
122
  keys_euv = self.drms_client.keys(ds_euv)
123
  header_euv, segment_euv = self.drms_client.query(ds_euv, key=','.join(keys_euv), seg='image')
124
+ logging.info('Fast-path query returned %d rows (need %d), qualities: %s' % (
125
+ len(header_euv), len(self.wavelengths),
126
+ list(header_euv.QUALITY) if len(header_euv) > 0 else []))
127
+ if len(header_euv) != len(self.wavelengths) or np.any(header_euv.QUALITY.fillna(0) != 0):
128
  self.fetchDataFallback(date)
129
  return
130
 
131
  queue = []
 
 
132
  for (idx, h), s in zip(header_euv.iterrows(), segment_euv.image):
133
  queue += [(h.to_dict(), s, date)]
134
 
 
152
  id = date.isoformat()
153
 
154
  logging.info('Fallback download: %s' % id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  header_euv, segment_euv = [], []
156
  t = date - timedelta(hours=6)
157
  for wl in self.wavelengths:
 
159
  t.replace(tzinfo=None).isoformat('_', timespec='seconds'), wl)
160
  keys_euv = self.drms_client.keys(euv_ds)
161
  header_tmp, segment_tmp = self.drms_client.query(euv_ds, key=','.join(keys_euv), seg='image')
162
+ logging.info('Fallback query wl=%s returned %d rows' % (wl, len(header_tmp)))
163
+ assert len(header_tmp) != 0, 'No data found for wl=%s at %s' % (wl, id)
164
  date_str = header_tmp['DATE__OBS'].replace('MISSING', '').str.replace('60', '59') # fix date format
165
  date_diff = (pd.to_datetime(date_str).dt.tz_localize(None) - date).abs()
166
  # sort and filter
167
  header_tmp['date_diff'] = date_diff
 
168
  segment_tmp['date_diff'] = date_diff
169
+ cond_tmp = (header_tmp.QUALITY == 0) | header_tmp.QUALITY.isna()
170
+ header_filtered = header_tmp[cond_tmp]
171
+ segment_filtered = segment_tmp[cond_tmp]
172
+ if len(header_filtered) > 0:
173
+ header_tmp = header_filtered
174
+ segment_tmp = segment_filtered
175
+ else:
176
+ logging.warning('No quality-0 EUV frames for wl=%s at %s — using closest available' % (wl, id))
177
+ header_euv.append(header_tmp.sort_values('date_diff').iloc[0].drop('date_diff'))
178
+ segment_euv.append(segment_tmp.sort_values('date_diff').iloc[0].drop('date_diff'))
179
 
180
  queue = []
181
  #queue += [(header_hmi.to_dict(), segment_hmi.magnetogram, date)]
 
195
  parser.add_argument('--download_dir', type=str, help='path to the download directory.')
196
  parser.add_argument('--email', type=str, help='registered email address for JSOC.')
197
  parser.add_argument('--start_date', type=str, help='start date in format YYYY-MM-DD.')
198
+ parser.add_argument('--end_date', type=str, help='end date in format YYYY-MM-DD HH:MM:SS.', required=False,
199
+ default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
200
  parser.add_argument('--cadence', type=int, help='cadence in minutes.', required=False, default=60)
201
 
202
  args = parser.parse_args()
 
205
  end_date = args.end_date
206
  cadence = args.cadence
207
 
208
+ [os.makedirs(os.path.join(download_dir, str(c)), exist_ok=True) for c in [94, 131, 171, 193, 211, 304, 335]]
209
  downloader = SDODownloader(base_path=download_dir, email=args.email)
210
  start_date_datetime = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
211
  #end_date = datetime.now()
 
213
 
214
 
215
  #Skip over dates that already exist in the download directory
216
+ for d in [start_date_datetime + i * timedelta(minutes=cadence) for i in
217
+ range((end_date_datetime - start_date_datetime) // timedelta(minutes=cadence))]:
218
  #make sure the file exists in all wavelengths directories
219
+ for wl in [94, 131, 171, 193, 211, 304, 335]:
220
  if not os.path.exists(os.path.join(
221
  download_dir,
222
  str(wl),
download/sxr_downloader.py CHANGED
@@ -10,11 +10,9 @@ import pandas as pd
10
  class SXRDownloader:
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
- def __init__(self, save_dir: str = '/mnt/data/PAPER/GOES-timespan', concat_dir: str = '/mnt/data/PAPER/GOES-timespan/combined'):
14
  self.save_dir = Path(save_dir)
15
  self.save_dir.mkdir(exist_ok=True)
16
- self.concat_dir = Path(concat_dir)
17
- self.concat_dir.mkdir(exist_ok=True)
18
  self.used_g13_files = []
19
  self.used_g14_files = []
20
  self.used_g15_files = []
 
10
  class SXRDownloader:
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
+ def __init__(self, save_dir: str = '/mnt/data/PAPER/GOES-timespan'):
14
  self.save_dir = Path(save_dir)
15
  self.save_dir.mkdir(exist_ok=True)
 
 
16
  self.used_g13_files = []
17
  self.used_g14_files = []
18
  self.used_g15_files = []
forecasting/data_loaders/SDOAIA_dataloader.py CHANGED
@@ -12,6 +12,16 @@ import glob
12
  import os
13
 
14
 
 
 
 
 
 
 
 
 
 
 
15
  class AIA_GOESDataset(torch.utils.data.Dataset):
16
  """
17
  PyTorch Dataset for loading paired AIA (EUV images) and GOES (SXR flux) data.
@@ -354,7 +364,7 @@ class AIA_GOESDataModule(LightningDataModule):
354
  self.train_ds = AIA_GOESDataset(
355
  aia_dir=self.aia_train_dir,
356
  sxr_dir=self.sxr_train_dir,
357
- sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
358
  target_size=(512, 512),
359
  wavelengths=self.wavelengths,
360
  cadence=1,
@@ -368,7 +378,7 @@ class AIA_GOESDataModule(LightningDataModule):
368
  self.val_ds = AIA_GOESDataset(
369
  aia_dir=self.aia_val_dir,
370
  sxr_dir=self.sxr_val_dir,
371
- sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
372
  target_size=(512, 512),
373
  wavelengths=self.wavelengths,
374
  cadence=1,
@@ -382,7 +392,7 @@ class AIA_GOESDataModule(LightningDataModule):
382
  self.test_ds = AIA_GOESDataset(
383
  aia_dir=self.aia_test_dir,
384
  sxr_dir=self.sxr_test_dir,
385
- sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
386
  target_size=(512, 512),
387
  wavelengths=self.wavelengths,
388
  cadence=1,
 
12
  import os
13
 
14
 
15
+ class SXRLogNormTransform:
16
+ """Picklable SXR log-normalization transform (replaces T.Lambda for spawn compatibility)."""
17
+ def __init__(self, mean: float, std: float):
18
+ self.mean = mean
19
+ self.std = std
20
+
21
+ def __call__(self, x: float) -> float:
22
+ return (np.log10(x + 1e-8) - self.mean) / self.std
23
+
24
+
25
  class AIA_GOESDataset(torch.utils.data.Dataset):
26
  """
27
  PyTorch Dataset for loading paired AIA (EUV images) and GOES (SXR flux) data.
 
364
  self.train_ds = AIA_GOESDataset(
365
  aia_dir=self.aia_train_dir,
366
  sxr_dir=self.sxr_train_dir,
367
+ sxr_transform=SXRLogNormTransform(self.sxr_norm[0], self.sxr_norm[1]),
368
  target_size=(512, 512),
369
  wavelengths=self.wavelengths,
370
  cadence=1,
 
378
  self.val_ds = AIA_GOESDataset(
379
  aia_dir=self.aia_val_dir,
380
  sxr_dir=self.sxr_val_dir,
381
+ sxr_transform=SXRLogNormTransform(self.sxr_norm[0], self.sxr_norm[1]),
382
  target_size=(512, 512),
383
  wavelengths=self.wavelengths,
384
  cadence=1,
 
392
  self.test_ds = AIA_GOESDataset(
393
  aia_dir=self.aia_test_dir,
394
  sxr_dir=self.sxr_test_dir,
395
+ sxr_transform=SXRLogNormTransform(self.sxr_norm[0], self.sxr_norm[1]),
396
  target_size=(512, 512),
397
  wavelengths=self.wavelengths,
398
  cadence=1,
forecasting/training/train.py CHANGED
@@ -84,69 +84,6 @@ def resolve_config_variables(config_dict):
84
  return recursive_substitute(config_dict, variables)
85
 
86
 
87
- # Parser
88
- parser = argparse.ArgumentParser()
89
- parser.add_argument('-config', type=str, default='config.yaml', required=True, help='Path to config YAML.')
90
- args = parser.parse_args()
91
-
92
- # Load config with variable substitution
93
- with open(args.config, 'r') as stream:
94
- config_data = yaml.load(stream, Loader=yaml.SafeLoader)
95
-
96
- # Resolve variables like ${base_data_dir}
97
- config_data = resolve_config_variables(config_data)
98
-
99
-
100
- # Debug: Print resolved paths
101
- print("Resolved paths:")
102
- print(f"AIA dir: {config_data['data']['aia_dir']}")
103
- print(f"SXR dir: {config_data['data']['sxr_dir']}")
104
- print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
105
-
106
- sxr_norm = np.load(config_data['data']['sxr_norm_path'])
107
- training_wavelengths = config_data['wavelengths']
108
-
109
-
110
- # DataModule
111
- data_loader = AIA_GOESDataModule(
112
- aia_train_dir= config_data['data']['aia_dir']+"/train",
113
- aia_val_dir=config_data['data']['aia_dir']+"/val",
114
- aia_test_dir=config_data['data']['aia_dir']+"/test",
115
- sxr_train_dir=config_data['data']['sxr_dir']+"/train",
116
- sxr_val_dir=config_data['data']['sxr_dir']+"/val",
117
- sxr_test_dir=config_data['data']['sxr_dir']+"/test",
118
- batch_size=config_data['batch_size'],
119
- num_workers=min(8, os.cpu_count()), # Limit workers to prevent shm issues
120
- sxr_norm=sxr_norm,
121
- wavelengths=training_wavelengths,
122
- oversample=config_data['oversample'],
123
- balance_strategy=config_data['balance_strategy'],
124
- )
125
- data_loader.setup()
126
- # Logger
127
- #wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
128
- wandb_logger = WandbLogger(
129
- entity=config_data['wandb']['entity'],
130
- project=config_data['wandb']['project'],
131
- job_type=config_data['wandb']['job_type'],
132
- tags=config_data['wandb']['tags'],
133
- name=config_data['wandb']['run_name'],
134
- notes=config_data['wandb']['notes'],
135
- config=config_data
136
- )
137
-
138
- # Logging callback
139
- total_n_valid = len(data_loader.val_ds)
140
- plot_data = [data_loader.val_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
141
- plot_samples = plot_data # Keep as list of ((aia, sxr), target)
142
- #sxr_callback = SXRPredictionLogger(plot_samples)
143
-
144
- sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
145
- # Attention map callback - get patch size from config
146
- patch_size = config_data.get('vit_architecture', {}).get('patch_size', 16)
147
- attention = AttentionMapCallback(patch_size=patch_size, use_local_attention=True)
148
-
149
-
150
  class PTHCheckpointCallback(Callback):
151
  """
152
  Custom PyTorch Lightning callback to save model checkpoints in `.pth` format.
@@ -209,65 +146,6 @@ class PTHCheckpointCallback(Callback):
209
 
210
 
211
 
212
- # Checkpoint callback
213
- checkpoint_callback = ModelCheckpoint(
214
- dirpath=config_data['data']['checkpoints_dir'],
215
- monitor='val_total_loss',
216
- mode='min',
217
- save_top_k=10,
218
- filename=f"{config_data['wandb']['run_name']}-{{epoch:02d}}-{{val_total_loss:.4f}}"
219
- )
220
-
221
- pth_callback = PTHCheckpointCallback(
222
- dirpath=config_data['data']['checkpoints_dir'],
223
- monitor='val_total_loss',
224
- mode='min',
225
- save_top_k=1,
226
- filename_prefix=config_data['wandb']['run_name']
227
- )
228
-
229
- def process_batch(batch_data, sxr_norm, c_threshold, m_threshold, x_threshold):
230
- """
231
- Process a batch of SXR data to count flare occurrences in different intensity classes.
232
-
233
- Parameters
234
- ----------
235
- batch_data : tuple
236
- Tuple containing (batch, batch_idx).
237
- sxr_norm : np.ndarray
238
- Normalization parameters for SXR values.
239
- c_threshold, m_threshold, x_threshold : float
240
- Thresholds defining flare intensity categories.
241
-
242
- Returns
243
- -------
244
- dict
245
- Dictionary containing counts for quiet, C, M, and X class flares.
246
- """
247
- from forecasting.models.vit_patch_model import unnormalize_sxr
248
-
249
- batch, batch_idx = batch_data
250
- _, sxr = batch
251
-
252
- # Unnormalize the SXR values
253
- sxr_un = unnormalize_sxr(sxr, sxr_norm)
254
- sxr_un_flat = sxr_un.view(-1).cpu().numpy()
255
-
256
- total = len(sxr_un_flat)
257
- quiet_count = ((sxr_un_flat < c_threshold)).sum()
258
- c_count = ((sxr_un_flat >= c_threshold) & (sxr_un_flat < m_threshold)).sum()
259
- m_count = ((sxr_un_flat >= m_threshold) & (sxr_un_flat < x_threshold)).sum()
260
- x_count = ((sxr_un_flat >= x_threshold)).sum()
261
-
262
- return {
263
- 'total': total,
264
- 'quiet_count': quiet_count,
265
- 'c_count': c_count,
266
- 'm_count': m_count,
267
- 'x_count': x_count,
268
- 'batch_idx': batch_idx
269
- }
270
-
271
  def get_base_weights(data_loader, sxr_norm):
272
  """
273
  Compute inverse-frequency weights for flare classes based on training data.
@@ -353,94 +231,126 @@ def get_base_weights(data_loader, sxr_norm):
353
 
354
 
355
 
356
- base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
357
- model = ViTLocal(model_kwargs=config_data['vit_architecture'], sxr_norm = sxr_norm, base_weights=base_weights)
358
-
359
-
360
- # Set device based on config
361
- # Support both old 'gpu_id' and new 'gpu_ids' config keys for backward compatibility
362
- gpu_config = config_data.get('gpu_ids', config_data.get('gpu_id', 0))
363
-
364
- if gpu_config == -1:
365
- """
366
- Use CPU for training if GPU config is set to -1.
367
- """
368
- # CPU only
369
- accelerator = "cpu"
370
- devices = 1
371
- strategy = "auto"
372
- print("Using CPU for training")
373
- elif gpu_config == "all":
374
- """
375
- Use all available GPUs if GPU config is set to 'all'.
376
- """
377
- # Use all available GPUs
378
- if torch.cuda.is_available():
379
- accelerator = "gpu"
380
- devices = -1 # -1 means use all available GPUs
381
- num_gpus = torch.cuda.device_count()
382
- strategy = "auto"
383
- print(f"Using all available GPUs ({num_gpus} GPUs)")
384
- if num_gpus > 1:
385
- print(f"Multi-GPU training with DDP: Effective batch size = {config_data['batch_size']} x {num_gpus} GPUs = {config_data['batch_size'] * num_gpus}")
386
- else:
387
- accelerator = "cpu"
388
- devices = 1
389
- strategy = "auto"
390
- print("No GPUs available, falling back to CPU")
391
- elif isinstance(gpu_config, list):
392
- """
393
- Use specific GPU IDs if provided as a list.
394
- """
395
- # Multiple specific GPUs
396
- if torch.cuda.is_available():
397
- accelerator = "gpu"
398
- devices = gpu_config
399
- strategy = "auto"
400
- print(f"Using GPUs: {gpu_config}")
401
- if len(gpu_config) > 1:
402
- print(f"Multi-GPU training with DDP: Effective batch size = {config_data['batch_size']} x {len(gpu_config)} GPUs = {config_data['batch_size'] * len(gpu_config)}")
403
- else:
404
- accelerator = "cpu"
405
- devices = 1
406
- strategy = "auto"
407
- print("No GPUs available, falling back to CPU")
408
- else:
409
- """
410
- Use a single GPU or CPU based on availability.
411
- """
412
- # Single GPU (integer)
413
- if torch.cuda.is_available():
414
- accelerator = "gpu"
415
- devices = [gpu_config]
416
- strategy = "auto"
417
- print(f"Using GPU {gpu_config}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  else:
419
- accelerator = "cpu"
420
- devices = 1
421
- strategy = "auto"
422
- print(f"GPU {gpu_config} not available, falling back to CPU")
423
-
424
- # Trainer
425
- trainer = Trainer(
426
- default_root_dir=config_data['data']['checkpoints_dir'],
427
- accelerator=accelerator,
428
- devices=devices,
429
- strategy=strategy,
430
- max_epochs=config_data['epochs'],
431
- callbacks=[attention, checkpoint_callback],
432
- logger=wandb_logger,
433
- log_every_n_steps=10,
434
- )
435
- trainer.fit(model, data_loader)
436
-
437
- # Save final PyTorch checkpoint with model and state_dict
438
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
439
- final_checkpoint_path = os.path.join(config_data['data']['checkpoints_dir'], f"{config_data['wandb']['run_name']}-final-{timestamp}.pth")
440
- torch.save({
441
- 'model': model,
442
- 'state_dict': model.state_dict()
443
- }, final_checkpoint_path)
444
- print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
445
- # Finalize
446
- wandb.finish()
 
 
84
  return recursive_substitute(config_dict, variables)
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  class PTHCheckpointCallback(Callback):
88
  """
89
  Custom PyTorch Lightning callback to save model checkpoints in `.pth` format.
 
146
 
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def get_base_weights(data_loader, sxr_norm):
150
  """
151
  Compute inverse-frequency weights for flare classes based on training data.
 
231
 
232
 
233
 
234
+ if __name__ == '__main__':
235
+ # Parser
236
+ parser = argparse.ArgumentParser()
237
+ parser.add_argument('-config', type=str, default='config.yaml', required=True, help='Path to config YAML.')
238
+ args = parser.parse_args()
239
+
240
+ # Load config with variable substitution
241
+ with open(args.config, 'r') as stream:
242
+ config_data = yaml.load(stream, Loader=yaml.SafeLoader)
243
+ config_data = resolve_config_variables(config_data)
244
+
245
+ print("Resolved paths:")
246
+ print(f"AIA dir: {config_data['data']['aia_dir']}")
247
+ print(f"SXR dir: {config_data['data']['sxr_dir']}")
248
+ print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
249
+
250
+ sxr_norm = np.load(config_data['data']['sxr_norm_path'])
251
+ training_wavelengths = config_data['wavelengths']
252
+
253
+ # DataModule
254
+ data_loader = AIA_GOESDataModule(
255
+ aia_train_dir=config_data['data']['aia_dir'] + "/train",
256
+ aia_val_dir=config_data['data']['aia_dir'] + "/val",
257
+ aia_test_dir=config_data['data']['aia_dir'] + "/test",
258
+ sxr_train_dir=config_data['data']['sxr_dir'] + "/train",
259
+ sxr_val_dir=config_data['data']['sxr_dir'] + "/val",
260
+ sxr_test_dir=config_data['data']['sxr_dir'] + "/test",
261
+ batch_size=config_data['batch_size'],
262
+ num_workers=min(8, os.cpu_count()),
263
+ sxr_norm=sxr_norm,
264
+ wavelengths=training_wavelengths,
265
+ oversample=config_data['oversample'],
266
+ balance_strategy=config_data['balance_strategy'],
267
+ )
268
+ data_loader.setup()
269
+
270
+ # Logger
271
+ wandb_logger = WandbLogger(
272
+ entity=config_data['wandb']['entity'],
273
+ project=config_data['wandb']['project'],
274
+ job_type=config_data['wandb']['job_type'],
275
+ tags=config_data['wandb']['tags'],
276
+ name=config_data['wandb']['run_name'],
277
+ notes=config_data['wandb']['notes'],
278
+ config=config_data
279
+ )
280
+
281
+ # Callbacks
282
+ total_n_valid = len(data_loader.val_ds)
283
+ plot_samples = [data_loader.val_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
284
+ sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
285
+ patch_size = config_data.get('vit_architecture', {}).get('patch_size', 16)
286
+ attention = AttentionMapCallback(patch_size=patch_size, use_local_attention=True)
287
+
288
+ base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
289
+ model = ViTLocal(model_kwargs=config_data['vit_architecture'], sxr_norm=sxr_norm, base_weights=base_weights)
290
+
291
+ # Checkpoint callbacks
292
+ checkpoint_callback = ModelCheckpoint(
293
+ dirpath=config_data['data']['checkpoints_dir'],
294
+ monitor='val_total_loss',
295
+ mode='min',
296
+ save_top_k=10,
297
+ filename=f"{config_data['wandb']['run_name']}-{{epoch:02d}}-{{val_total_loss:.4f}}"
298
+ )
299
+ pth_callback = PTHCheckpointCallback(
300
+ dirpath=config_data['data']['checkpoints_dir'],
301
+ monitor='val_total_loss',
302
+ mode='min',
303
+ save_top_k=1,
304
+ filename_prefix=config_data['wandb']['run_name']
305
+ )
306
+
307
+ # Set device based on config
308
+ gpu_config = config_data.get('gpu_ids', config_data.get('gpu_id', 0))
309
+ if gpu_config == -1:
310
+ accelerator, devices, strategy = "cpu", 1, "auto"
311
+ print("Using CPU for training")
312
+ elif gpu_config == "all":
313
+ if torch.cuda.is_available():
314
+ accelerator, devices, strategy = "gpu", -1, "auto"
315
+ num_gpus = torch.cuda.device_count()
316
+ print(f"Using all available GPUs ({num_gpus} GPUs)")
317
+ else:
318
+ accelerator, devices, strategy = "cpu", 1, "auto"
319
+ print("No GPUs available, falling back to CPU")
320
+ elif isinstance(gpu_config, list):
321
+ if torch.cuda.is_available():
322
+ accelerator, devices, strategy = "gpu", gpu_config, "auto"
323
+ print(f"Using GPUs: {gpu_config}")
324
+ else:
325
+ accelerator, devices, strategy = "cpu", 1, "auto"
326
+ print("No GPUs available, falling back to CPU")
327
  else:
328
+ if torch.cuda.is_available():
329
+ accelerator, devices, strategy = "gpu", [gpu_config], "auto"
330
+ print(f"Using GPU {gpu_config}")
331
+ else:
332
+ accelerator, devices, strategy = "cpu", 1, "auto"
333
+ print(f"GPU {gpu_config} not available, falling back to CPU")
334
+
335
+ # Trainer
336
+ trainer = Trainer(
337
+ default_root_dir=config_data['data']['checkpoints_dir'],
338
+ accelerator=accelerator,
339
+ devices=devices,
340
+ strategy=strategy,
341
+ max_epochs=config_data['epochs'],
342
+ callbacks=[attention, checkpoint_callback],
343
+ logger=wandb_logger,
344
+ log_every_n_steps=10,
345
+ )
346
+ trainer.fit(model, data_loader)
347
+
348
+ # Save final checkpoint
349
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
350
+ final_checkpoint_path = os.path.join(
351
+ config_data['data']['checkpoints_dir'],
352
+ f"{config_data['wandb']['run_name']}-final-{timestamp}.pth"
353
+ )
354
+ torch.save({'model': model, 'state_dict': model.state_dict()}, final_checkpoint_path)
355
+ print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
356
+ wandb.finish()
forecasting/training/train_config.yaml CHANGED
@@ -1,7 +1,7 @@
1
 
2
  #Base directories - change these to switch datasets
3
- base_data_dir: "/Volumes/T9/FOXES_Data" # Change this line for different datasets
4
- base_checkpoint_dir: "/Volumes/T9/FOXES_Data" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
6
 
7
  # GPU configuration
@@ -35,11 +35,11 @@ vit_architecture:
35
  # Data paths (automatically constructed from base directories)
36
  data:
37
  aia_dir:
38
- "${base_data_dir}/AIA"
39
  sxr_dir:
40
- "${base_data_dir}/SXR"
41
  sxr_norm_path:
42
- "${base_data_dir}/SXR/normalized_sxr.npy"
43
  checkpoints_dir:
44
  "${base_checkpoint_dir}/new-checkpoint/"
45
 
 
1
 
2
  #Base directories - change these to switch datasets
3
+ base_data_dir: "/Volumes/T9/Data_FOXES" # Change this line for different datasets
4
+ base_checkpoint_dir: "/Volumes/T9/Data_FOXES" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
6
 
7
  # GPU configuration
 
35
  # Data paths (automatically constructed from base directories)
36
  data:
37
  aia_dir:
38
+ "${base_data_dir}/AIA_processed"
39
  sxr_dir:
40
+ "${base_data_dir}/SXR_processed"
41
  sxr_norm_path:
42
+ "${base_data_dir}/SXR_processed/normalized_sxr.npy"
43
  checkpoints_dir:
44
  "${base_checkpoint_dir}/new-checkpoint/"
45
 
pipeline_config.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # FOXES Pipeline Configuration
3
+ # =============================================================================
4
+ # Used by run_pipeline.py to run any combination of pipeline steps.
5
+ #
6
+ # Usage:
7
+ # python run_pipeline.py --config pipeline_config.yaml --steps all
8
+ # python run_pipeline.py --config pipeline_config.yaml --steps train,inference,flare_analysis
9
+ # python run_pipeline.py --list
10
+
11
+ # -----------------------------------------------------------------------------
12
+ # Shared date range (used by download_aia and download_sxr)
13
+ # -----------------------------------------------------------------------------
14
+ start_date: "2014-07-01 00:00:00"
15
+ end_date: "2014-07-08 00:00:00"
16
+
17
+ # -----------------------------------------------------------------------------
18
+ # AIA download (step: download_aia)
19
+ # -----------------------------------------------------------------------------
20
+ aia:
21
+ download_dir: "/Volumes/T9/Data_FOXES/AIA_raw"
22
+ email: "" # Must be registered at http://jsoc.stanford.edu
23
+ cadence: 1 # Minutes between frames
24
+
25
+ # -----------------------------------------------------------------------------
26
+ # SXR download (step: download_sxr)
27
+ # -----------------------------------------------------------------------------
28
+ sxr:
29
+ save_dir: "/Volumes/T9/Data_FOXES/SXR_raw"
30
+
31
+ # -----------------------------------------------------------------------------
32
+ # Preprocessing (step: preprocess)
33
+ # -----------------------------------------------------------------------------
34
+ preprocess:
35
+ config: "data/pipeline_config.yaml" # PipelineConfig for process_data_pipeline.py
36
+
37
+ # -----------------------------------------------------------------------------
38
+ # SXR normalization (step: normalize)
39
+ # -----------------------------------------------------------------------------
40
+ normalize:
41
+ sxr_dir: "/Volumes/T9/Data_FOXES/SXR_processed/train"
42
+ output_path: "/Volumes/T9/Data_FOXES/SXR_processed/normalized_sxr.npy"
43
+
44
+ # -----------------------------------------------------------------------------
45
+ # Train/val/test split (step: split)
46
+ # Runs split_data.py once for AIA and once for SXR.
47
+ # -----------------------------------------------------------------------------
48
+ split:
49
+ aia_input_dir: "/Volumes/T9/Data_FOXES/AIA_processed" # splits into AIA_processed/train|val|test
50
+ sxr_input_dir: "/Volumes/T9/Data_FOXES/SXR_processed" # splits into SXR_processed/train|val|test
51
+ train_start: "2014-07-01"
52
+ train_end: "2014-07-05"
53
+ val_start: "2014-07-06"
54
+ val_end: "2014-07-07"
55
+ test_start: "2014-07-08"
56
+ test_end: "2025-12-31"
57
+
58
+ # -----------------------------------------------------------------------------
59
+ # Training (step: train)
60
+ # -----------------------------------------------------------------------------
61
+ train:
62
+ config: "forecasting/training/train_config.yaml"
63
+ overrides: # Any key from train_config.yaml can go here
64
+ base_data_dir: "/Volumes/T9/Data_FOXES"
65
+ base_checkpoint_dir: "/Volumes/T9/Data_FOXES"
66
+ epochs: 150
67
+ batch_size: 6
68
+ wandb:
69
+ run_name: "pipeline-run"
70
+ entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
71
+ project: Paper
72
+ job_type: training
73
+ tags:
74
+ - aia
75
+ - sxr
76
+ - regression
77
+ run_name: paper-8-patch-4ch
78
+ notes: Regression from AIA images to SXR images using ViTLocal model with 8x8 patches
79
+
80
+ # -----------------------------------------------------------------------------
81
+ # Inference & flare analysis (steps: inference, flare_analysis)
82
+ # -----------------------------------------------------------------------------
83
+ inference:
84
+ config: "forecasting/inference/local_config.yaml"
85
+ overrides: # Any key from local_config.yaml can go here
86
+ paths:
87
+ data_dir: "/Volumes/T9/Data_FOXES"
requirements.txt CHANGED
@@ -2,6 +2,7 @@
2
  sunpy[all]
3
  sunpy-soar
4
  astropy
 
5
  drms
6
  itipy
7
 
 
2
  sunpy[all]
3
  sunpy-soar
4
  astropy
5
+ aiapy==0.6.4 # itipy 0.1.1 requires calibrate.util which was renamed in 0.7.0
6
  drms
7
  itipy
8
 
run_pipeline.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FOXES End-to-End Pipeline Orchestrator
4
+
5
+ Runs any combination of pipeline steps in order:
6
+
7
+ 1. download_aia - Download SDO/AIA EUV images from JSOC (download/download_sdo.py)
8
+ 2. download_sxr - Download GOES SXR flux data (download/sxr_downloader.py)
9
+ 3. combine_sxr - Combine raw GOES .nc files into per-satellite CSVs (data/sxr_data_processing.py)
10
+ 4. preprocess - EUV cleaning, ITI processing, data alignment (data/process_data_pipeline.py)
11
+ 5. split - Split AIA + SXR into train/val/test (data/split_data.py)
12
+ 6. normalize - Compute SXR normalization stats on train split (data/sxr_normalization.py)
13
+ 7. train - Train the ViTLocal forecasting model (forecasting/training/train.py)
14
+ 8. inference - Run batch inference on val/test data (forecasting/inference/inference.py)
15
+ 9. flare_analysis - Detect, track, and match flares (forecasting/inference/flare_analysis.py)
16
+
17
+ Usage:
18
+ python run_pipeline.py --list
19
+ python run_pipeline.py --config pipeline_config.yaml --steps all
20
+ python run_pipeline.py --config pipeline_config.yaml --steps train,inference,flare_analysis
21
+ """
22
+
23
+ import argparse
24
+ import logging
25
+ import subprocess
26
+ import sys
27
+ import time
28
+ from pathlib import Path
29
+
30
+ import yaml
31
+
32
+ ROOT = Path(__file__).parent
33
+
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(asctime)s %(levelname)s %(message)s",
37
+ handlers=[
38
+ logging.StreamHandler(sys.stdout),
39
+ logging.FileHandler(ROOT / "pipeline.log"),
40
+ ],
41
+ )
42
+ log = logging.getLogger(__name__)
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Config helpers
47
+ # ---------------------------------------------------------------------------
48
+
49
+ def deep_merge(base: dict, overrides: dict) -> dict:
50
+ """Recursively merge overrides into base, modifying base in-place."""
51
+ for key, val in overrides.items():
52
+ if isinstance(val, dict) and isinstance(base.get(key), dict):
53
+ deep_merge(base[key], val)
54
+ else:
55
+ base[key] = val
56
+ return base
57
+
58
+
59
+ def write_merged_config(base_path: str, overrides: dict, out_name: str) -> Path:
60
+ """
61
+ Load base_path YAML, apply overrides, write merged result to ROOT/.{out_name}.yaml.
62
+ Returns the path of the merged file.
63
+ """
64
+ with open(base_path) as f:
65
+ base = yaml.safe_load(f) or {}
66
+ deep_merge(base, overrides)
67
+ out = ROOT / f".merged_{out_name}.yaml"
68
+ with open(out, "w") as f:
69
+ yaml.dump(base, f, default_flow_style=False)
70
+ log.info(f" Merged config written to {out}")
71
+ return out
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Step definitions
76
+ # ---------------------------------------------------------------------------
77
+
78
+ STEP_ORDER = [
79
+ "download_aia",
80
+ "download_sxr",
81
+ "combine_sxr",
82
+ "preprocess",
83
+ "split",
84
+ "normalize",
85
+ "train",
86
+ "inference",
87
+ "flare_analysis",
88
+ ]
89
+
90
+ STEP_INFO = {
91
+ "download_aia": {
92
+ "description": "Download SDO/AIA EUV images from JSOC",
93
+ "script": ROOT / "download" / "download_sdo.py",
94
+ },
95
+ "download_sxr": {
96
+ "description": "Download GOES SXR flux data via SXRDownloader",
97
+ "script": None, # invoked inline via python -c
98
+ },
99
+ "combine_sxr": {
100
+ "description": "Combine raw GOES .nc files into per-satellite CSVs for alignment",
101
+ "script": ROOT / "data" / "sxr_data_processing.py",
102
+ },
103
+ "preprocess": {
104
+ "description": "EUV cleaning, ITI processing, and AIA/SXR data alignment",
105
+ "script": ROOT / "data" / "process_data_pipeline.py",
106
+ },
107
+ "normalize": {
108
+ "description": "Compute SXR log-normalization statistics (mean/std)",
109
+ "script": ROOT / "data" / "sxr_normalization.py",
110
+ },
111
+ "split": {
112
+ "description": "Split AIA and SXR data into train/val/test by date range",
113
+ "script": ROOT / "data" / "split_data.py",
114
+ },
115
+ "train": {
116
+ "description": "Train the ViTLocal solar flare forecasting model",
117
+ "script": ROOT / "forecasting" / "training" / "train.py",
118
+ },
119
+ "inference": {
120
+ "description": "Run batch inference and save predictions CSV",
121
+ "script": ROOT / "forecasting" / "inference" / "inference.py",
122
+ },
123
+ "flare_analysis": {
124
+ "description": "Detect, track, and match flares; generate plots/movies",
125
+ "script": ROOT / "forecasting" / "inference" / "flare_analysis.py",
126
+ },
127
+ }
128
+
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # Command builders
132
+ # ---------------------------------------------------------------------------
133
+
134
+ def build_commands(step: str, cfg: dict, force: bool) -> list[list[str]] | None:
135
+ """
136
+ Return a list of subprocess commands for a given step, or None if required config is missing.
137
+ Most steps return a single command; 'split' returns two (AIA then SXR).
138
+ """
139
+
140
+ def require(keys: list[str], section: str = None) -> bool:
141
+ src = cfg.get(section, {}) if section else cfg
142
+ missing = [k for k in keys if not src.get(k)]
143
+ if missing:
144
+ prefix = f"{section}." if section else ""
145
+ log.error(f"pipeline_config.yaml missing required keys: {[prefix + k for k in missing]}")
146
+ return False
147
+ return True
148
+
149
+ if step == "download_aia":
150
+ if not require(["download_dir", "email"], "aia") or not require(["start_date"]):
151
+ return None
152
+ aia = cfg["aia"]
153
+ cmd = [sys.executable, str(STEP_INFO[step]["script"]),
154
+ "--download_dir", aia["download_dir"],
155
+ "--email", aia["email"],
156
+ "--start_date", cfg["start_date"]]
157
+ if cfg.get("end_date"):
158
+ cmd += ["--end_date", cfg["end_date"]]
159
+ if aia.get("cadence"):
160
+ cmd += ["--cadence", str(aia["cadence"])]
161
+ return [cmd]
162
+
163
+ if step == "download_sxr":
164
+ if not require(["save_dir"], "sxr") or not require(["start_date"]):
165
+ return None
166
+ start = cfg["start_date"]
167
+ end = cfg.get("end_date", start)
168
+ save_dir = cfg["sxr"]["save_dir"]
169
+ inline = (
170
+ f"import sys; sys.path.insert(0, r'{ROOT}'); "
171
+ f"from download.sxr_downloader import SXRDownloader; "
172
+ f"d = SXRDownloader(save_dir=r'{save_dir}'); "
173
+ f"d.download_and_save_goes_data(start='{start}', end='{end}')"
174
+ )
175
+ return [[sys.executable, "-c", inline]]
176
+
177
+ if step == "combine_sxr":
178
+ if not require(["save_dir"], "sxr"):
179
+ return None
180
+ raw_dir = cfg["sxr"]["save_dir"]
181
+ combined_dir = str(Path(raw_dir) / "combined")
182
+ return [[sys.executable, str(STEP_INFO[step]["script"]),
183
+ "--data_dir", raw_dir,
184
+ "--output_dir", combined_dir]]
185
+
186
+ script = STEP_INFO[step]["script"]
187
+ base = [sys.executable, str(script)]
188
+
189
+ if step == "preprocess":
190
+ pre = cfg.get("preprocess", {})
191
+ cmd = base[:]
192
+ if pre.get("config"):
193
+ cmd += ["--config", pre["config"]]
194
+ if force:
195
+ cmd += ["--force"]
196
+ return [cmd]
197
+
198
+ if step == "normalize":
199
+ if not require(["sxr_dir", "output_path"], "normalize"):
200
+ return None
201
+ n = cfg["normalize"]
202
+ return [base + ["--sxr_dir", n["sxr_dir"], "--output_path", n["output_path"]]]
203
+
204
+ if step == "split":
205
+ if not require(["aia_input_dir", "sxr_input_dir"], "split"):
206
+ return None
207
+ s = cfg["split"]
208
+ date_args = []
209
+ for key in ("train_start", "train_end", "val_start", "val_end", "test_start", "test_end"):
210
+ if s.get(key):
211
+ date_args += [f"--{key}", s[key]]
212
+ # Each data type splits into its own input directory (creates train/val/test subdirs there)
213
+ aia_cmd = base + ["--input_folder", s["aia_input_dir"], "--output_dir", s["aia_input_dir"],
214
+ "--data_type", "aia"] + date_args
215
+ sxr_cmd = base + ["--input_folder", s["sxr_input_dir"], "--output_dir", s["sxr_input_dir"],
216
+ "--data_type", "sxr"] + date_args
217
+ return [aia_cmd, sxr_cmd]
218
+
219
+ if step == "train":
220
+ if not require(["config"], "train"):
221
+ return None
222
+ t = cfg["train"]
223
+ config_path = t["config"]
224
+ if t.get("overrides"):
225
+ config_path = str(write_merged_config(config_path, t["overrides"], "train_config"))
226
+ return [base + ["-config", config_path]]
227
+
228
+ if step == "inference":
229
+ if not require(["config"], "inference"):
230
+ return None
231
+ inf = cfg["inference"]
232
+ config_path = inf["config"]
233
+ if inf.get("overrides"):
234
+ config_path = str(write_merged_config(config_path, inf["overrides"], "inference_config"))
235
+ return [base + ["-config", config_path]]
236
+
237
+ if step == "flare_analysis":
238
+ if not require(["config"], "inference"):
239
+ return None
240
+ inf = cfg["inference"]
241
+ config_path = inf["config"]
242
+ if inf.get("overrides"):
243
+ config_path = str(write_merged_config(config_path, inf["overrides"], "inference_config"))
244
+ return [base + ["--config", config_path]]
245
+
246
+ return [base]
247
+
248
+
249
+ # ---------------------------------------------------------------------------
250
+ # Runner
251
+ # ---------------------------------------------------------------------------
252
+
253
+ def run_step(step: str, cmds: list[list[str]]) -> bool:
254
+ info = STEP_INFO[step]
255
+ total_start = time.time()
256
+
257
+ for i, cmd in enumerate(cmds):
258
+ label = f"{step.upper()}" + (f" ({i + 1}/{len(cmds)})" if len(cmds) > 1 else "")
259
+ log.info("")
260
+ log.info("=" * 70)
261
+ log.info(f" STEP: {label}")
262
+ log.info(f" {info['description']}")
263
+ log.info(f" {' '.join(str(c) for c in cmd)}")
264
+ log.info("=" * 70)
265
+
266
+ result = subprocess.run(cmd, cwd=ROOT)
267
+ if result.returncode != 0:
268
+ log.error(f" FAILED {label} exited with code {result.returncode}")
269
+ return False
270
+
271
+ elapsed = time.time() - total_start
272
+ log.info(f" DONE {step} completed in {elapsed:.1f}s")
273
+ return True
274
+
275
+
276
+ # ---------------------------------------------------------------------------
277
+ # CLI
278
+ # ---------------------------------------------------------------------------
279
+
280
+ def list_steps():
281
+ print("\nAvailable pipeline steps (in order):\n")
282
+ for i, step in enumerate(STEP_ORDER, 1):
283
+ print(f" {i}. {step:<16} {STEP_INFO[step]['description']}")
284
+ print()
285
+ print("Use --steps all to run every step, or comma-separate specific steps.")
286
+ print("Example: --steps train,inference,flare_analysis\n")
287
+
288
+
289
+ def main():
290
+ parser = argparse.ArgumentParser(
291
+ description="FOXES End-to-End Pipeline Orchestrator",
292
+ formatter_class=argparse.RawDescriptionHelpFormatter,
293
+ )
294
+ parser.add_argument("--config", type=str, default=None, help="Path to pipeline_config.yaml")
295
+ parser.add_argument("--steps", type=str, default=None,
296
+ help=f"Comma-separated steps to run, or 'all'. Available: {', '.join(STEP_ORDER)}")
297
+ parser.add_argument("--list", action="store_true", help="List all available steps and exit")
298
+ parser.add_argument("--force", action="store_true", help="Force re-run (forwarded to preprocess step)")
299
+
300
+ args = parser.parse_args()
301
+
302
+ if args.list:
303
+ list_steps()
304
+ return
305
+
306
+ if not args.steps:
307
+ parser.print_help()
308
+ return
309
+
310
+ if not args.config:
311
+ log.error("--config is required. Point it at your pipeline_config.yaml.")
312
+ sys.exit(1)
313
+
314
+ with open(args.config, "r") as f:
315
+ cfg = yaml.safe_load(f)
316
+
317
+ # Resolve step list
318
+ if args.steps.strip().lower() == "all":
319
+ selected = list(STEP_ORDER)
320
+ else:
321
+ selected = [s.strip() for s in args.steps.split(",")]
322
+ unknown = [s for s in selected if s not in STEP_INFO]
323
+ if unknown:
324
+ log.error(f"Unknown steps: {', '.join(unknown)}")
325
+ list_steps()
326
+ sys.exit(1)
327
+ selected = [s for s in STEP_ORDER if s in selected] # preserve order
328
+
329
+ log.info(f"Config: {args.config}")
330
+ log.info(f"Running {len(selected)} step(s): {' -> '.join(selected)}")
331
+
332
+ passed, failed = [], []
333
+
334
+ for step in selected:
335
+ cmds = build_commands(step, cfg, args.force)
336
+ if cmds is None:
337
+ failed.append(step)
338
+ break
339
+
340
+ if run_step(step, cmds):
341
+ passed.append(step)
342
+ else:
343
+ failed.append(step)
344
+ log.error(f"Pipeline stopped at '{step}'.")
345
+ break
346
+
347
+ # Summary
348
+ log.info("")
349
+ log.info("=" * 70)
350
+ log.info("PIPELINE SUMMARY")
351
+ log.info("=" * 70)
352
+ for s in passed:
353
+ log.info(f" PASSED {s}")
354
+ for s in failed:
355
+ log.error(f" FAILED {s}")
356
+ for s in [s for s in selected if s not in passed and s not in failed]:
357
+ log.info(f" SKIPPED {s}")
358
+ log.info("=" * 70)
359
+
360
+ sys.exit(0 if not failed else 1)
361
+
362
+
363
+ if __name__ == "__main__":
364
+ main()