"""Integration tests for session-based API endpoints. Requires the app to be running (localhost or live Space). Start locally with: python app.py Run with: python -m pytest tests/test_session_api.py -v -s """ import os import pytest from gradio_client import Client, handle_file SERVER_URL = os.environ.get("TEST_SERVER_URL", "https://hetchyy-quran-multi-aligner.hf.space") _AUDIO_PATH = "data/112.mp3" # Surah Al-Ikhlas (~15s) AUDIO_FILE = handle_file(_AUDIO_PATH) FAKE_ID = "00000000000000000000000000000000" # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture(scope="module") def client(): return Client(SERVER_URL) @pytest.fixture(scope="module") def session(client): """Run process_audio_session once, share audio_id across tests.""" result = client.predict( AUDIO_FILE, 200, 1000, 100, "Base", "CPU", api_name="/process_audio_session", ) assert "audio_id" in result, f"Missing audio_id: {result}" assert result["audio_id"] is not None return result # --------------------------------------------------------------------------- # 1. process_audio_session # --------------------------------------------------------------------------- class TestProcessAudioSession: def test_creates_session(self, session): assert len(session["segments"]) > 0, "Expected at least one segment" aid = session["audio_id"] assert isinstance(aid, str) and len(aid) == 32 def test_all_response_fields_present(self, session): seg = session["segments"][0] for field in ("segment", "time_from", "time_to", "ref_from", "ref_to", "matched_text", "confidence", "has_missing_words", "error"): assert field in seg, f"Missing field: {field}" def test_segment_field_types(self, session): seg = session["segments"][0] assert isinstance(seg["segment"], int) assert isinstance(seg["time_from"], (int, float)) assert isinstance(seg["time_to"], (int, float)) assert isinstance(seg["confidence"], (int, float)) assert 0 <= seg["confidence"] <= 1 assert isinstance(seg["has_missing_words"], bool) def test_segments_ordered(self, session): nums = [s["segment"] for s in session["segments"]] assert nums == sorted(nums) def test_time_ordering(self, session): for seg in session["segments"]: assert seg["time_from"] >= 0 assert seg["time_to"] > seg["time_from"] # --------------------------------------------------------------------------- # 2. resegment_session # --------------------------------------------------------------------------- class TestResegmentSession: def test_resegment_basic(self, client, session): result = client.predict( session["audio_id"], 600, 1500, 300, "Base", "CPU", api_name="/resegment_session", ) assert result["audio_id"] == session["audio_id"] assert "segments" in result assert len(result["segments"]) > 0 def test_resegment_merges_with_high_silence(self, client, session): """min_silence=2000ms should produce fewer (or equal) segments.""" original_count = len(session["segments"]) result = client.predict( session["audio_id"], 2000, 500, 100, "Base", "CPU", api_name="/resegment_session", ) assert len(result["segments"]) <= original_count def test_resegment_updates_session(self, client, session): """After resegment, retranscribe with same model should still trigger guard (resegment already re-ran ASR with that model).""" # Resegment with Base model client.predict( session["audio_id"], 400, 1000, 150, "Base", "CPU", api_name="/resegment_session", ) # Retranscribe with same model — guard triggers because resegment # already stored model=Base and the new intervals_hash result = client.predict( session["audio_id"], "Base", "CPU", api_name="/retranscribe_session", ) assert "error" in result assert result["segments"] == [] # --------------------------------------------------------------------------- # 3. retranscribe_session # --------------------------------------------------------------------------- class TestRetranscribeSession: def test_retranscribe_different_model(self, client, session): # First ensure we're on Base by resegmenting client.predict( session["audio_id"], 200, 1000, 100, "Base", "CPU", api_name="/resegment_session", ) result = client.predict( session["audio_id"], "Large", "CPU", api_name="/retranscribe_session", ) assert result["audio_id"] == session["audio_id"] assert len(result["segments"]) > 0 def test_retranscribe_guard_same_model(self, client, session): """Same model + same boundaries -> error.""" result = client.predict( session["audio_id"], "Large", "CPU", api_name="/retranscribe_session", ) assert "error" in result assert result["segments"] == [] def test_retranscribe_allowed_after_resegment(self, client, session): """Resegment changes boundaries, so retranscribe with same model should also trigger guard (resegment stores same model).""" # Resegment with different params client.predict( session["audio_id"], 300, 1200, 200, "Large", "CPU", api_name="/resegment_session", ) # Same model as resegment used — guard triggers result = client.predict( session["audio_id"], "Large", "CPU", api_name="/retranscribe_session", ) assert "error" in result # But switching model works result2 = client.predict( session["audio_id"], "Base", "CPU", api_name="/retranscribe_session", ) assert len(result2["segments"]) > 0 # --------------------------------------------------------------------------- # 4. realign_from_timestamps # --------------------------------------------------------------------------- class TestRealignFromTimestamps: def test_custom_timestamps(self, client, session): timestamps = [ {"start": 0.5, "end": 3.0}, {"start": 3.5, "end": 6.0}, {"start": 6.5, "end": 10.0}, ] result = client.predict( session["audio_id"], timestamps, "Base", "CPU", api_name="/realign_from_timestamps", ) assert result["audio_id"] == session["audio_id"] assert len(result["segments"]) == 3 def test_realign_updates_boundaries(self, client, session): """After realign with Base, retranscribe with same model triggers guard, but switching model works.""" timestamps = [ {"start": 0.5, "end": 4.0}, {"start": 4.5, "end": 9.0}, ] client.predict( session["audio_id"], timestamps, "Base", "CPU", api_name="/realign_from_timestamps", ) # Same model — guard triggers result = client.predict( session["audio_id"], "Base", "CPU", api_name="/retranscribe_session", ) assert "error" in result # Different model — allowed result2 = client.predict( session["audio_id"], "Large", "CPU", api_name="/retranscribe_session", ) assert len(result2["segments"]) > 0 # --------------------------------------------------------------------------- # 5. Consecutive calls / full workflow # --------------------------------------------------------------------------- class TestWorkflow: def test_consecutive_resegments(self, client, session): r1 = client.predict( session["audio_id"], 200, 1000, 100, "Base", "CPU", api_name="/resegment_session", ) r2 = client.predict( session["audio_id"], 600, 1500, 300, "Base", "CPU", api_name="/resegment_session", ) assert len(r1["segments"]) > 0 assert len(r2["segments"]) > 0 # Different params should yield different segment counts (usually) # Just verify both succeed def test_full_workflow(self, client, session): aid = session["audio_id"] # 1. Resegment r1 = client.predict( aid, 200, 1000, 100, "Base", "CPU", api_name="/resegment_session", ) assert len(r1["segments"]) > 0 # 2. Retranscribe with different model r2 = client.predict( aid, "Large", "CPU", api_name="/retranscribe_session", ) assert len(r2["segments"]) > 0 # 3. Realign with custom timestamps timestamps = [{"start": 0.5, "end": 5.0}, {"start": 5.5, "end": 10.0}] r3 = client.predict( aid, timestamps, "Base", "CPU", api_name="/realign_from_timestamps", ) assert len(r3["segments"]) == 2 # 4. Resegment again (session still valid) r4 = client.predict( aid, 400, 1200, 150, "Base", "CPU", api_name="/resegment_session", ) assert len(r4["segments"]) > 0 # --------------------------------------------------------------------------- # 6. Error handling # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # 7. MFA timestamps — session-based # --------------------------------------------------------------------------- class TestMfaTimestampsSession: def test_basic_words_only(self, client, session): """Session endpoint with stored segments, words granularity.""" result = client.predict( session["audio_id"], None, "words", api_name="/mfa_timestamps_session", ) assert result["audio_id"] == session["audio_id"] assert len(result["segments"]) > 0 has_words = any("words" in seg for seg in result["segments"]) assert has_words, "Expected at least one segment with words" # Words-only: each word is [location, start, end] (3 elements) for seg in result["segments"]: for word in seg.get("words", []): assert len(word) == 3, f"words granularity should give 3-element arrays, got {len(word)}" def test_words_plus_chars(self, client, session): """Session endpoint with words+chars granularity.""" result = client.predict( session["audio_id"], None, "words+chars", api_name="/mfa_timestamps_session", ) has_letters = any( len(word) == 4 for seg in result["segments"] for word in seg.get("words", []) ) assert has_letters, "words+chars should include letter arrays (4th element)" def test_with_segments_override(self, client, session): """Session endpoint with explicit segments (override stored).""" segments_override = session["segments"][:2] result = client.predict( session["audio_id"], segments_override, "words", api_name="/mfa_timestamps_session", ) assert result["audio_id"] == session["audio_id"] assert len(result["segments"]) == 2 def test_word_timestamp_fields(self, client, session): """Verify word arrays have correct structure: [location, start, end, ?letters].""" result = client.predict( session["audio_id"], None, "words+chars", api_name="/mfa_timestamps_session", ) for seg in result["segments"]: for word in seg.get("words", []): assert isinstance(word[0], str), "word[0] should be location string" assert isinstance(word[1], (int, float)), "word[1] should be start time" assert isinstance(word[2], (int, float)), "word[2] should be end time" assert word[2] > word[1], "end should be > start" if len(word) == 4: # Letters: list of [char, start, end] for letter in word[3]: assert len(letter) == 3 assert isinstance(letter[0], str) def test_invalid_session(self, client): result = client.predict( FAKE_ID, None, "words", api_name="/mfa_timestamps_session", ) assert "error" in result assert result["segments"] == [] def test_default_granularity(self, client, session): """Empty granularity should default to words.""" result = client.predict( session["audio_id"], None, "", api_name="/mfa_timestamps_session", ) assert len(result["segments"]) > 0 for seg in result["segments"]: for word in seg.get("words", []): assert len(word) == 3, "default granularity should not include letters" # --------------------------------------------------------------------------- # 8. MFA timestamps — direct # --------------------------------------------------------------------------- class TestMfaTimestampsDirect: def test_basic(self, client, session): """Direct endpoint with audio file and segments.""" result = client.predict( AUDIO_FILE, session["segments"], "words", api_name="/mfa_timestamps_direct", ) assert "segments" in result assert len(result["segments"]) > 0 has_words = any("words" in seg for seg in result["segments"]) assert has_words def test_words_plus_chars(self, client, session): result = client.predict( AUDIO_FILE, session["segments"], "words+chars", api_name="/mfa_timestamps_direct", ) has_letters = any( len(word) == 4 for seg in result["segments"] for word in seg.get("words", []) ) assert has_letters def test_no_audio_id_in_response(self, client, session): """Direct endpoint should not return audio_id.""" result = client.predict( AUDIO_FILE, session["segments"], "words", api_name="/mfa_timestamps_direct", ) assert "audio_id" not in result def test_empty_segments_error(self, client): result = client.predict( AUDIO_FILE, [], "words", api_name="/mfa_timestamps_direct", ) assert "error" in result assert result["segments"] == [] # --------------------------------------------------------------------------- # 9. Segments stored in session after alignment # --------------------------------------------------------------------------- class TestSegmentStorage: def test_segments_stored_after_process(self, client): """process_audio_session should store segments for later MFA use.""" proc = client.predict( AUDIO_FILE, 200, 1000, 100, "Base", "CPU", api_name="/process_audio_session", ) # MFA session endpoint should find stored segments result = client.predict( proc["audio_id"], None, "words", api_name="/mfa_timestamps_session", ) assert "error" not in result or result.get("segments") assert result["audio_id"] == proc["audio_id"] # --------------------------------------------------------------------------- # 10. Error handling # --------------------------------------------------------------------------- class TestErrorHandling: def test_invalid_audio_id_retranscribe(self, client): result = client.predict( FAKE_ID, "Base", "CPU", api_name="/retranscribe_session", ) assert "error" in result assert "not found" in result["error"].lower() or "expired" in result["error"].lower() assert result["segments"] == [] def test_invalid_audio_id_resegment(self, client): result = client.predict( FAKE_ID, 200, 1000, 100, "Base", "CPU", api_name="/resegment_session", ) assert "error" in result assert result["segments"] == [] def test_invalid_audio_id_realign(self, client): timestamps = [{"start": 0.0, "end": 1.0}] result = client.predict( FAKE_ID, timestamps, "Base", "CPU", api_name="/realign_from_timestamps", ) assert "error" in result assert result["segments"] == []