File size: 16,817 Bytes
6cdb091
 
809b0b5
 
6cdb091
 
 
 
809b0b5
 
6cdb091
0d6804f
6cdb091
0d6804f
 
 
809b0b5
 
6cdb091
809b0b5
 
 
6cdb091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
809b0b5
 
 
6cdb091
809b0b5
 
 
 
 
6cdb091
809b0b5
 
 
 
 
6cdb091
809b0b5
 
 
 
 
 
 
 
6cdb091
809b0b5
 
 
6cdb091
809b0b5
 
 
 
6cdb091
 
809b0b5
 
 
6cdb091
809b0b5
 
 
 
 
 
 
 
 
6cdb091
809b0b5
 
 
 
 
 
 
 
6cdb091
809b0b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cdb091
 
809b0b5
 
 
6cdb091
809b0b5
 
 
 
 
 
 
 
 
 
 
 
 
6cdb091
809b0b5
 
 
 
 
 
 
 
6cdb091
809b0b5
 
 
 
 
 
 
 
 
 
 
 
 
 
6cdb091
809b0b5
 
 
 
 
 
6cdb091
 
809b0b5
 
 
6cdb091
809b0b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ce56b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
809b0b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
"""Integration tests for session-based API endpoints.

Requires the app to be running (localhost or live Space).
Start locally with: python app.py

Run with: python -m pytest tests/test_session_api.py -v -s
"""

import os

import pytest
from gradio_client import Client, handle_file

SERVER_URL = os.environ.get("TEST_SERVER_URL", "https://hetchyy-quran-multi-aligner.hf.space")
_AUDIO_PATH = "data/112.mp3"  # Surah Al-Ikhlas (~15s)
AUDIO_FILE = handle_file(_AUDIO_PATH)
FAKE_ID = "00000000000000000000000000000000"


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

@pytest.fixture(scope="module")
def client():
    return Client(SERVER_URL)


@pytest.fixture(scope="module")
def session(client):
    """Run process_audio_session once, share audio_id across tests."""
    result = client.predict(
        AUDIO_FILE, 200, 1000, 100, "Base", "CPU",
        api_name="/process_audio_session",
    )
    assert "audio_id" in result, f"Missing audio_id: {result}"
    assert result["audio_id"] is not None
    return result


# ---------------------------------------------------------------------------
# 1. process_audio_session
# ---------------------------------------------------------------------------

class TestProcessAudioSession:
    def test_creates_session(self, session):
        assert len(session["segments"]) > 0, "Expected at least one segment"
        aid = session["audio_id"]
        assert isinstance(aid, str) and len(aid) == 32

    def test_all_response_fields_present(self, session):
        seg = session["segments"][0]
        for field in ("segment", "time_from", "time_to", "ref_from", "ref_to",
                      "matched_text", "confidence", "has_missing_words", "error"):
            assert field in seg, f"Missing field: {field}"

    def test_segment_field_types(self, session):
        seg = session["segments"][0]
        assert isinstance(seg["segment"], int)
        assert isinstance(seg["time_from"], (int, float))
        assert isinstance(seg["time_to"], (int, float))
        assert isinstance(seg["confidence"], (int, float))
        assert 0 <= seg["confidence"] <= 1
        assert isinstance(seg["has_missing_words"], bool)

    def test_segments_ordered(self, session):
        nums = [s["segment"] for s in session["segments"]]
        assert nums == sorted(nums)

    def test_time_ordering(self, session):
        for seg in session["segments"]:
            assert seg["time_from"] >= 0
            assert seg["time_to"] > seg["time_from"]


# ---------------------------------------------------------------------------
# 2. resegment_session
# ---------------------------------------------------------------------------

class TestResegmentSession:
    def test_resegment_basic(self, client, session):
        result = client.predict(
            session["audio_id"], 600, 1500, 300, "Base", "CPU",
            api_name="/resegment_session",
        )
        assert result["audio_id"] == session["audio_id"]
        assert "segments" in result
        assert len(result["segments"]) > 0

    def test_resegment_merges_with_high_silence(self, client, session):
        """min_silence=2000ms should produce fewer (or equal) segments."""
        original_count = len(session["segments"])
        result = client.predict(
            session["audio_id"], 2000, 500, 100, "Base", "CPU",
            api_name="/resegment_session",
        )
        assert len(result["segments"]) <= original_count

    def test_resegment_updates_session(self, client, session):
        """After resegment, retranscribe with same model should still trigger
        guard (resegment already re-ran ASR with that model)."""
        # Resegment with Base model
        client.predict(
            session["audio_id"], 400, 1000, 150, "Base", "CPU",
            api_name="/resegment_session",
        )
        # Retranscribe with same model — guard triggers because resegment
        # already stored model=Base and the new intervals_hash
        result = client.predict(
            session["audio_id"], "Base", "CPU",
            api_name="/retranscribe_session",
        )
        assert "error" in result
        assert result["segments"] == []


# ---------------------------------------------------------------------------
# 3. retranscribe_session
# ---------------------------------------------------------------------------

class TestRetranscribeSession:
    def test_retranscribe_different_model(self, client, session):
        # First ensure we're on Base by resegmenting
        client.predict(
            session["audio_id"], 200, 1000, 100, "Base", "CPU",
            api_name="/resegment_session",
        )
        result = client.predict(
            session["audio_id"], "Large", "CPU",
            api_name="/retranscribe_session",
        )
        assert result["audio_id"] == session["audio_id"]
        assert len(result["segments"]) > 0

    def test_retranscribe_guard_same_model(self, client, session):
        """Same model + same boundaries -> error."""
        result = client.predict(
            session["audio_id"], "Large", "CPU",
            api_name="/retranscribe_session",
        )
        assert "error" in result
        assert result["segments"] == []

    def test_retranscribe_allowed_after_resegment(self, client, session):
        """Resegment changes boundaries, so retranscribe with same model
        should also trigger guard (resegment stores same model)."""
        # Resegment with different params
        client.predict(
            session["audio_id"], 300, 1200, 200, "Large", "CPU",
            api_name="/resegment_session",
        )
        # Same model as resegment used — guard triggers
        result = client.predict(
            session["audio_id"], "Large", "CPU",
            api_name="/retranscribe_session",
        )
        assert "error" in result

        # But switching model works
        result2 = client.predict(
            session["audio_id"], "Base", "CPU",
            api_name="/retranscribe_session",
        )
        assert len(result2["segments"]) > 0


# ---------------------------------------------------------------------------
# 4. realign_from_timestamps
# ---------------------------------------------------------------------------

class TestRealignFromTimestamps:
    def test_custom_timestamps(self, client, session):
        timestamps = [
            {"start": 0.5, "end": 3.0},
            {"start": 3.5, "end": 6.0},
            {"start": 6.5, "end": 10.0},
        ]
        result = client.predict(
            session["audio_id"], timestamps, "Base", "CPU",
            api_name="/realign_from_timestamps",
        )
        assert result["audio_id"] == session["audio_id"]
        assert len(result["segments"]) == 3

    def test_realign_updates_boundaries(self, client, session):
        """After realign with Base, retranscribe with same model triggers guard,
        but switching model works."""
        timestamps = [
            {"start": 0.5, "end": 4.0},
            {"start": 4.5, "end": 9.0},
        ]
        client.predict(
            session["audio_id"], timestamps, "Base", "CPU",
            api_name="/realign_from_timestamps",
        )
        # Same model — guard triggers
        result = client.predict(
            session["audio_id"], "Base", "CPU",
            api_name="/retranscribe_session",
        )
        assert "error" in result

        # Different model — allowed
        result2 = client.predict(
            session["audio_id"], "Large", "CPU",
            api_name="/retranscribe_session",
        )
        assert len(result2["segments"]) > 0


# ---------------------------------------------------------------------------
# 5. Consecutive calls / full workflow
# ---------------------------------------------------------------------------

class TestWorkflow:
    def test_consecutive_resegments(self, client, session):
        r1 = client.predict(
            session["audio_id"], 200, 1000, 100, "Base", "CPU",
            api_name="/resegment_session",
        )
        r2 = client.predict(
            session["audio_id"], 600, 1500, 300, "Base", "CPU",
            api_name="/resegment_session",
        )
        assert len(r1["segments"]) > 0
        assert len(r2["segments"]) > 0
        # Different params should yield different segment counts (usually)
        # Just verify both succeed

    def test_full_workflow(self, client, session):
        aid = session["audio_id"]

        # 1. Resegment
        r1 = client.predict(
            aid, 200, 1000, 100, "Base", "CPU",
            api_name="/resegment_session",
        )
        assert len(r1["segments"]) > 0

        # 2. Retranscribe with different model
        r2 = client.predict(
            aid, "Large", "CPU",
            api_name="/retranscribe_session",
        )
        assert len(r2["segments"]) > 0

        # 3. Realign with custom timestamps
        timestamps = [{"start": 0.5, "end": 5.0}, {"start": 5.5, "end": 10.0}]
        r3 = client.predict(
            aid, timestamps, "Base", "CPU",
            api_name="/realign_from_timestamps",
        )
        assert len(r3["segments"]) == 2

        # 4. Resegment again (session still valid)
        r4 = client.predict(
            aid, 400, 1200, 150, "Base", "CPU",
            api_name="/resegment_session",
        )
        assert len(r4["segments"]) > 0


# ---------------------------------------------------------------------------
# 6. Error handling
# ---------------------------------------------------------------------------

# ---------------------------------------------------------------------------
# 7. MFA timestamps — session-based
# ---------------------------------------------------------------------------

class TestMfaTimestampsSession:
    def test_basic_words_only(self, client, session):
        """Session endpoint with stored segments, words granularity."""
        result = client.predict(
            session["audio_id"], None, "words",
            api_name="/mfa_timestamps_session",
        )
        assert result["audio_id"] == session["audio_id"]
        assert len(result["segments"]) > 0
        has_words = any("words" in seg for seg in result["segments"])
        assert has_words, "Expected at least one segment with words"
        # Words-only: each word is [location, start, end] (3 elements)
        for seg in result["segments"]:
            for word in seg.get("words", []):
                assert len(word) == 3, f"words granularity should give 3-element arrays, got {len(word)}"

    def test_words_plus_chars(self, client, session):
        """Session endpoint with words+chars granularity."""
        result = client.predict(
            session["audio_id"], None, "words+chars",
            api_name="/mfa_timestamps_session",
        )
        has_letters = any(
            len(word) == 4
            for seg in result["segments"]
            for word in seg.get("words", [])
        )
        assert has_letters, "words+chars should include letter arrays (4th element)"

    def test_with_segments_override(self, client, session):
        """Session endpoint with explicit segments (override stored)."""
        segments_override = session["segments"][:2]
        result = client.predict(
            session["audio_id"], segments_override, "words",
            api_name="/mfa_timestamps_session",
        )
        assert result["audio_id"] == session["audio_id"]
        assert len(result["segments"]) == 2

    def test_word_timestamp_fields(self, client, session):
        """Verify word arrays have correct structure: [location, start, end, ?letters]."""
        result = client.predict(
            session["audio_id"], None, "words+chars",
            api_name="/mfa_timestamps_session",
        )
        for seg in result["segments"]:
            for word in seg.get("words", []):
                assert isinstance(word[0], str), "word[0] should be location string"
                assert isinstance(word[1], (int, float)), "word[1] should be start time"
                assert isinstance(word[2], (int, float)), "word[2] should be end time"
                assert word[2] > word[1], "end should be > start"
                if len(word) == 4:
                    # Letters: list of [char, start, end]
                    for letter in word[3]:
                        assert len(letter) == 3
                        assert isinstance(letter[0], str)

    def test_invalid_session(self, client):
        result = client.predict(
            FAKE_ID, None, "words",
            api_name="/mfa_timestamps_session",
        )
        assert "error" in result
        assert result["segments"] == []

    def test_default_granularity(self, client, session):
        """Empty granularity should default to words."""
        result = client.predict(
            session["audio_id"], None, "",
            api_name="/mfa_timestamps_session",
        )
        assert len(result["segments"]) > 0
        for seg in result["segments"]:
            for word in seg.get("words", []):
                assert len(word) == 3, "default granularity should not include letters"


# ---------------------------------------------------------------------------
# 8. MFA timestamps — direct
# ---------------------------------------------------------------------------

class TestMfaTimestampsDirect:
    def test_basic(self, client, session):
        """Direct endpoint with audio file and segments."""
        result = client.predict(
            AUDIO_FILE, session["segments"], "words",
            api_name="/mfa_timestamps_direct",
        )
        assert "segments" in result
        assert len(result["segments"]) > 0
        has_words = any("words" in seg for seg in result["segments"])
        assert has_words

    def test_words_plus_chars(self, client, session):
        result = client.predict(
            AUDIO_FILE, session["segments"], "words+chars",
            api_name="/mfa_timestamps_direct",
        )
        has_letters = any(
            len(word) == 4
            for seg in result["segments"]
            for word in seg.get("words", [])
        )
        assert has_letters

    def test_no_audio_id_in_response(self, client, session):
        """Direct endpoint should not return audio_id."""
        result = client.predict(
            AUDIO_FILE, session["segments"], "words",
            api_name="/mfa_timestamps_direct",
        )
        assert "audio_id" not in result

    def test_empty_segments_error(self, client):
        result = client.predict(
            AUDIO_FILE, [], "words",
            api_name="/mfa_timestamps_direct",
        )
        assert "error" in result
        assert result["segments"] == []


# ---------------------------------------------------------------------------
# 9. Segments stored in session after alignment
# ---------------------------------------------------------------------------

class TestSegmentStorage:
    def test_segments_stored_after_process(self, client):
        """process_audio_session should store segments for later MFA use."""
        proc = client.predict(
            AUDIO_FILE, 200, 1000, 100, "Base", "CPU",
            api_name="/process_audio_session",
        )
        # MFA session endpoint should find stored segments
        result = client.predict(
            proc["audio_id"], None, "words",
            api_name="/mfa_timestamps_session",
        )
        assert "error" not in result or result.get("segments")
        assert result["audio_id"] == proc["audio_id"]


# ---------------------------------------------------------------------------
# 10. Error handling
# ---------------------------------------------------------------------------

class TestErrorHandling:
    def test_invalid_audio_id_retranscribe(self, client):
        result = client.predict(
            FAKE_ID, "Base", "CPU",
            api_name="/retranscribe_session",
        )
        assert "error" in result
        assert "not found" in result["error"].lower() or "expired" in result["error"].lower()
        assert result["segments"] == []

    def test_invalid_audio_id_resegment(self, client):
        result = client.predict(
            FAKE_ID, 200, 1000, 100, "Base", "CPU",
            api_name="/resegment_session",
        )
        assert "error" in result
        assert result["segments"] == []

    def test_invalid_audio_id_realign(self, client):
        timestamps = [{"start": 0.0, "end": 1.0}]
        result = client.predict(
            FAKE_ID, timestamps, "Base", "CPU",
            api_name="/realign_from_timestamps",
        )
        assert "error" in result
        assert result["segments"] == []