File size: 4,595 Bytes
8ce6086 2b2b44a 8ce6086 2b2b44a 5d769aa e1af142 ec2b4e7 2b2b44a e1af142 c14af1c ec2b4e7 38e0aee 8ce6086 2b2b44a 8ce6086 2b2b44a 655c5c3 8ce6086 e1af142 8ce6086 2b2b44a 38e0aee 655c5c3 8ce6086 2b2b44a ec2b4e7 655c5c3 ec2b4e7 655c5c3 ec2b4e7 655c5c3 ec2b4e7 655c5c3 ec2b4e7 655c5c3 8ce6086 ec2b4e7 655c5c3 ec2b4e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import collections.abc
collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping
# Now import hyper
import numpy as np
from astropy.visualization import ImageNormalize, AsinhStretch
from itipy.data.dataset import StackDataset, get_intersecting_files, AIADataset
from itipy.data.editor import BrightestPixelPatchEditor, sdo_norms
import os
import json
from multiprocessing import Pool
from tqdm import tqdm
def load_config():
"""Load configuration from environment or use defaults."""
try:
config = json.loads(os.environ['PIPELINE_CONFIG'])
return config
except:
pass
class SDODataset_flaring(StackDataset):
"""
Dataset for SDO data
Args:
data: Data
patch_shape (tuple): Patch shape
wavelengths (list): List of wavelengths
resolution (int): Resolution
ext (str): File extension
**kwargs: Additional arguments
"""
def __init__(self, data, patch_shape=None, wavelengths=None, resolution=2048, ext='.fits', allow_errors=False, **kwargs):
if isinstance(data, list):
paths = data
else:
paths = get_intersecting_files(data, wavelengths, ext=ext, **kwargs)
ds_mapping = {94: AIADataset, 131: AIADataset, 171: AIADataset, 193: AIADataset, 211: AIADataset,
304: AIADataset, 335: AIADataset, 1600: AIADataset, 1700: AIADataset, 4500: AIADataset, 6173: AIADataset}
data_sets = [ds_mapping[wl_id](files, wavelength=wl_id, resolution=resolution, ext=ext, allow_errors=allow_errors)
for wl_id, files in zip(wavelengths, paths)]
super().__init__(data_sets, **kwargs)
if patch_shape is not None:
self.addEditor(BrightestPixelPatchEditor(patch_shape))
_aia_dataset = None
_output_folder = None
def _init_worker(dataset, out_folder):
global _aia_dataset, _output_folder
_aia_dataset = dataset
_output_folder = out_folder
def save_sample(i):
try:
data = _aia_dataset[i]
file_path = os.path.join(_output_folder, _aia_dataset.getId(i)) + '.npy'
np.save(file_path, data)
except Exception as e:
print(f"Warning: Could not process sample {i} (ID: {_aia_dataset.getId(i)}): {e}")
def check_existing_files(base_input_folder, wavelengths, output_folder):
"""Check how many files already exist without loading the full dataset."""
files = get_intersecting_files(base_input_folder, wavelengths, ext='.fits')
if not files or len(files) == 0:
return 0, 0
existing_count = 0
total_expected = len(files[0])
for i in range(total_expected):
first_wl_file = files[0][i]
base_name = os.path.splitext(os.path.basename(first_wl_file))[0]
if '_' in base_name:
base_name = '_'.join(base_name.split('_')[:-1])
output_path = os.path.join(output_folder, base_name) + '.npy'
if os.path.exists(output_path):
existing_count += 1
return existing_count, total_expected
if __name__ == '__main__':
config = load_config()
wavelengths = config['iti']['wavelengths']
base_input_folder = config['iti']['input_folder']
output_folder = config['iti']['output_folder']
os.makedirs(output_folder, exist_ok=True)
existing_files, total_expected = check_existing_files(base_input_folder, wavelengths, output_folder)
print(f"Found {existing_files} existing files out of {total_expected} expected files")
if existing_files >= total_expected:
print("All files already processed. Nothing to do.")
else:
print(f"Need to process {total_expected - existing_files} remaining files")
aia_dataset = SDODataset_flaring(data=base_input_folder, wavelengths=wavelengths, resolution=512, allow_errors=True)
unprocessed_indices = [
i for i in range(len(aia_dataset))
if not os.path.exists(os.path.join(output_folder, aia_dataset.getId(i)) + '.npy')
]
print(f"Processing {len(unprocessed_indices)} unprocessed samples")
if unprocessed_indices:
with Pool(processes=os.cpu_count(), initializer=_init_worker, initargs=(aia_dataset, output_folder)) as pool:
list(tqdm(pool.imap(save_sample, unprocessed_indices), total=len(unprocessed_indices)))
print("AIA data processing completed.")
else:
print("All samples already processed. Nothing to do.")
|