File size: 3,722 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Unit tests for the data pipeline: augmentation, feature extractor, agri dictionary.
"""
from __future__ import annotations

import numpy as np
import pytest


class TestFieldNoiseAugmenter:
    def test_augmenter_without_noise_files(self, tmp_path):
        """Augmenter with empty noise_dir should fall back to Gaussian-only and still be ready."""
        config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 0.6}}
        from src.data.augmentation import FieldNoiseAugmenter

        augmenter = FieldNoiseAugmenter(str(tmp_path), config)
        assert augmenter.is_ready()
        assert augmenter._gaussian_only

    def test_augmenter_output_shape(self, tmp_path):
        """Augmented audio should have the same length as input."""
        config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 1.0}}
        from src.data.augmentation import FieldNoiseAugmenter

        augmenter = FieldNoiseAugmenter(str(tmp_path), config)
        audio = np.random.randn(16000).astype(np.float32) * 0.01
        augmented = augmenter.augment(audio, 16000)
        assert augmented.shape == audio.shape

    def test_augmenter_no_crash_on_silent_audio(self, tmp_path):
        """Silent audio (all zeros) should not crash the augmenter."""
        config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 0.5}}
        from src.data.augmentation import FieldNoiseAugmenter

        augmenter = FieldNoiseAugmenter(str(tmp_path), config)
        audio = np.zeros(16000, dtype=np.float32)
        result = augmenter.augment(audio, 16000)
        assert result is not None


class TestAgriculturalDictionary:
    def test_bambara_vocab_not_empty(self):
        from src.data.agri_dictionary import BAMBARA_VOCAB

        assert len(BAMBARA_VOCAB) > 0

    def test_fula_vocab_not_empty(self):
        from src.data.agri_dictionary import FULA_VOCAB

        assert len(FULA_VOCAB) > 0

    def test_get_vocab_invalid_language(self):
        from src.data.agri_dictionary import AgriculturalDictionary

        d = AgriculturalDictionary()
        with pytest.raises(ValueError):
            d.get_vocab("xyz")

    def test_prompt_text_contains_terms(self):
        from src.data.agri_dictionary import AgriculturalDictionary

        d = AgriculturalDictionary()
        prompt = d.get_prompt_text("bam")
        assert "sɛnɛ" in prompt
        assert "kaba" in prompt


class TestDataCollator:
    def test_collator_pads_labels(self):
        """DataCollator should pad labels and replace pad tokens with -100."""
        from unittest.mock import MagicMock

        import torch

        from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding

        # Mock processor
        processor = MagicMock()
        processor.feature_extractor.pad.return_value = {
            "input_features": torch.zeros(2, 80, 3000)
        }
        # Simulate padded labels batch
        padded_labels = MagicMock()
        padded_labels.input_ids = torch.tensor([[1, 2, 3, 0], [1, 4, 0, 0]])
        padded_labels.attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])
        processor.tokenizer.pad.return_value = padded_labels

        collator = DataCollatorSpeechSeq2SeqWithPadding(
            processor=processor,
            decoder_start_token_id=1,
        )

        features = [
            {"input_features": np.zeros((80, 3000)), "labels": [1, 2, 3]},
            {"input_features": np.zeros((80, 3000)), "labels": [1, 4]},
        ]
        batch = collator(features)
        assert "labels" in batch
        # -100 should appear where attention_mask is 0
        assert -100 in batch["labels"].tolist()[0] or -100 in batch["labels"].tolist()[1]