File size: 3,779 Bytes
fa2f0d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from transformers import AutoProcessor, AutoModelForTextToWaveform, BarkModel
from scipy.io.wavfile import write as write_wav
import os
import time
from datetime import datetime, timedelta
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
import glob

# Environment settings
os.environ["SUNO_OFFLOAD_CPU"] = "True"
os.environ["SUNO_USE_SMALL_MODELS"] = "True"

# Create output directory if it doesn't exist
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "output") 
os.makedirs(OUTPUT_DIR, exist_ok=True)

#create hf directory if it doesn't exist
HF_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")

def log_time(start_time, step_name):
    elapsed = time.time() - start_time
    print(f"{step_name}: {elapsed:.2f} seconds")
    return time.time()

start = time.time()

device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("suno/bark-small")
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16).to(device)
model = model.to_bettertransformer()
model.enable_cpu_offload()

start = log_time(start, "Model loading")

# download and load all models
# preload_models()

def cleanup_old_files():
    """Remove audio files older than 24 hour"""
    cutoff_time = datetime.now() - timedelta(hours=24) 
    for file in glob.glob(os.path.join(OUTPUT_DIR, "audio_*.wav")):
        file_time = datetime.fromtimestamp(os.path.getmtime(file))
        if file_time < cutoff_time:
            try:
                os.remove(file)
                print(f"Removed old file: {file}")
            except Exception as e:
                print(f"Error removing file {file}: {e}")

# Initialize scheduler
scheduler = BackgroundScheduler()
scheduler.add_job(cleanup_old_files, 'interval', hours=1) 
scheduler.start()


def create_bark_audio(text, voice_preset, device):
    try:
        start = time.time()
        # Process input text directly without reloading model
        inputs = processor(
            text,
            voice_preset=voice_preset,
        )
        # Move inputs to device
        inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
        start = log_time(start, "Input processing")

        # Generate audio
        start = time.time()
        audio_array = model.generate(**inputs)
        audio_array = audio_array.cpu().numpy().squeeze()

        start = log_time(start, "Audio generation")

        return audio_array, model.generation_config.sample_rate
    
    except Exception as e:
        print(f"Error during audio generation: {str(e)}")
        raise

def save_audio(audio_array, sample_rate, prefix="audio"):
    try:
        start = time.time()
        # Convert to float32 and normalize
        audio_array = audio_array.astype(np.float32)
        # Ensure audio is in the range [-1, 1]
        audio_array = np.clip(audio_array, -1, 1)
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = os.path.join(OUTPUT_DIR, f"{prefix}_{timestamp}.wav")
        write_wav(filename, sample_rate, audio_array)
        log_time(start, "Audio saving")
        return filename
    
    except Exception as e:
        print(f"Error saving audio file: {str(e)}")
        raise

def generate_speech(text, voice_preset="v2/en_speaker_6"):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    try:
        audio_array, sample_rate = create_bark_audio(text, voice_preset, device)
        filename = save_audio(audio_array, sample_rate)
        return filename
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        raise

if __name__ == "__main__":
    text = "my cat is very cute"
    filename = generate_speech(text)
    print(f"Audio saved as: {filename}")