Spaces:
Running
Running
| """Tests for LoRA checkpoint oxygen diagnostics helpers.""" | |
| from __future__ import annotations | |
| import importlib.util | |
| from pathlib import Path | |
| import numpy as np | |
| import pytest | |
| def _load_module(): | |
| path = Path(__file__).parents[1] / "scripts" / "38_eval_lora_checkpoint.py" | |
| spec = importlib.util.spec_from_file_location("eval_lora_checkpoint", path) | |
| assert spec is not None | |
| module = importlib.util.module_from_spec(spec) | |
| assert spec.loader is not None | |
| spec.loader.exec_module(module) | |
| return module | |
| def test_oxygen_diagnostics_reports_confusion_and_per_class_metrics() -> None: | |
| mod = _load_module() | |
| classes = ["aerobe", "anaerobe", "facultative_anaerobe", "microaerobe"] | |
| labels = np.array([0, 0, 1, 1, 2, 3]) | |
| probs = np.array([ | |
| [0.90, 0.05, 0.03, 0.02], | |
| [0.20, 0.70, 0.05, 0.05], | |
| [0.05, 0.85, 0.05, 0.05], | |
| [0.10, 0.60, 0.20, 0.10], | |
| [0.05, 0.10, 0.80, 0.05], | |
| [0.10, 0.20, 0.60, 0.10], | |
| ]) | |
| rows = [{"bacdive_id": i, "genome_accession": f"G{i}"} for i in range(len(labels))] | |
| out = mod.compute_oxygen_diagnostics(probs, labels, rows, classes, top_n_errors=2) | |
| assert out["n"] == 6 | |
| assert out["accuracy"] == pytest.approx(4 / 6) | |
| assert out["confusion_matrix"] == [ | |
| [1, 1, 0, 0], | |
| [0, 2, 0, 0], | |
| [0, 0, 1, 0], | |
| [0, 0, 1, 0], | |
| ] | |
| assert out["per_class"]["aerobe"]["recall"] == 0.5 | |
| assert out["per_class"]["anaerobe"]["precision"] == pytest.approx(2 / 3) | |
| assert out["per_class"]["microaerobe"]["f1"] == 0.0 | |
| assert out["macro_f1"] == pytest.approx((2 / 3 + 0.8 + 2 / 3 + 0.0) / 4) | |
| assert out["macro_f1_all_classes"] == pytest.approx(out["macro_f1"]) | |
| assert out["wrong_predictions"][0]["confidence"] == 0.7 | |
| assert out["wrong_predictions"][0]["true"] == "aerobe" | |
| assert out["wrong_predictions"][0]["pred"] == "anaerobe" | |
| def test_macro_f1_ignores_zero_support_classes() -> None: | |
| mod = _load_module() | |
| classes = ["aerobe", "anaerobe", "facultative_anaerobe", "microaerobe"] | |
| labels = np.array([0, 1]) | |
| probs = np.array([ | |
| [0.90, 0.10, 0.00, 0.00], | |
| [0.20, 0.80, 0.00, 0.00], | |
| ]) | |
| rows = [{"bacdive_id": i, "genome_accession": f"G{i}"} for i in range(len(labels))] | |
| out = mod.compute_oxygen_diagnostics(probs, labels, rows, classes) | |
| assert out["macro_f1"] == 1.0 | |
| assert out["macro_f1_all_classes"] == 0.5 | |
| def test_render_markdown_includes_key_sections() -> None: | |
| mod = _load_module() | |
| diagnostics = { | |
| "checkpoint": "artifacts/lora/fold0_best.pt", | |
| "n": 2, | |
| "accuracy": 0.5, | |
| "macro_f1": 0.333333333, | |
| "macro_f1_all_classes": 0.333333333, | |
| "confusion_matrix": [[1, 0], [1, 0]], | |
| "classes": ["aerobe", "anaerobe"], | |
| "per_class": { | |
| "aerobe": { | |
| "precision": 0.5, | |
| "recall": 1.0, | |
| "f1": 0.666666667, | |
| "support": 1, | |
| "predicted": 2, | |
| }, | |
| "anaerobe": { | |
| "precision": 0.0, | |
| "recall": 0.0, | |
| "f1": 0.0, | |
| "support": 1, | |
| "predicted": 0, | |
| }, | |
| }, | |
| "wrong_predictions": [ | |
| { | |
| "bacdive_id": 42, | |
| "genome_accession": "GCA_42", | |
| "true": "anaerobe", | |
| "pred": "aerobe", | |
| "confidence": 0.9, | |
| "true_probability": 0.1, | |
| "margin": 0.8, | |
| } | |
| ], | |
| } | |
| md = mod.render_markdown(diagnostics) | |
| assert "# LoRA Oxygen Diagnostics" in md | |
| assert "| True \\ Pred | aerobe | anaerobe |" in md | |
| assert "| anaerobe | 1 | 0 |" in md | |
| assert "GCA_42" in md | |