File size: 8,854 Bytes
750bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
A test script to test the gRPC service for VibeVoice TTS and ASR
"""
import unittest
import subprocess
import time
import os
import tempfile
import shutil
import backend_pb2
import backend_pb2_grpc

import grpc

# Check if we should skip ASR tests (they require large models ~14B parameters total)
# Skip in CI or if explicitly disabled
SKIP_ASR_TESTS = os.environ.get("SKIP_ASR_TESTS", "false").lower() == "true"


class TestBackendServicer(unittest.TestCase):
    """
    TestBackendServicer is the class that tests the gRPC service
    """
    def setUp(self):
        """
        This method sets up the gRPC service by starting the server
        """
        self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
        time.sleep(30)

    def tearDown(self) -> None:
        """
        This method tears down the gRPC service by terminating the server
        """
        self.service.terminate()
        self.service.wait()

    def test_server_startup(self):
        """
        This method tests if the server starts up successfully
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.Health(backend_pb2.HealthMessage())
                self.assertEqual(response.message, b'OK')
        except Exception as err:
            print(err)
            self.fail("Server failed to start")
        finally:
            self.tearDown()

    def test_load_tts_model(self):
        """
        This method tests if the TTS model is loaded successfully
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="microsoft/VibeVoice-Realtime-0.5B"))
                print(response)
                self.assertTrue(response.success)
                self.assertEqual(response.message, "Model loaded successfully")
        except Exception as err:
            print(err)
            self.fail("LoadModel service failed")
        finally:
            self.tearDown()

    @unittest.skipIf(SKIP_ASR_TESTS, "ASR tests require large models (~14B parameters) and are skipped in CI")
    def test_load_asr_model(self):
        """
        This method tests if the ASR model is loaded successfully with asr_mode option
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(
                    Model="microsoft/VibeVoice-ASR",
                    Options=["asr_mode:true"]
                ))
                print(f"LoadModel response: {response}")
                if not response.success:
                    print(f"LoadModel failed with message: {response.message}")
                self.assertTrue(response.success, f"LoadModel failed: {response.message}")
                self.assertEqual(response.message, "Model loaded successfully")
        except Exception as err:
            print(f"Exception during LoadModel: {err}")
            import traceback
            traceback.print_exc()
            self.fail("LoadModel service failed for ASR mode")
        finally:
            self.tearDown()

    def test_tts(self):
        """
        This method tests if TTS generation works successfully
        """
        # Create a temporary directory for the output audio file
        temp_dir = tempfile.mkdtemp()
        output_file = os.path.join(temp_dir, 'output.wav')
        
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                # Load TTS model
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="microsoft/VibeVoice-Realtime-0.5B"))
                self.assertTrue(response.success)
                
                # Generate TTS
                tts_request = backend_pb2.TTSRequest(
                    text="Hello, this is a test of the VibeVoice text to speech system.",
                    dst=output_file
                )
                tts_response = stub.TTS(tts_request)
                
                # Verify response
                self.assertIsNotNone(tts_response)
                self.assertTrue(tts_response.success)
                
                # Verify output file was created
                self.assertTrue(os.path.exists(output_file), f"Output file was not created: {output_file}")
                self.assertGreater(os.path.getsize(output_file), 0, "Output file is empty")
        except Exception as err:
            print(err)
            self.fail("TTS service failed")
        finally:
            self.tearDown()
            # Clean up the temporary directory
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)

    @unittest.skipIf(SKIP_ASR_TESTS, "ASR tests require large models (~14B parameters) and are skipped in CI")
    def test_audio_transcription(self):
        """
        This method tests if audio transcription works successfully
        """
        # Create a temporary directory for the audio file
        temp_dir = tempfile.mkdtemp()
        audio_file = os.path.join(temp_dir, 'audio.wav')
        
        try:
            # Download the audio file to the temporary directory
            print(f"Downloading audio file to {audio_file}...")
            url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
            result = subprocess.run(
                ["wget", "-q", url, "-O", audio_file],
                capture_output=True,
                text=True
            )
            if result.returncode != 0:
                self.fail(f"Failed to download audio file: {result.stderr}")
            
            # Verify the file was downloaded
            if not os.path.exists(audio_file):
                self.fail(f"Audio file was not downloaded to {audio_file}")
            
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                # Load the ASR model first
                load_response = stub.LoadModel(backend_pb2.ModelOptions(
                    Model="microsoft/VibeVoice-ASR",
                    Options=["asr_mode:true"]
                ))
                print(f"LoadModel response: {load_response}")
                if not load_response.success:
                    print(f"LoadModel failed with message: {load_response.message}")
                self.assertTrue(load_response.success, f"LoadModel failed: {load_response.message}")
                
                # Perform transcription
                transcript_request = backend_pb2.TranscriptRequest(dst=audio_file)
                transcript_response = stub.AudioTranscription(transcript_request)
                
                # Print the transcribed text for debugging
                print(f"Transcribed text: {transcript_response.text}")
                print(f"Number of segments: {len(transcript_response.segments)}")
                
                # Verify response structure
                self.assertIsNotNone(transcript_response)
                self.assertIsNotNone(transcript_response.text)
                # Protobuf repeated fields return a sequence, not a list
                self.assertIsNotNone(transcript_response.segments)
                # Check if segments is iterable (has length)
                self.assertGreaterEqual(len(transcript_response.segments), 0)
                
                # Verify the transcription contains some text
                self.assertGreater(len(transcript_response.text), 0, "Transcription should not be empty")
                
                # If we got segments, verify they have the expected structure
                if len(transcript_response.segments) > 0:
                    segment = transcript_response.segments[0]
                    self.assertIsNotNone(segment.text)
                    self.assertIsInstance(segment.id, int)
                else:
                    # Even if no segments, we should have text
                    self.assertIsNotNone(transcript_response.text)
                    self.assertGreater(len(transcript_response.text), 0)
        except Exception as err:
            print(err)
            self.fail("AudioTranscription service failed")
        finally:
            self.tearDown()
            # Clean up the temporary directory
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)