File size: 6,763 Bytes
4e9fa0a
 
 
 
 
 
8f68bdf
4e9fa0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018e0f
4e9fa0a
 
1018e0f
4e9fa0a
1018e0f
4e9fa0a
8998978
1018e0f
4e9fa0a
 
 
50e2e93
 
8f68bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50e2e93
 
8f68bdf
1018e0f
8f68bdf
 
 
 
8998978
8f68bdf
 
8998978
8f68bdf
 
 
 
 
 
1018e0f
4e9fa0a
8f68bdf
1018e0f
 
4e9fa0a
1018e0f
 
4e9fa0a
8f68bdf
 
 
 
50e2e93
4e9fa0a
1018e0f
4e9fa0a
8998978
50e2e93
8f68bdf
 
 
50e2e93
8998978
 
 
 
 
4e9fa0a
 
 
 
 
 
 
 
 
 
 
50e2e93
4e9fa0a
 
 
 
50e2e93
 
 
 
 
 
 
 
 
 
 
4e9fa0a
 
50e2e93
 
 
4e9fa0a
 
 
 
 
 
 
 
 
50e2e93
 
4e9fa0a
 
 
 
 
 
 
 
50e2e93
 
4e9fa0a
50e2e93
 
4e9fa0a
 
50e2e93
4e9fa0a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import glob
import shutil
import logging
import torch
import torch.nn.functional as F
import concurrent.futures

from src.config.settings import Settings
from src.data.fetchers.goes_fetcher import GOESFetcher
from src.data.fetchers.himawari_fetcher import HimawariFetcher
from src.data.standardizer import UniversalStandardizer

logger = logging.getLogger(__name__)


class DataManager:
    """
    Universal multi-satellite data pipeline manager.
    """

    def __init__(self, settings: Settings):
        self.settings = settings
        self.pt_dir = settings.data.download_dir
        self.raw_dir = os.path.join(self.pt_dir, "raw_data")

        os.makedirs(self.pt_dir, exist_ok=True)
        os.makedirs(self.raw_dir, exist_ok=True)

        sat_type = getattr(settings.data, "satellite_type", "goes").lower()

        if sat_type == "goes":
            self.fetcher = GOESFetcher(
                bucket_name=settings.data.s3_bucket
            )
        elif sat_type == "himawari":
            self.fetcher = HimawariFetcher(
                bucket_name=settings.data.s3_bucket
            )
        else:
            raise ValueError(f"Unsupported satellite type: {sat_type}")
    

    def process_chunk(self, chunk_prefix: str) -> None:
        logger.info(f"Processing chunk {chunk_prefix}")

        frame_keys = self.fetcher.fetch_chunk(chunk_prefix)

        if len(frame_keys) < 3:
            logger.warning("Not enough frames for triplets.")
            return

        frame_step = self.settings.data.frame_step
        tensor_cache = {}

        # 🚀 HELPER: Streams ONLY the missing frames for the current triplet in parallel
        def fetch_triplet_to_ram(keys):
            missing = [k for k in keys if k not in tensor_cache]
            if missing:
                def _stream(k):
                    # Convert back to list if it's a tuple (for Himawari 10 segments)
                    actual_key = list(k) if isinstance(k, tuple) else k
                    raw = self.fetcher.stream_and_apply_planck(actual_key)
                    return k, UniversalStandardizer.normalize_bt(
                        raw, self.settings.data.min_bt, self.settings.data.max_bt
                    )
                
                # Fetch only the missing 1, 2, or 3 frames concurrently into RAM
                with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
                    for k, tensor in executor.map(_stream, missing):
                        tensor_cache[k] = tensor

        logger.info("🔥 Starting Zero-Disk In-Memory Streaming...")
        
        # 2. Main Processing Loop
        for i in range(len(frame_keys) - 2 * frame_step):
            try:
                # 🚨 SOTA FIX: Convert list to tuple so it can be used as a Dictionary Key!
                k0 = tuple(frame_keys[i]) if isinstance(frame_keys[i], list) else frame_keys[i]
                k1 = tuple(frame_keys[i + frame_step]) if isinstance(frame_keys[i + frame_step], list) else frame_keys[i + frame_step]
                k2 = tuple(frame_keys[i + 2 * frame_step]) if isinstance(frame_keys[i + 2 * frame_step], list) else frame_keys[i + 2 * frame_step]

                # 1. Fetch exactly what is needed right now to RAM
                fetch_triplet_to_ram([k0, k1, k2])

                # 2. Extract from RAM cache
                img0 = tensor_cache[k0]
                gt   = tensor_cache[k1]
                img1 = tensor_cache[k2]

                # 3. Fast Crop
                img0_crop, img1_crop, gt_crop = self._motion_guided_argmax_crop(img0, img1, gt)

                # 4. Save Triplet directly to disk
                safe_prefix = chunk_prefix.replace("/", "_")
                pt_filename = os.path.join(self.pt_dir, f"triplet_{safe_prefix}_{i:03d}.pt")

                triplet_tensor = torch.stack([img0_crop, gt_crop, img1_crop], dim=0)
                torch.save(triplet_tensor, pt_filename)

                # 🚀 SMART MEMORY MANAGEMENT
                # k0 will not be used in future iterations, so clear it from RAM
                if k0 in tensor_cache:
                    del tensor_cache[k0]

            except Exception as e:
                logger.error(f"Triplet failed ({i}): {e}")
                continue
        
        # 3. Cleanup
        tensor_cache.clear()
        self.purge_raw_files() # Kept for fallback cleanup
        logger.info("Chunk processing complete. Zero raw files written to disk!")
    
    def _delete_temp(self, path: str):
        if os.path.isfile(path):
            os.remove(path)
        elif os.path.isdir(path):
            shutil.rmtree(path)

    def _motion_guided_argmax_crop(
        self,
        img0: torch.Tensor,
        img1: torch.Tensor,
        gt: torch.Tensor
    ):
        crop_size = self.settings.data.crop_size
        _, h, w = img0.shape

        if h < crop_size or w < crop_size:
            raise ValueError(f"Image smaller than crop size: {h}x{w}")

        motion_map = torch.abs(img1 - img0)
        space_mask = (img0 > 0.0).float()
        motion_map = motion_map * space_mask
        
        scale_factor = 8
        small_motion = F.avg_pool2d(
            motion_map.unsqueeze(0), 
            kernel_size=scale_factor, 
            stride=scale_factor
        )
        
        small_crop_size = crop_size // scale_factor
        divisor = getattr(self.settings.data, 'crop_stride_divisor', 8)
        small_stride = max(1, small_crop_size // divisor)

        pooled_motion = F.avg_pool2d(
            small_motion,
            kernel_size=small_crop_size,
            stride=small_stride
        )

        _, _, h_out, w_out = pooled_motion.shape

        flat_idx = torch.argmax(pooled_motion).item()

        y_out = flat_idx // w_out
        x_out = flat_idx % w_out

        y = y_out * small_stride * scale_factor
        x = x_out * small_stride * scale_factor

        y = max(0, min(y, h - crop_size))
        x = max(0, min(x, w - crop_size))

        img0_crop = img0[:, y:y+crop_size, x:x+crop_size]
        img1_crop = img1[:, y:y+crop_size, x:x+crop_size]
        gt_crop = gt[:, y:y+crop_size, x:x+crop_size]

        crop_motion = torch.abs(img1_crop - img0_crop).mean().item()
        static_threshold = getattr(self.settings.data, 'static_motion_threshold', 0.005)

        if crop_motion < static_threshold:
            raise ValueError(f"Static crop rejected: {crop_motion:.5f}")

        return img0_crop, img1_crop, gt_crop
    
    def purge_raw_files(self):
        logger.info("Purging raw files...")
        for f in glob.glob(os.path.join(self.raw_dir, "*")):
            if os.path.isfile(f):
                os.remove(f)
            elif os.path.isdir(f):
                shutil.rmtree(f)