johnbridges commited on
Commit
f680106
·
verified ·
1 Parent(s): be4606a

Update app.py

Browse files

now using fixed dir for audio files

Files changed (1) hide show
  1. app.py +54 -44
app.py CHANGED
@@ -16,6 +16,39 @@ import tempfile
16
  from huggingface_hub import snapshot_download
17
  from tts_processor import preprocess_all
18
  import hashlib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO)
@@ -130,7 +163,7 @@ def initialize_models():
130
  raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
131
 
132
  logger.info("Loading ONNX session...")
133
- sess = InferenceSession(onnx_path)
134
  logger.info(f"ONNX session loaded successfully from {onnx_path}")
135
 
136
  # Load the voice style vector
@@ -174,73 +207,49 @@ def health_check():
174
  @app.route('/generate_audio', methods=['POST'])
175
  def generate_audio():
176
  """Text-to-Speech (T2S) Endpoint"""
177
- with global_lock: # Acquire global lock to ensure only one instance runs
178
  try:
179
  logger.debug("Received request to /generate_audio")
180
  data = request.json
181
  text = data['text']
182
- output_dir = data.get('output_dir')
183
 
184
  validate_text_input(text)
185
- logger.debug(f"Text: {text}")
186
- if not output_dir:
187
- raise ValueError("Output directory is required but not provided")
188
 
189
- # Ensure output_dir is an absolute path and valid
190
- if not os.path.isabs(output_dir):
191
- raise ValueError("Output directory must be an absolute path")
192
- if not os.path.exists(output_dir):
193
- raise ValueError(f"Output directory does not exist: {output_dir}")
194
-
195
- # Generate a unique hash for the text
196
  text = preprocess_all(text)
197
- logger.debug(f"Processed Text {text}")
198
  text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
199
- hashed_file_name = f"{text_hash}.wav"
200
- cached_file_path = os.path.join(output_dir, hashed_file_name)
201
- logger.debug(f"Generated hash for processed text: {text_hash}")
202
- logger.debug(f"Output directory: {output_dir}")
203
- logger.debug(f"Cached file path: {cached_file_path}")
204
 
205
- # Check if cached file exists
206
  if is_cached(cached_file_path):
207
- logger.info(f"Returning cached audio for text: {text}")
208
- return jsonify({"status": "success", "output_path": cached_file_path})
209
 
210
- # Tokenize text
211
- logger.debug("Tokenizing text...")
212
- from kokoro import phonemize, tokenize # Import dynamically
213
  tokens = tokenize(phonemize(text, 'a'))
214
- logger.debug(f"Initial tokens: {tokens}")
215
  if len(tokens) > 510:
216
  logger.warning("Text too long; truncating to 510 tokens.")
217
  tokens = tokens[:510]
218
- tokens = [[0, *tokens, 0]] # Add pad tokens
219
- logger.debug(f"Final tokens: {tokens}")
220
 
221
- # Get style vector based on token length
222
- logger.debug("Fetching style vector...")
223
- ref_s = voice_style[len(tokens[0]) - 2] # Shape: (1, 256)
224
- logger.debug(f"Style vector shape: {ref_s.shape}")
225
 
226
- # Run ONNX inference
227
- logger.debug("Running ONNX inference...")
228
  audio = sess.run(None, dict(
229
  input_ids=np.array(tokens, dtype=np.int64),
230
  style=ref_s,
231
  speed=np.ones(1, dtype=np.float32),
232
  ))[0]
233
- logger.debug(f"Audio generated with shape: {audio.shape}")
234
 
235
- # Fix audio data for saving
236
- audio = np.squeeze(audio) # Remove extra dimension
237
- audio = audio.astype(np.float32) # Ensure correct data type
238
 
239
- # Save audio
240
- logger.debug(f"Saving audio to {cached_file_path}...")
241
- sf.write(cached_file_path, audio, 24000) # Save with 24 kHz sample rate
242
- logger.info(f"Audio saved successfully to {cached_file_path}")
243
- return jsonify({"status": "success", "output_path": cached_file_path})
244
  except Exception as e:
245
  logger.error(f"Error generating audio: {str(e)}")
246
  return jsonify({"status": "error", "message": str(e)}), 500
@@ -364,4 +373,5 @@ def internal_error(error):
364
  return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500
365
 
366
  if __name__ == "__main__":
367
- app.run(host="0.0.0.0", port=7860)
 
 
16
  from huggingface_hub import snapshot_download
17
  from tts_processor import preprocess_all
18
  import hashlib
19
+ import os
20
+ import torch
21
+ import numpy as np
22
+ import onnxruntime as ort
23
+
24
+ # ---------------------------
25
+ # THREAD LIMIT CONFIG
26
+ # ---------------------------
27
+ MAX_THREADS = 2 # <-- change this number to control all thread usage
28
+
29
+ # ---------------------------
30
+ # ---------------------------
31
+ # STORAGE ROOT
32
+ # ---------------------------
33
+ SERVE_DIR = "/home/user/app/files"
34
+ os.makedirs(SERVE_DIR, exist_ok=True)
35
+
36
+ # Limit NumPy / BLAS / MKL threads
37
+ os.environ["OMP_NUM_THREADS"] = str(MAX_THREADS)
38
+ os.environ["OPENBLAS_NUM_THREADS"] = str(MAX_THREADS)
39
+ os.environ["MKL_NUM_THREADS"] = str(MAX_THREADS)
40
+ os.environ["VECLIB_MAXIMUM_THREADS"] = str(MAX_THREADS)
41
+ os.environ["NUMEXPR_NUM_THREADS"] = str(MAX_THREADS)
42
+
43
+ # Torch thread limits
44
+ torch.set_num_threads(MAX_THREADS)
45
+ torch.set_num_interop_threads(1) # keep inter-op small to avoid overhead
46
+
47
+ # ONNXRuntime session options (use when creating the session)
48
+ sess_options = ort.SessionOptions()
49
+ sess_options.intra_op_num_threads = MAX_THREADS
50
+ sess_options.inter_op_num_threads = 1
51
+
52
 
53
  # Configure logging
54
  logging.basicConfig(level=logging.INFO)
 
163
  raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
164
 
165
  logger.info("Loading ONNX session...")
166
+ sess = InferenceSession(onnx_path, sess_options)
167
  logger.info(f"ONNX session loaded successfully from {onnx_path}")
168
 
169
  # Load the voice style vector
 
207
  @app.route('/generate_audio', methods=['POST'])
208
  def generate_audio():
209
  """Text-to-Speech (T2S) Endpoint"""
210
+ with global_lock:
211
  try:
212
  logger.debug("Received request to /generate_audio")
213
  data = request.json
214
  text = data['text']
 
215
 
216
  validate_text_input(text)
 
 
 
217
 
218
+ # Preprocess & stable hash
 
 
 
 
 
 
219
  text = preprocess_all(text)
 
220
  text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
221
+ filename = f"{text_hash}.wav"
222
+ cached_file_path = os.path.join(SERVE_DIR, filename)
 
 
 
223
 
224
+ # Cache hit
225
  if is_cached(cached_file_path):
226
+ logger.info("Returning cached audio")
227
+ return jsonify({"status": "success", "filename": filename})
228
 
229
+ # Tokenize
230
+ from kokoro import phonemize, tokenize # lazy import is fine
 
231
  tokens = tokenize(phonemize(text, 'a'))
 
232
  if len(tokens) > 510:
233
  logger.warning("Text too long; truncating to 510 tokens.")
234
  tokens = tokens[:510]
235
+ tokens = [[0, *tokens, 0]]
 
236
 
237
+ # Style vector
238
+ ref_s = voice_style[len(tokens[0]) - 2] # (1,256)
 
 
239
 
240
+ # ONNX inference
 
241
  audio = sess.run(None, dict(
242
  input_ids=np.array(tokens, dtype=np.int64),
243
  style=ref_s,
244
  speed=np.ones(1, dtype=np.float32),
245
  ))[0]
 
246
 
247
+ # Save
248
+ audio = np.squeeze(audio).astype(np.float32)
249
+ sf.write(cached_file_path, audio, 24000)
250
 
251
+ logger.info(f"Audio saved: {cached_file_path}")
252
+ return jsonify({"status": "success", "filename": filename})
 
 
 
253
  except Exception as e:
254
  logger.error(f"Error generating audio: {str(e)}")
255
  return jsonify({"status": "error", "message": str(e)}), 500
 
373
  return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500
374
 
375
  if __name__ == "__main__":
376
+ app.run(host="0.0.0.0", port=7860, threaded=False, processes=1)
377
+