Spaces:
Running
Running
File size: 8,854 Bytes
750bbe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
"""
A test script to test the gRPC service for VibeVoice TTS and ASR
"""
import unittest
import subprocess
import time
import os
import tempfile
import shutil
import backend_pb2
import backend_pb2_grpc
import grpc
# Check if we should skip ASR tests (they require large models ~14B parameters total)
# Skip in CI or if explicitly disabled
SKIP_ASR_TESTS = os.environ.get("SKIP_ASR_TESTS", "false").lower() == "true"
class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service
"""
def setUp(self):
"""
This method sets up the gRPC service by starting the server
"""
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
time.sleep(30)
def tearDown(self) -> None:
"""
This method tears down the gRPC service by terminating the server
"""
self.service.terminate()
self.service.wait()
def test_server_startup(self):
"""
This method tests if the server starts up successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_load_tts_model(self):
"""
This method tests if the TTS model is loaded successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="microsoft/VibeVoice-Realtime-0.5B"))
print(response)
self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully")
except Exception as err:
print(err)
self.fail("LoadModel service failed")
finally:
self.tearDown()
@unittest.skipIf(SKIP_ASR_TESTS, "ASR tests require large models (~14B parameters) and are skipped in CI")
def test_load_asr_model(self):
"""
This method tests if the ASR model is loaded successfully with asr_mode option
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(
Model="microsoft/VibeVoice-ASR",
Options=["asr_mode:true"]
))
print(f"LoadModel response: {response}")
if not response.success:
print(f"LoadModel failed with message: {response.message}")
self.assertTrue(response.success, f"LoadModel failed: {response.message}")
self.assertEqual(response.message, "Model loaded successfully")
except Exception as err:
print(f"Exception during LoadModel: {err}")
import traceback
traceback.print_exc()
self.fail("LoadModel service failed for ASR mode")
finally:
self.tearDown()
def test_tts(self):
"""
This method tests if TTS generation works successfully
"""
# Create a temporary directory for the output audio file
temp_dir = tempfile.mkdtemp()
output_file = os.path.join(temp_dir, 'output.wav')
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
# Load TTS model
response = stub.LoadModel(backend_pb2.ModelOptions(Model="microsoft/VibeVoice-Realtime-0.5B"))
self.assertTrue(response.success)
# Generate TTS
tts_request = backend_pb2.TTSRequest(
text="Hello, this is a test of the VibeVoice text to speech system.",
dst=output_file
)
tts_response = stub.TTS(tts_request)
# Verify response
self.assertIsNotNone(tts_response)
self.assertTrue(tts_response.success)
# Verify output file was created
self.assertTrue(os.path.exists(output_file), f"Output file was not created: {output_file}")
self.assertGreater(os.path.getsize(output_file), 0, "Output file is empty")
except Exception as err:
print(err)
self.fail("TTS service failed")
finally:
self.tearDown()
# Clean up the temporary directory
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
@unittest.skipIf(SKIP_ASR_TESTS, "ASR tests require large models (~14B parameters) and are skipped in CI")
def test_audio_transcription(self):
"""
This method tests if audio transcription works successfully
"""
# Create a temporary directory for the audio file
temp_dir = tempfile.mkdtemp()
audio_file = os.path.join(temp_dir, 'audio.wav')
try:
# Download the audio file to the temporary directory
print(f"Downloading audio file to {audio_file}...")
url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
result = subprocess.run(
["wget", "-q", url, "-O", audio_file],
capture_output=True,
text=True
)
if result.returncode != 0:
self.fail(f"Failed to download audio file: {result.stderr}")
# Verify the file was downloaded
if not os.path.exists(audio_file):
self.fail(f"Audio file was not downloaded to {audio_file}")
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
# Load the ASR model first
load_response = stub.LoadModel(backend_pb2.ModelOptions(
Model="microsoft/VibeVoice-ASR",
Options=["asr_mode:true"]
))
print(f"LoadModel response: {load_response}")
if not load_response.success:
print(f"LoadModel failed with message: {load_response.message}")
self.assertTrue(load_response.success, f"LoadModel failed: {load_response.message}")
# Perform transcription
transcript_request = backend_pb2.TranscriptRequest(dst=audio_file)
transcript_response = stub.AudioTranscription(transcript_request)
# Print the transcribed text for debugging
print(f"Transcribed text: {transcript_response.text}")
print(f"Number of segments: {len(transcript_response.segments)}")
# Verify response structure
self.assertIsNotNone(transcript_response)
self.assertIsNotNone(transcript_response.text)
# Protobuf repeated fields return a sequence, not a list
self.assertIsNotNone(transcript_response.segments)
# Check if segments is iterable (has length)
self.assertGreaterEqual(len(transcript_response.segments), 0)
# Verify the transcription contains some text
self.assertGreater(len(transcript_response.text), 0, "Transcription should not be empty")
# If we got segments, verify they have the expected structure
if len(transcript_response.segments) > 0:
segment = transcript_response.segments[0]
self.assertIsNotNone(segment.text)
self.assertIsInstance(segment.id, int)
else:
# Even if no segments, we should have text
self.assertIsNotNone(transcript_response.text)
self.assertGreater(len(transcript_response.text), 0)
except Exception as err:
print(err)
self.fail("AudioTranscription service failed")
finally:
self.tearDown()
# Clean up the temporary directory
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir) |