Fill-the-Frames / tests /test_data_manager.py
Siddhant Sharma
Added universal date and time for training
5e2228a
Raw
History Blame Contribute Delete
1.53 kB
import torch
import pytest
from src.data.data_manager import DataManager
from src.config.settings import Settings, TrainingConfig, DataConfig
@pytest.fixture
def settings():
return Settings(
training=TrainingConfig(
epochs=1,
batch_size=1,
learning_rate=1e-4,
weight_decay=0.001,
num_workers=2,
checkpoints_dir="chkpt"
),
data=DataConfig(
satellite_type="goes",
s3_bucket="test",
download_dir="tmp",
prefix_type="ABI",
year=2024,
start_day=1,
end_day=2,
frame_step=1,
crop_size=64,
crop_stride_divisor=4,
static_motion_threshold=0.005
)
)
def test_motion_crop_detects_motion(settings):
manager = DataManager(settings)
img0 = torch.ones((1, 256, 256)) * 0.5
img1 = img0.clone()
gt = img0.clone()
img1[:, 100:164, 100:164] = 1.0
crop0, crop1, crop_gt = manager._motion_guided_argmax_crop(
img0, img1, gt
)
motion = torch.abs(crop1 - crop0).mean().item()
assert crop0.shape == (1, 64, 64)
assert motion > settings.data.static_motion_threshold
def test_static_crop_rejection(settings):
manager = DataManager(settings)
img0 = torch.zeros((1, 256, 256))
img1 = torch.zeros((1, 256, 256))
gt = torch.zeros((1, 256, 256))
with pytest.raises(ValueError):
manager._motion_guided_argmax_crop(img0, img1, gt)