EllenBeta commited on
Commit
8fa01de
·
verified ·
1 Parent(s): e2d7c7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -152
app.py CHANGED
@@ -2,13 +2,13 @@ from flask import Flask, request, jsonify, render_template
2
  from datetime import datetime
3
  from flask_cors import CORS
4
  from TTS.api import TTS
5
- from TTS.utils.manage import ModelManager
6
  import os
7
  import base64
8
- import shutil
9
- import wave
10
  import logging
11
  import threading
 
 
 
12
 
13
  from helper import (
14
  save_audio,
@@ -17,6 +17,7 @@ from helper import (
17
  video_to_audio,
18
  validate_audio_file,
19
  ensure_wav_format,
 
20
  )
21
 
22
  # ---------- Basic config ----------
@@ -28,138 +29,19 @@ CORS(app)
28
  os.environ["COQUI_TOS_AGREED"] = "1"
29
 
30
  device = "cpu"
31
-
32
- # ============================================================
33
- # MODEL STORAGE PATHS & NAMES
34
- # ============================================================
35
- DATASET_MODEL_DIR = "/datasets/EllenBeta/Xtts_2/model" # dataset mount (destination)
36
- LOCAL_CACHE_DIR = os.path.expanduser("~/.local/share/tts/xtts_v2_cache") # local cache
37
  MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" # coqui model id
38
- # Maximum audio (MB)
39
  MAX_AUDIO_SIZE_MB = 15
 
40
 
41
- # ============================================================
42
- # Utilities for resolving model download path (defensive)
43
- # ============================================================
44
- def resolve_model_path(raw):
45
- """
46
- Given the return value from ModelManager.download_model(...) try to
47
- return a filesystem path (string) pointing at the downloaded model folder.
48
- Handles strings, tuples/lists, or dict-like returns.
49
- """
50
- # If already a path string
51
- if isinstance(raw, str):
52
- return raw
53
-
54
- # If a list/tuple, try first string-like element
55
- if isinstance(raw, (list, tuple)):
56
- for element in raw:
57
- if isinstance(element, str) and os.path.exists(element):
58
- return element
59
- # fallback: try to join tuple items into a path if meaningful
60
- try:
61
- cand = os.path.join(*[str(x) for x in raw])
62
- if os.path.exists(cand):
63
- return cand
64
- except Exception:
65
- pass
66
-
67
- # If dict-like, try common keys
68
- if isinstance(raw, dict):
69
- for key in ("model_path", "path", "directory"):
70
- val = raw.get(key)
71
- if isinstance(val, str) and os.path.exists(val):
72
- return val
73
-
74
- # final fallback: try to find the typical download directory
75
- fallback = os.path.expanduser("~/.local/share/tts")
76
- if os.path.exists(fallback):
77
- # find matching folder
78
- for root, dirs, files in os.walk(fallback):
79
- if MODEL_NAME.split("/")[-1] in root:
80
- return root
81
-
82
- # Nothing found
83
- return None
84
-
85
-
86
- # ============================================================
87
- # Ensure model is present (download once and copy into dataset)
88
- # ============================================================
89
  tts = None
90
  try:
91
- if os.path.exists(DATASET_MODEL_DIR) and os.listdir(DATASET_MODEL_DIR):
92
- log.info("✅ Loading XTTS model directly from dataset mount: %s", DATASET_MODEL_DIR)
93
- tts = TTS(model_path=DATASET_MODEL_DIR).to(device)
94
- else:
95
- log.info("⬇️ Dataset model not found — downloading XTTS model (first run)...")
96
- manager = ModelManager()
97
- raw_path = manager.download_model(MODEL_NAME)
98
- model_path = resolve_model_path(raw_path)
99
-
100
- if not model_path or not os.path.exists(model_path):
101
- # As a robust fallback, call TTS() with model id then try to locate typical folder
102
- log.warning("Could not resolve model path from ModelManager result; falling back to direct TTS init.")
103
- tts_tmp = TTS(MODEL_NAME).to(device)
104
- # try to locate in default coqui location
105
- candidate = os.path.expanduser("~/.local/share/tts")
106
- model_path = None
107
- if os.path.exists(candidate):
108
- # pick the directory that contains the xtts_v2 name
109
- for root, dirs, files in os.walk(candidate):
110
- if "xtts_v2" in root or "xtts" in root:
111
- model_path = root
112
- break
113
- # if still None, set model_path to candidate root
114
- if not model_path:
115
- model_path = candidate
116
- # assign tts from tts_tmp
117
- tts = tts_tmp
118
-
119
- # Ensure model_path now points to a directory
120
- if model_path and os.path.exists(model_path):
121
- # create local cache dir and copy files (ensure string)
122
- os.makedirs(LOCAL_CACHE_DIR, exist_ok=True)
123
- try:
124
- shutil.copytree(model_path, LOCAL_CACHE_DIR, dirs_exist_ok=True)
125
- except Exception as e:
126
- # if copytree fails (we still continue)
127
- log.warning("Copy to LOCAL_CACHE_DIR failed: %s", e)
128
-
129
- # Copy into dataset mount for persistence (if writable)
130
- try:
131
- os.makedirs(DATASET_MODEL_DIR, exist_ok=True)
132
- for item in os.listdir(model_path):
133
- s = os.path.join(model_path, item)
134
- d = os.path.join(DATASET_MODEL_DIR, item)
135
- if os.path.isdir(s):
136
- shutil.copytree(s, d, dirs_exist_ok=True)
137
- else:
138
- shutil.copy2(s, d)
139
- log.info("📦 Model copied into dataset mount: %s", DATASET_MODEL_DIR)
140
- except Exception as e:
141
- log.warning("Could not copy model into dataset mount (may be read-only or missing perms): %s", e)
142
-
143
- # If tts not already set (from fallback), initialize from model_path or dataset mount
144
- if tts is None:
145
- # prefer dataset dir if copy succeeded, otherwise local cache
146
- init_path = DATASET_MODEL_DIR if os.path.exists(DATASET_MODEL_DIR) and os.listdir(DATASET_MODEL_DIR) else LOCAL_CACHE_DIR
147
- tts = TTS(model_path=init_path).to(device)
148
- else:
149
- # final fallback: initialize directly from model name (internet)
150
- log.warning("Could not find downloaded model folder; initializing TTS from model id directly.")
151
- tts = TTS(MODEL_NAME).to(device)
152
-
153
- log.info("✅ TTS ready.")
154
  except Exception as exc:
155
- log.exception("Failed to prepare TTS model: %s", exc)
156
- # Try a minimal fallback to avoid crash - attempt to init directly.
157
- try:
158
- tts = TTS(MODEL_NAME).to(device)
159
- except Exception as exc2:
160
- log.exception("Fatal: TTS could not be initialized: %s", exc2)
161
- # re-raise so app startup fails loudly (preferred)
162
- raise
163
 
164
  # ============================================================
165
  # Application logic (routes & helpers)
@@ -213,7 +95,7 @@ def generate_voice():
213
  "created_at": datetime.now(),
214
  }
215
 
216
- # Run processing (synchronous here - see note below about background processing)
217
  process_vox(user_id, text, video, audio_base64, task_id)
218
  return jsonify({"message": "Processing started", "task_id": task_id}), 202
219
 
@@ -224,7 +106,13 @@ def generate_voice():
224
 
225
  def process_vox(user_id, text, video, audio_base64, task_id):
226
  temp_audio_path = None
 
227
  try:
 
 
 
 
 
228
  # 1) Prepare input audio
229
  if audio_base64:
230
  if audio_base64.startswith("data:audio/"):
@@ -241,8 +129,8 @@ def process_vox(user_id, text, video, audio_base64, task_id):
241
  if not valid:
242
  raise Exception(f"Invalid audio file: {msg}")
243
 
244
- # 3) Generate TTS (clone)
245
- result_file = clone(text, temp_audio_path)
246
 
247
  # 4) Save output to user_audios
248
  out_dir = "user_audios"
@@ -250,16 +138,17 @@ def process_vox(user_id, text, video, audio_base64, task_id):
250
  file_name = generate_random_filename("mp3")
251
  file_path = os.path.join(out_dir, file_name)
252
 
253
- with open(result_file, "rb") as src, open(file_path, "wb") as dst:
254
  dst.write(src.read())
255
 
256
  # 5) Gather metadata
 
257
  with wave.open(file_path, "rb") as wf:
258
  dura = wf.getnframes() / float(wf.getframerate())
259
  duration = f"{dura:.2f}"
260
  title = text[:20]
261
 
262
- # 6) Upload and save
263
  audio_url = save_to_dataset_repo(file_path, f"user/data/audios/{file_name}", file_name)
264
  active_tasks[task_id].update(
265
  {
@@ -279,28 +168,70 @@ def process_vox(user_id, text, video, audio_base64, task_id):
279
  }
280
 
281
  finally:
282
- # cleanup
283
- try:
284
- if temp_audio_path and os.path.exists(temp_audio_path):
285
- os.remove(temp_audio_path)
286
- task = active_tasks.get(task_id)
287
- if task:
288
- if task["status"]== "completed":
289
- remove_task_after_delay(task_id, delay_seconds=300)
290
- elif task["status"] == "failed":
291
- del active_tasks[task_id]
292
- except Exception:
293
- # ignore cleanup issues
294
- pass
295
 
296
 
297
  def clone(text, audio):
298
  """
299
- Use the TTS instance to produce an output file. Returns the path to the output file.
 
300
  """
301
- out_path = "./output.wav"
302
- # use tts to write audio; let TTS manage model specifics
303
- tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path=out_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  return out_path
305
 
306
 
@@ -342,4 +273,5 @@ def remove_task_after_delay(task_id, delay_seconds=300):
342
  timer.start()
343
 
344
 
345
- # Run only when invoked directly (Gunicorn will ignore this block)
 
 
2
  from datetime import datetime
3
  from flask_cors import CORS
4
  from TTS.api import TTS
 
5
  import os
6
  import base64
 
 
7
  import logging
8
  import threading
9
+ import tempfile # for better temp handling
10
+ from pydub import AudioSegment # for WAV concat (OOM fix)
11
+ import psutil # for RAM check
12
 
13
  from helper import (
14
  save_audio,
 
17
  video_to_audio,
18
  validate_audio_file,
19
  ensure_wav_format,
20
+ # Assume you add: create_connection (with retry below)
21
  )
22
 
23
  # ---------- Basic config ----------
 
29
  os.environ["COQUI_TOS_AGREED"] = "1"
30
 
31
  device = "cpu"
 
 
 
 
 
 
32
  MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" # coqui model id
 
33
  MAX_AUDIO_SIZE_MB = 15
34
+ MAX_TEXT_LEN = 250 # per chunk for OOM safety
35
 
36
+ # Simplified TTS init: Direct from model name (handles download/config auto)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  tts = None
38
  try:
39
+ log.info(f"⬇️ Initializing XTTS from {MODEL_NAME}...")
40
+ tts = TTS(model_name=MODEL_NAME).to(device) # Uses model_name kwarg for HF-style load
41
+ log.info("✅ TTS ready (direct init).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  except Exception as exc:
43
+ log.exception("Fatal: TTS init failed: %s", exc)
44
+ raise
 
 
 
 
 
 
45
 
46
  # ============================================================
47
  # Application logic (routes & helpers)
 
95
  "created_at": datetime.now(),
96
  }
97
 
98
+ # Run processing (synchronous; consider Celery for prod scaling)
99
  process_vox(user_id, text, video, audio_base64, task_id)
100
  return jsonify({"message": "Processing started", "task_id": task_id}), 202
101
 
 
106
 
107
  def process_vox(user_id, text, video, audio_base64, task_id):
108
  temp_audio_path = None
109
+ temp_output_path = None
110
  try:
111
+ # RAM check (OOM guard)
112
+ ram_gb = psutil.virtual_memory().available / (1024 ** 3)
113
+ if ram_gb < 2: # XTTS needs ~2GB free
114
+ raise Exception("Low RAM: Please try a shorter text.")
115
+
116
  # 1) Prepare input audio
117
  if audio_base64:
118
  if audio_base64.startswith("data:audio/"):
 
129
  if not valid:
130
  raise Exception(f"Invalid audio file: {msg}")
131
 
132
+ # 3) Generate TTS (clone) with chunking for long text
133
+ temp_output_path = clone(text, temp_audio_path) # now returns possibly concatenated path
134
 
135
  # 4) Save output to user_audios
136
  out_dir = "user_audios"
 
138
  file_name = generate_random_filename("mp3")
139
  file_path = os.path.join(out_dir, file_name)
140
 
141
+ with open(temp_output_path, "rb") as src, open(file_path, "wb") as dst:
142
  dst.write(src.read())
143
 
144
  # 5) Gather metadata
145
+ import wave
146
  with wave.open(file_path, "rb") as wf:
147
  dura = wf.getnframes() / float(wf.getframerate())
148
  duration = f"{dura:.2f}"
149
  title = text[:20]
150
 
151
+ # 6) Upload and save (with DB retry in helper)
152
  audio_url = save_to_dataset_repo(file_path, f"user/data/audios/{file_name}", file_name)
153
  active_tasks[task_id].update(
154
  {
 
168
  }
169
 
170
  finally:
171
+ # Better cleanup with tempfile
172
+ for path in [temp_audio_path, temp_output_path]:
173
+ if path and os.path.exists(path):
174
+ try:
175
+ os.remove(path)
176
+ except:
177
+ pass
178
+ task = active_tasks.get(task_id)
179
+ if task and task["status"] == "completed":
180
+ remove_task_after_delay(task_id, delay_seconds=300)
181
+ elif task and task["status"] == "failed":
182
+ # Keep failed for 60s then del
183
+ threading.Timer(60, lambda: active_tasks.pop(task_id, None)).start()
184
 
185
 
186
  def clone(text, audio):
187
  """
188
+ Generate cloned audio; chunk long text to avoid OOM.
189
+ Returns path to (possibly concatenated) output WAV.
190
  """
191
+ # Simple lang detect (improve with langdetect lib if needed)
192
+ lang = "en" # default
193
+ if any(c in text for c in "अइउ"): lang = "hi" # Hindi example
194
+ elif any(c in text for c in "äöü"): lang = "de" # German
195
+
196
+ out_path = tempfile.mktemp(suffix=".wav")
197
+ chunks = []
198
+ sentences = text.split(". ") # Basic split
199
+ current_chunk = ""
200
+ for sent in sentences + ["."]: # Add final
201
+ if len(current_chunk + sent) < MAX_TEXT_LEN:
202
+ current_chunk += sent + ". "
203
+ else:
204
+ if current_chunk:
205
+ chunks.append(current_chunk.strip())
206
+ current_chunk = sent + ". "
207
+ if current_chunk:
208
+ chunks.append(current_chunk.strip())
209
+
210
+ chunk_files = []
211
+ for chunk in chunks:
212
+ if not chunk: continue
213
+ chunk_out = tempfile.mktemp(suffix=".wav")
214
+ tts.tts_to_file(
215
+ text=chunk,
216
+ speaker_wav=audio,
217
+ language=lang,
218
+ file_path=chunk_out,
219
+ split_sentences=False # Avoid double-split
220
+ )
221
+ chunk_files.append(chunk_out)
222
+
223
+ # Concat if multi-chunk
224
+ if len(chunk_files) > 1:
225
+ combined = AudioSegment.empty()
226
+ for f in chunk_files:
227
+ combined += AudioSegment.from_wav(f)
228
+ combined.export(out_path, format="wav")
229
+ # Clean chunk temps
230
+ for f in chunk_files: os.remove(f)
231
+ else:
232
+ shutil.move(chunk_files[0] if chunk_files else out_path, out_path)
233
+ os.remove(chunk_files[0]) if chunk_files else None
234
+
235
  return out_path
236
 
237
 
 
273
  timer.start()
274
 
275
 
276
+ if __name__ == "__main__":
277
+ app.run(debug=True, host="0.0.0.0", port=7860)