Spaces:
Running on Zero
Running on Zero
| """Integration tests for session-based API endpoints. | |
| Requires the app to be running (localhost or live Space). | |
| Start locally with: python app.py | |
| Run with: python -m pytest tests/test_session_api.py -v -s | |
| """ | |
| import os | |
| import pytest | |
| from gradio_client import Client, handle_file | |
| SERVER_URL = os.environ.get("TEST_SERVER_URL", "https://hetchyy-quran-multi-aligner.hf.space") | |
| _AUDIO_PATH = "data/112.mp3" # Surah Al-Ikhlas (~15s) | |
| AUDIO_FILE = handle_file(_AUDIO_PATH) | |
| FAKE_ID = "00000000000000000000000000000000" | |
| # --------------------------------------------------------------------------- | |
| # Fixtures | |
| # --------------------------------------------------------------------------- | |
| def client(): | |
| return Client(SERVER_URL) | |
| def session(client): | |
| """Run process_audio_session once, share audio_id across tests.""" | |
| result = client.predict( | |
| AUDIO_FILE, 200, 1000, 100, "Base", "CPU", | |
| api_name="/process_audio_session", | |
| ) | |
| assert "audio_id" in result, f"Missing audio_id: {result}" | |
| assert result["audio_id"] is not None | |
| return result | |
| # --------------------------------------------------------------------------- | |
| # 1. process_audio_session | |
| # --------------------------------------------------------------------------- | |
| class TestProcessAudioSession: | |
| def test_creates_session(self, session): | |
| assert len(session["segments"]) > 0, "Expected at least one segment" | |
| aid = session["audio_id"] | |
| assert isinstance(aid, str) and len(aid) == 32 | |
| def test_all_response_fields_present(self, session): | |
| seg = session["segments"][0] | |
| for field in ("segment", "time_from", "time_to", "ref_from", "ref_to", | |
| "matched_text", "confidence", "has_missing_words", "error"): | |
| assert field in seg, f"Missing field: {field}" | |
| def test_segment_field_types(self, session): | |
| seg = session["segments"][0] | |
| assert isinstance(seg["segment"], int) | |
| assert isinstance(seg["time_from"], (int, float)) | |
| assert isinstance(seg["time_to"], (int, float)) | |
| assert isinstance(seg["confidence"], (int, float)) | |
| assert 0 <= seg["confidence"] <= 1 | |
| assert isinstance(seg["has_missing_words"], bool) | |
| def test_segments_ordered(self, session): | |
| nums = [s["segment"] for s in session["segments"]] | |
| assert nums == sorted(nums) | |
| def test_time_ordering(self, session): | |
| for seg in session["segments"]: | |
| assert seg["time_from"] >= 0 | |
| assert seg["time_to"] > seg["time_from"] | |
| # --------------------------------------------------------------------------- | |
| # 2. resegment_session | |
| # --------------------------------------------------------------------------- | |
| class TestResegmentSession: | |
| def test_resegment_basic(self, client, session): | |
| result = client.predict( | |
| session["audio_id"], 600, 1500, 300, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| assert result["audio_id"] == session["audio_id"] | |
| assert "segments" in result | |
| assert len(result["segments"]) > 0 | |
| def test_resegment_merges_with_high_silence(self, client, session): | |
| """min_silence=2000ms should produce fewer (or equal) segments.""" | |
| original_count = len(session["segments"]) | |
| result = client.predict( | |
| session["audio_id"], 2000, 500, 100, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| assert len(result["segments"]) <= original_count | |
| def test_resegment_updates_session(self, client, session): | |
| """After resegment, retranscribe with same model should still trigger | |
| guard (resegment already re-ran ASR with that model).""" | |
| # Resegment with Base model | |
| client.predict( | |
| session["audio_id"], 400, 1000, 150, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| # Retranscribe with same model — guard triggers because resegment | |
| # already stored model=Base and the new intervals_hash | |
| result = client.predict( | |
| session["audio_id"], "Base", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert "error" in result | |
| assert result["segments"] == [] | |
| # --------------------------------------------------------------------------- | |
| # 3. retranscribe_session | |
| # --------------------------------------------------------------------------- | |
| class TestRetranscribeSession: | |
| def test_retranscribe_different_model(self, client, session): | |
| # First ensure we're on Base by resegmenting | |
| client.predict( | |
| session["audio_id"], 200, 1000, 100, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| result = client.predict( | |
| session["audio_id"], "Large", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert result["audio_id"] == session["audio_id"] | |
| assert len(result["segments"]) > 0 | |
| def test_retranscribe_guard_same_model(self, client, session): | |
| """Same model + same boundaries -> error.""" | |
| result = client.predict( | |
| session["audio_id"], "Large", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert "error" in result | |
| assert result["segments"] == [] | |
| def test_retranscribe_allowed_after_resegment(self, client, session): | |
| """Resegment changes boundaries, so retranscribe with same model | |
| should also trigger guard (resegment stores same model).""" | |
| # Resegment with different params | |
| client.predict( | |
| session["audio_id"], 300, 1200, 200, "Large", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| # Same model as resegment used — guard triggers | |
| result = client.predict( | |
| session["audio_id"], "Large", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert "error" in result | |
| # But switching model works | |
| result2 = client.predict( | |
| session["audio_id"], "Base", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert len(result2["segments"]) > 0 | |
| # --------------------------------------------------------------------------- | |
| # 4. realign_from_timestamps | |
| # --------------------------------------------------------------------------- | |
| class TestRealignFromTimestamps: | |
| def test_custom_timestamps(self, client, session): | |
| timestamps = [ | |
| {"start": 0.5, "end": 3.0}, | |
| {"start": 3.5, "end": 6.0}, | |
| {"start": 6.5, "end": 10.0}, | |
| ] | |
| result = client.predict( | |
| session["audio_id"], timestamps, "Base", "CPU", | |
| api_name="/realign_from_timestamps", | |
| ) | |
| assert result["audio_id"] == session["audio_id"] | |
| assert len(result["segments"]) == 3 | |
| def test_realign_updates_boundaries(self, client, session): | |
| """After realign with Base, retranscribe with same model triggers guard, | |
| but switching model works.""" | |
| timestamps = [ | |
| {"start": 0.5, "end": 4.0}, | |
| {"start": 4.5, "end": 9.0}, | |
| ] | |
| client.predict( | |
| session["audio_id"], timestamps, "Base", "CPU", | |
| api_name="/realign_from_timestamps", | |
| ) | |
| # Same model — guard triggers | |
| result = client.predict( | |
| session["audio_id"], "Base", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert "error" in result | |
| # Different model — allowed | |
| result2 = client.predict( | |
| session["audio_id"], "Large", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert len(result2["segments"]) > 0 | |
| # --------------------------------------------------------------------------- | |
| # 5. Consecutive calls / full workflow | |
| # --------------------------------------------------------------------------- | |
| class TestWorkflow: | |
| def test_consecutive_resegments(self, client, session): | |
| r1 = client.predict( | |
| session["audio_id"], 200, 1000, 100, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| r2 = client.predict( | |
| session["audio_id"], 600, 1500, 300, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| assert len(r1["segments"]) > 0 | |
| assert len(r2["segments"]) > 0 | |
| # Different params should yield different segment counts (usually) | |
| # Just verify both succeed | |
| def test_full_workflow(self, client, session): | |
| aid = session["audio_id"] | |
| # 1. Resegment | |
| r1 = client.predict( | |
| aid, 200, 1000, 100, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| assert len(r1["segments"]) > 0 | |
| # 2. Retranscribe with different model | |
| r2 = client.predict( | |
| aid, "Large", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert len(r2["segments"]) > 0 | |
| # 3. Realign with custom timestamps | |
| timestamps = [{"start": 0.5, "end": 5.0}, {"start": 5.5, "end": 10.0}] | |
| r3 = client.predict( | |
| aid, timestamps, "Base", "CPU", | |
| api_name="/realign_from_timestamps", | |
| ) | |
| assert len(r3["segments"]) == 2 | |
| # 4. Resegment again (session still valid) | |
| r4 = client.predict( | |
| aid, 400, 1200, 150, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| assert len(r4["segments"]) > 0 | |
| # --------------------------------------------------------------------------- | |
| # 6. Error handling | |
| # --------------------------------------------------------------------------- | |
| # --------------------------------------------------------------------------- | |
| # 7. MFA timestamps — session-based | |
| # --------------------------------------------------------------------------- | |
| class TestMfaTimestampsSession: | |
| def test_basic_words_only(self, client, session): | |
| """Session endpoint with stored segments, words granularity.""" | |
| result = client.predict( | |
| session["audio_id"], None, "words", | |
| api_name="/mfa_timestamps_session", | |
| ) | |
| assert result["audio_id"] == session["audio_id"] | |
| assert len(result["segments"]) > 0 | |
| has_words = any("words" in seg for seg in result["segments"]) | |
| assert has_words, "Expected at least one segment with words" | |
| # Words-only: each word is [location, start, end] (3 elements) | |
| for seg in result["segments"]: | |
| for word in seg.get("words", []): | |
| assert len(word) == 3, f"words granularity should give 3-element arrays, got {len(word)}" | |
| def test_words_plus_chars(self, client, session): | |
| """Session endpoint with words+chars granularity.""" | |
| result = client.predict( | |
| session["audio_id"], None, "words+chars", | |
| api_name="/mfa_timestamps_session", | |
| ) | |
| has_letters = any( | |
| len(word) == 4 | |
| for seg in result["segments"] | |
| for word in seg.get("words", []) | |
| ) | |
| assert has_letters, "words+chars should include letter arrays (4th element)" | |
| def test_with_segments_override(self, client, session): | |
| """Session endpoint with explicit segments (override stored).""" | |
| segments_override = session["segments"][:2] | |
| result = client.predict( | |
| session["audio_id"], segments_override, "words", | |
| api_name="/mfa_timestamps_session", | |
| ) | |
| assert result["audio_id"] == session["audio_id"] | |
| assert len(result["segments"]) == 2 | |
| def test_word_timestamp_fields(self, client, session): | |
| """Verify word arrays have correct structure: [location, start, end, ?letters].""" | |
| result = client.predict( | |
| session["audio_id"], None, "words+chars", | |
| api_name="/mfa_timestamps_session", | |
| ) | |
| for seg in result["segments"]: | |
| for word in seg.get("words", []): | |
| assert isinstance(word[0], str), "word[0] should be location string" | |
| assert isinstance(word[1], (int, float)), "word[1] should be start time" | |
| assert isinstance(word[2], (int, float)), "word[2] should be end time" | |
| assert word[2] > word[1], "end should be > start" | |
| if len(word) == 4: | |
| # Letters: list of [char, start, end] | |
| for letter in word[3]: | |
| assert len(letter) == 3 | |
| assert isinstance(letter[0], str) | |
| def test_invalid_session(self, client): | |
| result = client.predict( | |
| FAKE_ID, None, "words", | |
| api_name="/mfa_timestamps_session", | |
| ) | |
| assert "error" in result | |
| assert result["segments"] == [] | |
| def test_default_granularity(self, client, session): | |
| """Empty granularity should default to words.""" | |
| result = client.predict( | |
| session["audio_id"], None, "", | |
| api_name="/mfa_timestamps_session", | |
| ) | |
| assert len(result["segments"]) > 0 | |
| for seg in result["segments"]: | |
| for word in seg.get("words", []): | |
| assert len(word) == 3, "default granularity should not include letters" | |
| # --------------------------------------------------------------------------- | |
| # 8. MFA timestamps — direct | |
| # --------------------------------------------------------------------------- | |
| class TestMfaTimestampsDirect: | |
| def test_basic(self, client, session): | |
| """Direct endpoint with audio file and segments.""" | |
| result = client.predict( | |
| AUDIO_FILE, session["segments"], "words", | |
| api_name="/mfa_timestamps_direct", | |
| ) | |
| assert "segments" in result | |
| assert len(result["segments"]) > 0 | |
| has_words = any("words" in seg for seg in result["segments"]) | |
| assert has_words | |
| def test_words_plus_chars(self, client, session): | |
| result = client.predict( | |
| AUDIO_FILE, session["segments"], "words+chars", | |
| api_name="/mfa_timestamps_direct", | |
| ) | |
| has_letters = any( | |
| len(word) == 4 | |
| for seg in result["segments"] | |
| for word in seg.get("words", []) | |
| ) | |
| assert has_letters | |
| def test_no_audio_id_in_response(self, client, session): | |
| """Direct endpoint should not return audio_id.""" | |
| result = client.predict( | |
| AUDIO_FILE, session["segments"], "words", | |
| api_name="/mfa_timestamps_direct", | |
| ) | |
| assert "audio_id" not in result | |
| def test_empty_segments_error(self, client): | |
| result = client.predict( | |
| AUDIO_FILE, [], "words", | |
| api_name="/mfa_timestamps_direct", | |
| ) | |
| assert "error" in result | |
| assert result["segments"] == [] | |
| # --------------------------------------------------------------------------- | |
| # 9. Segments stored in session after alignment | |
| # --------------------------------------------------------------------------- | |
| class TestSegmentStorage: | |
| def test_segments_stored_after_process(self, client): | |
| """process_audio_session should store segments for later MFA use.""" | |
| proc = client.predict( | |
| AUDIO_FILE, 200, 1000, 100, "Base", "CPU", | |
| api_name="/process_audio_session", | |
| ) | |
| # MFA session endpoint should find stored segments | |
| result = client.predict( | |
| proc["audio_id"], None, "words", | |
| api_name="/mfa_timestamps_session", | |
| ) | |
| assert "error" not in result or result.get("segments") | |
| assert result["audio_id"] == proc["audio_id"] | |
| # --------------------------------------------------------------------------- | |
| # 10. Error handling | |
| # --------------------------------------------------------------------------- | |
| class TestErrorHandling: | |
| def test_invalid_audio_id_retranscribe(self, client): | |
| result = client.predict( | |
| FAKE_ID, "Base", "CPU", | |
| api_name="/retranscribe_session", | |
| ) | |
| assert "error" in result | |
| assert "not found" in result["error"].lower() or "expired" in result["error"].lower() | |
| assert result["segments"] == [] | |
| def test_invalid_audio_id_resegment(self, client): | |
| result = client.predict( | |
| FAKE_ID, 200, 1000, 100, "Base", "CPU", | |
| api_name="/resegment_session", | |
| ) | |
| assert "error" in result | |
| assert result["segments"] == [] | |
| def test_invalid_audio_id_realign(self, client): | |
| timestamps = [{"start": 0.0, "end": 1.0}] | |
| result = client.predict( | |
| FAKE_ID, timestamps, "Base", "CPU", | |
| api_name="/realign_from_timestamps", | |
| ) | |
| assert "error" in result | |
| assert result["segments"] == [] | |