sound-broken / tests /test_eval.py
mitvho09's picture
Upload Space app
edb671a verified
Raw
History Blame Contribute Delete
15.7 kB
"""Full evaluation suite — verifies every sample gets the RIGHT diagnosis.
Run: python -m pytest tests/test_eval.py -v
"""
import os, sys, json, tempfile
import numpy as np
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from audio_analyzer import extract_features, AudioFeatures
from fault_rules import rank_candidates, RULES
from feature_prompt import build_diagnosis_prompt
from json_guard import validate, DiagnosisResult
from mock_model import mock_generate as _mock_generate
ASSETS = os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets")
# =============================================================================
# EVALUATION 1: Each sample gets correct diagnosis with correct appliance
# =============================================================================
SAMPLE_DIAGNOSES = {
"sample_washer_bearing.wav": {
"appliance": "Washing machine",
"expected_fault_contains": ["bearing", "drum"],
"expected_urgency": ["HIGH", "CRITICAL"],
"expected_min_weight": 0.6,
},
"sample_fan_imbalanced.wav": {
"appliance": "Electric fan",
"expected_fault_contains": ["imbalance", "blade"],
"expected_urgency": ["MEDIUM", "HIGH"],
"expected_min_weight": 0.5,
},
"sample_motor_squeal.wav": {
"appliance": "Electric motor (generic)",
"expected_fault_contains": ["squeal", "whine", "bearing", "belt"],
"expected_urgency": ["MEDIUM", "HIGH"],
"expected_min_weight": 0.5,
},
"sample_washer_good.wav": {
"appliance": "Washing machine",
"expected_fault_contains": ["inconclusive"],
"expected_urgency": ["LOW"],
"expected_max_weight": 0.1,
},
}
class TestSampleDiagnoses:
"""Each sample WAV should produce a specific, correct diagnosis."""
@pytest.mark.parametrize("wav_name,expected", list(SAMPLE_DIAGNOSES.items()))
def test_sample_diagnosis(self, wav_name, expected):
path = os.path.join(ASSETS, wav_name)
if not os.path.exists(path):
pytest.skip(f"Sample {wav_name} not found")
features = extract_features(path)
candidates = rank_candidates(features, expected["appliance"])
assert len(candidates) >= 1
top = candidates[0]
# Check fault name contains expected keyword
fault_lower = top.name.lower()
assert any(kw in fault_lower for kw in expected["expected_fault_contains"]), \
f"Expected fault containing {expected['expected_fault_contains']}, got '{top.name}'"
# Check urgency
assert top.urgency in expected["expected_urgency"], \
f"Expected urgency in {expected['expected_urgency']}, got '{top.urgency}'"
# Check weight bounds
if "expected_min_weight" in expected:
assert top.weight >= expected["expected_min_weight"], \
f"Expected weight >= {expected['expected_min_weight']}, got {top.weight}"
if "expected_max_weight" in expected:
assert top.weight <= expected["expected_max_weight"], \
f"Expected weight <= {expected['expected_max_weight']}, got {top.weight}"
def test_all_four_samples_give_different_diagnoses(self):
"""The 4 samples should produce 3+ distinct fault names."""
faults = []
for wav_name, info in SAMPLE_DIAGNOSES.items():
path = os.path.join(ASSETS, wav_name)
if not os.path.exists(path):
continue
features = extract_features(path)
candidates = rank_candidates(features, info["appliance"])
faults.append(candidates[0].name)
assert len(set(faults)) >= 3, f"Expected 3+ distinct faults, got: {faults}"
# =============================================================================
# EVALUATION 2: Mock generate uses correct candidate
# =============================================================================
class TestMockGenerate:
"""The mock should return the top candidate's fault name, not 'Inconclusive'."""
def test_bearing_sample_returns_bearing(self):
path = os.path.join(ASSETS, "sample_washer_bearing.wav")
if not os.path.exists(path):
pytest.skip("Sample not found")
features = extract_features(path)
candidates = rank_candidates(features, "Washing machine")
prompt = build_diagnosis_prompt(features, candidates, "Washing machine")
raw = _mock_generate(prompt, candidates, features)
parsed = json.loads(raw)
assert parsed["fault"].lower() in candidates[0].name.lower() or \
candidates[0].name.lower() in parsed["fault"].lower(), \
f"Mock should return '{candidates[0].name}', got '{parsed['fault']}'"
assert parsed["confidence"] >= 60, f"Bearing should have high confidence, got {parsed['confidence']}"
def test_fan_sample_returns_fan_fault(self):
path = os.path.join(ASSETS, "sample_fan_imbalanced.wav")
if not os.path.exists(path):
pytest.skip("Sample not found")
features = extract_features(path)
candidates = rank_candidates(features, "Electric fan")
prompt = build_diagnosis_prompt(features, candidates, "Electric fan")
raw = _mock_generate(prompt, candidates, features)
parsed = json.loads(raw)
assert "inconclusive" not in parsed["fault"].lower(), \
f"Fan sample should NOT be Inconclusive, got '{parsed['fault']}'"
def test_good_sample_returns_inconclusive(self):
path = os.path.join(ASSETS, "sample_washer_good.wav")
if not os.path.exists(path):
pytest.skip("Sample not found")
features = extract_features(path)
candidates = rank_candidates(features, "Washing machine")
prompt = build_diagnosis_prompt(features, candidates, "Washing machine")
raw = _mock_generate(prompt, candidates, features)
parsed = json.loads(raw)
assert parsed["fault"].lower() == "inconclusive", \
f"Good sample should be Inconclusive, got '{parsed['fault']}'"
assert parsed["urgency"] == "LOW"
def test_mock_always_returns_valid_json(self):
"""Every sample should produce parseable JSON with required fields."""
for wav_name, info in SAMPLE_DIAGNOSES.items():
path = os.path.join(ASSETS, wav_name)
if not os.path.exists(path):
continue
features = extract_features(path)
candidates = rank_candidates(features, info["appliance"])
prompt = build_diagnosis_prompt(features, candidates, info["appliance"])
raw = _mock_generate(prompt, candidates, features)
parsed = json.loads(raw)
assert "fault" in parsed
assert "urgency" in parsed
assert "checks" in parsed and len(parsed["checks"]) >= 1
assert "safety" in parsed
assert "confidence" in parsed
assert 0 <= parsed["confidence"] <= 100
# =============================================================================
# EVALUATION 3: Full pipeline end-to-end
# =============================================================================
class TestFullPipeline:
"""End-to-end: audio file -> features -> rules -> mock -> validate -> result."""
def test_bearing_pipeline(self):
path = os.path.join(ASSETS, "sample_washer_bearing.wav")
if not os.path.exists(path):
pytest.skip("Sample not found")
features = extract_features(path)
candidates = rank_candidates(features, "Washing machine")
prompt = build_diagnosis_prompt(features, candidates, "Washing machine")
raw = _mock_generate(prompt, candidates, features)
result = validate(raw, candidates)
assert isinstance(result, DiagnosisResult)
assert result.grounded
assert "bearing" in result.fault.lower() or "drum" in result.fault.lower()
assert result.urgency in ("HIGH", "CRITICAL")
assert result.confidence >= 60
assert len(result.checks) >= 1
def test_fan_pipeline(self):
path = os.path.join(ASSETS, "sample_fan_imbalanced.wav")
if not os.path.exists(path):
pytest.skip("Sample not found")
features = extract_features(path)
candidates = rank_candidates(features, "Electric fan")
prompt = build_diagnosis_prompt(features, candidates, "Electric fan")
raw = _mock_generate(prompt, candidates, features)
result = validate(raw, candidates)
assert isinstance(result, DiagnosisResult)
assert result.grounded
assert "inconclusive" not in result.fault.lower()
assert result.confidence >= 50
def test_good_sample_pipeline(self):
path = os.path.join(ASSETS, "sample_washer_good.wav")
if not os.path.exists(path):
pytest.skip("Sample not found")
features = extract_features(path)
candidates = rank_candidates(features, "Washing machine")
prompt = build_diagnosis_prompt(features, candidates, "Washing machine")
raw = _mock_generate(prompt, candidates, features)
result = validate(raw, candidates)
assert isinstance(result, DiagnosisResult)
assert result.fault.lower() == "inconclusive"
assert result.urgency == "LOW"
def test_compare_produces_different_results(self):
"""Comparing bearing vs good washer should show clear improvement."""
f_bearing = extract_features(os.path.join(ASSETS, "sample_washer_bearing.wav"))
f_good = extract_features(os.path.join(ASSETS, "sample_washer_good.wav"))
c_bearing = rank_candidates(f_bearing, "Washing machine")
c_good = rank_candidates(f_good, "Washing machine")
# Bearing should have high anomaly, good should have low
assert f_bearing.anomaly_score > f_good.anomaly_score, \
f"Bearing ({f_bearing.anomaly_score}) should be more anomalous than good ({f_good.anomaly_score})"
# Bearing should fire rules, good should not (or fire fewer)
assert c_bearing[0].name != "Inconclusive" or len(c_bearing) > len(c_good)
# =============================================================================
# EVALUATION 4: Rule engine correctness for all appliances
# =============================================================================
class TestRuleEngine:
"""Verify each appliance has rules and they fire correctly."""
def test_all_appliances_have_rules(self):
expected = [
"Washing machine", "Tumble dryer", "Refrigerator/Freezer",
"Electric fan", "Air conditioner", "Vacuum cleaner",
"Dishwasher", "Microwave", "Electric motor (generic)",
"Car engine", "Bicycle (chain/gears)", "Power drill",
]
for appliance in expected:
assert appliance in RULES, f"Missing rules for '{appliance}'"
assert len(RULES[appliance]) >= 2
def test_typical_bad_input_fires_rules_for_every_appliance(self):
"""A 'typical bad' feature set should fire at least one rule per appliance."""
bad = AudioFeatures(
duration_s=8.0, rms_db=-25.0, rms_variance=0.03,
zero_crossing_rate=0.1, spectral_centroid_hz=2000,
spectral_bandwidth_hz=2000, spectral_rolloff_hz=4500,
dominant_frequency_hz=150.0, harmonic_ratio=0.5,
onset_rate_per_sec=3.0, has_regular_pattern=True,
pattern_interval_ms=120.0, peak_db=-18.0, anomaly_score=0.7,
)
for appliance in RULES:
cands = rank_candidates(bad, appliance)
assert len(cands) >= 1, f"No rules fired for {appliance}"
def test_normal_input_returns_inconclusive(self):
"""A quiet, normal-sounding input should be Inconclusive for most appliances."""
normal = AudioFeatures(
duration_s=8.0, rms_db=-45.0, rms_variance=0.002,
zero_crossing_rate=0.02, spectral_centroid_hz=600,
spectral_bandwidth_hz=500, spectral_rolloff_hz=1200,
dominant_frequency_hz=50.0, harmonic_ratio=0.4,
onset_rate_per_sec=0.1, has_regular_pattern=False,
pattern_interval_ms=0.0, peak_db=-40.0, anomaly_score=0.05,
)
for appliance in ["Washing machine", "Electric fan", "Car engine"]:
cands = rank_candidates(normal, appliance)
assert cands[0].name == "Inconclusive", \
f"Normal input should be Inconclusive for {appliance}, got {cands[0].name}"
# =============================================================================
# EVALUATION 5: Edge cases
# =============================================================================
class TestEdgeCases:
"""Defensive checks on degenerate inputs."""
def test_empty_audio(self):
import soundfile as sf
path = tempfile.mktemp(suffix=".wav")
sf.write(path, np.zeros(1600, dtype="float32"), 16000)
try:
f = extract_features(path)
cands = rank_candidates(f, "Washing machine")
assert cands[0].name == "Inconclusive"
finally:
os.unlink(path)
def test_garbage_audio(self):
path = tempfile.mktemp(suffix=".wav")
import soundfile as sf
sf.write(path, np.random.randn(22050 * 2).astype(np.float32) * 0.01, 22050)
try:
f = extract_features(path)
cands = rank_candidates(f, "Electric fan")
assert len(cands) >= 1
finally:
os.unlink(path)
def test_validate_malformed_json(self):
f = AudioFeatures(
duration_s=8.0, rms_db=-30.0, rms_variance=0.01,
zero_crossing_rate=0.05, spectral_centroid_hz=500,
spectral_bandwidth_hz=800, spectral_rolloff_hz=1500,
dominant_frequency_hz=60.0, harmonic_ratio=0.5,
onset_rate_per_sec=0.8, has_regular_pattern=False,
pattern_interval_ms=0.0, peak_db=-24.0, anomaly_score=0.45,
)
cands = rank_candidates(f, "Electric fan")
result = validate("not json at all", cands)
assert result.grounded
assert result.fault == cands[0].name
def test_validate_ungrounded_output(self):
f = AudioFeatures(
duration_s=8.0, rms_db=-25.0, rms_variance=0.03,
zero_crossing_rate=0.08, spectral_centroid_hz=2200,
spectral_bandwidth_hz=1800, spectral_rolloff_hz=4500,
dominant_frequency_hz=180.0, harmonic_ratio=0.65,
onset_rate_per_sec=3.5, has_regular_pattern=True,
pattern_interval_ms=150.0, peak_db=-18.0, anomaly_score=0.75,
)
cands = rank_candidates(f, "Washing machine")
response = json.dumps({
"fault": "Exploding capacitor", "urgency": "CRITICAL",
"checks": ["Check it"], "safety": "Unplug", "confidence": 95,
})
result = validate(response, cands)
assert result.grounded
assert result.fault != "Exploding capacitor"
def test_candidates_always_returned(self):
extreme = AudioFeatures(
duration_s=0.0, rms_db=-80.0, rms_variance=0.0,
zero_crossing_rate=0.0, spectral_centroid_hz=0.0,
spectral_bandwidth_hz=0.0, spectral_rolloff_hz=0.0,
dominant_frequency_hz=0.0, harmonic_ratio=0.0,
onset_rate_per_sec=0.0, has_regular_pattern=False,
pattern_interval_ms=0.0, peak_db=-80.0, anomaly_score=0.0,
)
for appliance in RULES:
cands = rank_candidates(extreme, appliance)
assert len(cands) >= 1