learnable-speech / speech /tools /S3Tokenizer /test /test_batch_efficiency.py
Ubuntu
update tokenizer
24d0b1d
raw
history blame
9.46 kB
#!/usr/bin/env python3
"""
Batch processing efficiency test
Test the efficiency improvement of new batch processing functionality for mixed long and short audio
"""
import time
import torch
import pytest
import s3tokenizer
def create_test_audio(duration_seconds=20, sample_rate=16000):
"""Create test audio"""
length = int(duration_seconds * sample_rate)
# Create meaningful audio signal (sine wave mixture)
t = torch.linspace(0, duration_seconds, length)
audio = 0.5 * torch.sin(2 * torch.pi * 440 * t) # 440Hz fundamental
audio += 0.3 * torch.sin(2 * torch.pi * 880 * t) # 880Hz second harmonic
audio += 0.1 * torch.randn(length) # Add some noise
return audio
@pytest.fixture
def test_audios():
"""Create test audio dataset"""
return [
create_test_audio(10), # Short audio
create_test_audio(20), # Medium audio
create_test_audio(40), # Long audio
create_test_audio(60), # Long audio
create_test_audio(15), # Short audio
create_test_audio(35), # Long audio
create_test_audio(25), # Medium audio
create_test_audio(50), # Long audio
]
@pytest.fixture
def long_audios():
"""Create long audio dataset"""
return [
create_test_audio(45.5),
create_test_audio(60),
create_test_audio(91.2),
create_test_audio(120),
]
@pytest.mark.parametrize("model_name", [
"speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
"speech_tokenizer_v2_25hz"
])
def test_batch_efficiency(test_audios, model_name):
"""Test batch processing efficiency for different models"""
print(f"\n=== Batch Processing Efficiency Test for {model_name} ===")
# Load model
model = s3tokenizer.load_model(model_name)
model.eval()
# Method 1: Individual processing
print(f"\n--- Method 1: Individual Processing ({model_name}) ---")
start_time = time.time()
individual_results = []
for i, audio in enumerate(test_audios):
mel = s3tokenizer.log_mel_spectrogram(audio)
mels = mel.unsqueeze(0)
mels_lens = torch.tensor([mel.size(1)])
with torch.no_grad():
codes, codes_lens = model.quantize(mels, mels_lens)
final_codes = codes[0, :codes_lens[0].item()].tolist()
individual_results.append(final_codes)
duration = audio.shape[0] / 16000
processing_type = "Long audio" if duration > 30 else "Short audio"
print(
f"Audio {i+1}: {duration:.1f}s, {len(final_codes)} tokens, {processing_type}"
)
individual_time = time.time() - start_time
print(f"Individual processing total time: {individual_time:.2f}s")
# Method 2: Batch processing
print(f"\n--- Method 2: Batch Processing ({model_name}) ---")
start_time = time.time()
# Prepare batch input
mels = []
for audio in test_audios:
mel = s3tokenizer.log_mel_spectrogram(audio)
mels.append(mel)
# Use padding to handle different lengths of mel
mels, mels_lens = s3tokenizer.padding(mels)
# Batch processing
with torch.no_grad():
codes, codes_lens = model.quantize(mels, mels_lens)
# Process results
batch_results = []
for i in range(len(test_audios)):
final_codes = codes[i, :codes_lens[i].item()].tolist()
batch_results.append(final_codes)
duration = test_audios[i].shape[0] / 16000
processing_type = "Long audio" if duration > 30 else "Short audio"
print(
f"Audio {i+1}: {duration:.1f}s, {len(final_codes)} tokens, {processing_type}"
)
batch_time = time.time() - start_time
print(f"Batch processing total time: {batch_time:.2f}s")
# Verify result consistency
print(f"\n--- Result Verification for {model_name} ---")
all_ok = True
for i in range(len(test_audios)):
individual_tokens = individual_results[i]
batch_tokens = batch_results[i]
# Calculate miss rate
if len(individual_tokens) != len(batch_tokens):
print(
f"❌ Audio {i+1} length mismatch: individual={len(individual_tokens)}, batch={len(batch_tokens)}"
)
all_ok = False
else:
mismatches = sum(1 for a, b in zip(individual_tokens, batch_tokens)
if a != b)
miss_rate = mismatches / len(individual_tokens) * 100 if len(
individual_tokens) > 0 else 0
if miss_rate < 0.2: # Less than 0.2% is considered OK
print(f"✅ Audio {i+1} miss rate: {miss_rate:.4f}% (OK)")
else:
print(f"❌ Audio {i+1} miss rate: {miss_rate:.4f}% (Too high)")
all_ok = False
# Efficiency improvement
speedup = individual_time / batch_time
print(f"\n--- Efficiency Improvement for {model_name} ---")
print(f"Batch processing speedup: {speedup:.2f}x")
if speedup > 1:
print("✅ Batch processing indeed improves efficiency!")
else:
print("⚠️ Batch processing doesn't significantly improve efficiency")
# Assertions for pytest
assert all_ok, f"Results don't match for model {model_name}"
assert len(individual_results) == len(
batch_results), "Number of results don't match"
assert all(
len(individual_results[i]) == len(batch_results[i])
for i in range(len(test_audios))), "Token counts don't match"
# Performance assertion - batch should be at least as fast as individual (allowing for some variance)
# assert batch_time <= individual_time * 1.1, f"Batch processing should not be significantly slower than individual processing for {model_name}"
@pytest.mark.parametrize("model_name", [
"speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
"speech_tokenizer_v2_25hz"
])
def test_pure_long_audio_batch(long_audios, model_name):
"""Test pure long audio batch processing for different models"""
print(f"\n=== Pure Long Audio Batch Processing Test for {model_name} ===")
model = s3tokenizer.load_model(model_name)
model.eval()
# Prepare batch input
mels = []
for audio in long_audios:
mel = s3tokenizer.log_mel_spectrogram(audio)
mels.append(mel)
mels, mels_lens = s3tokenizer.padding(mels)
# Batch process long audio
start_time = time.time()
with torch.no_grad():
codes, codes_lens = model.quantize(mels, mels_lens)
processing_time = time.time() - start_time
print(
f"Batch processing {len(long_audios)} long audios took: {processing_time:.2f}s"
)
results = []
for i in range(len(long_audios)):
duration = long_audios[i].shape[0] / 16000
tokens_count = codes_lens[i].item()
results.append((duration, tokens_count))
print(f"Long audio {i+1}: {duration:.1f}s → {tokens_count} tokens")
print(
f"✅ Pure long audio batch processing test completed for {model_name}")
# Assertions for pytest
assert codes is not None, f"Codes should not be None for model {model_name}"
assert codes_lens is not None, f"Codes lengths should not be None for model {model_name}"
assert len(results) == len(
long_audios), "Number of results should match number of input audios"
assert all(
tokens_count > 0
for _, tokens_count in results), "All audio should produce tokens"
assert processing_time > 0, "Processing time should be positive"
@pytest.mark.parametrize("model_name", [
"speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
"speech_tokenizer_v2_25hz"
])
def test_model_loading(model_name):
"""Test that all models can be loaded successfully"""
print(f"\n=== Model Loading Test for {model_name} ===")
model = s3tokenizer.load_model(model_name)
assert model is not None, f"Model {model_name} should load successfully"
# Test model can be set to eval mode
model.eval()
print(f"✅ Model {model_name} loaded and set to eval mode successfully")
@pytest.mark.parametrize("model_name", [
"speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
"speech_tokenizer_v2_25hz"
])
def test_single_audio_processing(model_name):
"""Test single audio processing for different models"""
print(f"\n=== Single Audio Processing Test for {model_name} ===")
# Create a single test audio
audio = create_test_audio(30) # 30 second audio
model = s3tokenizer.load_model(model_name)
model.eval()
# Process the audio
mel = s3tokenizer.log_mel_spectrogram(audio)
mels = mel.unsqueeze(0)
mels_lens = torch.tensor([mel.size(1)])
with torch.no_grad():
codes, codes_lens = model.quantize(mels, mels_lens)
final_codes = codes[0, :codes_lens[0].item()].tolist()
# Assertions
assert codes is not None, f"Codes should not be None for model {model_name}"
assert codes_lens is not None, f"Codes lengths should not be None for model {model_name}"
assert len(
final_codes) > 0, f"Should produce tokens for model {model_name}"
assert codes_lens[0].item() == len(
final_codes
), f"Codes length should match actual codes for model {model_name}"
duration = audio.shape[0] / 16000
print(
f"✅ Single audio processing test completed for {model_name}: {duration:.1f}s → {len(final_codes)} tokens"
)
if __name__ == "__main__":
# Run tests with pytest
pytest.main([__file__, "-v"])