Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import os | |
| from io import BytesIO | |
| import base64 | |
| import numpy as np | |
| from pydub import AudioSegment | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Load the TTS model from Huggingface Hub | |
| repo_id = "parler-tts/parler_tts_mini_v0.1" | |
| model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id) | |
| SAMPLE_RATE = feature_extractor.sampling_rate | |
| SEED = 42 | |
| # Secret token for authenticating requests | |
| SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') | |
| def generate_audio(secret_token, text, description): | |
| # Token-based access | |
| if secret_token != SECRET_TOKEN: | |
| raise gr.Error('Invalid secret token.') | |
| # Tokenize input text and description | |
| inputs = tokenizer(description, return_tensors="pt").to(device) | |
| prompt = tokenizer(text, return_tensors="pt").to(device) | |
| # Set the random seed for reproducibility | |
| set_seed(SEED) | |
| # Generate speech from text | |
| generation = model.generate( | |
| input_ids=inputs.input_ids, | |
| prompt_input_ids=prompt.input_ids, | |
| do_sample=True, | |
| temperature=1.0 | |
| ) | |
| # Convert output to numpy array | |
| audio_arr = generation.cpu().numpy().squeeze() | |
| samples = np.array(audio_arr * (2**15 - 1), dtype=np.int16) | |
| # Create audio segment from numpy data | |
| sound = AudioSegment( | |
| samples.tobytes(), | |
| frame_rate=SAMPLE_RATE, | |
| sample_width=samples.dtype.itemsize, | |
| channels=1 | |
| ) | |
| # Export the audio to MP3 format | |
| buff_mp3 = BytesIO() | |
| sound.export(buff_mp3, format="mp3") | |
| buff_mp3.seek(0) | |
| # Encode the MP3 to base64 | |
| audio_base64 = base64.b64encode(buff_mp3.read()).decode('utf-8') | |
| audio_uri = 'data:audio/mp3;base64,' + audio_base64 | |
| return audio_uri | |
| # Gradio interface | |
| with gr.Blocks() as app: | |
| gr.HTML(""" | |
| <div style="text-align: center;"> | |
| <h3>TTS Audio Generator</h3> | |
| </div> | |
| """) | |
| secret_token = gr.Textbox(label="Secret Token") | |
| input_text = gr.Textbox(label="Text Input") | |
| description = gr.Textbox(label="Description") | |
| run_button = gr.Button("Generate Audio") | |
| audio_output = gr.Audio(label="Generated Audio", type="auto") | |
| inputs = [secret_token, input_text, description] | |
| outputs = [audio_output] | |
| run_button.click(fn=generate_audio, inputs=inputs, outputs=outputs, queue=True) | |
| app.queue() | |
| app.launch() | |