Spaces:
Runtime error
Runtime error
| """Full evaluation suite — verifies every sample gets the RIGHT diagnosis. | |
| Run: python -m pytest tests/test_eval.py -v | |
| """ | |
| import os, sys, json, tempfile | |
| import numpy as np | |
| import pytest | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) | |
| from audio_analyzer import extract_features, AudioFeatures | |
| from fault_rules import rank_candidates, RULES | |
| from feature_prompt import build_diagnosis_prompt | |
| from json_guard import validate, DiagnosisResult | |
| from mock_model import mock_generate as _mock_generate | |
| ASSETS = os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets") | |
| # ============================================================================= | |
| # EVALUATION 1: Each sample gets correct diagnosis with correct appliance | |
| # ============================================================================= | |
| SAMPLE_DIAGNOSES = { | |
| "sample_washer_bearing.wav": { | |
| "appliance": "Washing machine", | |
| "expected_fault_contains": ["bearing", "drum"], | |
| "expected_urgency": ["HIGH", "CRITICAL"], | |
| "expected_min_weight": 0.6, | |
| }, | |
| "sample_fan_imbalanced.wav": { | |
| "appliance": "Electric fan", | |
| "expected_fault_contains": ["imbalance", "blade"], | |
| "expected_urgency": ["MEDIUM", "HIGH"], | |
| "expected_min_weight": 0.5, | |
| }, | |
| "sample_motor_squeal.wav": { | |
| "appliance": "Electric motor (generic)", | |
| "expected_fault_contains": ["squeal", "whine", "bearing", "belt"], | |
| "expected_urgency": ["MEDIUM", "HIGH"], | |
| "expected_min_weight": 0.5, | |
| }, | |
| "sample_washer_good.wav": { | |
| "appliance": "Washing machine", | |
| "expected_fault_contains": ["inconclusive"], | |
| "expected_urgency": ["LOW"], | |
| "expected_max_weight": 0.1, | |
| }, | |
| } | |
| class TestSampleDiagnoses: | |
| """Each sample WAV should produce a specific, correct diagnosis.""" | |
| def test_sample_diagnosis(self, wav_name, expected): | |
| path = os.path.join(ASSETS, wav_name) | |
| if not os.path.exists(path): | |
| pytest.skip(f"Sample {wav_name} not found") | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, expected["appliance"]) | |
| assert len(candidates) >= 1 | |
| top = candidates[0] | |
| # Check fault name contains expected keyword | |
| fault_lower = top.name.lower() | |
| assert any(kw in fault_lower for kw in expected["expected_fault_contains"]), \ | |
| f"Expected fault containing {expected['expected_fault_contains']}, got '{top.name}'" | |
| # Check urgency | |
| assert top.urgency in expected["expected_urgency"], \ | |
| f"Expected urgency in {expected['expected_urgency']}, got '{top.urgency}'" | |
| # Check weight bounds | |
| if "expected_min_weight" in expected: | |
| assert top.weight >= expected["expected_min_weight"], \ | |
| f"Expected weight >= {expected['expected_min_weight']}, got {top.weight}" | |
| if "expected_max_weight" in expected: | |
| assert top.weight <= expected["expected_max_weight"], \ | |
| f"Expected weight <= {expected['expected_max_weight']}, got {top.weight}" | |
| def test_all_four_samples_give_different_diagnoses(self): | |
| """The 4 samples should produce 3+ distinct fault names.""" | |
| faults = [] | |
| for wav_name, info in SAMPLE_DIAGNOSES.items(): | |
| path = os.path.join(ASSETS, wav_name) | |
| if not os.path.exists(path): | |
| continue | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, info["appliance"]) | |
| faults.append(candidates[0].name) | |
| assert len(set(faults)) >= 3, f"Expected 3+ distinct faults, got: {faults}" | |
| # ============================================================================= | |
| # EVALUATION 2: Mock generate uses correct candidate | |
| # ============================================================================= | |
| class TestMockGenerate: | |
| """The mock should return the top candidate's fault name, not 'Inconclusive'.""" | |
| def test_bearing_sample_returns_bearing(self): | |
| path = os.path.join(ASSETS, "sample_washer_bearing.wav") | |
| if not os.path.exists(path): | |
| pytest.skip("Sample not found") | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, "Washing machine") | |
| prompt = build_diagnosis_prompt(features, candidates, "Washing machine") | |
| raw = _mock_generate(prompt, candidates, features) | |
| parsed = json.loads(raw) | |
| assert parsed["fault"].lower() in candidates[0].name.lower() or \ | |
| candidates[0].name.lower() in parsed["fault"].lower(), \ | |
| f"Mock should return '{candidates[0].name}', got '{parsed['fault']}'" | |
| assert parsed["confidence"] >= 60, f"Bearing should have high confidence, got {parsed['confidence']}" | |
| def test_fan_sample_returns_fan_fault(self): | |
| path = os.path.join(ASSETS, "sample_fan_imbalanced.wav") | |
| if not os.path.exists(path): | |
| pytest.skip("Sample not found") | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, "Electric fan") | |
| prompt = build_diagnosis_prompt(features, candidates, "Electric fan") | |
| raw = _mock_generate(prompt, candidates, features) | |
| parsed = json.loads(raw) | |
| assert "inconclusive" not in parsed["fault"].lower(), \ | |
| f"Fan sample should NOT be Inconclusive, got '{parsed['fault']}'" | |
| def test_good_sample_returns_inconclusive(self): | |
| path = os.path.join(ASSETS, "sample_washer_good.wav") | |
| if not os.path.exists(path): | |
| pytest.skip("Sample not found") | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, "Washing machine") | |
| prompt = build_diagnosis_prompt(features, candidates, "Washing machine") | |
| raw = _mock_generate(prompt, candidates, features) | |
| parsed = json.loads(raw) | |
| assert parsed["fault"].lower() == "inconclusive", \ | |
| f"Good sample should be Inconclusive, got '{parsed['fault']}'" | |
| assert parsed["urgency"] == "LOW" | |
| def test_mock_always_returns_valid_json(self): | |
| """Every sample should produce parseable JSON with required fields.""" | |
| for wav_name, info in SAMPLE_DIAGNOSES.items(): | |
| path = os.path.join(ASSETS, wav_name) | |
| if not os.path.exists(path): | |
| continue | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, info["appliance"]) | |
| prompt = build_diagnosis_prompt(features, candidates, info["appliance"]) | |
| raw = _mock_generate(prompt, candidates, features) | |
| parsed = json.loads(raw) | |
| assert "fault" in parsed | |
| assert "urgency" in parsed | |
| assert "checks" in parsed and len(parsed["checks"]) >= 1 | |
| assert "safety" in parsed | |
| assert "confidence" in parsed | |
| assert 0 <= parsed["confidence"] <= 100 | |
| # ============================================================================= | |
| # EVALUATION 3: Full pipeline end-to-end | |
| # ============================================================================= | |
| class TestFullPipeline: | |
| """End-to-end: audio file -> features -> rules -> mock -> validate -> result.""" | |
| def test_bearing_pipeline(self): | |
| path = os.path.join(ASSETS, "sample_washer_bearing.wav") | |
| if not os.path.exists(path): | |
| pytest.skip("Sample not found") | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, "Washing machine") | |
| prompt = build_diagnosis_prompt(features, candidates, "Washing machine") | |
| raw = _mock_generate(prompt, candidates, features) | |
| result = validate(raw, candidates) | |
| assert isinstance(result, DiagnosisResult) | |
| assert result.grounded | |
| assert "bearing" in result.fault.lower() or "drum" in result.fault.lower() | |
| assert result.urgency in ("HIGH", "CRITICAL") | |
| assert result.confidence >= 60 | |
| assert len(result.checks) >= 1 | |
| def test_fan_pipeline(self): | |
| path = os.path.join(ASSETS, "sample_fan_imbalanced.wav") | |
| if not os.path.exists(path): | |
| pytest.skip("Sample not found") | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, "Electric fan") | |
| prompt = build_diagnosis_prompt(features, candidates, "Electric fan") | |
| raw = _mock_generate(prompt, candidates, features) | |
| result = validate(raw, candidates) | |
| assert isinstance(result, DiagnosisResult) | |
| assert result.grounded | |
| assert "inconclusive" not in result.fault.lower() | |
| assert result.confidence >= 50 | |
| def test_good_sample_pipeline(self): | |
| path = os.path.join(ASSETS, "sample_washer_good.wav") | |
| if not os.path.exists(path): | |
| pytest.skip("Sample not found") | |
| features = extract_features(path) | |
| candidates = rank_candidates(features, "Washing machine") | |
| prompt = build_diagnosis_prompt(features, candidates, "Washing machine") | |
| raw = _mock_generate(prompt, candidates, features) | |
| result = validate(raw, candidates) | |
| assert isinstance(result, DiagnosisResult) | |
| assert result.fault.lower() == "inconclusive" | |
| assert result.urgency == "LOW" | |
| def test_compare_produces_different_results(self): | |
| """Comparing bearing vs good washer should show clear improvement.""" | |
| f_bearing = extract_features(os.path.join(ASSETS, "sample_washer_bearing.wav")) | |
| f_good = extract_features(os.path.join(ASSETS, "sample_washer_good.wav")) | |
| c_bearing = rank_candidates(f_bearing, "Washing machine") | |
| c_good = rank_candidates(f_good, "Washing machine") | |
| # Bearing should have high anomaly, good should have low | |
| assert f_bearing.anomaly_score > f_good.anomaly_score, \ | |
| f"Bearing ({f_bearing.anomaly_score}) should be more anomalous than good ({f_good.anomaly_score})" | |
| # Bearing should fire rules, good should not (or fire fewer) | |
| assert c_bearing[0].name != "Inconclusive" or len(c_bearing) > len(c_good) | |
| # ============================================================================= | |
| # EVALUATION 4: Rule engine correctness for all appliances | |
| # ============================================================================= | |
| class TestRuleEngine: | |
| """Verify each appliance has rules and they fire correctly.""" | |
| def test_all_appliances_have_rules(self): | |
| expected = [ | |
| "Washing machine", "Tumble dryer", "Refrigerator/Freezer", | |
| "Electric fan", "Air conditioner", "Vacuum cleaner", | |
| "Dishwasher", "Microwave", "Electric motor (generic)", | |
| "Car engine", "Bicycle (chain/gears)", "Power drill", | |
| ] | |
| for appliance in expected: | |
| assert appliance in RULES, f"Missing rules for '{appliance}'" | |
| assert len(RULES[appliance]) >= 2 | |
| def test_typical_bad_input_fires_rules_for_every_appliance(self): | |
| """A 'typical bad' feature set should fire at least one rule per appliance.""" | |
| bad = AudioFeatures( | |
| duration_s=8.0, rms_db=-25.0, rms_variance=0.03, | |
| zero_crossing_rate=0.1, spectral_centroid_hz=2000, | |
| spectral_bandwidth_hz=2000, spectral_rolloff_hz=4500, | |
| dominant_frequency_hz=150.0, harmonic_ratio=0.5, | |
| onset_rate_per_sec=3.0, has_regular_pattern=True, | |
| pattern_interval_ms=120.0, peak_db=-18.0, anomaly_score=0.7, | |
| ) | |
| for appliance in RULES: | |
| cands = rank_candidates(bad, appliance) | |
| assert len(cands) >= 1, f"No rules fired for {appliance}" | |
| def test_normal_input_returns_inconclusive(self): | |
| """A quiet, normal-sounding input should be Inconclusive for most appliances.""" | |
| normal = AudioFeatures( | |
| duration_s=8.0, rms_db=-45.0, rms_variance=0.002, | |
| zero_crossing_rate=0.02, spectral_centroid_hz=600, | |
| spectral_bandwidth_hz=500, spectral_rolloff_hz=1200, | |
| dominant_frequency_hz=50.0, harmonic_ratio=0.4, | |
| onset_rate_per_sec=0.1, has_regular_pattern=False, | |
| pattern_interval_ms=0.0, peak_db=-40.0, anomaly_score=0.05, | |
| ) | |
| for appliance in ["Washing machine", "Electric fan", "Car engine"]: | |
| cands = rank_candidates(normal, appliance) | |
| assert cands[0].name == "Inconclusive", \ | |
| f"Normal input should be Inconclusive for {appliance}, got {cands[0].name}" | |
| # ============================================================================= | |
| # EVALUATION 5: Edge cases | |
| # ============================================================================= | |
| class TestEdgeCases: | |
| """Defensive checks on degenerate inputs.""" | |
| def test_empty_audio(self): | |
| import soundfile as sf | |
| path = tempfile.mktemp(suffix=".wav") | |
| sf.write(path, np.zeros(1600, dtype="float32"), 16000) | |
| try: | |
| f = extract_features(path) | |
| cands = rank_candidates(f, "Washing machine") | |
| assert cands[0].name == "Inconclusive" | |
| finally: | |
| os.unlink(path) | |
| def test_garbage_audio(self): | |
| path = tempfile.mktemp(suffix=".wav") | |
| import soundfile as sf | |
| sf.write(path, np.random.randn(22050 * 2).astype(np.float32) * 0.01, 22050) | |
| try: | |
| f = extract_features(path) | |
| cands = rank_candidates(f, "Electric fan") | |
| assert len(cands) >= 1 | |
| finally: | |
| os.unlink(path) | |
| def test_validate_malformed_json(self): | |
| f = AudioFeatures( | |
| duration_s=8.0, rms_db=-30.0, rms_variance=0.01, | |
| zero_crossing_rate=0.05, spectral_centroid_hz=500, | |
| spectral_bandwidth_hz=800, spectral_rolloff_hz=1500, | |
| dominant_frequency_hz=60.0, harmonic_ratio=0.5, | |
| onset_rate_per_sec=0.8, has_regular_pattern=False, | |
| pattern_interval_ms=0.0, peak_db=-24.0, anomaly_score=0.45, | |
| ) | |
| cands = rank_candidates(f, "Electric fan") | |
| result = validate("not json at all", cands) | |
| assert result.grounded | |
| assert result.fault == cands[0].name | |
| def test_validate_ungrounded_output(self): | |
| f = AudioFeatures( | |
| duration_s=8.0, rms_db=-25.0, rms_variance=0.03, | |
| zero_crossing_rate=0.08, spectral_centroid_hz=2200, | |
| spectral_bandwidth_hz=1800, spectral_rolloff_hz=4500, | |
| dominant_frequency_hz=180.0, harmonic_ratio=0.65, | |
| onset_rate_per_sec=3.5, has_regular_pattern=True, | |
| pattern_interval_ms=150.0, peak_db=-18.0, anomaly_score=0.75, | |
| ) | |
| cands = rank_candidates(f, "Washing machine") | |
| response = json.dumps({ | |
| "fault": "Exploding capacitor", "urgency": "CRITICAL", | |
| "checks": ["Check it"], "safety": "Unplug", "confidence": 95, | |
| }) | |
| result = validate(response, cands) | |
| assert result.grounded | |
| assert result.fault != "Exploding capacitor" | |
| def test_candidates_always_returned(self): | |
| extreme = AudioFeatures( | |
| duration_s=0.0, rms_db=-80.0, rms_variance=0.0, | |
| zero_crossing_rate=0.0, spectral_centroid_hz=0.0, | |
| spectral_bandwidth_hz=0.0, spectral_rolloff_hz=0.0, | |
| dominant_frequency_hz=0.0, harmonic_ratio=0.0, | |
| onset_rate_per_sec=0.0, has_regular_pattern=False, | |
| pattern_interval_ms=0.0, peak_db=-80.0, anomaly_score=0.0, | |
| ) | |
| for appliance in RULES: | |
| cands = rank_candidates(extreme, appliance) | |
| assert len(cands) >= 1 | |