Spaces:
Running
Running
Update app.py
Browse filesnow using fixed dir for audio files
app.py
CHANGED
|
@@ -16,6 +16,39 @@ import tempfile
|
|
| 16 |
from huggingface_hub import snapshot_download
|
| 17 |
from tts_processor import preprocess_all
|
| 18 |
import hashlib
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Configure logging
|
| 21 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -130,7 +163,7 @@ def initialize_models():
|
|
| 130 |
raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
|
| 131 |
|
| 132 |
logger.info("Loading ONNX session...")
|
| 133 |
-
sess = InferenceSession(onnx_path)
|
| 134 |
logger.info(f"ONNX session loaded successfully from {onnx_path}")
|
| 135 |
|
| 136 |
# Load the voice style vector
|
|
@@ -174,73 +207,49 @@ def health_check():
|
|
| 174 |
@app.route('/generate_audio', methods=['POST'])
|
| 175 |
def generate_audio():
|
| 176 |
"""Text-to-Speech (T2S) Endpoint"""
|
| 177 |
-
with global_lock:
|
| 178 |
try:
|
| 179 |
logger.debug("Received request to /generate_audio")
|
| 180 |
data = request.json
|
| 181 |
text = data['text']
|
| 182 |
-
output_dir = data.get('output_dir')
|
| 183 |
|
| 184 |
validate_text_input(text)
|
| 185 |
-
logger.debug(f"Text: {text}")
|
| 186 |
-
if not output_dir:
|
| 187 |
-
raise ValueError("Output directory is required but not provided")
|
| 188 |
|
| 189 |
-
#
|
| 190 |
-
if not os.path.isabs(output_dir):
|
| 191 |
-
raise ValueError("Output directory must be an absolute path")
|
| 192 |
-
if not os.path.exists(output_dir):
|
| 193 |
-
raise ValueError(f"Output directory does not exist: {output_dir}")
|
| 194 |
-
|
| 195 |
-
# Generate a unique hash for the text
|
| 196 |
text = preprocess_all(text)
|
| 197 |
-
logger.debug(f"Processed Text {text}")
|
| 198 |
text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
|
| 199 |
-
|
| 200 |
-
cached_file_path = os.path.join(
|
| 201 |
-
logger.debug(f"Generated hash for processed text: {text_hash}")
|
| 202 |
-
logger.debug(f"Output directory: {output_dir}")
|
| 203 |
-
logger.debug(f"Cached file path: {cached_file_path}")
|
| 204 |
|
| 205 |
-
#
|
| 206 |
if is_cached(cached_file_path):
|
| 207 |
-
logger.info(
|
| 208 |
-
return jsonify({"status": "success", "
|
| 209 |
|
| 210 |
-
# Tokenize
|
| 211 |
-
|
| 212 |
-
from kokoro import phonemize, tokenize # Import dynamically
|
| 213 |
tokens = tokenize(phonemize(text, 'a'))
|
| 214 |
-
logger.debug(f"Initial tokens: {tokens}")
|
| 215 |
if len(tokens) > 510:
|
| 216 |
logger.warning("Text too long; truncating to 510 tokens.")
|
| 217 |
tokens = tokens[:510]
|
| 218 |
-
tokens = [[0, *tokens, 0]]
|
| 219 |
-
logger.debug(f"Final tokens: {tokens}")
|
| 220 |
|
| 221 |
-
#
|
| 222 |
-
|
| 223 |
-
ref_s = voice_style[len(tokens[0]) - 2] # Shape: (1, 256)
|
| 224 |
-
logger.debug(f"Style vector shape: {ref_s.shape}")
|
| 225 |
|
| 226 |
-
#
|
| 227 |
-
logger.debug("Running ONNX inference...")
|
| 228 |
audio = sess.run(None, dict(
|
| 229 |
input_ids=np.array(tokens, dtype=np.int64),
|
| 230 |
style=ref_s,
|
| 231 |
speed=np.ones(1, dtype=np.float32),
|
| 232 |
))[0]
|
| 233 |
-
logger.debug(f"Audio generated with shape: {audio.shape}")
|
| 234 |
|
| 235 |
-
#
|
| 236 |
-
audio = np.squeeze(audio)
|
| 237 |
-
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
sf.write(cached_file_path, audio, 24000) # Save with 24 kHz sample rate
|
| 242 |
-
logger.info(f"Audio saved successfully to {cached_file_path}")
|
| 243 |
-
return jsonify({"status": "success", "output_path": cached_file_path})
|
| 244 |
except Exception as e:
|
| 245 |
logger.error(f"Error generating audio: {str(e)}")
|
| 246 |
return jsonify({"status": "error", "message": str(e)}), 500
|
|
@@ -364,4 +373,5 @@ def internal_error(error):
|
|
| 364 |
return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500
|
| 365 |
|
| 366 |
if __name__ == "__main__":
|
| 367 |
-
app.run(host="0.0.0.0", port=7860)
|
|
|
|
|
|
| 16 |
from huggingface_hub import snapshot_download
|
| 17 |
from tts_processor import preprocess_all
|
| 18 |
import hashlib
|
| 19 |
+
import os
|
| 20 |
+
import torch
|
| 21 |
+
import numpy as np
|
| 22 |
+
import onnxruntime as ort
|
| 23 |
+
|
| 24 |
+
# ---------------------------
|
| 25 |
+
# THREAD LIMIT CONFIG
|
| 26 |
+
# ---------------------------
|
| 27 |
+
MAX_THREADS = 2 # <-- change this number to control all thread usage
|
| 28 |
+
|
| 29 |
+
# ---------------------------
|
| 30 |
+
# ---------------------------
|
| 31 |
+
# STORAGE ROOT
|
| 32 |
+
# ---------------------------
|
| 33 |
+
SERVE_DIR = "/home/user/app/files"
|
| 34 |
+
os.makedirs(SERVE_DIR, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
# Limit NumPy / BLAS / MKL threads
|
| 37 |
+
os.environ["OMP_NUM_THREADS"] = str(MAX_THREADS)
|
| 38 |
+
os.environ["OPENBLAS_NUM_THREADS"] = str(MAX_THREADS)
|
| 39 |
+
os.environ["MKL_NUM_THREADS"] = str(MAX_THREADS)
|
| 40 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = str(MAX_THREADS)
|
| 41 |
+
os.environ["NUMEXPR_NUM_THREADS"] = str(MAX_THREADS)
|
| 42 |
+
|
| 43 |
+
# Torch thread limits
|
| 44 |
+
torch.set_num_threads(MAX_THREADS)
|
| 45 |
+
torch.set_num_interop_threads(1) # keep inter-op small to avoid overhead
|
| 46 |
+
|
| 47 |
+
# ONNXRuntime session options (use when creating the session)
|
| 48 |
+
sess_options = ort.SessionOptions()
|
| 49 |
+
sess_options.intra_op_num_threads = MAX_THREADS
|
| 50 |
+
sess_options.inter_op_num_threads = 1
|
| 51 |
+
|
| 52 |
|
| 53 |
# Configure logging
|
| 54 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 163 |
raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
|
| 164 |
|
| 165 |
logger.info("Loading ONNX session...")
|
| 166 |
+
sess = InferenceSession(onnx_path, sess_options)
|
| 167 |
logger.info(f"ONNX session loaded successfully from {onnx_path}")
|
| 168 |
|
| 169 |
# Load the voice style vector
|
|
|
|
| 207 |
@app.route('/generate_audio', methods=['POST'])
|
| 208 |
def generate_audio():
|
| 209 |
"""Text-to-Speech (T2S) Endpoint"""
|
| 210 |
+
with global_lock:
|
| 211 |
try:
|
| 212 |
logger.debug("Received request to /generate_audio")
|
| 213 |
data = request.json
|
| 214 |
text = data['text']
|
|
|
|
| 215 |
|
| 216 |
validate_text_input(text)
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
+
# Preprocess & stable hash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
text = preprocess_all(text)
|
|
|
|
| 220 |
text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
|
| 221 |
+
filename = f"{text_hash}.wav"
|
| 222 |
+
cached_file_path = os.path.join(SERVE_DIR, filename)
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
+
# Cache hit
|
| 225 |
if is_cached(cached_file_path):
|
| 226 |
+
logger.info("Returning cached audio")
|
| 227 |
+
return jsonify({"status": "success", "filename": filename})
|
| 228 |
|
| 229 |
+
# Tokenize
|
| 230 |
+
from kokoro import phonemize, tokenize # lazy import is fine
|
|
|
|
| 231 |
tokens = tokenize(phonemize(text, 'a'))
|
|
|
|
| 232 |
if len(tokens) > 510:
|
| 233 |
logger.warning("Text too long; truncating to 510 tokens.")
|
| 234 |
tokens = tokens[:510]
|
| 235 |
+
tokens = [[0, *tokens, 0]]
|
|
|
|
| 236 |
|
| 237 |
+
# Style vector
|
| 238 |
+
ref_s = voice_style[len(tokens[0]) - 2] # (1,256)
|
|
|
|
|
|
|
| 239 |
|
| 240 |
+
# ONNX inference
|
|
|
|
| 241 |
audio = sess.run(None, dict(
|
| 242 |
input_ids=np.array(tokens, dtype=np.int64),
|
| 243 |
style=ref_s,
|
| 244 |
speed=np.ones(1, dtype=np.float32),
|
| 245 |
))[0]
|
|
|
|
| 246 |
|
| 247 |
+
# Save
|
| 248 |
+
audio = np.squeeze(audio).astype(np.float32)
|
| 249 |
+
sf.write(cached_file_path, audio, 24000)
|
| 250 |
|
| 251 |
+
logger.info(f"Audio saved: {cached_file_path}")
|
| 252 |
+
return jsonify({"status": "success", "filename": filename})
|
|
|
|
|
|
|
|
|
|
| 253 |
except Exception as e:
|
| 254 |
logger.error(f"Error generating audio: {str(e)}")
|
| 255 |
return jsonify({"status": "error", "message": str(e)}), 500
|
|
|
|
| 373 |
return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500
|
| 374 |
|
| 375 |
if __name__ == "__main__":
|
| 376 |
+
app.run(host="0.0.0.0", port=7860, threaded=False, processes=1)
|
| 377 |
+
|