Quran-multi-aligner / tests /test_session_api.py
hetchyy's picture
Update timestamps API
0d6804f
"""Integration tests for session-based API endpoints.
Requires the app to be running (localhost or live Space).
Start locally with: python app.py
Run with: python -m pytest tests/test_session_api.py -v -s
"""
import os
import pytest
from gradio_client import Client, handle_file
SERVER_URL = os.environ.get("TEST_SERVER_URL", "https://hetchyy-quran-multi-aligner.hf.space")
_AUDIO_PATH = "data/112.mp3" # Surah Al-Ikhlas (~15s)
AUDIO_FILE = handle_file(_AUDIO_PATH)
FAKE_ID = "00000000000000000000000000000000"
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def client():
return Client(SERVER_URL)
@pytest.fixture(scope="module")
def session(client):
"""Run process_audio_session once, share audio_id across tests."""
result = client.predict(
AUDIO_FILE, 200, 1000, 100, "Base", "CPU",
api_name="/process_audio_session",
)
assert "audio_id" in result, f"Missing audio_id: {result}"
assert result["audio_id"] is not None
return result
# ---------------------------------------------------------------------------
# 1. process_audio_session
# ---------------------------------------------------------------------------
class TestProcessAudioSession:
def test_creates_session(self, session):
assert len(session["segments"]) > 0, "Expected at least one segment"
aid = session["audio_id"]
assert isinstance(aid, str) and len(aid) == 32
def test_all_response_fields_present(self, session):
seg = session["segments"][0]
for field in ("segment", "time_from", "time_to", "ref_from", "ref_to",
"matched_text", "confidence", "has_missing_words", "error"):
assert field in seg, f"Missing field: {field}"
def test_segment_field_types(self, session):
seg = session["segments"][0]
assert isinstance(seg["segment"], int)
assert isinstance(seg["time_from"], (int, float))
assert isinstance(seg["time_to"], (int, float))
assert isinstance(seg["confidence"], (int, float))
assert 0 <= seg["confidence"] <= 1
assert isinstance(seg["has_missing_words"], bool)
def test_segments_ordered(self, session):
nums = [s["segment"] for s in session["segments"]]
assert nums == sorted(nums)
def test_time_ordering(self, session):
for seg in session["segments"]:
assert seg["time_from"] >= 0
assert seg["time_to"] > seg["time_from"]
# ---------------------------------------------------------------------------
# 2. resegment_session
# ---------------------------------------------------------------------------
class TestResegmentSession:
def test_resegment_basic(self, client, session):
result = client.predict(
session["audio_id"], 600, 1500, 300, "Base", "CPU",
api_name="/resegment_session",
)
assert result["audio_id"] == session["audio_id"]
assert "segments" in result
assert len(result["segments"]) > 0
def test_resegment_merges_with_high_silence(self, client, session):
"""min_silence=2000ms should produce fewer (or equal) segments."""
original_count = len(session["segments"])
result = client.predict(
session["audio_id"], 2000, 500, 100, "Base", "CPU",
api_name="/resegment_session",
)
assert len(result["segments"]) <= original_count
def test_resegment_updates_session(self, client, session):
"""After resegment, retranscribe with same model should still trigger
guard (resegment already re-ran ASR with that model)."""
# Resegment with Base model
client.predict(
session["audio_id"], 400, 1000, 150, "Base", "CPU",
api_name="/resegment_session",
)
# Retranscribe with same model — guard triggers because resegment
# already stored model=Base and the new intervals_hash
result = client.predict(
session["audio_id"], "Base", "CPU",
api_name="/retranscribe_session",
)
assert "error" in result
assert result["segments"] == []
# ---------------------------------------------------------------------------
# 3. retranscribe_session
# ---------------------------------------------------------------------------
class TestRetranscribeSession:
def test_retranscribe_different_model(self, client, session):
# First ensure we're on Base by resegmenting
client.predict(
session["audio_id"], 200, 1000, 100, "Base", "CPU",
api_name="/resegment_session",
)
result = client.predict(
session["audio_id"], "Large", "CPU",
api_name="/retranscribe_session",
)
assert result["audio_id"] == session["audio_id"]
assert len(result["segments"]) > 0
def test_retranscribe_guard_same_model(self, client, session):
"""Same model + same boundaries -> error."""
result = client.predict(
session["audio_id"], "Large", "CPU",
api_name="/retranscribe_session",
)
assert "error" in result
assert result["segments"] == []
def test_retranscribe_allowed_after_resegment(self, client, session):
"""Resegment changes boundaries, so retranscribe with same model
should also trigger guard (resegment stores same model)."""
# Resegment with different params
client.predict(
session["audio_id"], 300, 1200, 200, "Large", "CPU",
api_name="/resegment_session",
)
# Same model as resegment used — guard triggers
result = client.predict(
session["audio_id"], "Large", "CPU",
api_name="/retranscribe_session",
)
assert "error" in result
# But switching model works
result2 = client.predict(
session["audio_id"], "Base", "CPU",
api_name="/retranscribe_session",
)
assert len(result2["segments"]) > 0
# ---------------------------------------------------------------------------
# 4. realign_from_timestamps
# ---------------------------------------------------------------------------
class TestRealignFromTimestamps:
def test_custom_timestamps(self, client, session):
timestamps = [
{"start": 0.5, "end": 3.0},
{"start": 3.5, "end": 6.0},
{"start": 6.5, "end": 10.0},
]
result = client.predict(
session["audio_id"], timestamps, "Base", "CPU",
api_name="/realign_from_timestamps",
)
assert result["audio_id"] == session["audio_id"]
assert len(result["segments"]) == 3
def test_realign_updates_boundaries(self, client, session):
"""After realign with Base, retranscribe with same model triggers guard,
but switching model works."""
timestamps = [
{"start": 0.5, "end": 4.0},
{"start": 4.5, "end": 9.0},
]
client.predict(
session["audio_id"], timestamps, "Base", "CPU",
api_name="/realign_from_timestamps",
)
# Same model — guard triggers
result = client.predict(
session["audio_id"], "Base", "CPU",
api_name="/retranscribe_session",
)
assert "error" in result
# Different model — allowed
result2 = client.predict(
session["audio_id"], "Large", "CPU",
api_name="/retranscribe_session",
)
assert len(result2["segments"]) > 0
# ---------------------------------------------------------------------------
# 5. Consecutive calls / full workflow
# ---------------------------------------------------------------------------
class TestWorkflow:
def test_consecutive_resegments(self, client, session):
r1 = client.predict(
session["audio_id"], 200, 1000, 100, "Base", "CPU",
api_name="/resegment_session",
)
r2 = client.predict(
session["audio_id"], 600, 1500, 300, "Base", "CPU",
api_name="/resegment_session",
)
assert len(r1["segments"]) > 0
assert len(r2["segments"]) > 0
# Different params should yield different segment counts (usually)
# Just verify both succeed
def test_full_workflow(self, client, session):
aid = session["audio_id"]
# 1. Resegment
r1 = client.predict(
aid, 200, 1000, 100, "Base", "CPU",
api_name="/resegment_session",
)
assert len(r1["segments"]) > 0
# 2. Retranscribe with different model
r2 = client.predict(
aid, "Large", "CPU",
api_name="/retranscribe_session",
)
assert len(r2["segments"]) > 0
# 3. Realign with custom timestamps
timestamps = [{"start": 0.5, "end": 5.0}, {"start": 5.5, "end": 10.0}]
r3 = client.predict(
aid, timestamps, "Base", "CPU",
api_name="/realign_from_timestamps",
)
assert len(r3["segments"]) == 2
# 4. Resegment again (session still valid)
r4 = client.predict(
aid, 400, 1200, 150, "Base", "CPU",
api_name="/resegment_session",
)
assert len(r4["segments"]) > 0
# ---------------------------------------------------------------------------
# 6. Error handling
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# 7. MFA timestamps — session-based
# ---------------------------------------------------------------------------
class TestMfaTimestampsSession:
def test_basic_words_only(self, client, session):
"""Session endpoint with stored segments, words granularity."""
result = client.predict(
session["audio_id"], None, "words",
api_name="/mfa_timestamps_session",
)
assert result["audio_id"] == session["audio_id"]
assert len(result["segments"]) > 0
has_words = any("words" in seg for seg in result["segments"])
assert has_words, "Expected at least one segment with words"
# Words-only: each word is [location, start, end] (3 elements)
for seg in result["segments"]:
for word in seg.get("words", []):
assert len(word) == 3, f"words granularity should give 3-element arrays, got {len(word)}"
def test_words_plus_chars(self, client, session):
"""Session endpoint with words+chars granularity."""
result = client.predict(
session["audio_id"], None, "words+chars",
api_name="/mfa_timestamps_session",
)
has_letters = any(
len(word) == 4
for seg in result["segments"]
for word in seg.get("words", [])
)
assert has_letters, "words+chars should include letter arrays (4th element)"
def test_with_segments_override(self, client, session):
"""Session endpoint with explicit segments (override stored)."""
segments_override = session["segments"][:2]
result = client.predict(
session["audio_id"], segments_override, "words",
api_name="/mfa_timestamps_session",
)
assert result["audio_id"] == session["audio_id"]
assert len(result["segments"]) == 2
def test_word_timestamp_fields(self, client, session):
"""Verify word arrays have correct structure: [location, start, end, ?letters]."""
result = client.predict(
session["audio_id"], None, "words+chars",
api_name="/mfa_timestamps_session",
)
for seg in result["segments"]:
for word in seg.get("words", []):
assert isinstance(word[0], str), "word[0] should be location string"
assert isinstance(word[1], (int, float)), "word[1] should be start time"
assert isinstance(word[2], (int, float)), "word[2] should be end time"
assert word[2] > word[1], "end should be > start"
if len(word) == 4:
# Letters: list of [char, start, end]
for letter in word[3]:
assert len(letter) == 3
assert isinstance(letter[0], str)
def test_invalid_session(self, client):
result = client.predict(
FAKE_ID, None, "words",
api_name="/mfa_timestamps_session",
)
assert "error" in result
assert result["segments"] == []
def test_default_granularity(self, client, session):
"""Empty granularity should default to words."""
result = client.predict(
session["audio_id"], None, "",
api_name="/mfa_timestamps_session",
)
assert len(result["segments"]) > 0
for seg in result["segments"]:
for word in seg.get("words", []):
assert len(word) == 3, "default granularity should not include letters"
# ---------------------------------------------------------------------------
# 8. MFA timestamps — direct
# ---------------------------------------------------------------------------
class TestMfaTimestampsDirect:
def test_basic(self, client, session):
"""Direct endpoint with audio file and segments."""
result = client.predict(
AUDIO_FILE, session["segments"], "words",
api_name="/mfa_timestamps_direct",
)
assert "segments" in result
assert len(result["segments"]) > 0
has_words = any("words" in seg for seg in result["segments"])
assert has_words
def test_words_plus_chars(self, client, session):
result = client.predict(
AUDIO_FILE, session["segments"], "words+chars",
api_name="/mfa_timestamps_direct",
)
has_letters = any(
len(word) == 4
for seg in result["segments"]
for word in seg.get("words", [])
)
assert has_letters
def test_no_audio_id_in_response(self, client, session):
"""Direct endpoint should not return audio_id."""
result = client.predict(
AUDIO_FILE, session["segments"], "words",
api_name="/mfa_timestamps_direct",
)
assert "audio_id" not in result
def test_empty_segments_error(self, client):
result = client.predict(
AUDIO_FILE, [], "words",
api_name="/mfa_timestamps_direct",
)
assert "error" in result
assert result["segments"] == []
# ---------------------------------------------------------------------------
# 9. Segments stored in session after alignment
# ---------------------------------------------------------------------------
class TestSegmentStorage:
def test_segments_stored_after_process(self, client):
"""process_audio_session should store segments for later MFA use."""
proc = client.predict(
AUDIO_FILE, 200, 1000, 100, "Base", "CPU",
api_name="/process_audio_session",
)
# MFA session endpoint should find stored segments
result = client.predict(
proc["audio_id"], None, "words",
api_name="/mfa_timestamps_session",
)
assert "error" not in result or result.get("segments")
assert result["audio_id"] == proc["audio_id"]
# ---------------------------------------------------------------------------
# 10. Error handling
# ---------------------------------------------------------------------------
class TestErrorHandling:
def test_invalid_audio_id_retranscribe(self, client):
result = client.predict(
FAKE_ID, "Base", "CPU",
api_name="/retranscribe_session",
)
assert "error" in result
assert "not found" in result["error"].lower() or "expired" in result["error"].lower()
assert result["segments"] == []
def test_invalid_audio_id_resegment(self, client):
result = client.predict(
FAKE_ID, 200, 1000, 100, "Base", "CPU",
api_name="/resegment_session",
)
assert "error" in result
assert result["segments"] == []
def test_invalid_audio_id_realign(self, client):
timestamps = [{"start": 0.0, "end": 1.0}]
result = client.predict(
FAKE_ID, timestamps, "Base", "CPU",
api_name="/realign_from_timestamps",
)
assert "error" in result
assert result["segments"] == []