| from __future__ import annotations | |
| import warnings | |
| import pandas as pd | |
| from sepsis_mcp.modeling import ProbabilityEstimator | |
| def test_probability_estimator_supports_logistic_regression() -> None: | |
| features = pd.DataFrame( | |
| { | |
| "feature_a": [0.0, 0.2, 0.8, 1.0, 0.1, 0.9], | |
| "feature_b": [0.0, 0.1, 0.9, 1.0, 0.2, 0.8], | |
| } | |
| ) | |
| labels = pd.Series([0, 0, 1, 1, 0, 1], dtype=int) | |
| estimator = ProbabilityEstimator(random_state=0, model_type="logistic_regression") | |
| estimator.fit(features, labels) | |
| probabilities = estimator.predict_positive_proba(features) | |
| assert probabilities.between(0.0, 1.0).all() | |
| assert len(probabilities) == len(features) | |
| def test_probability_estimator_logistic_regression_handles_nan_features() -> None: | |
| features = pd.DataFrame( | |
| { | |
| "feature_a": [0.0, float("nan"), 0.8, 1.0, 0.1, 0.9], | |
| "feature_b": [0.0, 0.1, float("nan"), 1.0, 0.2, 0.8], | |
| "feature_c": [float("nan")] * 6, | |
| } | |
| ) | |
| labels = pd.Series([0, 0, 1, 1, 0, 1], dtype=int) | |
| with warnings.catch_warnings(record=True) as caught: | |
| warnings.simplefilter("always") | |
| estimator = ProbabilityEstimator(random_state=0, model_type="logistic_regression") | |
| estimator.fit(features, labels) | |
| probabilities = estimator.predict_positive_proba(features) | |
| assert caught == [] | |
| assert probabilities.between(0.0, 1.0).all() | |
| def test_probability_estimator_supports_mlp() -> None: | |
| features = pd.DataFrame( | |
| { | |
| "feature_a": [0.0, 0.2, 0.8, 1.0, 0.1, 0.9], | |
| "feature_b": [0.0, 0.1, 0.9, 1.0, 0.2, 0.8], | |
| "feature_c": [float("nan"), 0.3, 0.7, float("nan"), 0.4, 0.6], | |
| } | |
| ) | |
| labels = pd.Series([0, 0, 1, 1, 0, 1], dtype=int) | |
| estimator = ProbabilityEstimator(random_state=0, model_type="mlp") | |
| estimator.fit(features, labels) | |
| probabilities = estimator.predict_positive_proba(features) | |
| assert probabilities.between(0.0, 1.0).all() | |
| assert len(probabilities) == len(features) | |