Spaces:
Running
Running
File size: 6,763 Bytes
4e9fa0a 8f68bdf 4e9fa0a 1018e0f 4e9fa0a 1018e0f 4e9fa0a 1018e0f 4e9fa0a 8998978 1018e0f 4e9fa0a 50e2e93 8f68bdf 50e2e93 8f68bdf 1018e0f 8f68bdf 8998978 8f68bdf 8998978 8f68bdf 1018e0f 4e9fa0a 8f68bdf 1018e0f 4e9fa0a 1018e0f 4e9fa0a 8f68bdf 50e2e93 4e9fa0a 1018e0f 4e9fa0a 8998978 50e2e93 8f68bdf 50e2e93 8998978 4e9fa0a 50e2e93 4e9fa0a 50e2e93 4e9fa0a 50e2e93 4e9fa0a 50e2e93 4e9fa0a 50e2e93 4e9fa0a 50e2e93 4e9fa0a 50e2e93 4e9fa0a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | import os
import glob
import shutil
import logging
import torch
import torch.nn.functional as F
import concurrent.futures
from src.config.settings import Settings
from src.data.fetchers.goes_fetcher import GOESFetcher
from src.data.fetchers.himawari_fetcher import HimawariFetcher
from src.data.standardizer import UniversalStandardizer
logger = logging.getLogger(__name__)
class DataManager:
"""
Universal multi-satellite data pipeline manager.
"""
def __init__(self, settings: Settings):
self.settings = settings
self.pt_dir = settings.data.download_dir
self.raw_dir = os.path.join(self.pt_dir, "raw_data")
os.makedirs(self.pt_dir, exist_ok=True)
os.makedirs(self.raw_dir, exist_ok=True)
sat_type = getattr(settings.data, "satellite_type", "goes").lower()
if sat_type == "goes":
self.fetcher = GOESFetcher(
bucket_name=settings.data.s3_bucket
)
elif sat_type == "himawari":
self.fetcher = HimawariFetcher(
bucket_name=settings.data.s3_bucket
)
else:
raise ValueError(f"Unsupported satellite type: {sat_type}")
def process_chunk(self, chunk_prefix: str) -> None:
logger.info(f"Processing chunk {chunk_prefix}")
frame_keys = self.fetcher.fetch_chunk(chunk_prefix)
if len(frame_keys) < 3:
logger.warning("Not enough frames for triplets.")
return
frame_step = self.settings.data.frame_step
tensor_cache = {}
# 🚀 HELPER: Streams ONLY the missing frames for the current triplet in parallel
def fetch_triplet_to_ram(keys):
missing = [k for k in keys if k not in tensor_cache]
if missing:
def _stream(k):
# Convert back to list if it's a tuple (for Himawari 10 segments)
actual_key = list(k) if isinstance(k, tuple) else k
raw = self.fetcher.stream_and_apply_planck(actual_key)
return k, UniversalStandardizer.normalize_bt(
raw, self.settings.data.min_bt, self.settings.data.max_bt
)
# Fetch only the missing 1, 2, or 3 frames concurrently into RAM
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
for k, tensor in executor.map(_stream, missing):
tensor_cache[k] = tensor
logger.info("🔥 Starting Zero-Disk In-Memory Streaming...")
# 2. Main Processing Loop
for i in range(len(frame_keys) - 2 * frame_step):
try:
# 🚨 SOTA FIX: Convert list to tuple so it can be used as a Dictionary Key!
k0 = tuple(frame_keys[i]) if isinstance(frame_keys[i], list) else frame_keys[i]
k1 = tuple(frame_keys[i + frame_step]) if isinstance(frame_keys[i + frame_step], list) else frame_keys[i + frame_step]
k2 = tuple(frame_keys[i + 2 * frame_step]) if isinstance(frame_keys[i + 2 * frame_step], list) else frame_keys[i + 2 * frame_step]
# 1. Fetch exactly what is needed right now to RAM
fetch_triplet_to_ram([k0, k1, k2])
# 2. Extract from RAM cache
img0 = tensor_cache[k0]
gt = tensor_cache[k1]
img1 = tensor_cache[k2]
# 3. Fast Crop
img0_crop, img1_crop, gt_crop = self._motion_guided_argmax_crop(img0, img1, gt)
# 4. Save Triplet directly to disk
safe_prefix = chunk_prefix.replace("/", "_")
pt_filename = os.path.join(self.pt_dir, f"triplet_{safe_prefix}_{i:03d}.pt")
triplet_tensor = torch.stack([img0_crop, gt_crop, img1_crop], dim=0)
torch.save(triplet_tensor, pt_filename)
# 🚀 SMART MEMORY MANAGEMENT
# k0 will not be used in future iterations, so clear it from RAM
if k0 in tensor_cache:
del tensor_cache[k0]
except Exception as e:
logger.error(f"Triplet failed ({i}): {e}")
continue
# 3. Cleanup
tensor_cache.clear()
self.purge_raw_files() # Kept for fallback cleanup
logger.info("Chunk processing complete. Zero raw files written to disk!")
def _delete_temp(self, path: str):
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
def _motion_guided_argmax_crop(
self,
img0: torch.Tensor,
img1: torch.Tensor,
gt: torch.Tensor
):
crop_size = self.settings.data.crop_size
_, h, w = img0.shape
if h < crop_size or w < crop_size:
raise ValueError(f"Image smaller than crop size: {h}x{w}")
motion_map = torch.abs(img1 - img0)
space_mask = (img0 > 0.0).float()
motion_map = motion_map * space_mask
scale_factor = 8
small_motion = F.avg_pool2d(
motion_map.unsqueeze(0),
kernel_size=scale_factor,
stride=scale_factor
)
small_crop_size = crop_size // scale_factor
divisor = getattr(self.settings.data, 'crop_stride_divisor', 8)
small_stride = max(1, small_crop_size // divisor)
pooled_motion = F.avg_pool2d(
small_motion,
kernel_size=small_crop_size,
stride=small_stride
)
_, _, h_out, w_out = pooled_motion.shape
flat_idx = torch.argmax(pooled_motion).item()
y_out = flat_idx // w_out
x_out = flat_idx % w_out
y = y_out * small_stride * scale_factor
x = x_out * small_stride * scale_factor
y = max(0, min(y, h - crop_size))
x = max(0, min(x, w - crop_size))
img0_crop = img0[:, y:y+crop_size, x:x+crop_size]
img1_crop = img1[:, y:y+crop_size, x:x+crop_size]
gt_crop = gt[:, y:y+crop_size, x:x+crop_size]
crop_motion = torch.abs(img1_crop - img0_crop).mean().item()
static_threshold = getattr(self.settings.data, 'static_motion_threshold', 0.005)
if crop_motion < static_threshold:
raise ValueError(f"Static crop rejected: {crop_motion:.5f}")
return img0_crop, img1_crop, gt_crop
def purge_raw_files(self):
logger.info("Purging raw files...")
for f in glob.glob(os.path.join(self.raw_dir, "*")):
if os.path.isfile(f):
os.remove(f)
elif os.path.isdir(f):
shutil.rmtree(f) |