File size: 3,392 Bytes
7e6ee13
 
 
f6f0443
7e6ee13
d9716fe
7e6ee13
 
 
c432d05
d9716fe
c432d05
7e6ee13
f6f0443
7e6ee13
 
f6f0443
7e6ee13
0e6aaa2
f6f0443
3b2c590
a2e218b
d483b19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6ee13
 
d483b19
 
 
 
 
c77ae43
d483b19
 
c77ae43
 
 
 
d483b19
 
 
c77ae43
d483b19
 
 
 
c77ae43
7e6ee13
d483b19
eae360d
a31a724
7e6ee13
a2e218b
c77ae43
 
a2e218b
c77ae43
d483b19
c77ae43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ace071
7e6ee13
 
 
 
c7d8eef
5ace071
c7d8eef
7e6ee13
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch
from transformers import BarkModel
from optimum.bettertransformer import BetterTransformer

model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = model.to(device)

from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("suno/bark")

# Use bettertransform for flash attention
model = BetterTransformer.transform(model, keep_original_model=False)

# Enable CPU offload
model.enable_cpu_offload()

import numpy as np
from scipy.io.wavfile import write as write_wav
import wave

def split_text_into_sentences(text):
    sentences = []
    current_sentence = ''
    words = text.split()

    for word in words:
        current_sentence += ' ' + word
        if word.endswith('.'):
            sentences.append(current_sentence.strip())
            current_sentence = ''

    if current_sentence:
        sentences.append(current_sentence.strip())

    return sentences

def join_wav_files(input_files, output_file):
    # Open the first input file to get its parameters
    with wave.open(input_files[0], 'rb') as first_file:
        # Get the audio parameters from the first file
        params = first_file.getparams()

        # Create a new wave file for writing the joined audio
        with wave.open(output_file, 'wb') as output:
            output.setparams(params)

            # Iterate over the input files and write their audio data to the output file
            for input_file in input_files:
                with wave.open(input_file, 'rb') as input:
                    output.writeframes(input.readframes(input.getnframes()))


def infer(text_prompt):
    print("""

    Cutting text in chunks

    """)
    
    
    text_chunks = split_text_into_sentences(text_prompt)
   
    result = generate(text_chunks, "wav")
    print(result)
    
    
    output_wav = 'full_story.wav'

    join_wav_files(result, output_wav)

    return 'full_story.wav'


def generate(text_prompt, out_type):
    text_prompt = text_prompt

    inputs = processor(text_prompt, voice_preset="v2/en_speaker_6").to(device)
    
    with torch.inference_mode():
        speech_output = model.generate(**inputs)

    input_waves = []
    
    for i, speech_out in enumerate(speech_output):
        
        audio_array = speech_out.cpu().numpy().squeeze()
        print(f'AUDIO_ARRAY: {audio_array}')
    
        # Assuming audio_array contains audio data and the sampling rate
        sampling_rate = model.generation_config.sample_rate
        print(f'sampling_rate: {sampling_rate}')

        if out_type == "numpy":
            input_waves.append(sampling_rate, audio_array)
        elif out_type == "wav":
            #If you want to return a WAV file :
            # Ensure the audio data is properly scaled (between -1 and 1 for 16-bit audio)
            
            audio_data = np.int16(audio_array * 32767)  # Scale for 16-bit signed integer
            write_wav(f"output_{i}.wav", sampling_rate, audio_data)
            input_waves.append(f"output_{i}.wav")
    return input_waves


with gr.Blocks() as demo:
    with gr.Column():
        prompt = gr.Textbox(label="prompt")
        submit_btn = gr.Button("Submit")
        audio_out = gr.Audio()
    submit_btn.click(fn=infer, inputs=[prompt], outputs=[audio_out])

demo.launch()