File size: 4,595 Bytes
8ce6086
2b2b44a
8ce6086
 
 
 
 
 
 
2b2b44a
5d769aa
e1af142
ec2b4e7
2b2b44a
 
e1af142
c14af1c
 
 
ec2b4e7
 
 
 
 
38e0aee
8ce6086
2b2b44a
8ce6086
 
 
 
 
 
 
 
 
 
 
 
2b2b44a
655c5c3
8ce6086
 
e1af142
8ce6086
2b2b44a
38e0aee
655c5c3
8ce6086
 
 
 
 
2b2b44a
ec2b4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655c5c3
 
 
ec2b4e7
655c5c3
ec2b4e7
 
655c5c3
ec2b4e7
655c5c3
 
 
 
 
 
ec2b4e7
655c5c3
8ce6086
ec2b4e7
 
 
 
 
 
 
 
 
 
 
 
 
655c5c3
ec2b4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import collections.abc

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping
# Now import hyper
import numpy as np
from astropy.visualization import ImageNormalize, AsinhStretch
from itipy.data.dataset import StackDataset, get_intersecting_files, AIADataset
from itipy.data.editor import BrightestPixelPatchEditor, sdo_norms
import os
import json
from multiprocessing import Pool
from tqdm import tqdm


def load_config():
    """Load configuration from environment or use defaults."""
    try:
        config = json.loads(os.environ['PIPELINE_CONFIG'])
        return config
    except:
        pass



class SDODataset_flaring(StackDataset):
    """
    Dataset for SDO data

    Args:
        data: Data
        patch_shape (tuple): Patch shape
        wavelengths (list): List of wavelengths
        resolution (int): Resolution
        ext (str): File extension
        **kwargs: Additional arguments
    """

    def __init__(self, data, patch_shape=None, wavelengths=None, resolution=2048, ext='.fits', allow_errors=False, **kwargs):
        if isinstance(data, list):
            paths = data
        else:
            paths = get_intersecting_files(data, wavelengths, ext=ext, **kwargs)
        ds_mapping = {94: AIADataset, 131: AIADataset, 171: AIADataset, 193: AIADataset, 211: AIADataset,
                      304: AIADataset, 335: AIADataset, 1600: AIADataset, 1700: AIADataset, 4500: AIADataset, 6173: AIADataset}
        data_sets = [ds_mapping[wl_id](files, wavelength=wl_id, resolution=resolution, ext=ext, allow_errors=allow_errors)
                     for wl_id, files in zip(wavelengths, paths)]
        super().__init__(data_sets, **kwargs)
        if patch_shape is not None:
            self.addEditor(BrightestPixelPatchEditor(patch_shape))


_aia_dataset = None
_output_folder = None


def _init_worker(dataset, out_folder):
    global _aia_dataset, _output_folder
    _aia_dataset = dataset
    _output_folder = out_folder


def save_sample(i):
    try:
        data = _aia_dataset[i]
        file_path = os.path.join(_output_folder, _aia_dataset.getId(i)) + '.npy'
        np.save(file_path, data)
    except Exception as e:
        print(f"Warning: Could not process sample {i} (ID: {_aia_dataset.getId(i)}): {e}")


def check_existing_files(base_input_folder, wavelengths, output_folder):
    """Check how many files already exist without loading the full dataset."""
    files = get_intersecting_files(base_input_folder, wavelengths, ext='.fits')
    if not files or len(files) == 0:
        return 0, 0

    existing_count = 0
    total_expected = len(files[0])

    for i in range(total_expected):
        first_wl_file = files[0][i]
        base_name = os.path.splitext(os.path.basename(first_wl_file))[0]
        if '_' in base_name:
            base_name = '_'.join(base_name.split('_')[:-1])
        output_path = os.path.join(output_folder, base_name) + '.npy'
        if os.path.exists(output_path):
            existing_count += 1

    return existing_count, total_expected


if __name__ == '__main__':
    config = load_config()
    wavelengths = config['iti']['wavelengths']
    base_input_folder = config['iti']['input_folder']
    output_folder = config['iti']['output_folder']
    os.makedirs(output_folder, exist_ok=True)

    existing_files, total_expected = check_existing_files(base_input_folder, wavelengths, output_folder)
    print(f"Found {existing_files} existing files out of {total_expected} expected files")

    if existing_files >= total_expected:
        print("All files already processed. Nothing to do.")
    else:
        print(f"Need to process {total_expected - existing_files} remaining files")

        aia_dataset = SDODataset_flaring(data=base_input_folder, wavelengths=wavelengths, resolution=512, allow_errors=True)

        unprocessed_indices = [
            i for i in range(len(aia_dataset))
            if not os.path.exists(os.path.join(output_folder, aia_dataset.getId(i)) + '.npy')
        ]
        print(f"Processing {len(unprocessed_indices)} unprocessed samples")

        if unprocessed_indices:
            with Pool(processes=os.cpu_count(), initializer=_init_worker, initargs=(aia_dataset, output_folder)) as pool:
                list(tqdm(pool.imap(save_sample, unprocessed_indices), total=len(unprocessed_indices)))
            print("AIA data processing completed.")
        else:
            print("All samples already processed. Nothing to do.")