Commit ·
ec2b4e7
1
Parent(s): 10ad6fc
Refactor pipeline configuration and update data processing scripts
Browse files- .gitignore +1 -1
- .dockerignore → Untracked/.dockerignore +1 -1
- QUICKSTART_DOCKER.md → Untracked/QUICKSTART_DOCKER.md +0 -0
- data/align_data.py +10 -29
- data/euv_data_cleaning.py +48 -62
- data/iti_data_processing.py +61 -81
- data/pipeline_config.py +106 -0
- data/pipeline_config.yaml +47 -0
- data/process_data_pipeline.py +5 -13
- data/sxr_data_processing.py +29 -29
- download/download_sdo.py +27 -50
- download/sxr_downloader.py +1 -3
- forecasting/data_loaders/SDOAIA_dataloader.py +13 -3
- forecasting/training/train.py +122 -212
- forecasting/training/train_config.yaml +5 -5
- pipeline_config.yaml +87 -0
- requirements.txt +1 -0
- run_pipeline.py +364 -0
.gitignore
CHANGED
|
@@ -154,4 +154,4 @@ wandb/
|
|
| 154 |
*.code-workspace
|
| 155 |
|
| 156 |
.claude/
|
| 157 |
-
|
|
|
|
| 154 |
*.code-workspace
|
| 155 |
|
| 156 |
.claude/
|
| 157 |
+
Untracked/
|
.dockerignore → Untracked/.dockerignore
RENAMED
|
@@ -24,7 +24,7 @@ venv.bak/
|
|
| 24 |
|
| 25 |
# IDE
|
| 26 |
.vscode/
|
| 27 |
-
.idea/
|
| 28 |
*.swp
|
| 29 |
*.swo
|
| 30 |
*~
|
|
|
|
| 24 |
|
| 25 |
# IDE
|
| 26 |
.vscode/
|
| 27 |
+
../.idea/
|
| 28 |
*.swp
|
| 29 |
*.swo
|
| 30 |
*~
|
QUICKSTART_DOCKER.md → Untracked/QUICKSTART_DOCKER.md
RENAMED
|
File without changes
|
data/align_data.py
CHANGED
|
@@ -38,9 +38,8 @@ def load_config():
|
|
| 38 |
'alignment': {
|
| 39 |
'goes_data_dir': "/mnt/data/PAPER/GOES-timespan/combined",
|
| 40 |
'aia_processed_dir': "/mnt/data/PAPER/SDOITI",
|
| 41 |
-
'
|
| 42 |
-
'
|
| 43 |
-
'aia_missing_dir': "/mnt/data/PAPER/AIA_ITI_MISSING"
|
| 44 |
},
|
| 45 |
'processing': {
|
| 46 |
'batch_size_multiplier': 4,
|
|
@@ -56,8 +55,7 @@ GOES_DATA_DIR = config['alignment']['goes_data_dir']
|
|
| 56 |
AIA_PROCESSED_DIR = config['alignment']['aia_processed_dir']
|
| 57 |
|
| 58 |
# Output directories
|
| 59 |
-
|
| 60 |
-
OUTPUT_SXR_B_DIR = config['alignment']['output_sxr_b_dir']
|
| 61 |
AIA_MISSING_DIR = config['alignment']['aia_missing_dir']
|
| 62 |
|
| 63 |
# Processing configuration
|
|
@@ -136,7 +134,6 @@ def create_combined_lookup_table(goes_data_dict, target_timestamps):
|
|
| 136 |
|
| 137 |
# For each target timestamp, average over all available instruments at that time
|
| 138 |
for target_time in tqdm(target_times, desc="Building lookup table"):
|
| 139 |
-
sxr_a_values = []
|
| 140 |
sxr_b_values = []
|
| 141 |
available_instruments = []
|
| 142 |
|
|
@@ -144,22 +141,15 @@ def create_combined_lookup_table(goes_data_dict, target_timestamps):
|
|
| 144 |
goes_data = goes_data_dict[g_number]
|
| 145 |
if target_time in goes_data.index:
|
| 146 |
row = goes_data.loc[target_time]
|
| 147 |
-
sxr_a = row['xrsa_flux']
|
| 148 |
sxr_b = row['xrsb_flux']
|
| 149 |
-
# Only care about xrsb_flux for validity
|
| 150 |
if not pd.isna(sxr_b):
|
| 151 |
sxr_b_values.append(float(sxr_b))
|
| 152 |
-
if not pd.isna(sxr_a):
|
| 153 |
-
sxr_a_values.append(float(sxr_a))
|
| 154 |
available_instruments.append(f"GOES-{g_number}")
|
| 155 |
|
| 156 |
if sxr_b_values:
|
| 157 |
-
avg_sxr_b = float(np.mean(sxr_b_values))
|
| 158 |
-
avg_sxr_a = float(np.mean(sxr_a_values)) if sxr_a_values else float('nan')
|
| 159 |
lookup_data.append({
|
| 160 |
'timestamp': target_time.strftime('%Y-%m-%dT%H:%M:%S'),
|
| 161 |
-
'
|
| 162 |
-
'sxr_b': avg_sxr_b,
|
| 163 |
'instrument': ",".join(available_instruments)
|
| 164 |
})
|
| 165 |
|
|
@@ -179,21 +169,14 @@ def process_batch(batch_data):
|
|
| 179 |
for data in batch_data:
|
| 180 |
try:
|
| 181 |
timestamp = data['timestamp']
|
| 182 |
-
sxr_a = data['sxr_a']
|
| 183 |
sxr_b = data['sxr_b']
|
| 184 |
instrument = data['instrument']
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
sxr_b_data = np.array([sxr_b], dtype=np.float32)
|
| 189 |
-
|
| 190 |
-
# Save data to disk using configured directories
|
| 191 |
-
np.save(f"{OUTPUT_SXR_A_DIR}/{timestamp}.npy", sxr_a_data)
|
| 192 |
-
np.save(f"{OUTPUT_SXR_B_DIR}/{timestamp}.npy", sxr_b_data)
|
| 193 |
-
|
| 194 |
successful_count += 1
|
| 195 |
results.append((timestamp, True, f"Success using {instrument}"))
|
| 196 |
-
|
| 197 |
except Exception as e:
|
| 198 |
failed_count += 1
|
| 199 |
results.append((timestamp, False, f"Error processing timestamp {timestamp}: {e}"))
|
|
@@ -213,14 +196,12 @@ def main():
|
|
| 213 |
print("=" * 60)
|
| 214 |
print(f"GOES data directory: {GOES_DATA_DIR}")
|
| 215 |
print(f"AIA processed directory: {AIA_PROCESSED_DIR}")
|
| 216 |
-
print(f"Output SXR
|
| 217 |
-
print(f"Output SXR-B directory: {OUTPUT_SXR_B_DIR}")
|
| 218 |
print(f"AIA missing directory: {AIA_MISSING_DIR}")
|
| 219 |
print("=" * 60)
|
| 220 |
|
| 221 |
# Make output directories if they don't exist
|
| 222 |
-
os.makedirs(
|
| 223 |
-
os.makedirs(OUTPUT_SXR_B_DIR, exist_ok=True)
|
| 224 |
os.makedirs(AIA_MISSING_DIR, exist_ok=True)
|
| 225 |
|
| 226 |
# Load and prepare GOES data with optimizations
|
|
|
|
| 38 |
'alignment': {
|
| 39 |
'goes_data_dir': "/mnt/data/PAPER/GOES-timespan/combined",
|
| 40 |
'aia_processed_dir': "/mnt/data/PAPER/SDOITI",
|
| 41 |
+
'output_sxr_dir': "/Volumes/T9/Data_FOXES/SXR_processed",
|
| 42 |
+
'aia_missing_dir': "/Volumes/T9/Data_FOXES/AIA_missing"
|
|
|
|
| 43 |
},
|
| 44 |
'processing': {
|
| 45 |
'batch_size_multiplier': 4,
|
|
|
|
| 55 |
AIA_PROCESSED_DIR = config['alignment']['aia_processed_dir']
|
| 56 |
|
| 57 |
# Output directories
|
| 58 |
+
OUTPUT_SXR_DIR = config['alignment']['output_sxr_dir']
|
|
|
|
| 59 |
AIA_MISSING_DIR = config['alignment']['aia_missing_dir']
|
| 60 |
|
| 61 |
# Processing configuration
|
|
|
|
| 134 |
|
| 135 |
# For each target timestamp, average over all available instruments at that time
|
| 136 |
for target_time in tqdm(target_times, desc="Building lookup table"):
|
|
|
|
| 137 |
sxr_b_values = []
|
| 138 |
available_instruments = []
|
| 139 |
|
|
|
|
| 141 |
goes_data = goes_data_dict[g_number]
|
| 142 |
if target_time in goes_data.index:
|
| 143 |
row = goes_data.loc[target_time]
|
|
|
|
| 144 |
sxr_b = row['xrsb_flux']
|
|
|
|
| 145 |
if not pd.isna(sxr_b):
|
| 146 |
sxr_b_values.append(float(sxr_b))
|
|
|
|
|
|
|
| 147 |
available_instruments.append(f"GOES-{g_number}")
|
| 148 |
|
| 149 |
if sxr_b_values:
|
|
|
|
|
|
|
| 150 |
lookup_data.append({
|
| 151 |
'timestamp': target_time.strftime('%Y-%m-%dT%H:%M:%S'),
|
| 152 |
+
'sxr_b': float(np.mean(sxr_b_values)),
|
|
|
|
| 153 |
'instrument': ",".join(available_instruments)
|
| 154 |
})
|
| 155 |
|
|
|
|
| 169 |
for data in batch_data:
|
| 170 |
try:
|
| 171 |
timestamp = data['timestamp']
|
|
|
|
| 172 |
sxr_b = data['sxr_b']
|
| 173 |
instrument = data['instrument']
|
| 174 |
+
|
| 175 |
+
np.save(f"{OUTPUT_SXR_DIR}/{timestamp}.npy", np.array([sxr_b], dtype=np.float32))
|
| 176 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
successful_count += 1
|
| 178 |
results.append((timestamp, True, f"Success using {instrument}"))
|
| 179 |
+
|
| 180 |
except Exception as e:
|
| 181 |
failed_count += 1
|
| 182 |
results.append((timestamp, False, f"Error processing timestamp {timestamp}: {e}"))
|
|
|
|
| 196 |
print("=" * 60)
|
| 197 |
print(f"GOES data directory: {GOES_DATA_DIR}")
|
| 198 |
print(f"AIA processed directory: {AIA_PROCESSED_DIR}")
|
| 199 |
+
print(f"Output SXR directory: {OUTPUT_SXR_DIR}")
|
|
|
|
| 200 |
print(f"AIA missing directory: {AIA_MISSING_DIR}")
|
| 201 |
print("=" * 60)
|
| 202 |
|
| 203 |
# Make output directories if they don't exist
|
| 204 |
+
os.makedirs(OUTPUT_SXR_DIR, exist_ok=True)
|
|
|
|
| 205 |
os.makedirs(AIA_MISSING_DIR, exist_ok=True)
|
| 206 |
|
| 207 |
# Load and prepare GOES data with optimizations
|
data/euv_data_cleaning.py
CHANGED
|
@@ -14,36 +14,18 @@ collections.MutableMapping = collections.abc.MutableMapping
|
|
| 14 |
from itipy.data.dataset import get_intersecting_files
|
| 15 |
from astropy.io import fits
|
| 16 |
|
| 17 |
-
# Configuration for all wavelengths to process
|
| 18 |
-
# Load configuration from environment or use defaults
|
| 19 |
-
import os
|
| 20 |
import json
|
| 21 |
|
|
|
|
| 22 |
def load_config():
|
| 23 |
"""Load configuration from environment or use defaults."""
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# Default configuration
|
| 32 |
-
return {
|
| 33 |
-
'euv': {
|
| 34 |
-
'wavelengths': [94, 131, 171, 193, 211, 304],
|
| 35 |
-
'input_folder': '/mnt/data/PAPER/SDOData',
|
| 36 |
-
'bad_files_dir': '/mnt/data/PAPER/SDO-AIA_bad'
|
| 37 |
-
}
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
config = load_config()
|
| 41 |
-
wavelengths = config['euv']['wavelengths']
|
| 42 |
-
base_input_folder = config['euv']['input_folder']
|
| 43 |
-
|
| 44 |
-
aia_files = get_intersecting_files(base_input_folder, wavelengths)
|
| 45 |
-
|
| 46 |
-
# Function to process a single file
|
| 47 |
def process_fits_file(file_path):
|
| 48 |
try:
|
| 49 |
with fits.open(file_path) as hdu:
|
|
@@ -59,39 +41,43 @@ def process_fits_file(file_path):
|
|
| 59 |
print(f"Error processing {file_path}: {e}")
|
| 60 |
return None
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from itipy.data.dataset import get_intersecting_files
|
| 15 |
from astropy.io import fits
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
import json
|
| 18 |
|
| 19 |
+
|
| 20 |
def load_config():
|
| 21 |
"""Load configuration from environment or use defaults."""
|
| 22 |
+
try:
|
| 23 |
+
config = json.loads(os.environ['PIPELINE_CONFIG'])
|
| 24 |
+
return config
|
| 25 |
+
except:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def process_fits_file(file_path):
|
| 30 |
try:
|
| 31 |
with fits.open(file_path) as hdu:
|
|
|
|
| 41 |
print(f"Error processing {file_path}: {e}")
|
| 42 |
return None
|
| 43 |
|
| 44 |
+
|
| 45 |
+
if __name__ == '__main__':
|
| 46 |
+
config = load_config()
|
| 47 |
+
wavelengths = config['euv']['wavelengths']
|
| 48 |
+
base_input_folder = config['euv']['input_folder']
|
| 49 |
+
|
| 50 |
+
aia_files = get_intersecting_files(base_input_folder, wavelengths)
|
| 51 |
+
file_list = aia_files[0] # List of FITS file paths
|
| 52 |
+
|
| 53 |
+
with Pool(processes=os.cpu_count()) as pool:
|
| 54 |
+
results = list(tqdm(pool.imap(process_fits_file, file_list), total=len(file_list)))
|
| 55 |
+
|
| 56 |
+
# Filter out None results (in case of failed files)
|
| 57 |
+
results = [r for r in results if r is not None]
|
| 58 |
+
|
| 59 |
+
# Convert to DataFrame
|
| 60 |
+
aia_header = pd.DataFrame(results)
|
| 61 |
+
aia_header['DATE-OBS'] = pd.to_datetime(aia_header['DATE-OBS'])
|
| 62 |
+
|
| 63 |
+
# Add a column for date difference between DATE-OBS and FILENAME
|
| 64 |
+
aia_header['DATE_DIFF'] = (
|
| 65 |
+
pd.to_datetime(aia_header['FILENAME']) - pd.to_datetime(aia_header['DATE-OBS'])
|
| 66 |
+
).dt.total_seconds()
|
| 67 |
+
|
| 68 |
+
# Remove rows where DATE_DIFF is greater than ±60 seconds
|
| 69 |
+
files_to_remove = aia_header[(aia_header['DATE_DIFF'] <= -60) | (aia_header['DATE_DIFF'] >= 60)]
|
| 70 |
+
print(f"{len(files_to_remove)} bad files found")
|
| 71 |
+
|
| 72 |
+
for wavelength in wavelengths:
|
| 73 |
+
print(f"\nProcessing wavelength: {wavelength}")
|
| 74 |
+
for names in files_to_remove['FILENAME'].to_numpy():
|
| 75 |
+
filename = pd.to_datetime(names).strftime('%Y-%m-%dT%H:%M:%S') + ".fits"
|
| 76 |
+
file_path = os.path.join(base_input_folder, f"{wavelength}/{filename}")
|
| 77 |
+
destination_folder = os.path.join(config['euv']['bad_files_dir'], str(wavelength))
|
| 78 |
+
os.makedirs(destination_folder, exist_ok=True)
|
| 79 |
+
if os.path.exists(file_path):
|
| 80 |
+
shutil.move(file_path, destination_folder)
|
| 81 |
+
print(f"Moved: {file_path}")
|
| 82 |
+
else:
|
| 83 |
+
print(f"Not found: {file_path}")
|
data/iti_data_processing.py
CHANGED
|
@@ -10,38 +10,18 @@ from astropy.visualization import ImageNormalize, AsinhStretch
|
|
| 10 |
from itipy.data.dataset import StackDataset, get_intersecting_files, AIADataset
|
| 11 |
from itipy.data.editor import BrightestPixelPatchEditor, sdo_norms
|
| 12 |
import os
|
|
|
|
| 13 |
from multiprocessing import Pool
|
| 14 |
from tqdm import tqdm
|
| 15 |
|
| 16 |
-
# Configuration for all wavelengths to process
|
| 17 |
-
# Load configuration from environment or use defaults
|
| 18 |
-
import os
|
| 19 |
-
import json
|
| 20 |
|
| 21 |
def load_config():
|
| 22 |
"""Load configuration from environment or use defaults."""
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
pass
|
| 29 |
-
|
| 30 |
-
# Default configuration
|
| 31 |
-
return {
|
| 32 |
-
'iti': {
|
| 33 |
-
'wavelengths': [94, 131, 171, 193, 211, 304],
|
| 34 |
-
'input_folder': '/mnt/data/PAPER/SDOData',
|
| 35 |
-
'output_folder': '/mnt/data/PAPER/SDOITI'
|
| 36 |
-
}
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
config = load_config()
|
| 40 |
-
wavelengths = config['iti']['wavelengths']
|
| 41 |
-
base_input_folder = config['iti']['input_folder']
|
| 42 |
-
output_folder = config['iti']['output_folder']
|
| 43 |
-
os.makedirs(output_folder, exist_ok=True)
|
| 44 |
-
|
| 45 |
|
| 46 |
|
| 47 |
|
|
@@ -72,72 +52,72 @@ class SDODataset_flaring(StackDataset):
|
|
| 72 |
self.addEditor(BrightestPixelPatchEditor(patch_shape))
|
| 73 |
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
files = get_intersecting_files(base_input_folder, wavelengths, ext='.fits')
|
| 81 |
if not files or len(files) == 0:
|
| 82 |
return 0, 0
|
| 83 |
-
|
| 84 |
-
# Count existing output files - need to check for each wavelength combination
|
| 85 |
existing_count = 0
|
| 86 |
-
total_expected = len(files[0])
|
| 87 |
-
|
| 88 |
-
# Check each time step (index across all wavelengths)
|
| 89 |
for i in range(total_expected):
|
| 90 |
-
|
| 91 |
-
# The output filename should be based on the first wavelength's filename
|
| 92 |
-
first_wl_file = files[0][i] # Use first wavelength as reference
|
| 93 |
base_name = os.path.splitext(os.path.basename(first_wl_file))[0]
|
| 94 |
-
# Remove wavelength suffix if present (e.g., "_171" from filename)
|
| 95 |
if '_' in base_name:
|
| 96 |
base_name = '_'.join(base_name.split('_')[:-1])
|
| 97 |
output_path = os.path.join(output_folder, base_name) + '.npy'
|
| 98 |
-
|
| 99 |
if os.path.exists(output_path):
|
| 100 |
existing_count += 1
|
| 101 |
-
|
| 102 |
return existing_count, total_expected
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def get_unprocessed_indices():
|
| 118 |
-
unprocessed = []
|
| 119 |
-
for i in range(len(aia_dataset)):
|
| 120 |
-
file_path = os.path.join(output_folder, aia_dataset.getId(i)) + '.npy'
|
| 121 |
-
if not os.path.exists(file_path):
|
| 122 |
-
unprocessed.append(i)
|
| 123 |
-
return unprocessed
|
| 124 |
-
|
| 125 |
-
def save_sample(i):
|
| 126 |
-
try:
|
| 127 |
-
data = aia_dataset[i]
|
| 128 |
-
file_path = os.path.join(output_folder, aia_dataset.getId(i)) + '.npy'
|
| 129 |
-
np.save(file_path, data)
|
| 130 |
-
except Exception as e:
|
| 131 |
-
print(f"Warning: Could not process sample {i} (ID: {aia_dataset.getId(i)}): {e}")
|
| 132 |
-
return # Skip this sample and continue with the next one
|
| 133 |
-
|
| 134 |
-
# Get only unprocessed indices
|
| 135 |
-
unprocessed_indices = get_unprocessed_indices()
|
| 136 |
-
print(f"Processing {len(unprocessed_indices)} unprocessed samples")
|
| 137 |
-
|
| 138 |
-
if unprocessed_indices:
|
| 139 |
-
with Pool(processes=os.cpu_count()) as pool:
|
| 140 |
-
list(tqdm(pool.imap(save_sample, unprocessed_indices), total=len(unprocessed_indices)))
|
| 141 |
-
print("AIA data processing completed.")
|
| 142 |
else:
|
| 143 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from itipy.data.dataset import StackDataset, get_intersecting_files, AIADataset
|
| 11 |
from itipy.data.editor import BrightestPixelPatchEditor, sdo_norms
|
| 12 |
import os
|
| 13 |
+
import json
|
| 14 |
from multiprocessing import Pool
|
| 15 |
from tqdm import tqdm
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def load_config():
|
| 19 |
"""Load configuration from environment or use defaults."""
|
| 20 |
+
try:
|
| 21 |
+
config = json.loads(os.environ['PIPELINE_CONFIG'])
|
| 22 |
+
return config
|
| 23 |
+
except:
|
| 24 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
|
|
|
|
| 52 |
self.addEditor(BrightestPixelPatchEditor(patch_shape))
|
| 53 |
|
| 54 |
|
| 55 |
+
_aia_dataset = None
|
| 56 |
+
_output_folder = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _init_worker(dataset, out_folder):
|
| 60 |
+
global _aia_dataset, _output_folder
|
| 61 |
+
_aia_dataset = dataset
|
| 62 |
+
_output_folder = out_folder
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def save_sample(i):
|
| 66 |
+
try:
|
| 67 |
+
data = _aia_dataset[i]
|
| 68 |
+
file_path = os.path.join(_output_folder, _aia_dataset.getId(i)) + '.npy'
|
| 69 |
+
np.save(file_path, data)
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Warning: Could not process sample {i} (ID: {_aia_dataset.getId(i)}): {e}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def check_existing_files(base_input_folder, wavelengths, output_folder):
|
| 75 |
+
"""Check how many files already exist without loading the full dataset."""
|
| 76 |
files = get_intersecting_files(base_input_folder, wavelengths, ext='.fits')
|
| 77 |
if not files or len(files) == 0:
|
| 78 |
return 0, 0
|
| 79 |
+
|
|
|
|
| 80 |
existing_count = 0
|
| 81 |
+
total_expected = len(files[0])
|
| 82 |
+
|
|
|
|
| 83 |
for i in range(total_expected):
|
| 84 |
+
first_wl_file = files[0][i]
|
|
|
|
|
|
|
| 85 |
base_name = os.path.splitext(os.path.basename(first_wl_file))[0]
|
|
|
|
| 86 |
if '_' in base_name:
|
| 87 |
base_name = '_'.join(base_name.split('_')[:-1])
|
| 88 |
output_path = os.path.join(output_folder, base_name) + '.npy'
|
|
|
|
| 89 |
if os.path.exists(output_path):
|
| 90 |
existing_count += 1
|
| 91 |
+
|
| 92 |
return existing_count, total_expected
|
| 93 |
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
config = load_config()
|
| 97 |
+
wavelengths = config['iti']['wavelengths']
|
| 98 |
+
base_input_folder = config['iti']['input_folder']
|
| 99 |
+
output_folder = config['iti']['output_folder']
|
| 100 |
+
os.makedirs(output_folder, exist_ok=True)
|
| 101 |
+
|
| 102 |
+
existing_files, total_expected = check_existing_files(base_input_folder, wavelengths, output_folder)
|
| 103 |
+
print(f"Found {existing_files} existing files out of {total_expected} expected files")
|
| 104 |
+
|
| 105 |
+
if existing_files >= total_expected:
|
| 106 |
+
print("All files already processed. Nothing to do.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
else:
|
| 108 |
+
print(f"Need to process {total_expected - existing_files} remaining files")
|
| 109 |
+
|
| 110 |
+
aia_dataset = SDODataset_flaring(data=base_input_folder, wavelengths=wavelengths, resolution=512, allow_errors=True)
|
| 111 |
+
|
| 112 |
+
unprocessed_indices = [
|
| 113 |
+
i for i in range(len(aia_dataset))
|
| 114 |
+
if not os.path.exists(os.path.join(output_folder, aia_dataset.getId(i)) + '.npy')
|
| 115 |
+
]
|
| 116 |
+
print(f"Processing {len(unprocessed_indices)} unprocessed samples")
|
| 117 |
+
|
| 118 |
+
if unprocessed_indices:
|
| 119 |
+
with Pool(processes=os.cpu_count(), initializer=_init_worker, initargs=(aia_dataset, output_folder)) as pool:
|
| 120 |
+
list(tqdm(pool.imap(save_sample, unprocessed_indices), total=len(unprocessed_indices)))
|
| 121 |
+
print("AIA data processing completed.")
|
| 122 |
+
else:
|
| 123 |
+
print("All samples already processed. Nothing to do.")
|
data/pipeline_config.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PipelineConfig: loads and validates the data processing pipeline configuration.
|
| 3 |
+
|
| 4 |
+
Used by process_data_pipeline.py. Config is read from a YAML file and passed
|
| 5 |
+
to each sub-script as a JSON string via the PIPELINE_CONFIG environment variable.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import yaml
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
TEMPLATE_PATH = Path(__file__).parent / "pipeline_config.yaml"
|
| 15 |
+
|
| 16 |
+
# Paths that must exist before the pipeline runs
|
| 17 |
+
REQUIRED_INPUT_PATHS = [
|
| 18 |
+
("euv", "input_folder"),
|
| 19 |
+
("iti", "input_folder"),
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
# Paths that should be created before the pipeline runs
|
| 23 |
+
OUTPUT_PATHS = [
|
| 24 |
+
("euv", "bad_files_dir"),
|
| 25 |
+
("iti", "output_folder"),
|
| 26 |
+
("alignment", "output_sxr_dir"),
|
| 27 |
+
("alignment", "aia_missing_dir"),
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PipelineConfig:
|
| 32 |
+
def __init__(self, config_path: str = None):
|
| 33 |
+
if config_path:
|
| 34 |
+
with open(config_path, "r") as f:
|
| 35 |
+
self.config = yaml.safe_load(f)
|
| 36 |
+
else:
|
| 37 |
+
self.config = self._defaults()
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
def get_path(self, section: str, key: str) -> str:
|
| 42 |
+
"""Return config[section][key], or config[section] if key == section."""
|
| 43 |
+
section_data = self.config.get(section, {})
|
| 44 |
+
if isinstance(section_data, dict):
|
| 45 |
+
return section_data.get(key, "")
|
| 46 |
+
return section_data # scalar value (e.g. base_data_dir)
|
| 47 |
+
|
| 48 |
+
def to_json(self) -> str:
|
| 49 |
+
"""Serialize config to JSON string for passing via environment variable."""
|
| 50 |
+
return json.dumps(self.config)
|
| 51 |
+
|
| 52 |
+
# ------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
def validate_paths(self) -> tuple[bool, list[str]]:
|
| 55 |
+
"""Check that all required input paths exist. Returns (valid, missing)."""
|
| 56 |
+
missing = []
|
| 57 |
+
for section, key in REQUIRED_INPUT_PATHS:
|
| 58 |
+
p = self.get_path(section, key)
|
| 59 |
+
if p and not Path(p).exists():
|
| 60 |
+
missing.append(f"{section}.{key}: {p}")
|
| 61 |
+
return (len(missing) == 0, missing)
|
| 62 |
+
|
| 63 |
+
def create_directories(self):
|
| 64 |
+
"""Create all output directories."""
|
| 65 |
+
for section, key in OUTPUT_PATHS:
|
| 66 |
+
p = self.get_path(section, key)
|
| 67 |
+
if p:
|
| 68 |
+
Path(p).mkdir(parents=True, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
def print_config(self):
|
| 71 |
+
print(yaml.dump(self.config, default_flow_style=False))
|
| 72 |
+
|
| 73 |
+
def save_config_template(self, path: str = None):
|
| 74 |
+
dest = Path(path) if path else TEMPLATE_PATH
|
| 75 |
+
with open(dest, "w") as f:
|
| 76 |
+
yaml.dump(self._defaults(), f, default_flow_style=False)
|
| 77 |
+
print(f"Template saved to {dest}")
|
| 78 |
+
|
| 79 |
+
# ------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def _defaults() -> dict:
|
| 83 |
+
return {
|
| 84 |
+
"base_data_dir": "/Volumes/T9/Data_FOXES",
|
| 85 |
+
"euv": {
|
| 86 |
+
"input_folder": "/Volumes/T9/Data_FOXES/AIA_raw",
|
| 87 |
+
"bad_files_dir": "/Volumes/T9/Data_FOXES/AIA_bad",
|
| 88 |
+
"wavelengths": [94, 131, 171, 193, 211, 304, 335],
|
| 89 |
+
},
|
| 90 |
+
"iti": {
|
| 91 |
+
"input_folder": "/Volumes/T9/Data_FOXES/AIA_raw",
|
| 92 |
+
"output_folder": "/Volumes/T9/Data_FOXES/AIA_processed",
|
| 93 |
+
"wavelengths": [94, 131, 171, 193, 211, 304, 335],
|
| 94 |
+
},
|
| 95 |
+
"alignment": {
|
| 96 |
+
"goes_data_dir": "/Volumes/T9/Data_FOXES/SXR_raw/combined",
|
| 97 |
+
"aia_processed_dir": "/Volumes/T9/Data_FOXES/AIA_processed",
|
| 98 |
+
"output_sxr_dir": "/Volumes/T9/Data_FOXES/SXR_processed",
|
| 99 |
+
"aia_missing_dir": "/Volumes/T9/Data_FOXES/AIA_missing",
|
| 100 |
+
},
|
| 101 |
+
"processing": {
|
| 102 |
+
"max_processes": None,
|
| 103 |
+
"batch_size_multiplier": 4,
|
| 104 |
+
"min_batch_size": 1,
|
| 105 |
+
},
|
| 106 |
+
}
|
data/pipeline_config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Processing Pipeline Configuration
|
| 2 |
+
#
|
| 3 |
+
# Usage: python process_data_pipeline.py --config pipeline_config.yaml
|
| 4 |
+
#
|
| 5 |
+
# Directory flow:
|
| 6 |
+
# AIA_raw → euv_data_cleaning → bad files moved to AIA_bad
|
| 7 |
+
# AIA_raw → iti_data_processing → AIA_processed (512x512 .npy)
|
| 8 |
+
# AIA_processed ┐
|
| 9 |
+
# ├─ align_data ──────→ SXR_processed (xrsb_flux .npy per timestamp)
|
| 10 |
+
# SXR_raw/combined ┘ AIA_missing (AIA files with no SXR match)
|
| 11 |
+
|
| 12 |
+
base_data_dir: /Volumes/T9/Data_FOXES
|
| 13 |
+
|
| 14 |
+
euv:
|
| 15 |
+
input_folder: /Volumes/T9/Data_FOXES/AIA_raw
|
| 16 |
+
bad_files_dir: /Volumes/T9/Data_FOXES/AIA_bad
|
| 17 |
+
wavelengths:
|
| 18 |
+
- 94
|
| 19 |
+
- 131
|
| 20 |
+
- 171
|
| 21 |
+
- 193
|
| 22 |
+
- 211
|
| 23 |
+
- 304
|
| 24 |
+
- 335
|
| 25 |
+
|
| 26 |
+
iti:
|
| 27 |
+
input_folder: /Volumes/T9/Data_FOXES/AIA_raw # same as euv (bad files already moved out)
|
| 28 |
+
output_folder: /Volumes/T9/Data_FOXES/AIA_processed
|
| 29 |
+
wavelengths:
|
| 30 |
+
- 94
|
| 31 |
+
- 131
|
| 32 |
+
- 171
|
| 33 |
+
- 193
|
| 34 |
+
- 211
|
| 35 |
+
- 304
|
| 36 |
+
- 335
|
| 37 |
+
|
| 38 |
+
alignment:
|
| 39 |
+
goes_data_dir: /Volumes/T9/Data_FOXES/SXR_raw/combined # output of sxr_downloader concat
|
| 40 |
+
aia_processed_dir: /Volumes/T9/Data_FOXES/AIA_processed # must match iti.output_folder
|
| 41 |
+
output_sxr_dir: /Volumes/T9/Data_FOXES/SXR_processed
|
| 42 |
+
aia_missing_dir: /Volumes/T9/Data_FOXES/AIA_missing
|
| 43 |
+
|
| 44 |
+
processing:
|
| 45 |
+
max_processes: null # null = use all available cores
|
| 46 |
+
batch_size_multiplier: 4
|
| 47 |
+
min_batch_size: 1
|
data/process_data_pipeline.py
CHANGED
|
@@ -105,11 +105,8 @@ class DataProcessingPipeline:
|
|
| 105 |
"""
|
| 106 |
Check if data alignment is complete by looking for output directories.
|
| 107 |
"""
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
Path(self.config.get_path('alignment', 'output_sxr_b_dir'))
|
| 111 |
-
]
|
| 112 |
-
return all(d.exists() and any(d.iterdir()) for d in output_dirs)
|
| 113 |
|
| 114 |
def run_script(self, script_name, step_info):
|
| 115 |
"""
|
|
@@ -135,7 +132,7 @@ class DataProcessingPipeline:
|
|
| 135 |
# Create environment variables for configuration
|
| 136 |
env = os.environ.copy()
|
| 137 |
env.update({
|
| 138 |
-
'PIPELINE_CONFIG':
|
| 139 |
'BASE_DATA_DIR': self.config.get_path('base_data_dir', 'base_data_dir')
|
| 140 |
})
|
| 141 |
|
|
@@ -145,23 +142,18 @@ class DataProcessingPipeline:
|
|
| 145 |
# Run the script
|
| 146 |
result = subprocess.run(
|
| 147 |
[sys.executable, str(script_path)],
|
| 148 |
-
capture_output=True,
|
| 149 |
-
text=True,
|
| 150 |
cwd=self.base_dir,
|
| 151 |
env=env
|
| 152 |
)
|
| 153 |
-
|
| 154 |
end_time = time.time()
|
| 155 |
duration = end_time - start_time
|
| 156 |
-
|
| 157 |
if result.returncode == 0:
|
| 158 |
logger.info(f"✓ {step_info['name']} completed successfully in {duration:.2f} seconds")
|
| 159 |
-
if result.stdout:
|
| 160 |
-
logger.debug(f"Output: {result.stdout}")
|
| 161 |
return True
|
| 162 |
else:
|
| 163 |
logger.error(f"✗ {step_info['name']} failed with return code {result.returncode}")
|
| 164 |
-
logger.error(f"Error output: {result.stderr}")
|
| 165 |
return False
|
| 166 |
|
| 167 |
except Exception as e:
|
|
|
|
| 105 |
"""
|
| 106 |
Check if data alignment is complete by looking for output directories.
|
| 107 |
"""
|
| 108 |
+
output_dir = Path(self.config.get_path('alignment', 'output_sxr_dir'))
|
| 109 |
+
return output_dir.exists() and any(output_dir.iterdir())
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
def run_script(self, script_name, step_info):
|
| 112 |
"""
|
|
|
|
| 132 |
# Create environment variables for configuration
|
| 133 |
env = os.environ.copy()
|
| 134 |
env.update({
|
| 135 |
+
'PIPELINE_CONFIG': self.config.to_json(),
|
| 136 |
'BASE_DATA_DIR': self.config.get_path('base_data_dir', 'base_data_dir')
|
| 137 |
})
|
| 138 |
|
|
|
|
| 142 |
# Run the script
|
| 143 |
result = subprocess.run(
|
| 144 |
[sys.executable, str(script_path)],
|
|
|
|
|
|
|
| 145 |
cwd=self.base_dir,
|
| 146 |
env=env
|
| 147 |
)
|
| 148 |
+
|
| 149 |
end_time = time.time()
|
| 150 |
duration = end_time - start_time
|
| 151 |
+
|
| 152 |
if result.returncode == 0:
|
| 153 |
logger.info(f"✓ {step_info['name']} completed successfully in {duration:.2f} seconds")
|
|
|
|
|
|
|
| 154 |
return True
|
| 155 |
else:
|
| 156 |
logger.error(f"✗ {step_info['name']} failed with return code {result.returncode}")
|
|
|
|
| 157 |
return False
|
| 158 |
|
| 159 |
except Exception as e:
|
data/sxr_data_processing.py
CHANGED
|
@@ -50,10 +50,10 @@ class SXRDataProcessor:
|
|
| 50 |
total_files = len(g13_files) + len(g14_files) + len(g15_files) + len(g16_files) + len(g17_files) + len(g18_files)
|
| 51 |
logging.info(
|
| 52 |
f"Found {len(g13_files)} GOES-13 files, {len(g14_files)} GOES-14 files, {len(g15_files)} GOES-15 files, {len(g16_files)} GOES-16 files, {len(g17_files)} GOES-17 files, and {len(g18_files)} GOES-18 files.")
|
| 53 |
-
print(f"
|
| 54 |
|
| 55 |
if total_files == 0:
|
| 56 |
-
print("
|
| 57 |
return
|
| 58 |
|
| 59 |
def process_files(files, satellite_name, output_file, used_file_list):
|
|
@@ -84,37 +84,37 @@ class SXRDataProcessor:
|
|
| 84 |
ds.close()
|
| 85 |
|
| 86 |
if not datasets:
|
| 87 |
-
print(f"
|
| 88 |
logging.warning(f"No valid datasets for {satellite_name}")
|
| 89 |
return
|
| 90 |
|
| 91 |
-
print(f"
|
| 92 |
|
| 93 |
try:
|
| 94 |
-
print(f"
|
| 95 |
combined_ds = xr.concat(datasets, dim='time').sortby('time')
|
| 96 |
|
| 97 |
# Scaling factors for GOES-13, GOES-14, and GOES-15
|
| 98 |
if satellite_name in ['GOES-13', 'GOES-14', 'GOES-15']:
|
| 99 |
-
print(f"
|
| 100 |
combined_ds['xrsa_flux'] = combined_ds['xrsa_flux'] / .85
|
| 101 |
combined_ds['xrsb_flux'] = combined_ds['xrsb_flux'] / .7
|
| 102 |
|
| 103 |
-
print(f"
|
| 104 |
df = combined_ds.to_dataframe().reset_index()
|
| 105 |
|
| 106 |
if 'quad_diode' in df.columns:
|
| 107 |
-
print(f"
|
| 108 |
df = df[df['quad_diode'] == 0] # Filter out quad diode data
|
| 109 |
|
| 110 |
#Filter out data where xrsb_flux has a quality flag of >0
|
| 111 |
-
print(f"
|
| 112 |
-
df = df[df['
|
| 113 |
|
| 114 |
df['time'] = pd.to_datetime(df['time'])
|
| 115 |
df.set_index('time', inplace=True)
|
| 116 |
|
| 117 |
-
print(f"
|
| 118 |
df_log = np.log10(df[columns_to_interp].replace(0, np.nan))
|
| 119 |
|
| 120 |
# Step 3: Interpolate in log space
|
|
@@ -128,14 +128,14 @@ class SXRDataProcessor:
|
|
| 128 |
max_date = df.index.max().strftime('%Y%m%d')
|
| 129 |
filename = f"{str(output_file)}_{min_date}_{max_date}.csv"
|
| 130 |
|
| 131 |
-
print(f"
|
| 132 |
df.to_csv(filename, index=True)
|
| 133 |
|
| 134 |
-
print(f"
|
| 135 |
logging.info(f"Saved combined file: {output_file}")
|
| 136 |
|
| 137 |
except Exception as e:
|
| 138 |
-
print(f"
|
| 139 |
logging.error(f"Failed to write {output_file}: {e}")
|
| 140 |
finally:
|
| 141 |
for ds in datasets:
|
|
@@ -156,7 +156,7 @@ class SXRDataProcessor:
|
|
| 156 |
if len(g18_files) != 0:
|
| 157 |
satellites_to_process.append((g18_files, "GOES-18", self.output_dir / "combined_g18_avg1m", self.used_g18_files))
|
| 158 |
|
| 159 |
-
print(f"\
|
| 160 |
|
| 161 |
# Process each satellite with overall progress tracking
|
| 162 |
successful_satellites = 0
|
|
@@ -164,36 +164,36 @@ class SXRDataProcessor:
|
|
| 164 |
|
| 165 |
for i, (files, satellite_name, output_file, used_file_list) in enumerate(satellites_to_process, 1):
|
| 166 |
print(f"\n{'='*60}")
|
| 167 |
-
print(f"
|
| 168 |
print(f"{'='*60}")
|
| 169 |
|
| 170 |
try:
|
| 171 |
process_files(files, satellite_name, output_file, used_file_list)
|
| 172 |
successful_satellites += 1
|
| 173 |
except Exception as e:
|
| 174 |
-
print(f"
|
| 175 |
failed_satellites += 1
|
| 176 |
logging.error(f"Failed to process {satellite_name}: {e}")
|
| 177 |
|
| 178 |
# Print final summary
|
| 179 |
print(f"\n{'='*60}")
|
| 180 |
-
print(f"
|
| 181 |
print(f"{'='*60}")
|
| 182 |
-
print(f"
|
| 183 |
-
print(f"
|
| 184 |
-
print(f"
|
| 185 |
-
print(f"
|
| 186 |
|
| 187 |
# Print file usage statistics
|
| 188 |
total_used_files = (len(self.used_g13_files) + len(self.used_g14_files) +
|
| 189 |
len(self.used_g15_files) + len(self.used_g16_files) +
|
| 190 |
len(self.used_g17_files) + len(self.used_g18_files))
|
| 191 |
-
print(f"
|
| 192 |
|
| 193 |
if successful_satellites > 0:
|
| 194 |
-
print(f"\
|
| 195 |
else:
|
| 196 |
-
print(f"\n⚠
|
| 197 |
|
| 198 |
|
| 199 |
if __name__ == '__main__':
|
|
@@ -204,13 +204,13 @@ if __name__ == '__main__':
|
|
| 204 |
help='Directory where combined GOES data will be saved.')
|
| 205 |
args = parser.parse_args()
|
| 206 |
|
| 207 |
-
print("
|
| 208 |
print("=" * 50)
|
| 209 |
-
print(f"
|
| 210 |
-
print(f"
|
| 211 |
print("=" * 50)
|
| 212 |
|
| 213 |
processor = SXRDataProcessor(data_dir=args.data_dir, output_dir=args.output_dir)
|
| 214 |
processor.combine_goes_data()
|
| 215 |
|
| 216 |
-
print("\
|
|
|
|
| 50 |
total_files = len(g13_files) + len(g14_files) + len(g15_files) + len(g16_files) + len(g17_files) + len(g18_files)
|
| 51 |
logging.info(
|
| 52 |
f"Found {len(g13_files)} GOES-13 files, {len(g14_files)} GOES-14 files, {len(g15_files)} GOES-15 files, {len(g16_files)} GOES-16 files, {len(g17_files)} GOES-17 files, and {len(g18_files)} GOES-18 files.")
|
| 53 |
+
print(f"Total files found: {total_files}")
|
| 54 |
|
| 55 |
if total_files == 0:
|
| 56 |
+
print("No GOES data files found in the specified directory.")
|
| 57 |
return
|
| 58 |
|
| 59 |
def process_files(files, satellite_name, output_file, used_file_list):
|
|
|
|
| 84 |
ds.close()
|
| 85 |
|
| 86 |
if not datasets:
|
| 87 |
+
print(f"No valid datasets for {satellite_name}")
|
| 88 |
logging.warning(f"No valid datasets for {satellite_name}")
|
| 89 |
return
|
| 90 |
|
| 91 |
+
print(f"Processing {len(datasets)} datasets for {satellite_name}...")
|
| 92 |
|
| 93 |
try:
|
| 94 |
+
print(f"Concatenating datasets...")
|
| 95 |
combined_ds = xr.concat(datasets, dim='time').sortby('time')
|
| 96 |
|
| 97 |
# Scaling factors for GOES-13, GOES-14, and GOES-15
|
| 98 |
if satellite_name in ['GOES-13', 'GOES-14', 'GOES-15']:
|
| 99 |
+
print(f"Applying scaling factors for {satellite_name}...")
|
| 100 |
combined_ds['xrsa_flux'] = combined_ds['xrsa_flux'] / .85
|
| 101 |
combined_ds['xrsb_flux'] = combined_ds['xrsb_flux'] / .7
|
| 102 |
|
| 103 |
+
print(f"Converting to DataFrame...")
|
| 104 |
df = combined_ds.to_dataframe().reset_index()
|
| 105 |
|
| 106 |
if 'quad_diode' in df.columns:
|
| 107 |
+
print(f"Filtering quad diode data...")
|
| 108 |
df = df[df['quad_diode'] == 0] # Filter out quad diode data
|
| 109 |
|
| 110 |
#Filter out data where xrsb_flux has a quality flag of >0
|
| 111 |
+
print(f"Filtering out data where xrsb_flux has a quality flag of >0...")
|
| 112 |
+
df = df[df['xrsb_flag'] == 0]
|
| 113 |
|
| 114 |
df['time'] = pd.to_datetime(df['time'])
|
| 115 |
df.set_index('time', inplace=True)
|
| 116 |
|
| 117 |
+
print(f"Applying log interpolation...")
|
| 118 |
df_log = np.log10(df[columns_to_interp].replace(0, np.nan))
|
| 119 |
|
| 120 |
# Step 3: Interpolate in log space
|
|
|
|
| 128 |
max_date = df.index.max().strftime('%Y%m%d')
|
| 129 |
filename = f"{str(output_file)}_{min_date}_{max_date}.csv"
|
| 130 |
|
| 131 |
+
print(f"Saving to {filename}...")
|
| 132 |
df.to_csv(filename, index=True)
|
| 133 |
|
| 134 |
+
print(f"Successfully processed {satellite_name}: {successful_files} files loaded, {failed_files} failed")
|
| 135 |
logging.info(f"Saved combined file: {output_file}")
|
| 136 |
|
| 137 |
except Exception as e:
|
| 138 |
+
print(f"Failed to process {satellite_name}: {e}")
|
| 139 |
logging.error(f"Failed to write {output_file}: {e}")
|
| 140 |
finally:
|
| 141 |
for ds in datasets:
|
|
|
|
| 156 |
if len(g18_files) != 0:
|
| 157 |
satellites_to_process.append((g18_files, "GOES-18", self.output_dir / "combined_g18_avg1m", self.used_g18_files))
|
| 158 |
|
| 159 |
+
print(f"\nStarting processing of {len(satellites_to_process)} satellites...")
|
| 160 |
|
| 161 |
# Process each satellite with overall progress tracking
|
| 162 |
successful_satellites = 0
|
|
|
|
| 164 |
|
| 165 |
for i, (files, satellite_name, output_file, used_file_list) in enumerate(satellites_to_process, 1):
|
| 166 |
print(f"\n{'='*60}")
|
| 167 |
+
print(f"Processing satellite {i}/{len(satellites_to_process)}: {satellite_name}")
|
| 168 |
print(f"{'='*60}")
|
| 169 |
|
| 170 |
try:
|
| 171 |
process_files(files, satellite_name, output_file, used_file_list)
|
| 172 |
successful_satellites += 1
|
| 173 |
except Exception as e:
|
| 174 |
+
print(f"Failed to process {satellite_name}: {e}")
|
| 175 |
failed_satellites += 1
|
| 176 |
logging.error(f"Failed to process {satellite_name}: {e}")
|
| 177 |
|
| 178 |
# Print final summary
|
| 179 |
print(f"\n{'='*60}")
|
| 180 |
+
print(f"PROCESSING COMPLETE")
|
| 181 |
print(f"{'='*60}")
|
| 182 |
+
print(f"Successfully processed: {successful_satellites} satellites")
|
| 183 |
+
print(f"Failed: {failed_satellites} satellites")
|
| 184 |
+
print(f"Total files processed: {total_files}")
|
| 185 |
+
print(f"Output directory: {self.output_dir}")
|
| 186 |
|
| 187 |
# Print file usage statistics
|
| 188 |
total_used_files = (len(self.used_g13_files) + len(self.used_g14_files) +
|
| 189 |
len(self.used_g15_files) + len(self.used_g16_files) +
|
| 190 |
len(self.used_g17_files) + len(self.used_g18_files))
|
| 191 |
+
print(f"Files used in processing: {total_used_files}")
|
| 192 |
|
| 193 |
if successful_satellites > 0:
|
| 194 |
+
print(f"\nSXR data processing completed successfully!")
|
| 195 |
else:
|
| 196 |
+
print(f"\n⚠No satellites were processed successfully.")
|
| 197 |
|
| 198 |
|
| 199 |
if __name__ == '__main__':
|
|
|
|
| 204 |
help='Directory where combined GOES data will be saved.')
|
| 205 |
args = parser.parse_args()
|
| 206 |
|
| 207 |
+
print("GOES SXR Data Processing Tool")
|
| 208 |
print("=" * 50)
|
| 209 |
+
print(f"Data directory: {args.data_dir}")
|
| 210 |
+
print(f"Output directory: {args.output_dir}")
|
| 211 |
print("=" * 50)
|
| 212 |
|
| 213 |
processor = SXRDataProcessor(data_dir=args.data_dir, output_dir=args.output_dir)
|
| 214 |
processor.combine_goes_data()
|
| 215 |
|
| 216 |
+
print("\nAll processing tasks completed.")
|
download/download_sdo.py
CHANGED
|
@@ -27,7 +27,7 @@ class SDODownloader:
|
|
| 27 |
wavelengths (list): List of wavelengths to download.
|
| 28 |
n_workers (int): Number of worker threads for parallel download.
|
| 29 |
"""
|
| 30 |
-
def __init__(self, base_path='/mnt/data/PAPER/SDOData', email=None, wavelengths=['94', '131', '171', '193', '211', '304'], n_workers=4, cadence=60):
|
| 31 |
self.ds_path = base_path
|
| 32 |
self.wavelengths = [str(wl) for wl in wavelengths]
|
| 33 |
self.n_workers = n_workers
|
|
@@ -53,7 +53,11 @@ class SDODownloader:
|
|
| 53 |
if os.path.exists(map_path):
|
| 54 |
return map_path
|
| 55 |
# load map
|
|
|
|
|
|
|
|
|
|
| 56 |
url = 'http://jsoc.stanford.edu' + segment
|
|
|
|
| 57 |
|
| 58 |
# Retry download with exponential backoff
|
| 59 |
max_retries = 3
|
|
@@ -111,27 +115,20 @@ class SDODownloader:
|
|
| 111 |
id = date.isoformat()
|
| 112 |
|
| 113 |
logging.info('Start download: %s' % id)
|
| 114 |
-
# query Magnetogram
|
| 115 |
-
#time_param = '%sZ' % date.isoformat('_', timespec='seconds')
|
| 116 |
-
#ds_hmi = 'hmi.M_720s[%s]{magnetogram}' % time_param
|
| 117 |
-
#keys_hmi = self.drms_client.keys(ds_hmi)
|
| 118 |
-
#header_hmi, segment_hmi = self.drms_client.query(ds_hmi, key=','.join(keys_hmi), seg='magnetogram')
|
| 119 |
-
#if len(header_hmi) != 1 or np.any(header_hmi.QUALITY != 0):
|
| 120 |
-
# self.fetchDataFallback(date)
|
| 121 |
-
# return
|
| 122 |
|
| 123 |
# query EUV
|
| 124 |
time_param = '%sZ' % date.isoformat('_', timespec='seconds')
|
| 125 |
ds_euv = 'aia.lev1_euv_12s[%s][%s]{image}' % (time_param, ','.join(self.wavelengths))
|
| 126 |
keys_euv = self.drms_client.keys(ds_euv)
|
| 127 |
header_euv, segment_euv = self.drms_client.query(ds_euv, key=','.join(keys_euv), seg='image')
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
| 129 |
self.fetchDataFallback(date)
|
| 130 |
return
|
| 131 |
|
| 132 |
queue = []
|
| 133 |
-
#for (idx, h), s in zip(header_hmi.iterrows(), segment_hmi.magnetogram):
|
| 134 |
-
# queue += [(h.to_dict(), s, date)]
|
| 135 |
for (idx, h), s in zip(header_euv.iterrows(), segment_euv.image):
|
| 136 |
queue += [(h.to_dict(), s, date)]
|
| 137 |
|
|
@@ -155,28 +152,6 @@ class SDODownloader:
|
|
| 155 |
id = date.isoformat()
|
| 156 |
|
| 157 |
logging.info('Fallback download: %s' % id)
|
| 158 |
-
# query Magnetogram
|
| 159 |
-
t = date - timedelta(hours=24)
|
| 160 |
-
ds_hmi = 'hmi.M_720s[%sZ/12h@720s]{magnetogram}' % t.replace(tzinfo=None).isoformat('_', timespec='seconds')
|
| 161 |
-
keys_hmi = self.drms_client.keys(ds_hmi)
|
| 162 |
-
header_tmp, segment_tmp = self.drms_client.query(ds_hmi, key=','.join(keys_hmi), seg='magnetogram')
|
| 163 |
-
assert len(header_tmp) != 0, 'No data found!'
|
| 164 |
-
date_str = header_tmp['DATE__OBS'].replace('MISSING', '').str.replace('60', '59') # fix date format
|
| 165 |
-
date_diff = np.abs(pd.to_datetime(date_str).dt.tz_localize(None) - date)
|
| 166 |
-
# sort and filter
|
| 167 |
-
header_tmp['date_diff'] = date_diff
|
| 168 |
-
header_tmp.sort_values('date_diff')
|
| 169 |
-
segment_tmp['date_diff'] = date_diff
|
| 170 |
-
segment_tmp.sort_values('date_diff')
|
| 171 |
-
cond_tmp = header_tmp.QUALITY == 0
|
| 172 |
-
header_tmp = header_tmp[cond_tmp]
|
| 173 |
-
segment_tmp = segment_tmp[cond_tmp]
|
| 174 |
-
assert len(header_tmp) > 0, 'No valid quality flag found'
|
| 175 |
-
# replace invalid
|
| 176 |
-
header_hmi = header_tmp.iloc[0].drop('date_diff')
|
| 177 |
-
segment_hmi = segment_tmp.iloc[0].drop('date_diff')
|
| 178 |
-
############################################################
|
| 179 |
-
# query EUV
|
| 180 |
header_euv, segment_euv = [], []
|
| 181 |
t = date - timedelta(hours=6)
|
| 182 |
for wl in self.wavelengths:
|
|
@@ -184,21 +159,23 @@ class SDODownloader:
|
|
| 184 |
t.replace(tzinfo=None).isoformat('_', timespec='seconds'), wl)
|
| 185 |
keys_euv = self.drms_client.keys(euv_ds)
|
| 186 |
header_tmp, segment_tmp = self.drms_client.query(euv_ds, key=','.join(keys_euv), seg='image')
|
| 187 |
-
|
|
|
|
| 188 |
date_str = header_tmp['DATE__OBS'].replace('MISSING', '').str.replace('60', '59') # fix date format
|
| 189 |
date_diff = (pd.to_datetime(date_str).dt.tz_localize(None) - date).abs()
|
| 190 |
# sort and filter
|
| 191 |
header_tmp['date_diff'] = date_diff
|
| 192 |
-
header_tmp.sort_values('date_diff')
|
| 193 |
segment_tmp['date_diff'] = date_diff
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
|
| 203 |
queue = []
|
| 204 |
#queue += [(header_hmi.to_dict(), segment_hmi.magnetogram, date)]
|
|
@@ -218,8 +195,8 @@ if __name__ == '__main__':
|
|
| 218 |
parser.add_argument('--download_dir', type=str, help='path to the download directory.')
|
| 219 |
parser.add_argument('--email', type=str, help='registered email address for JSOC.')
|
| 220 |
parser.add_argument('--start_date', type=str, help='start date in format YYYY-MM-DD.')
|
| 221 |
-
parser.add_argument('--end_date', type=str, help='end date in format YYYY-MM-DD.', required=False,
|
| 222 |
-
default=
|
| 223 |
parser.add_argument('--cadence', type=int, help='cadence in minutes.', required=False, default=60)
|
| 224 |
|
| 225 |
args = parser.parse_args()
|
|
@@ -228,7 +205,7 @@ if __name__ == '__main__':
|
|
| 228 |
end_date = args.end_date
|
| 229 |
cadence = args.cadence
|
| 230 |
|
| 231 |
-
[os.makedirs(os.path.join(download_dir, str(c)), exist_ok=True) for c in [94, 131, 171, 193, 211, 304]]
|
| 232 |
downloader = SDODownloader(base_path=download_dir, email=args.email)
|
| 233 |
start_date_datetime = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
|
| 234 |
#end_date = datetime.now()
|
|
@@ -236,10 +213,10 @@ if __name__ == '__main__':
|
|
| 236 |
|
| 237 |
|
| 238 |
#Skip over dates that already exist in the download directory
|
| 239 |
-
for d in [start_date_datetime + i * timedelta(minutes=
|
| 240 |
-
range((end_date_datetime - start_date_datetime) // timedelta(minutes=
|
| 241 |
#make sure the file exists in all wavelengths directories
|
| 242 |
-
for wl in [94, 131, 171, 193, 211, 304]:
|
| 243 |
if not os.path.exists(os.path.join(
|
| 244 |
download_dir,
|
| 245 |
str(wl),
|
|
|
|
| 27 |
wavelengths (list): List of wavelengths to download.
|
| 28 |
n_workers (int): Number of worker threads for parallel download.
|
| 29 |
"""
|
| 30 |
+
def __init__(self, base_path='/mnt/data/PAPER/SDOData', email=None, wavelengths=['94', '131', '171', '193', '211', '304', '335'], n_workers=4, cadence=60):
|
| 31 |
self.ds_path = base_path
|
| 32 |
self.wavelengths = [str(wl) for wl in wavelengths]
|
| 33 |
self.n_workers = n_workers
|
|
|
|
| 53 |
if os.path.exists(map_path):
|
| 54 |
return map_path
|
| 55 |
# load map
|
| 56 |
+
if not segment or pd.isna(segment):
|
| 57 |
+
logging.error('Segment path is null for %s — data may not be in JSOC cache' % header.get('DATE__OBS'))
|
| 58 |
+
raise ValueError('Null segment path for %s' % header.get('DATE__OBS'))
|
| 59 |
url = 'http://jsoc.stanford.edu' + segment
|
| 60 |
+
logging.info('Downloading: %s' % url)
|
| 61 |
|
| 62 |
# Retry download with exponential backoff
|
| 63 |
max_retries = 3
|
|
|
|
| 115 |
id = date.isoformat()
|
| 116 |
|
| 117 |
logging.info('Start download: %s' % id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# query EUV
|
| 120 |
time_param = '%sZ' % date.isoformat('_', timespec='seconds')
|
| 121 |
ds_euv = 'aia.lev1_euv_12s[%s][%s]{image}' % (time_param, ','.join(self.wavelengths))
|
| 122 |
keys_euv = self.drms_client.keys(ds_euv)
|
| 123 |
header_euv, segment_euv = self.drms_client.query(ds_euv, key=','.join(keys_euv), seg='image')
|
| 124 |
+
logging.info('Fast-path query returned %d rows (need %d), qualities: %s' % (
|
| 125 |
+
len(header_euv), len(self.wavelengths),
|
| 126 |
+
list(header_euv.QUALITY) if len(header_euv) > 0 else []))
|
| 127 |
+
if len(header_euv) != len(self.wavelengths) or np.any(header_euv.QUALITY.fillna(0) != 0):
|
| 128 |
self.fetchDataFallback(date)
|
| 129 |
return
|
| 130 |
|
| 131 |
queue = []
|
|
|
|
|
|
|
| 132 |
for (idx, h), s in zip(header_euv.iterrows(), segment_euv.image):
|
| 133 |
queue += [(h.to_dict(), s, date)]
|
| 134 |
|
|
|
|
| 152 |
id = date.isoformat()
|
| 153 |
|
| 154 |
logging.info('Fallback download: %s' % id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
header_euv, segment_euv = [], []
|
| 156 |
t = date - timedelta(hours=6)
|
| 157 |
for wl in self.wavelengths:
|
|
|
|
| 159 |
t.replace(tzinfo=None).isoformat('_', timespec='seconds'), wl)
|
| 160 |
keys_euv = self.drms_client.keys(euv_ds)
|
| 161 |
header_tmp, segment_tmp = self.drms_client.query(euv_ds, key=','.join(keys_euv), seg='image')
|
| 162 |
+
logging.info('Fallback query wl=%s returned %d rows' % (wl, len(header_tmp)))
|
| 163 |
+
assert len(header_tmp) != 0, 'No data found for wl=%s at %s' % (wl, id)
|
| 164 |
date_str = header_tmp['DATE__OBS'].replace('MISSING', '').str.replace('60', '59') # fix date format
|
| 165 |
date_diff = (pd.to_datetime(date_str).dt.tz_localize(None) - date).abs()
|
| 166 |
# sort and filter
|
| 167 |
header_tmp['date_diff'] = date_diff
|
|
|
|
| 168 |
segment_tmp['date_diff'] = date_diff
|
| 169 |
+
cond_tmp = (header_tmp.QUALITY == 0) | header_tmp.QUALITY.isna()
|
| 170 |
+
header_filtered = header_tmp[cond_tmp]
|
| 171 |
+
segment_filtered = segment_tmp[cond_tmp]
|
| 172 |
+
if len(header_filtered) > 0:
|
| 173 |
+
header_tmp = header_filtered
|
| 174 |
+
segment_tmp = segment_filtered
|
| 175 |
+
else:
|
| 176 |
+
logging.warning('No quality-0 EUV frames for wl=%s at %s — using closest available' % (wl, id))
|
| 177 |
+
header_euv.append(header_tmp.sort_values('date_diff').iloc[0].drop('date_diff'))
|
| 178 |
+
segment_euv.append(segment_tmp.sort_values('date_diff').iloc[0].drop('date_diff'))
|
| 179 |
|
| 180 |
queue = []
|
| 181 |
#queue += [(header_hmi.to_dict(), segment_hmi.magnetogram, date)]
|
|
|
|
| 195 |
parser.add_argument('--download_dir', type=str, help='path to the download directory.')
|
| 196 |
parser.add_argument('--email', type=str, help='registered email address for JSOC.')
|
| 197 |
parser.add_argument('--start_date', type=str, help='start date in format YYYY-MM-DD.')
|
| 198 |
+
parser.add_argument('--end_date', type=str, help='end date in format YYYY-MM-DD HH:MM:SS.', required=False,
|
| 199 |
+
default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
| 200 |
parser.add_argument('--cadence', type=int, help='cadence in minutes.', required=False, default=60)
|
| 201 |
|
| 202 |
args = parser.parse_args()
|
|
|
|
| 205 |
end_date = args.end_date
|
| 206 |
cadence = args.cadence
|
| 207 |
|
| 208 |
+
[os.makedirs(os.path.join(download_dir, str(c)), exist_ok=True) for c in [94, 131, 171, 193, 211, 304, 335]]
|
| 209 |
downloader = SDODownloader(base_path=download_dir, email=args.email)
|
| 210 |
start_date_datetime = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
|
| 211 |
#end_date = datetime.now()
|
|
|
|
| 213 |
|
| 214 |
|
| 215 |
#Skip over dates that already exist in the download directory
|
| 216 |
+
for d in [start_date_datetime + i * timedelta(minutes=cadence) for i in
|
| 217 |
+
range((end_date_datetime - start_date_datetime) // timedelta(minutes=cadence))]:
|
| 218 |
#make sure the file exists in all wavelengths directories
|
| 219 |
+
for wl in [94, 131, 171, 193, 211, 304, 335]:
|
| 220 |
if not os.path.exists(os.path.join(
|
| 221 |
download_dir,
|
| 222 |
str(wl),
|
download/sxr_downloader.py
CHANGED
|
@@ -10,11 +10,9 @@ import pandas as pd
|
|
| 10 |
class SXRDownloader:
|
| 11 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 12 |
|
| 13 |
-
def __init__(self, save_dir: str = '/mnt/data/PAPER/GOES-timespan'
|
| 14 |
self.save_dir = Path(save_dir)
|
| 15 |
self.save_dir.mkdir(exist_ok=True)
|
| 16 |
-
self.concat_dir = Path(concat_dir)
|
| 17 |
-
self.concat_dir.mkdir(exist_ok=True)
|
| 18 |
self.used_g13_files = []
|
| 19 |
self.used_g14_files = []
|
| 20 |
self.used_g15_files = []
|
|
|
|
| 10 |
class SXRDownloader:
|
| 11 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 12 |
|
| 13 |
+
def __init__(self, save_dir: str = '/mnt/data/PAPER/GOES-timespan'):
|
| 14 |
self.save_dir = Path(save_dir)
|
| 15 |
self.save_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
|
| 16 |
self.used_g13_files = []
|
| 17 |
self.used_g14_files = []
|
| 18 |
self.used_g15_files = []
|
forecasting/data_loaders/SDOAIA_dataloader.py
CHANGED
|
@@ -12,6 +12,16 @@ import glob
|
|
| 12 |
import os
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
class AIA_GOESDataset(torch.utils.data.Dataset):
|
| 16 |
"""
|
| 17 |
PyTorch Dataset for loading paired AIA (EUV images) and GOES (SXR flux) data.
|
|
@@ -354,7 +364,7 @@ class AIA_GOESDataModule(LightningDataModule):
|
|
| 354 |
self.train_ds = AIA_GOESDataset(
|
| 355 |
aia_dir=self.aia_train_dir,
|
| 356 |
sxr_dir=self.sxr_train_dir,
|
| 357 |
-
sxr_transform=
|
| 358 |
target_size=(512, 512),
|
| 359 |
wavelengths=self.wavelengths,
|
| 360 |
cadence=1,
|
|
@@ -368,7 +378,7 @@ class AIA_GOESDataModule(LightningDataModule):
|
|
| 368 |
self.val_ds = AIA_GOESDataset(
|
| 369 |
aia_dir=self.aia_val_dir,
|
| 370 |
sxr_dir=self.sxr_val_dir,
|
| 371 |
-
sxr_transform=
|
| 372 |
target_size=(512, 512),
|
| 373 |
wavelengths=self.wavelengths,
|
| 374 |
cadence=1,
|
|
@@ -382,7 +392,7 @@ class AIA_GOESDataModule(LightningDataModule):
|
|
| 382 |
self.test_ds = AIA_GOESDataset(
|
| 383 |
aia_dir=self.aia_test_dir,
|
| 384 |
sxr_dir=self.sxr_test_dir,
|
| 385 |
-
sxr_transform=
|
| 386 |
target_size=(512, 512),
|
| 387 |
wavelengths=self.wavelengths,
|
| 388 |
cadence=1,
|
|
|
|
| 12 |
import os
|
| 13 |
|
| 14 |
|
| 15 |
+
class SXRLogNormTransform:
|
| 16 |
+
"""Picklable SXR log-normalization transform (replaces T.Lambda for spawn compatibility)."""
|
| 17 |
+
def __init__(self, mean: float, std: float):
|
| 18 |
+
self.mean = mean
|
| 19 |
+
self.std = std
|
| 20 |
+
|
| 21 |
+
def __call__(self, x: float) -> float:
|
| 22 |
+
return (np.log10(x + 1e-8) - self.mean) / self.std
|
| 23 |
+
|
| 24 |
+
|
| 25 |
class AIA_GOESDataset(torch.utils.data.Dataset):
|
| 26 |
"""
|
| 27 |
PyTorch Dataset for loading paired AIA (EUV images) and GOES (SXR flux) data.
|
|
|
|
| 364 |
self.train_ds = AIA_GOESDataset(
|
| 365 |
aia_dir=self.aia_train_dir,
|
| 366 |
sxr_dir=self.sxr_train_dir,
|
| 367 |
+
sxr_transform=SXRLogNormTransform(self.sxr_norm[0], self.sxr_norm[1]),
|
| 368 |
target_size=(512, 512),
|
| 369 |
wavelengths=self.wavelengths,
|
| 370 |
cadence=1,
|
|
|
|
| 378 |
self.val_ds = AIA_GOESDataset(
|
| 379 |
aia_dir=self.aia_val_dir,
|
| 380 |
sxr_dir=self.sxr_val_dir,
|
| 381 |
+
sxr_transform=SXRLogNormTransform(self.sxr_norm[0], self.sxr_norm[1]),
|
| 382 |
target_size=(512, 512),
|
| 383 |
wavelengths=self.wavelengths,
|
| 384 |
cadence=1,
|
|
|
|
| 392 |
self.test_ds = AIA_GOESDataset(
|
| 393 |
aia_dir=self.aia_test_dir,
|
| 394 |
sxr_dir=self.sxr_test_dir,
|
| 395 |
+
sxr_transform=SXRLogNormTransform(self.sxr_norm[0], self.sxr_norm[1]),
|
| 396 |
target_size=(512, 512),
|
| 397 |
wavelengths=self.wavelengths,
|
| 398 |
cadence=1,
|
forecasting/training/train.py
CHANGED
|
@@ -84,69 +84,6 @@ def resolve_config_variables(config_dict):
|
|
| 84 |
return recursive_substitute(config_dict, variables)
|
| 85 |
|
| 86 |
|
| 87 |
-
# Parser
|
| 88 |
-
parser = argparse.ArgumentParser()
|
| 89 |
-
parser.add_argument('-config', type=str, default='config.yaml', required=True, help='Path to config YAML.')
|
| 90 |
-
args = parser.parse_args()
|
| 91 |
-
|
| 92 |
-
# Load config with variable substitution
|
| 93 |
-
with open(args.config, 'r') as stream:
|
| 94 |
-
config_data = yaml.load(stream, Loader=yaml.SafeLoader)
|
| 95 |
-
|
| 96 |
-
# Resolve variables like ${base_data_dir}
|
| 97 |
-
config_data = resolve_config_variables(config_data)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
# Debug: Print resolved paths
|
| 101 |
-
print("Resolved paths:")
|
| 102 |
-
print(f"AIA dir: {config_data['data']['aia_dir']}")
|
| 103 |
-
print(f"SXR dir: {config_data['data']['sxr_dir']}")
|
| 104 |
-
print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
|
| 105 |
-
|
| 106 |
-
sxr_norm = np.load(config_data['data']['sxr_norm_path'])
|
| 107 |
-
training_wavelengths = config_data['wavelengths']
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
# DataModule
|
| 111 |
-
data_loader = AIA_GOESDataModule(
|
| 112 |
-
aia_train_dir= config_data['data']['aia_dir']+"/train",
|
| 113 |
-
aia_val_dir=config_data['data']['aia_dir']+"/val",
|
| 114 |
-
aia_test_dir=config_data['data']['aia_dir']+"/test",
|
| 115 |
-
sxr_train_dir=config_data['data']['sxr_dir']+"/train",
|
| 116 |
-
sxr_val_dir=config_data['data']['sxr_dir']+"/val",
|
| 117 |
-
sxr_test_dir=config_data['data']['sxr_dir']+"/test",
|
| 118 |
-
batch_size=config_data['batch_size'],
|
| 119 |
-
num_workers=min(8, os.cpu_count()), # Limit workers to prevent shm issues
|
| 120 |
-
sxr_norm=sxr_norm,
|
| 121 |
-
wavelengths=training_wavelengths,
|
| 122 |
-
oversample=config_data['oversample'],
|
| 123 |
-
balance_strategy=config_data['balance_strategy'],
|
| 124 |
-
)
|
| 125 |
-
data_loader.setup()
|
| 126 |
-
# Logger
|
| 127 |
-
#wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
|
| 128 |
-
wandb_logger = WandbLogger(
|
| 129 |
-
entity=config_data['wandb']['entity'],
|
| 130 |
-
project=config_data['wandb']['project'],
|
| 131 |
-
job_type=config_data['wandb']['job_type'],
|
| 132 |
-
tags=config_data['wandb']['tags'],
|
| 133 |
-
name=config_data['wandb']['run_name'],
|
| 134 |
-
notes=config_data['wandb']['notes'],
|
| 135 |
-
config=config_data
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
# Logging callback
|
| 139 |
-
total_n_valid = len(data_loader.val_ds)
|
| 140 |
-
plot_data = [data_loader.val_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
|
| 141 |
-
plot_samples = plot_data # Keep as list of ((aia, sxr), target)
|
| 142 |
-
#sxr_callback = SXRPredictionLogger(plot_samples)
|
| 143 |
-
|
| 144 |
-
sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
|
| 145 |
-
# Attention map callback - get patch size from config
|
| 146 |
-
patch_size = config_data.get('vit_architecture', {}).get('patch_size', 16)
|
| 147 |
-
attention = AttentionMapCallback(patch_size=patch_size, use_local_attention=True)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
class PTHCheckpointCallback(Callback):
|
| 151 |
"""
|
| 152 |
Custom PyTorch Lightning callback to save model checkpoints in `.pth` format.
|
|
@@ -209,65 +146,6 @@ class PTHCheckpointCallback(Callback):
|
|
| 209 |
|
| 210 |
|
| 211 |
|
| 212 |
-
# Checkpoint callback
|
| 213 |
-
checkpoint_callback = ModelCheckpoint(
|
| 214 |
-
dirpath=config_data['data']['checkpoints_dir'],
|
| 215 |
-
monitor='val_total_loss',
|
| 216 |
-
mode='min',
|
| 217 |
-
save_top_k=10,
|
| 218 |
-
filename=f"{config_data['wandb']['run_name']}-{{epoch:02d}}-{{val_total_loss:.4f}}"
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
pth_callback = PTHCheckpointCallback(
|
| 222 |
-
dirpath=config_data['data']['checkpoints_dir'],
|
| 223 |
-
monitor='val_total_loss',
|
| 224 |
-
mode='min',
|
| 225 |
-
save_top_k=1,
|
| 226 |
-
filename_prefix=config_data['wandb']['run_name']
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
def process_batch(batch_data, sxr_norm, c_threshold, m_threshold, x_threshold):
|
| 230 |
-
"""
|
| 231 |
-
Process a batch of SXR data to count flare occurrences in different intensity classes.
|
| 232 |
-
|
| 233 |
-
Parameters
|
| 234 |
-
----------
|
| 235 |
-
batch_data : tuple
|
| 236 |
-
Tuple containing (batch, batch_idx).
|
| 237 |
-
sxr_norm : np.ndarray
|
| 238 |
-
Normalization parameters for SXR values.
|
| 239 |
-
c_threshold, m_threshold, x_threshold : float
|
| 240 |
-
Thresholds defining flare intensity categories.
|
| 241 |
-
|
| 242 |
-
Returns
|
| 243 |
-
-------
|
| 244 |
-
dict
|
| 245 |
-
Dictionary containing counts for quiet, C, M, and X class flares.
|
| 246 |
-
"""
|
| 247 |
-
from forecasting.models.vit_patch_model import unnormalize_sxr
|
| 248 |
-
|
| 249 |
-
batch, batch_idx = batch_data
|
| 250 |
-
_, sxr = batch
|
| 251 |
-
|
| 252 |
-
# Unnormalize the SXR values
|
| 253 |
-
sxr_un = unnormalize_sxr(sxr, sxr_norm)
|
| 254 |
-
sxr_un_flat = sxr_un.view(-1).cpu().numpy()
|
| 255 |
-
|
| 256 |
-
total = len(sxr_un_flat)
|
| 257 |
-
quiet_count = ((sxr_un_flat < c_threshold)).sum()
|
| 258 |
-
c_count = ((sxr_un_flat >= c_threshold) & (sxr_un_flat < m_threshold)).sum()
|
| 259 |
-
m_count = ((sxr_un_flat >= m_threshold) & (sxr_un_flat < x_threshold)).sum()
|
| 260 |
-
x_count = ((sxr_un_flat >= x_threshold)).sum()
|
| 261 |
-
|
| 262 |
-
return {
|
| 263 |
-
'total': total,
|
| 264 |
-
'quiet_count': quiet_count,
|
| 265 |
-
'c_count': c_count,
|
| 266 |
-
'm_count': m_count,
|
| 267 |
-
'x_count': x_count,
|
| 268 |
-
'batch_idx': batch_idx
|
| 269 |
-
}
|
| 270 |
-
|
| 271 |
def get_base_weights(data_loader, sxr_norm):
|
| 272 |
"""
|
| 273 |
Compute inverse-frequency weights for flare classes based on training data.
|
|
@@ -353,94 +231,126 @@ def get_base_weights(data_loader, sxr_norm):
|
|
| 353 |
|
| 354 |
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
""
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
else:
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
}
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
|
|
|
|
|
| 84 |
return recursive_substitute(config_dict, variables)
|
| 85 |
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
class PTHCheckpointCallback(Callback):
|
| 88 |
"""
|
| 89 |
Custom PyTorch Lightning callback to save model checkpoints in `.pth` format.
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
def get_base_weights(data_loader, sxr_norm):
|
| 150 |
"""
|
| 151 |
Compute inverse-frequency weights for flare classes based on training data.
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
|
| 234 |
+
if __name__ == '__main__':
|
| 235 |
+
# Parser
|
| 236 |
+
parser = argparse.ArgumentParser()
|
| 237 |
+
parser.add_argument('-config', type=str, default='config.yaml', required=True, help='Path to config YAML.')
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
# Load config with variable substitution
|
| 241 |
+
with open(args.config, 'r') as stream:
|
| 242 |
+
config_data = yaml.load(stream, Loader=yaml.SafeLoader)
|
| 243 |
+
config_data = resolve_config_variables(config_data)
|
| 244 |
+
|
| 245 |
+
print("Resolved paths:")
|
| 246 |
+
print(f"AIA dir: {config_data['data']['aia_dir']}")
|
| 247 |
+
print(f"SXR dir: {config_data['data']['sxr_dir']}")
|
| 248 |
+
print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
|
| 249 |
+
|
| 250 |
+
sxr_norm = np.load(config_data['data']['sxr_norm_path'])
|
| 251 |
+
training_wavelengths = config_data['wavelengths']
|
| 252 |
+
|
| 253 |
+
# DataModule
|
| 254 |
+
data_loader = AIA_GOESDataModule(
|
| 255 |
+
aia_train_dir=config_data['data']['aia_dir'] + "/train",
|
| 256 |
+
aia_val_dir=config_data['data']['aia_dir'] + "/val",
|
| 257 |
+
aia_test_dir=config_data['data']['aia_dir'] + "/test",
|
| 258 |
+
sxr_train_dir=config_data['data']['sxr_dir'] + "/train",
|
| 259 |
+
sxr_val_dir=config_data['data']['sxr_dir'] + "/val",
|
| 260 |
+
sxr_test_dir=config_data['data']['sxr_dir'] + "/test",
|
| 261 |
+
batch_size=config_data['batch_size'],
|
| 262 |
+
num_workers=min(8, os.cpu_count()),
|
| 263 |
+
sxr_norm=sxr_norm,
|
| 264 |
+
wavelengths=training_wavelengths,
|
| 265 |
+
oversample=config_data['oversample'],
|
| 266 |
+
balance_strategy=config_data['balance_strategy'],
|
| 267 |
+
)
|
| 268 |
+
data_loader.setup()
|
| 269 |
+
|
| 270 |
+
# Logger
|
| 271 |
+
wandb_logger = WandbLogger(
|
| 272 |
+
entity=config_data['wandb']['entity'],
|
| 273 |
+
project=config_data['wandb']['project'],
|
| 274 |
+
job_type=config_data['wandb']['job_type'],
|
| 275 |
+
tags=config_data['wandb']['tags'],
|
| 276 |
+
name=config_data['wandb']['run_name'],
|
| 277 |
+
notes=config_data['wandb']['notes'],
|
| 278 |
+
config=config_data
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Callbacks
|
| 282 |
+
total_n_valid = len(data_loader.val_ds)
|
| 283 |
+
plot_samples = [data_loader.val_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
|
| 284 |
+
sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
|
| 285 |
+
patch_size = config_data.get('vit_architecture', {}).get('patch_size', 16)
|
| 286 |
+
attention = AttentionMapCallback(patch_size=patch_size, use_local_attention=True)
|
| 287 |
+
|
| 288 |
+
base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
|
| 289 |
+
model = ViTLocal(model_kwargs=config_data['vit_architecture'], sxr_norm=sxr_norm, base_weights=base_weights)
|
| 290 |
+
|
| 291 |
+
# Checkpoint callbacks
|
| 292 |
+
checkpoint_callback = ModelCheckpoint(
|
| 293 |
+
dirpath=config_data['data']['checkpoints_dir'],
|
| 294 |
+
monitor='val_total_loss',
|
| 295 |
+
mode='min',
|
| 296 |
+
save_top_k=10,
|
| 297 |
+
filename=f"{config_data['wandb']['run_name']}-{{epoch:02d}}-{{val_total_loss:.4f}}"
|
| 298 |
+
)
|
| 299 |
+
pth_callback = PTHCheckpointCallback(
|
| 300 |
+
dirpath=config_data['data']['checkpoints_dir'],
|
| 301 |
+
monitor='val_total_loss',
|
| 302 |
+
mode='min',
|
| 303 |
+
save_top_k=1,
|
| 304 |
+
filename_prefix=config_data['wandb']['run_name']
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Set device based on config
|
| 308 |
+
gpu_config = config_data.get('gpu_ids', config_data.get('gpu_id', 0))
|
| 309 |
+
if gpu_config == -1:
|
| 310 |
+
accelerator, devices, strategy = "cpu", 1, "auto"
|
| 311 |
+
print("Using CPU for training")
|
| 312 |
+
elif gpu_config == "all":
|
| 313 |
+
if torch.cuda.is_available():
|
| 314 |
+
accelerator, devices, strategy = "gpu", -1, "auto"
|
| 315 |
+
num_gpus = torch.cuda.device_count()
|
| 316 |
+
print(f"Using all available GPUs ({num_gpus} GPUs)")
|
| 317 |
+
else:
|
| 318 |
+
accelerator, devices, strategy = "cpu", 1, "auto"
|
| 319 |
+
print("No GPUs available, falling back to CPU")
|
| 320 |
+
elif isinstance(gpu_config, list):
|
| 321 |
+
if torch.cuda.is_available():
|
| 322 |
+
accelerator, devices, strategy = "gpu", gpu_config, "auto"
|
| 323 |
+
print(f"Using GPUs: {gpu_config}")
|
| 324 |
+
else:
|
| 325 |
+
accelerator, devices, strategy = "cpu", 1, "auto"
|
| 326 |
+
print("No GPUs available, falling back to CPU")
|
| 327 |
else:
|
| 328 |
+
if torch.cuda.is_available():
|
| 329 |
+
accelerator, devices, strategy = "gpu", [gpu_config], "auto"
|
| 330 |
+
print(f"Using GPU {gpu_config}")
|
| 331 |
+
else:
|
| 332 |
+
accelerator, devices, strategy = "cpu", 1, "auto"
|
| 333 |
+
print(f"GPU {gpu_config} not available, falling back to CPU")
|
| 334 |
+
|
| 335 |
+
# Trainer
|
| 336 |
+
trainer = Trainer(
|
| 337 |
+
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 338 |
+
accelerator=accelerator,
|
| 339 |
+
devices=devices,
|
| 340 |
+
strategy=strategy,
|
| 341 |
+
max_epochs=config_data['epochs'],
|
| 342 |
+
callbacks=[attention, checkpoint_callback],
|
| 343 |
+
logger=wandb_logger,
|
| 344 |
+
log_every_n_steps=10,
|
| 345 |
+
)
|
| 346 |
+
trainer.fit(model, data_loader)
|
| 347 |
+
|
| 348 |
+
# Save final checkpoint
|
| 349 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 350 |
+
final_checkpoint_path = os.path.join(
|
| 351 |
+
config_data['data']['checkpoints_dir'],
|
| 352 |
+
f"{config_data['wandb']['run_name']}-final-{timestamp}.pth"
|
| 353 |
+
)
|
| 354 |
+
torch.save({'model': model, 'state_dict': model.state_dict()}, final_checkpoint_path)
|
| 355 |
+
print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
|
| 356 |
+
wandb.finish()
|
forecasting/training/train_config.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
|
| 2 |
#Base directories - change these to switch datasets
|
| 3 |
-
base_data_dir: "/Volumes/T9/
|
| 4 |
-
base_checkpoint_dir: "/Volumes/T9/
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
|
| 6 |
|
| 7 |
# GPU configuration
|
|
@@ -35,11 +35,11 @@ vit_architecture:
|
|
| 35 |
# Data paths (automatically constructed from base directories)
|
| 36 |
data:
|
| 37 |
aia_dir:
|
| 38 |
-
"${base_data_dir}/
|
| 39 |
sxr_dir:
|
| 40 |
-
"${base_data_dir}/
|
| 41 |
sxr_norm_path:
|
| 42 |
-
"${base_data_dir}/
|
| 43 |
checkpoints_dir:
|
| 44 |
"${base_checkpoint_dir}/new-checkpoint/"
|
| 45 |
|
|
|
|
| 1 |
|
| 2 |
#Base directories - change these to switch datasets
|
| 3 |
+
base_data_dir: "/Volumes/T9/Data_FOXES" # Change this line for different datasets
|
| 4 |
+
base_checkpoint_dir: "/Volumes/T9/Data_FOXES" # Change this line for different datasets
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
|
| 6 |
|
| 7 |
# GPU configuration
|
|
|
|
| 35 |
# Data paths (automatically constructed from base directories)
|
| 36 |
data:
|
| 37 |
aia_dir:
|
| 38 |
+
"${base_data_dir}/AIA_processed"
|
| 39 |
sxr_dir:
|
| 40 |
+
"${base_data_dir}/SXR_processed"
|
| 41 |
sxr_norm_path:
|
| 42 |
+
"${base_data_dir}/SXR_processed/normalized_sxr.npy"
|
| 43 |
checkpoints_dir:
|
| 44 |
"${base_checkpoint_dir}/new-checkpoint/"
|
| 45 |
|
pipeline_config.yaml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# FOXES Pipeline Configuration
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# Used by run_pipeline.py to run any combination of pipeline steps.
|
| 5 |
+
#
|
| 6 |
+
# Usage:
|
| 7 |
+
# python run_pipeline.py --config pipeline_config.yaml --steps all
|
| 8 |
+
# python run_pipeline.py --config pipeline_config.yaml --steps train,inference,flare_analysis
|
| 9 |
+
# python run_pipeline.py --list
|
| 10 |
+
|
| 11 |
+
# -----------------------------------------------------------------------------
|
| 12 |
+
# Shared date range (used by download_aia and download_sxr)
|
| 13 |
+
# -----------------------------------------------------------------------------
|
| 14 |
+
start_date: "2014-07-01 00:00:00"
|
| 15 |
+
end_date: "2014-07-08 00:00:00"
|
| 16 |
+
|
| 17 |
+
# -----------------------------------------------------------------------------
|
| 18 |
+
# AIA download (step: download_aia)
|
| 19 |
+
# -----------------------------------------------------------------------------
|
| 20 |
+
aia:
|
| 21 |
+
download_dir: "/Volumes/T9/Data_FOXES/AIA_raw"
|
| 22 |
+
email: "" # Must be registered at http://jsoc.stanford.edu
|
| 23 |
+
cadence: 1 # Minutes between frames
|
| 24 |
+
|
| 25 |
+
# -----------------------------------------------------------------------------
|
| 26 |
+
# SXR download (step: download_sxr)
|
| 27 |
+
# -----------------------------------------------------------------------------
|
| 28 |
+
sxr:
|
| 29 |
+
save_dir: "/Volumes/T9/Data_FOXES/SXR_raw"
|
| 30 |
+
|
| 31 |
+
# -----------------------------------------------------------------------------
|
| 32 |
+
# Preprocessing (step: preprocess)
|
| 33 |
+
# -----------------------------------------------------------------------------
|
| 34 |
+
preprocess:
|
| 35 |
+
config: "data/pipeline_config.yaml" # PipelineConfig for process_data_pipeline.py
|
| 36 |
+
|
| 37 |
+
# -----------------------------------------------------------------------------
|
| 38 |
+
# SXR normalization (step: normalize)
|
| 39 |
+
# -----------------------------------------------------------------------------
|
| 40 |
+
normalize:
|
| 41 |
+
sxr_dir: "/Volumes/T9/Data_FOXES/SXR_processed/train"
|
| 42 |
+
output_path: "/Volumes/T9/Data_FOXES/SXR_processed/normalized_sxr.npy"
|
| 43 |
+
|
| 44 |
+
# -----------------------------------------------------------------------------
|
| 45 |
+
# Train/val/test split (step: split)
|
| 46 |
+
# Runs split_data.py once for AIA and once for SXR.
|
| 47 |
+
# -----------------------------------------------------------------------------
|
| 48 |
+
split:
|
| 49 |
+
aia_input_dir: "/Volumes/T9/Data_FOXES/AIA_processed" # splits into AIA_processed/train|val|test
|
| 50 |
+
sxr_input_dir: "/Volumes/T9/Data_FOXES/SXR_processed" # splits into SXR_processed/train|val|test
|
| 51 |
+
train_start: "2014-07-01"
|
| 52 |
+
train_end: "2014-07-05"
|
| 53 |
+
val_start: "2014-07-06"
|
| 54 |
+
val_end: "2014-07-07"
|
| 55 |
+
test_start: "2014-07-08"
|
| 56 |
+
test_end: "2025-12-31"
|
| 57 |
+
|
| 58 |
+
# -----------------------------------------------------------------------------
|
| 59 |
+
# Training (step: train)
|
| 60 |
+
# -----------------------------------------------------------------------------
|
| 61 |
+
train:
|
| 62 |
+
config: "forecasting/training/train_config.yaml"
|
| 63 |
+
overrides: # Any key from train_config.yaml can go here
|
| 64 |
+
base_data_dir: "/Volumes/T9/Data_FOXES"
|
| 65 |
+
base_checkpoint_dir: "/Volumes/T9/Data_FOXES"
|
| 66 |
+
epochs: 150
|
| 67 |
+
batch_size: 6
|
| 68 |
+
wandb:
|
| 69 |
+
run_name: "pipeline-run"
|
| 70 |
+
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 71 |
+
project: Paper
|
| 72 |
+
job_type: training
|
| 73 |
+
tags:
|
| 74 |
+
- aia
|
| 75 |
+
- sxr
|
| 76 |
+
- regression
|
| 77 |
+
run_name: paper-8-patch-4ch
|
| 78 |
+
notes: Regression from AIA images to SXR images using ViTLocal model with 8x8 patches
|
| 79 |
+
|
| 80 |
+
# -----------------------------------------------------------------------------
|
| 81 |
+
# Inference & flare analysis (steps: inference, flare_analysis)
|
| 82 |
+
# -----------------------------------------------------------------------------
|
| 83 |
+
inference:
|
| 84 |
+
config: "forecasting/inference/local_config.yaml"
|
| 85 |
+
overrides: # Any key from local_config.yaml can go here
|
| 86 |
+
paths:
|
| 87 |
+
data_dir: "/Volumes/T9/Data_FOXES"
|
requirements.txt
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
sunpy[all]
|
| 3 |
sunpy-soar
|
| 4 |
astropy
|
|
|
|
| 5 |
drms
|
| 6 |
itipy
|
| 7 |
|
|
|
|
| 2 |
sunpy[all]
|
| 3 |
sunpy-soar
|
| 4 |
astropy
|
| 5 |
+
aiapy==0.6.4 # itipy 0.1.1 requires calibrate.util which was renamed in 0.7.0
|
| 6 |
drms
|
| 7 |
itipy
|
| 8 |
|
run_pipeline.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
FOXES End-to-End Pipeline Orchestrator
|
| 4 |
+
|
| 5 |
+
Runs any combination of pipeline steps in order:
|
| 6 |
+
|
| 7 |
+
1. download_aia - Download SDO/AIA EUV images from JSOC (download/download_sdo.py)
|
| 8 |
+
2. download_sxr - Download GOES SXR flux data (download/sxr_downloader.py)
|
| 9 |
+
3. combine_sxr - Combine raw GOES .nc files into per-satellite CSVs (data/sxr_data_processing.py)
|
| 10 |
+
4. preprocess - EUV cleaning, ITI processing, data alignment (data/process_data_pipeline.py)
|
| 11 |
+
5. split - Split AIA + SXR into train/val/test (data/split_data.py)
|
| 12 |
+
6. normalize - Compute SXR normalization stats on train split (data/sxr_normalization.py)
|
| 13 |
+
7. train - Train the ViTLocal forecasting model (forecasting/training/train.py)
|
| 14 |
+
8. inference - Run batch inference on val/test data (forecasting/inference/inference.py)
|
| 15 |
+
9. flare_analysis - Detect, track, and match flares (forecasting/inference/flare_analysis.py)
|
| 16 |
+
|
| 17 |
+
Usage:
|
| 18 |
+
python run_pipeline.py --list
|
| 19 |
+
python run_pipeline.py --config pipeline_config.yaml --steps all
|
| 20 |
+
python run_pipeline.py --config pipeline_config.yaml --steps train,inference,flare_analysis
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import logging
|
| 25 |
+
import subprocess
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
import yaml
|
| 31 |
+
|
| 32 |
+
ROOT = Path(__file__).parent
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format="%(asctime)s %(levelname)s %(message)s",
|
| 37 |
+
handlers=[
|
| 38 |
+
logging.StreamHandler(sys.stdout),
|
| 39 |
+
logging.FileHandler(ROOT / "pipeline.log"),
|
| 40 |
+
],
|
| 41 |
+
)
|
| 42 |
+
log = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Config helpers
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
def deep_merge(base: dict, overrides: dict) -> dict:
|
| 50 |
+
"""Recursively merge overrides into base, modifying base in-place."""
|
| 51 |
+
for key, val in overrides.items():
|
| 52 |
+
if isinstance(val, dict) and isinstance(base.get(key), dict):
|
| 53 |
+
deep_merge(base[key], val)
|
| 54 |
+
else:
|
| 55 |
+
base[key] = val
|
| 56 |
+
return base
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def write_merged_config(base_path: str, overrides: dict, out_name: str) -> Path:
|
| 60 |
+
"""
|
| 61 |
+
Load base_path YAML, apply overrides, write merged result to ROOT/.{out_name}.yaml.
|
| 62 |
+
Returns the path of the merged file.
|
| 63 |
+
"""
|
| 64 |
+
with open(base_path) as f:
|
| 65 |
+
base = yaml.safe_load(f) or {}
|
| 66 |
+
deep_merge(base, overrides)
|
| 67 |
+
out = ROOT / f".merged_{out_name}.yaml"
|
| 68 |
+
with open(out, "w") as f:
|
| 69 |
+
yaml.dump(base, f, default_flow_style=False)
|
| 70 |
+
log.info(f" Merged config written to {out}")
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
# Step definitions
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
STEP_ORDER = [
|
| 79 |
+
"download_aia",
|
| 80 |
+
"download_sxr",
|
| 81 |
+
"combine_sxr",
|
| 82 |
+
"preprocess",
|
| 83 |
+
"split",
|
| 84 |
+
"normalize",
|
| 85 |
+
"train",
|
| 86 |
+
"inference",
|
| 87 |
+
"flare_analysis",
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
STEP_INFO = {
|
| 91 |
+
"download_aia": {
|
| 92 |
+
"description": "Download SDO/AIA EUV images from JSOC",
|
| 93 |
+
"script": ROOT / "download" / "download_sdo.py",
|
| 94 |
+
},
|
| 95 |
+
"download_sxr": {
|
| 96 |
+
"description": "Download GOES SXR flux data via SXRDownloader",
|
| 97 |
+
"script": None, # invoked inline via python -c
|
| 98 |
+
},
|
| 99 |
+
"combine_sxr": {
|
| 100 |
+
"description": "Combine raw GOES .nc files into per-satellite CSVs for alignment",
|
| 101 |
+
"script": ROOT / "data" / "sxr_data_processing.py",
|
| 102 |
+
},
|
| 103 |
+
"preprocess": {
|
| 104 |
+
"description": "EUV cleaning, ITI processing, and AIA/SXR data alignment",
|
| 105 |
+
"script": ROOT / "data" / "process_data_pipeline.py",
|
| 106 |
+
},
|
| 107 |
+
"normalize": {
|
| 108 |
+
"description": "Compute SXR log-normalization statistics (mean/std)",
|
| 109 |
+
"script": ROOT / "data" / "sxr_normalization.py",
|
| 110 |
+
},
|
| 111 |
+
"split": {
|
| 112 |
+
"description": "Split AIA and SXR data into train/val/test by date range",
|
| 113 |
+
"script": ROOT / "data" / "split_data.py",
|
| 114 |
+
},
|
| 115 |
+
"train": {
|
| 116 |
+
"description": "Train the ViTLocal solar flare forecasting model",
|
| 117 |
+
"script": ROOT / "forecasting" / "training" / "train.py",
|
| 118 |
+
},
|
| 119 |
+
"inference": {
|
| 120 |
+
"description": "Run batch inference and save predictions CSV",
|
| 121 |
+
"script": ROOT / "forecasting" / "inference" / "inference.py",
|
| 122 |
+
},
|
| 123 |
+
"flare_analysis": {
|
| 124 |
+
"description": "Detect, track, and match flares; generate plots/movies",
|
| 125 |
+
"script": ROOT / "forecasting" / "inference" / "flare_analysis.py",
|
| 126 |
+
},
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
# Command builders
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
def build_commands(step: str, cfg: dict, force: bool) -> list[list[str]] | None:
|
| 135 |
+
"""
|
| 136 |
+
Return a list of subprocess commands for a given step, or None if required config is missing.
|
| 137 |
+
Most steps return a single command; 'split' returns two (AIA then SXR).
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def require(keys: list[str], section: str = None) -> bool:
|
| 141 |
+
src = cfg.get(section, {}) if section else cfg
|
| 142 |
+
missing = [k for k in keys if not src.get(k)]
|
| 143 |
+
if missing:
|
| 144 |
+
prefix = f"{section}." if section else ""
|
| 145 |
+
log.error(f"pipeline_config.yaml missing required keys: {[prefix + k for k in missing]}")
|
| 146 |
+
return False
|
| 147 |
+
return True
|
| 148 |
+
|
| 149 |
+
if step == "download_aia":
|
| 150 |
+
if not require(["download_dir", "email"], "aia") or not require(["start_date"]):
|
| 151 |
+
return None
|
| 152 |
+
aia = cfg["aia"]
|
| 153 |
+
cmd = [sys.executable, str(STEP_INFO[step]["script"]),
|
| 154 |
+
"--download_dir", aia["download_dir"],
|
| 155 |
+
"--email", aia["email"],
|
| 156 |
+
"--start_date", cfg["start_date"]]
|
| 157 |
+
if cfg.get("end_date"):
|
| 158 |
+
cmd += ["--end_date", cfg["end_date"]]
|
| 159 |
+
if aia.get("cadence"):
|
| 160 |
+
cmd += ["--cadence", str(aia["cadence"])]
|
| 161 |
+
return [cmd]
|
| 162 |
+
|
| 163 |
+
if step == "download_sxr":
|
| 164 |
+
if not require(["save_dir"], "sxr") or not require(["start_date"]):
|
| 165 |
+
return None
|
| 166 |
+
start = cfg["start_date"]
|
| 167 |
+
end = cfg.get("end_date", start)
|
| 168 |
+
save_dir = cfg["sxr"]["save_dir"]
|
| 169 |
+
inline = (
|
| 170 |
+
f"import sys; sys.path.insert(0, r'{ROOT}'); "
|
| 171 |
+
f"from download.sxr_downloader import SXRDownloader; "
|
| 172 |
+
f"d = SXRDownloader(save_dir=r'{save_dir}'); "
|
| 173 |
+
f"d.download_and_save_goes_data(start='{start}', end='{end}')"
|
| 174 |
+
)
|
| 175 |
+
return [[sys.executable, "-c", inline]]
|
| 176 |
+
|
| 177 |
+
if step == "combine_sxr":
|
| 178 |
+
if not require(["save_dir"], "sxr"):
|
| 179 |
+
return None
|
| 180 |
+
raw_dir = cfg["sxr"]["save_dir"]
|
| 181 |
+
combined_dir = str(Path(raw_dir) / "combined")
|
| 182 |
+
return [[sys.executable, str(STEP_INFO[step]["script"]),
|
| 183 |
+
"--data_dir", raw_dir,
|
| 184 |
+
"--output_dir", combined_dir]]
|
| 185 |
+
|
| 186 |
+
script = STEP_INFO[step]["script"]
|
| 187 |
+
base = [sys.executable, str(script)]
|
| 188 |
+
|
| 189 |
+
if step == "preprocess":
|
| 190 |
+
pre = cfg.get("preprocess", {})
|
| 191 |
+
cmd = base[:]
|
| 192 |
+
if pre.get("config"):
|
| 193 |
+
cmd += ["--config", pre["config"]]
|
| 194 |
+
if force:
|
| 195 |
+
cmd += ["--force"]
|
| 196 |
+
return [cmd]
|
| 197 |
+
|
| 198 |
+
if step == "normalize":
|
| 199 |
+
if not require(["sxr_dir", "output_path"], "normalize"):
|
| 200 |
+
return None
|
| 201 |
+
n = cfg["normalize"]
|
| 202 |
+
return [base + ["--sxr_dir", n["sxr_dir"], "--output_path", n["output_path"]]]
|
| 203 |
+
|
| 204 |
+
if step == "split":
|
| 205 |
+
if not require(["aia_input_dir", "sxr_input_dir"], "split"):
|
| 206 |
+
return None
|
| 207 |
+
s = cfg["split"]
|
| 208 |
+
date_args = []
|
| 209 |
+
for key in ("train_start", "train_end", "val_start", "val_end", "test_start", "test_end"):
|
| 210 |
+
if s.get(key):
|
| 211 |
+
date_args += [f"--{key}", s[key]]
|
| 212 |
+
# Each data type splits into its own input directory (creates train/val/test subdirs there)
|
| 213 |
+
aia_cmd = base + ["--input_folder", s["aia_input_dir"], "--output_dir", s["aia_input_dir"],
|
| 214 |
+
"--data_type", "aia"] + date_args
|
| 215 |
+
sxr_cmd = base + ["--input_folder", s["sxr_input_dir"], "--output_dir", s["sxr_input_dir"],
|
| 216 |
+
"--data_type", "sxr"] + date_args
|
| 217 |
+
return [aia_cmd, sxr_cmd]
|
| 218 |
+
|
| 219 |
+
if step == "train":
|
| 220 |
+
if not require(["config"], "train"):
|
| 221 |
+
return None
|
| 222 |
+
t = cfg["train"]
|
| 223 |
+
config_path = t["config"]
|
| 224 |
+
if t.get("overrides"):
|
| 225 |
+
config_path = str(write_merged_config(config_path, t["overrides"], "train_config"))
|
| 226 |
+
return [base + ["-config", config_path]]
|
| 227 |
+
|
| 228 |
+
if step == "inference":
|
| 229 |
+
if not require(["config"], "inference"):
|
| 230 |
+
return None
|
| 231 |
+
inf = cfg["inference"]
|
| 232 |
+
config_path = inf["config"]
|
| 233 |
+
if inf.get("overrides"):
|
| 234 |
+
config_path = str(write_merged_config(config_path, inf["overrides"], "inference_config"))
|
| 235 |
+
return [base + ["-config", config_path]]
|
| 236 |
+
|
| 237 |
+
if step == "flare_analysis":
|
| 238 |
+
if not require(["config"], "inference"):
|
| 239 |
+
return None
|
| 240 |
+
inf = cfg["inference"]
|
| 241 |
+
config_path = inf["config"]
|
| 242 |
+
if inf.get("overrides"):
|
| 243 |
+
config_path = str(write_merged_config(config_path, inf["overrides"], "inference_config"))
|
| 244 |
+
return [base + ["--config", config_path]]
|
| 245 |
+
|
| 246 |
+
return [base]
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ---------------------------------------------------------------------------
|
| 250 |
+
# Runner
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
def run_step(step: str, cmds: list[list[str]]) -> bool:
|
| 254 |
+
info = STEP_INFO[step]
|
| 255 |
+
total_start = time.time()
|
| 256 |
+
|
| 257 |
+
for i, cmd in enumerate(cmds):
|
| 258 |
+
label = f"{step.upper()}" + (f" ({i + 1}/{len(cmds)})" if len(cmds) > 1 else "")
|
| 259 |
+
log.info("")
|
| 260 |
+
log.info("=" * 70)
|
| 261 |
+
log.info(f" STEP: {label}")
|
| 262 |
+
log.info(f" {info['description']}")
|
| 263 |
+
log.info(f" {' '.join(str(c) for c in cmd)}")
|
| 264 |
+
log.info("=" * 70)
|
| 265 |
+
|
| 266 |
+
result = subprocess.run(cmd, cwd=ROOT)
|
| 267 |
+
if result.returncode != 0:
|
| 268 |
+
log.error(f" FAILED {label} exited with code {result.returncode}")
|
| 269 |
+
return False
|
| 270 |
+
|
| 271 |
+
elapsed = time.time() - total_start
|
| 272 |
+
log.info(f" DONE {step} completed in {elapsed:.1f}s")
|
| 273 |
+
return True
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ---------------------------------------------------------------------------
|
| 277 |
+
# CLI
|
| 278 |
+
# ---------------------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
def list_steps():
|
| 281 |
+
print("\nAvailable pipeline steps (in order):\n")
|
| 282 |
+
for i, step in enumerate(STEP_ORDER, 1):
|
| 283 |
+
print(f" {i}. {step:<16} {STEP_INFO[step]['description']}")
|
| 284 |
+
print()
|
| 285 |
+
print("Use --steps all to run every step, or comma-separate specific steps.")
|
| 286 |
+
print("Example: --steps train,inference,flare_analysis\n")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def main():
|
| 290 |
+
parser = argparse.ArgumentParser(
|
| 291 |
+
description="FOXES End-to-End Pipeline Orchestrator",
|
| 292 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 293 |
+
)
|
| 294 |
+
parser.add_argument("--config", type=str, default=None, help="Path to pipeline_config.yaml")
|
| 295 |
+
parser.add_argument("--steps", type=str, default=None,
|
| 296 |
+
help=f"Comma-separated steps to run, or 'all'. Available: {', '.join(STEP_ORDER)}")
|
| 297 |
+
parser.add_argument("--list", action="store_true", help="List all available steps and exit")
|
| 298 |
+
parser.add_argument("--force", action="store_true", help="Force re-run (forwarded to preprocess step)")
|
| 299 |
+
|
| 300 |
+
args = parser.parse_args()
|
| 301 |
+
|
| 302 |
+
if args.list:
|
| 303 |
+
list_steps()
|
| 304 |
+
return
|
| 305 |
+
|
| 306 |
+
if not args.steps:
|
| 307 |
+
parser.print_help()
|
| 308 |
+
return
|
| 309 |
+
|
| 310 |
+
if not args.config:
|
| 311 |
+
log.error("--config is required. Point it at your pipeline_config.yaml.")
|
| 312 |
+
sys.exit(1)
|
| 313 |
+
|
| 314 |
+
with open(args.config, "r") as f:
|
| 315 |
+
cfg = yaml.safe_load(f)
|
| 316 |
+
|
| 317 |
+
# Resolve step list
|
| 318 |
+
if args.steps.strip().lower() == "all":
|
| 319 |
+
selected = list(STEP_ORDER)
|
| 320 |
+
else:
|
| 321 |
+
selected = [s.strip() for s in args.steps.split(",")]
|
| 322 |
+
unknown = [s for s in selected if s not in STEP_INFO]
|
| 323 |
+
if unknown:
|
| 324 |
+
log.error(f"Unknown steps: {', '.join(unknown)}")
|
| 325 |
+
list_steps()
|
| 326 |
+
sys.exit(1)
|
| 327 |
+
selected = [s for s in STEP_ORDER if s in selected] # preserve order
|
| 328 |
+
|
| 329 |
+
log.info(f"Config: {args.config}")
|
| 330 |
+
log.info(f"Running {len(selected)} step(s): {' -> '.join(selected)}")
|
| 331 |
+
|
| 332 |
+
passed, failed = [], []
|
| 333 |
+
|
| 334 |
+
for step in selected:
|
| 335 |
+
cmds = build_commands(step, cfg, args.force)
|
| 336 |
+
if cmds is None:
|
| 337 |
+
failed.append(step)
|
| 338 |
+
break
|
| 339 |
+
|
| 340 |
+
if run_step(step, cmds):
|
| 341 |
+
passed.append(step)
|
| 342 |
+
else:
|
| 343 |
+
failed.append(step)
|
| 344 |
+
log.error(f"Pipeline stopped at '{step}'.")
|
| 345 |
+
break
|
| 346 |
+
|
| 347 |
+
# Summary
|
| 348 |
+
log.info("")
|
| 349 |
+
log.info("=" * 70)
|
| 350 |
+
log.info("PIPELINE SUMMARY")
|
| 351 |
+
log.info("=" * 70)
|
| 352 |
+
for s in passed:
|
| 353 |
+
log.info(f" PASSED {s}")
|
| 354 |
+
for s in failed:
|
| 355 |
+
log.error(f" FAILED {s}")
|
| 356 |
+
for s in [s for s in selected if s not in passed and s not in failed]:
|
| 357 |
+
log.info(f" SKIPPED {s}")
|
| 358 |
+
log.info("=" * 70)
|
| 359 |
+
|
| 360 |
+
sys.exit(0 if not failed else 1)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
if __name__ == "__main__":
|
| 364 |
+
main()
|