| """Tests for class imbalance helpers.""" | |
| from __future__ import annotations | |
| from dipauglib.sampling.class_imbalance import build_weighted_sampler, class_weights_from_counts, minority_class_names | |
| def test_class_weights_from_counts(): | |
| counts = {"a": 10, "b": 20} | |
| weights = class_weights_from_counts(counts) | |
| assert weights["a"] > weights["b"] | |
| def test_minority_class_names(): | |
| counts = {"major": 90, "minor": 10} | |
| minorities = minority_class_names(counts, threshold_ratio=0.15) | |
| assert minorities == {"minor"} | |
| def test_weighted_sampler_builds(): | |
| labels = ["a", "a", "b", "b", "b"] | |
| counts = {"a": 2, "b": 3} | |
| sampler = build_weighted_sampler(labels, counts) | |
| assert sampler.num_samples == len(labels) | |