File size: 3,785 Bytes
750bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Tests for the NEMO Toolkit ASR gRPC backend.
"""
import unittest
import subprocess
import time
import os
import tempfile
import shutil
import backend_pb2
import backend_pb2_grpc

import grpc

# Skip heavy transcription test in CI (model download + inference)
SKIP_ASR_TESTS = os.environ.get("SKIP_ASR_TESTS", "false").lower() == "true"


class TestBackendServicer(unittest.TestCase):
    def setUp(self):
        self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
        time.sleep(15)

    def tearDown(self):
        self.service.terminate()
        self.service.wait()

    def test_server_startup(self):
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.Health(backend_pb2.HealthMessage())
                self.assertEqual(response.message, b'OK')
        except Exception as err:
            print(err)
            self.fail("Server failed to start")
        finally:
            self.tearDown()

    def test_load_model(self):
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="nvidia/parakeet-tdt-0.6b-v3"))
                self.assertTrue(response.success, response.message)
                self.assertEqual(response.message, "Model loaded successfully")
        except Exception as err:
            print(err)
            self.fail("LoadModel service failed")
        finally:
            self.tearDown()

    @unittest.skipIf(SKIP_ASR_TESTS, "ASR transcription test skipped (SKIP_ASR_TESTS=true)")
    def test_audio_transcription(self):
        temp_dir = tempfile.mkdtemp()
        audio_file = os.path.join(temp_dir, 'audio.wav')
        try:
            # Download a sample audio file for testing
            url = "https://audio-samples.github.io/samples/mp3/crowd-cheering-and-applause-sound-effect.mp3"
            result = subprocess.run(
                ["wget", "-q", url, "-O", audio_file],
                capture_output=True,
                text=True,
                timeout=30,
            )
            if result.returncode != 0:
                self.skipTest(f"Could not download sample audio: {result.stderr}")
            if not os.path.exists(audio_file):
                self.skipTest("Sample audio file not found after download")

            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="nvidia/parakeet-tdt-0.6b-v3"))
                self.assertTrue(load_response.success, load_response.message)

                transcript_response = stub.AudioTranscription(
                    backend_pb2.TranscriptRequest(dst=audio_file)
                )
                self.assertIsNotNone(transcript_response)
                self.assertIsNotNone(transcript_response.text)
                self.assertGreaterEqual(len(transcript_response.segments), 0)
                all_text = ""
                for segment in transcript_response.segments:
                    all_text += segment.text
                print(f"Transcription result: {all_text}")
                self.assertIn("big", all_text)
                if transcript_response.segments:
                    self.assertIsNotNone(transcript_response.segments[0].text)
        finally:
            self.tearDown()
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)


if __name__ == '__main__':
    unittest.main()