import gradio as gr import torch from transformers import BarkModel from optimum.bettertransformer import BetterTransformer model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16) device = "cuda:0" if torch.cuda.is_available() else "cpu" model = model.to(device) from transformers import AutoProcessor processor = AutoProcessor.from_pretrained("suno/bark") # Use bettertransform for flash attention model = BetterTransformer.transform(model, keep_original_model=False) # Enable CPU offload model.enable_cpu_offload() import numpy as np from scipy.io.wavfile import write as write_wav import wave def split_text_into_sentences(text): sentences = [] current_sentence = '' words = text.split() for word in words: current_sentence += ' ' + word if word.endswith('.'): sentences.append(current_sentence.strip()) current_sentence = '' if current_sentence: sentences.append(current_sentence.strip()) return sentences def join_wav_files(input_files, output_file): # Open the first input file to get its parameters with wave.open(input_files[0], 'rb') as first_file: # Get the audio parameters from the first file params = first_file.getparams() # Create a new wave file for writing the joined audio with wave.open(output_file, 'wb') as output: output.setparams(params) # Iterate over the input files and write their audio data to the output file for input_file in input_files: with wave.open(input_file, 'rb') as input: output.writeframes(input.readframes(input.getnframes())) def infer(text_prompt): print(""" — Cutting text in chunks — """) text_chunks = split_text_into_sentences(text_prompt) result = generate(text_chunks, "wav") print(result) output_wav = 'full_story.wav' join_wav_files(result, output_wav) return 'full_story.wav' def generate(text_prompt, out_type): text_prompt = text_prompt inputs = processor(text_prompt, voice_preset="v2/en_speaker_6").to(device) with torch.inference_mode(): speech_output = model.generate(**inputs) input_waves = [] for i, speech_out in enumerate(speech_output): audio_array = speech_out.cpu().numpy().squeeze() print(f'AUDIO_ARRAY: {audio_array}') # Assuming audio_array contains audio data and the sampling rate sampling_rate = model.generation_config.sample_rate print(f'sampling_rate: {sampling_rate}') if out_type == "numpy": input_waves.append(sampling_rate, audio_array) elif out_type == "wav": #If you want to return a WAV file : # Ensure the audio data is properly scaled (between -1 and 1 for 16-bit audio) audio_data = np.int16(audio_array * 32767) # Scale for 16-bit signed integer write_wav(f"output_{i}.wav", sampling_rate, audio_data) input_waves.append(f"output_{i}.wav") return input_waves with gr.Blocks() as demo: with gr.Column(): prompt = gr.Textbox(label="prompt") submit_btn = gr.Button("Submit") audio_out = gr.Audio() submit_btn.click(fn=infer, inputs=[prompt], outputs=[audio_out]) demo.launch()