Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test script for the enhanced Chatterbox TTS Modal API | |
| This script demonstrates how to interact with all the new endpoints | |
| """ | |
| import requests | |
| import base64 | |
| import json | |
| import os | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Base URLs for the deployed endpoints | |
| ENDPOINTS = { | |
| "health": os.getenv("HEALTH_ENDPOINT"), | |
| "generate_audio": os.getenv("GENERATE_AUDIO_ENDPOINT"), | |
| "generate_json": os.getenv("GENERATE_JSON_ENDPOINT"), | |
| "generate_with_file": os.getenv("GENERATE_WITH_FILE_ENDPOINT"), | |
| "generate": os.getenv("GENERATE_ENDPOINT"), | |
| "generate_full_text_audio": os.getenv("GENERATE_FULL_TEXT_AUDIO_ENDPOINT"), | |
| "generate_full_text_json": os.getenv("GENERATE_FULL_TEXT_JSON_ENDPOINT") | |
| } | |
| def test_health_check(): | |
| """Test the health check endpoint""" | |
| print("Testing health check...") | |
| try: | |
| response = requests.get(ENDPOINTS["health"]) | |
| print(f"Status: {response.status_code}") | |
| print(f"Response: {response.json()}") | |
| return response.status_code == 200 | |
| except Exception as e: | |
| print(f"Health check failed: {e}") | |
| return False | |
| def test_basic_generation(): | |
| """Test basic text-to-speech generation""" | |
| print("\nTesting basic audio generation...") | |
| try: | |
| response = requests.post( | |
| ENDPOINTS["generate_audio"], | |
| json={"text": "Hello, this is Chatterbox TTS running on Modal!"} | |
| ) | |
| if response.status_code == 200: | |
| Path("output").mkdir(exist_ok=True) | |
| with open("output/basic_output.wav", "wb") as f: | |
| f.write(response.content) | |
| print("β Basic generation successful - saved as output/basic_output.wav") | |
| return True | |
| else: | |
| print(f"β Basic generation failed: {response.status_code}") | |
| print(f"Response: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"β Basic generation error: {e}") | |
| return False | |
| def test_json_generation(): | |
| """Test JSON response with base64 audio""" | |
| print("\nTesting JSON audio generation...") | |
| try: | |
| response = requests.post( | |
| ENDPOINTS["generate_json"], | |
| json={"text": "This returns JSON with base64 audio data"} | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if data['success'] and data['audio_base64']: | |
| # Decode base64 audio and save | |
| Path("output").mkdir(exist_ok=True) | |
| audio_data = base64.b64decode(data['audio_base64']) | |
| with open("output/json_output.wav", "wb") as f: | |
| f.write(audio_data) | |
| print(f"β JSON generation successful - Duration: {data['duration_seconds']:.2f}s") | |
| print(" Saved as output/json_output.wav") | |
| return True | |
| else: | |
| print(f"β JSON generation failed: {data['message']}") | |
| return False | |
| else: | |
| print(f"β JSON generation failed: {response.status_code}") | |
| print(f"Response: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"β JSON generation error: {e}") | |
| return False | |
| def test_voice_cloning(): | |
| """Test voice cloning with audio prompt""" | |
| print("\nTesting voice cloning...") | |
| # First, check if we have a sample audio file | |
| sample_file = Path("voice_sample.wav") | |
| if not sample_file.exists(): | |
| print("β No voice_sample.wav found - skipping voice cloning test") | |
| print(" To test voice cloning, add a voice_sample.wav file") | |
| return True | |
| try: | |
| # Read the voice sample and encode as base64 | |
| with open(sample_file, "rb") as f: | |
| voice_data = base64.b64encode(f.read()).decode() | |
| response = requests.post( | |
| ENDPOINTS["generate_audio"], | |
| json={ | |
| "text": "This should sound like the provided voice sample!", | |
| "voice_prompt_base64": voice_data | |
| } | |
| ) | |
| if response.status_code == 200: | |
| Path("output").mkdir(exist_ok=True) | |
| with open("output/cloned_output.wav", "wb") as f: | |
| f.write(response.content) | |
| print("β Voice cloning successful - saved as output/cloned_output.wav") | |
| return True | |
| else: | |
| print(f"β Voice cloning failed: {response.status_code}") | |
| print(f"Response: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"β Voice cloning error: {e}") | |
| return False | |
| def test_file_upload(): | |
| """Test file upload endpoint""" | |
| print("\nTesting file upload...") | |
| sample_file = Path("voice_sample.wav") | |
| if not sample_file.exists(): | |
| print("β No voice_sample.wav found - testing without voice prompt") | |
| files = None | |
| else: | |
| files = {"voice_prompt": open(sample_file, "rb")} | |
| try: | |
| data = {"text": "Testing the file upload endpoint!"} | |
| response = requests.post(ENDPOINTS["generate_with_file"], data=data, files=files) | |
| if files: | |
| files["voice_prompt"].close() | |
| if response.status_code == 200: | |
| Path("output").mkdir(exist_ok=True) | |
| with open("output/upload_output.wav", "wb") as f: | |
| f.write(response.content) | |
| print("β File upload successful - saved as output/upload_output.wav") | |
| return True | |
| else: | |
| print(f"β File upload failed: {response.status_code}") | |
| print(f"Response: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"β File upload error: {e}") | |
| return False | |
| def test_legacy_endpoint(): | |
| """Test backward compatibility with legacy endpoint""" | |
| print("\nTesting legacy endpoint...") | |
| try: | |
| # Legacy endpoint expects query parameters, not form data | |
| response = requests.post( | |
| ENDPOINTS["generate"], | |
| params={"prompt": "Testing the legacy endpoint for backward compatibility"} | |
| ) | |
| if response.status_code == 200: | |
| Path("output").mkdir(exist_ok=True) | |
| with open("output/legacy_output.wav", "wb") as f: | |
| f.write(response.content) | |
| print("β Legacy endpoint successful - saved as output/legacy_output.wav") | |
| return True | |
| else: | |
| print(f"β Legacy endpoint failed: {response.status_code}") | |
| print(f"Response: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"β Legacy endpoint error: {e}") | |
| return False | |
| def test_full_text_generation(): | |
| """Test full-text audio generation with server-side chunking""" | |
| print("\nTesting full-text audio generation...") | |
| # Create a long text that will require chunking | |
| long_text = """ | |
| This is a comprehensive test of the full-text audio generation endpoint. | |
| The text is intentionally long to demonstrate the server-side chunking capabilities. | |
| The enhanced API will automatically split this text into appropriate chunks, | |
| process them in parallel using GPU acceleration, and then concatenate the | |
| resulting audio segments with proper transitions and fade effects. | |
| This approach significantly improves performance for long documents while | |
| maintaining high audio quality and natural speech flow. The server handles | |
| all the complex processing, allowing the client to simply send the full text | |
| and receive the final audio file. | |
| The chunking algorithm respects sentence and paragraph boundaries to ensure | |
| natural speech patterns and maintains proper context across chunk boundaries. | |
| This results in more natural-sounding speech for long-form content. | |
| """ | |
| try: | |
| if not ENDPOINTS["generate_full_text_audio"]: | |
| print("β FULL_TEXT_TTS_ENDPOINT not configured - skipping full-text test") | |
| return True | |
| response = requests.post( | |
| ENDPOINTS["generate_full_text_audio"], | |
| json={ | |
| "text": long_text.strip(), | |
| "max_chunk_size": 400, # Smaller chunks for testing | |
| "silence_duration": 0.3, | |
| "fade_duration": 0.1, | |
| "overlap_sentences": 0 | |
| }, | |
| timeout=120 # Longer timeout for processing | |
| ) | |
| if response.status_code == 200: | |
| Path("output").mkdir(exist_ok=True) | |
| with open("output/full_text_output.wav", "wb") as f: | |
| f.write(response.content) | |
| # Check response headers for processing info | |
| duration = response.headers.get('X-Audio-Duration', 'unknown') | |
| chunks = response.headers.get('X-Chunks-Processed', 'unknown') | |
| characters = response.headers.get('X-Total-Characters', len(long_text)) | |
| print(f"β Full-text generation successful") | |
| print(f" Duration: {duration}s") | |
| print(f" Chunks processed: {chunks}") | |
| print(f" Characters: {characters}") | |
| print(" Saved as output/full_text_output.wav") | |
| return True | |
| else: | |
| print(f"β Full-text generation failed: {response.status_code}") | |
| print(f"Response: {response.text}") | |
| return False | |
| except requests.exceptions.Timeout: | |
| print("β Full-text generation timed out (this may be normal for very long texts)") | |
| return False | |
| except Exception as e: | |
| print(f"β Full-text generation error: {e}") | |
| return False | |
| def test_full_text_json(): | |
| """Test full-text JSON response with processing information""" | |
| print("\nTesting full-text JSON response...") | |
| test_text = """ | |
| This is a test of the full-text JSON endpoint that returns detailed | |
| processing information along with the base64 encoded audio data. | |
| The response includes chunk information, processing parameters, | |
| and timing details that can be useful for monitoring and debugging. | |
| """ | |
| try: | |
| if not ENDPOINTS["generate_full_text_json"]: | |
| print("β FULL_TEXT_JSON_ENDPOINT not configured - skipping test") | |
| return True | |
| response = requests.post( | |
| ENDPOINTS["generate_full_text_json"], | |
| json={ | |
| "text": test_text.strip(), | |
| "max_chunk_size": 300, | |
| "silence_duration": 0.4, | |
| "fade_duration": 0.15 | |
| }, | |
| timeout=60 | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if data['success'] and data['audio_base64']: | |
| # Decode and save audio | |
| Path("output").mkdir(exist_ok=True) | |
| audio_data = base64.b64decode(data['audio_base64']) | |
| with open("output/full_text_json_output.wav", "wb") as f: | |
| f.write(audio_data) | |
| # Display processing information | |
| print(f"β Full-text JSON generation successful") | |
| print(f" Duration: {data['duration_seconds']:.2f}s") | |
| if 'processing_info' in data: | |
| info = data['processing_info'] | |
| if 'chunk_info' in info: | |
| chunk_info = info['chunk_info'] | |
| print(f" Chunks: {chunk_info.get('total_chunks', 'unknown')}") | |
| print(f" Characters: {chunk_info.get('total_characters', 'unknown')}") | |
| print(f" Avg chunk size: {chunk_info.get('avg_chunk_size', 'unknown'):.0f}") | |
| print(" Saved as output/full_text_json_output.wav") | |
| return True | |
| else: | |
| print(f"β Full-text JSON generation failed: {data['message']}") | |
| return False | |
| else: | |
| print(f"β Full-text JSON generation failed: {response.status_code}") | |
| print(f"Response: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"β Full-text JSON generation error: {e}") | |
| return False | |
| def test_performance_comparison(): | |
| """Compare performance between standard and full-text endpoints""" | |
| print("\nTesting performance comparison...") | |
| # Short text for standard endpoint | |
| short_text = "This is a short text for performance comparison testing." | |
| # Medium text that benefits from chunking | |
| medium_text = """ | |
| This is a medium-length text designed to test the performance differences | |
| between the standard endpoint and the enhanced full-text endpoint. | |
| The full-text endpoint should show its advantages when processing longer | |
| texts that require intelligent chunking and parallel processing. | |
| This text is long enough to require multiple chunks but not so long | |
| that it becomes unwieldy for testing purposes. | |
| """ | |
| results = {} | |
| try: | |
| # Test standard endpoint with short text | |
| import time | |
| start_time = time.time() | |
| response = requests.post( | |
| ENDPOINTS["generate_audio"], | |
| json={"text": short_text}, | |
| timeout=30 | |
| ) | |
| if response.status_code == 200: | |
| results['standard_short'] = time.time() - start_time | |
| print(f"β Standard endpoint (short): {results['standard_short']:.2f}s") | |
| # Test full-text endpoint with medium text | |
| if ENDPOINTS["generate_full_text_audio"]: | |
| start_time = time.time() | |
| response = requests.post( | |
| ENDPOINTS["generate_full_text_audio"], | |
| json={ | |
| "text": medium_text.strip(), | |
| "max_chunk_size": 300 | |
| }, | |
| timeout=60 | |
| ) | |
| if response.status_code == 200: | |
| results['fulltext_medium'] = time.time() - start_time | |
| chunks = response.headers.get('X-Chunks-Processed', 'unknown') | |
| print(f"β Full-text endpoint (medium, {chunks} chunks): {results['fulltext_medium']:.2f}s") | |
| # Summary | |
| if results: | |
| print(" Performance comparison complete!") | |
| return True | |
| else: | |
| print(" Could not complete performance comparison") | |
| return False | |
| except Exception as e: | |
| print(f"β Performance comparison error: {e}") | |
| return False | |
| def main(): | |
| """Run all tests""" | |
| print("Enhanced Chatterbox TTS API Test Suite") | |
| print("=" * 50) | |
| # Check if required endpoints are configured | |
| missing_endpoints = [name for name, url in ENDPOINTS.items() if not url] | |
| if missing_endpoints: | |
| print("β Warning: Some endpoints not configured:") | |
| for endpoint in missing_endpoints: | |
| print(f" {endpoint}") | |
| print(" Set environment variables in .env file") | |
| print() | |
| tests = [ | |
| test_health_check, | |
| test_basic_generation, | |
| test_json_generation, | |
| test_voice_cloning, | |
| test_file_upload, | |
| test_legacy_endpoint, | |
| test_full_text_generation, | |
| test_performance_comparison | |
| ] | |
| results = [] | |
| for test in tests: | |
| results.append(test()) | |
| print("\n" + "=" * 50) | |
| print("Test Results:") | |
| passed = sum(results) | |
| total = len(results) | |
| print(f"β {passed}/{total} tests passed") | |
| if passed == total: | |
| print("π All tests passed!") | |
| print("\nGenerated files in output/ directory:") | |
| output_dir = Path("output") | |
| if output_dir.exists(): | |
| for file in output_dir.glob("*.wav"): | |
| size_kb = file.stat().st_size / 1024 | |
| print(f" {file.name} ({size_kb:.1f} KB)") | |
| else: | |
| print("β Some tests failed - check your Modal deployment") | |
| print(f"\nAPI Endpoints tested:") | |
| for name, url in ENDPOINTS.items(): | |
| status = "β" if url else "β" | |
| print(f" {status} {name}: {url or 'Not configured'}") | |
| def create_sample_env_file(): | |
| """Create a sample .env file with endpoint placeholders""" | |
| env_content = """# Enhanced Chatterbox TTS API Endpoints | |
| # Replace YOUR-MODAL-ENDPOINT with your actual Modal deployment URL | |
| HEALTH_ENDPOINT=https://YOUR-MODAL-ENDPOINT.modal.run/health | |
| GENERATE_AUDIO_ENDPOINT=https://YOUR-MODAL-ENDPOINT.modal.run/generate_audio | |
| GENERATE_JSON_ENDPOINT=https://YOUR-MODAL-ENDPOINT.modal.run/generate_json | |
| GENERATE_WITH_FILE_ENDPOINT=https://YOUR-MODAL-ENDPOINT.modal.run/generate_with_file | |
| GENERATE_ENDPOINT=https://YOUR-MODAL-ENDPOINT.modal.run/generate | |
| # New enhanced endpoints | |
| FULL_TEXT_TTS_ENDPOINT=https://YOUR-MODAL-ENDPOINT.modal.run/generate_full_text_audio | |
| FULL_TEXT_JSON_ENDPOINT=https://YOUR-MODAL-ENDPOINT.modal.run/generate_full_text_json | |
| """ | |
| if not Path(".env").exists(): | |
| with open(".env", "w") as f: | |
| f.write(env_content) | |
| print("Created sample .env file - please update with your actual endpoints") | |
| if __name__ == "__main__": | |
| # Create sample .env if it doesn't exist | |
| create_sample_env_file() | |
| main() | |