Joe6636564 commited on
Commit
a582f4a
·
verified ·
1 Parent(s): 86693c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -271
app.py CHANGED
@@ -1,14 +1,13 @@
1
  from flask import Flask, request, jsonify, render_template
2
- from datetime import datetime
3
  from flask_cors import CORS
 
4
  from TTS.api import TTS
5
- from TTS.utils.manage import ModelManager
6
  import os
7
  import base64
8
- import shutil
9
  import wave
10
  import logging
11
  import threading
 
12
 
13
  from helper import (
14
  save_audio,
@@ -19,327 +18,199 @@ from helper import (
19
  ensure_wav_format,
20
  )
21
 
22
- # ---------- Basic config ----------
 
 
23
  logging.basicConfig(level=logging.INFO)
24
  log = logging.getLogger("app")
25
 
26
  app = Flask(__name__)
27
  CORS(app)
28
- os.environ["COQUI_TOS_AGREED"] = "1"
29
 
 
30
  device = "cpu"
31
-
32
- # ============================================================
33
- # MODEL STORAGE PATHS & NAMES
34
- # ============================================================
35
- DATASET_MODEL_DIR = "/datasets/EllenBeta/Xtts_2/model" # dataset mount (destination)
36
- LOCAL_CACHE_DIR = os.path.expanduser("~/.local/share/tts/xtts_v2_cache") # local cache
37
- MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" # coqui model id
38
- # Maximum audio (MB)
39
  MAX_AUDIO_SIZE_MB = 15
40
 
41
- # ============================================================
42
- # Utilities for resolving model download path (defensive)
43
- # ============================================================
44
- def resolve_model_path(raw):
45
- """
46
- Given the return value from ModelManager.download_model(...) try to
47
- return a filesystem path (string) pointing at the downloaded model folder.
48
- Handles strings, tuples/lists, or dict-like returns.
49
- """
50
- # If already a path string
51
- if isinstance(raw, str):
52
- return raw
53
-
54
- # If a list/tuple, try first string-like element
55
- if isinstance(raw, (list, tuple)):
56
- for element in raw:
57
- if isinstance(element, str) and os.path.exists(element):
58
- return element
59
- # fallback: try to join tuple items into a path if meaningful
60
- try:
61
- cand = os.path.join(*[str(x) for x in raw])
62
- if os.path.exists(cand):
63
- return cand
64
- except Exception:
65
- pass
66
-
67
- # If dict-like, try common keys
68
- if isinstance(raw, dict):
69
- for key in ("model_path", "path", "directory"):
70
- val = raw.get(key)
71
- if isinstance(val, str) and os.path.exists(val):
72
- return val
73
-
74
- # final fallback: try to find the typical download directory
75
- fallback = os.path.expanduser("~/.local/share/tts")
76
- if os.path.exists(fallback):
77
- # find matching folder
78
- for root, dirs, files in os.walk(fallback):
79
- if MODEL_NAME.split("/")[-1] in root:
80
- return root
81
-
82
- # Nothing found
83
- return None
84
-
85
-
86
- # ============================================================
87
- # Ensure model is present (download once and copy into dataset)
88
- # ============================================================
89
  tts = None
90
- try:
91
- if os.path.exists(DATASET_MODEL_DIR) and os.listdir(DATASET_MODEL_DIR):
92
- log.info("✅ Loading XTTS model directly from dataset mount: %s", DATASET_MODEL_DIR)
93
- tts = TTS(model_path=DATASET_MODEL_DIR).to(device)
94
- else:
95
- log.info("⬇️ Dataset model not found — downloading XTTS model (first run)...")
96
- manager = ModelManager()
97
- raw_path = manager.download_model(MODEL_NAME)
98
- model_path = resolve_model_path(raw_path)
99
-
100
- if not model_path or not os.path.exists(model_path):
101
- # As a robust fallback, call TTS() with model id then try to locate typical folder
102
- log.warning("Could not resolve model path from ModelManager result; falling back to direct TTS init.")
103
- tts_tmp = TTS(MODEL_NAME).to(device)
104
- # try to locate in default coqui location
105
- candidate = os.path.expanduser("~/.local/share/tts")
106
- model_path = None
107
- if os.path.exists(candidate):
108
- # pick the directory that contains the xtts_v2 name
109
- for root, dirs, files in os.walk(candidate):
110
- if "xtts_v2" in root or "xtts" in root:
111
- model_path = root
112
- break
113
- # if still None, set model_path to candidate root
114
- if not model_path:
115
- model_path = candidate
116
- # assign tts from tts_tmp
117
- tts = tts_tmp
118
-
119
- # Ensure model_path now points to a directory
120
- if model_path and os.path.exists(model_path):
121
- # create local cache dir and copy files (ensure string)
122
- os.makedirs(LOCAL_CACHE_DIR, exist_ok=True)
123
- try:
124
- shutil.copytree(model_path, LOCAL_CACHE_DIR, dirs_exist_ok=True)
125
- except Exception as e:
126
- # if copytree fails (we still continue)
127
- log.warning("Copy to LOCAL_CACHE_DIR failed: %s", e)
128
-
129
- # Copy into dataset mount for persistence (if writable)
130
- try:
131
- os.makedirs(DATASET_MODEL_DIR, exist_ok=True)
132
- for item in os.listdir(model_path):
133
- s = os.path.join(model_path, item)
134
- d = os.path.join(DATASET_MODEL_DIR, item)
135
- if os.path.isdir(s):
136
- shutil.copytree(s, d, dirs_exist_ok=True)
137
- else:
138
- shutil.copy2(s, d)
139
- log.info("📦 Model copied into dataset mount: %s", DATASET_MODEL_DIR)
140
- except Exception as e:
141
- log.warning("Could not copy model into dataset mount (may be read-only or missing perms): %s", e)
142
-
143
- # If tts not already set (from fallback), initialize from model_path or dataset mount
144
- if tts is None:
145
- # prefer dataset dir if copy succeeded, otherwise local cache
146
- init_path = DATASET_MODEL_DIR if os.path.exists(DATASET_MODEL_DIR) and os.listdir(DATASET_MODEL_DIR) else LOCAL_CACHE_DIR
147
- tts = TTS(model_path=init_path).to(device)
148
- else:
149
- # final fallback: initialize directly from model name (internet)
150
- log.warning("Could not find downloaded model folder; initializing TTS from model id directly.")
151
- tts = TTS(MODEL_NAME).to(device)
152
-
153
- log.info("✅ TTS ready.")
154
- except Exception as exc:
155
- log.exception("Failed to prepare TTS model: %s", exc)
156
- # Try a minimal fallback to avoid crash - attempt to init directly.
157
- try:
158
- tts = TTS(MODEL_NAME).to(device)
159
- except Exception as exc2:
160
- log.exception("Fatal: TTS could not be initialized: %s", exc2)
161
- # re-raise so app startup fails loudly (preferred)
162
- raise
163
-
164
- # ============================================================
165
- # Application logic (routes & helpers)
166
- # ============================================================
167
- active_tasks = {}
168
-
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  @app.route("/")
171
  def greet_html():
172
  return render_template("home.html")
173
 
174
-
175
  @app.route("/sign-in")
176
  def sign_in():
177
  return render_template("sign_in.html")
178
 
179
-
180
  @app.route("/user_dash")
181
  def user_dash():
182
  user_id = request.args.get("user_id")
183
- if user_id:
184
- return render_template("u_dash.html", user_id=user_id)
185
- return jsonify({"error": "Missing user_id"}), 400
186
-
187
 
 
 
 
188
  @app.route("/generate_voice", methods=["POST"])
189
  def generate_voice():
190
- try:
191
- data = request.get_json()
192
- if not data:
193
- return jsonify({"error": "No JSON body"}), 400
194
-
195
- video = data.get("video")
196
- text = data.get("text")
197
- audio_base64 = data.get("audio")
198
- task_id = data.get("task_id")
199
- user_id = data.get("user_id")
200
-
201
- if not user_id:
202
- return jsonify({"error": "You must sign in before using this AI"}), 401
203
- if not text:
204
- return jsonify({"error": "Please input a prompt"}), 400
205
- if not task_id:
206
- return jsonify({"error": "task_id is required"}), 400
 
207
  if task_id in active_tasks:
208
- return jsonify({"error": f"There is already an active task for {task_id}"}), 409
209
 
210
  active_tasks[task_id] = {
211
- "user_id": user_id,
212
- "status": "Processing",
213
  "created_at": datetime.now(),
214
  }
215
 
216
- # Run processing (synchronous here - see note below about background processing)
217
- process_vox(user_id, text, video, audio_base64, task_id)
218
- return jsonify({"message": "Processing started", "task_id": task_id}), 202
219
-
220
- except Exception as e:
221
- log.exception("generate_voice error: %s", e)
222
- return jsonify({"error": str(e)}), 500
223
 
 
224
 
 
 
 
225
  def process_vox(user_id, text, video, audio_base64, task_id):
226
- temp_audio_path = None
 
 
227
  try:
228
- # 1) Prepare input audio
229
  if audio_base64:
230
- if audio_base64.startswith("data:audio/"):
231
  audio_base64 = audio_base64.split(",", 1)[1]
232
- temp_audio_path = f"/tmp/temp_ref_{task_id}.wav"
233
- with open(temp_audio_path, "wb") as f:
234
  f.write(base64.b64decode(audio_base64))
235
  elif video:
236
- temp_audio_path = video_to_audio(video, output_path=None)
237
 
238
- # 2) Ensure WAV and validate
239
- temp_audio_path = ensure_wav_format(temp_audio_path)
240
- valid, msg = validate_audio_file(temp_audio_path, MAX_AUDIO_SIZE_MB)
241
  if not valid:
242
- raise Exception(f"Invalid audio file: {msg}")
243
-
244
- # 3) Generate TTS (clone)
245
- result_file = clone(text, temp_audio_path)
246
-
247
- # 4) Save output to user_audios
248
- out_dir = "user_audios"
249
- os.makedirs(out_dir, exist_ok=True)
250
- file_name = generate_random_filename("mp3")
251
- file_path = os.path.join(out_dir, file_name)
252
-
253
- with open(result_file, "rb") as src, open(file_path, "wb") as dst:
254
- dst.write(src.read())
255
-
256
- # 5) Gather metadata
257
- with wave.open(file_path, "rb") as wf:
258
- dura = wf.getnframes() / float(wf.getframerate())
259
- duration = f"{dura:.2f}"
260
- title = text[:20]
261
-
262
- # 6) Upload and save
263
- audio_url = save_to_dataset_repo(file_path, f"user/data/audios/{file_name}", file_name)
264
- active_tasks[task_id].update(
265
- {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  "status": "completed",
267
  "audio_url": audio_url,
268
- "completion_time": datetime.now(),
269
- }
270
- )
271
- save_audio(user_id, audio_url, title or "Audio", text, duration)
272
 
273
  except Exception as e:
274
- log.exception("process_vox failed: %s", e)
275
- active_tasks[task_id] = {
276
- "status": "failed",
277
- "error": str(e),
278
- "completion_time": datetime.now(),
279
- }
280
 
281
  finally:
282
- # cleanup
283
- try:
284
- if temp_audio_path and os.path.exists(temp_audio_path):
285
- os.remove(temp_audio_path)
286
- task = active_tasks.get(task_id)
287
- if task:
288
- if task["status"]== "completed":
289
- remove_task_after_delay(task_id, delay_seconds=300)
290
- elif task["status"] == "failed":
291
- del active_tasks[task_id]
292
- except Exception:
293
- # ignore cleanup issues
294
- pass
295
-
296
-
297
- def clone(text, audio):
298
- """
299
- Use the TTS instance to produce an output file. Returns the path to the output file.
300
- """
301
- out_path = "./output.wav"
302
- # use tts to write audio; let TTS manage model specifics
303
- tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path=out_path)
304
- return out_path
305
-
306
 
 
 
 
307
  @app.route("/task_status")
308
  def task_status():
309
  task_id = request.args.get("task_id")
310
  if not task_id:
311
- return jsonify({"error": "task_id parameter is required"}), 400
312
-
313
- if task_id not in active_tasks:
314
- return jsonify({"status": "not found"}), 404
315
-
316
- task = active_tasks[task_id]
317
- response_data = {
318
- "status": task["status"],
319
- "start_time": task.get("created_at").isoformat() if task.get("created_at") else None,
320
- }
321
-
322
- if task["status"] == "completed":
323
- response_data["audio_url"] = task.get("audio_url")
324
- response_data["completion_time"] = (
325
- task.get("completion_time").isoformat() if task.get("completion_time") else None
326
- )
327
- elif task["status"] == "failed":
328
- response_data["error"] = task.get("error")
329
- response_data["completion_time"] = (
330
- task.get("completion_time").isoformat() if task.get("completion_time") else None
331
- )
332
-
333
- return jsonify(response_data)
334
 
 
 
335
 
336
- def remove_task_after_delay(task_id, delay_seconds=300):
337
- def remove_task():
338
- if task_id in active_tasks:
339
- del active_tasks[task_id]
340
- log.info(f"Task {task_id} auto-deleted after {delay_seconds} seconds.")
341
- timer = threading.Timer(delay_seconds, remove_task)
342
- timer.start()
343
 
 
344
 
345
- # Run only when invoked directly (Gunicorn will ignore this block)
 
 
 
 
 
 
 
 
1
  from flask import Flask, request, jsonify, render_template
 
2
  from flask_cors import CORS
3
+ from datetime import datetime
4
  from TTS.api import TTS
 
5
  import os
6
  import base64
 
7
  import wave
8
  import logging
9
  import threading
10
+ from uuid import uuid4
11
 
12
  from helper import (
13
  save_audio,
 
18
  ensure_wav_format,
19
  )
20
 
21
+ # --------------------------------------------------
22
+ # Basic config
23
+ # --------------------------------------------------
24
  logging.basicConfig(level=logging.INFO)
25
  log = logging.getLogger("app")
26
 
27
  app = Flask(__name__)
28
  CORS(app)
 
29
 
30
+ os.environ["COQUI_TOS_AGREED"] = "1"
31
  device = "cpu"
32
+ MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
 
 
 
 
 
 
 
33
  MAX_AUDIO_SIZE_MB = 15
34
 
35
+ # --------------------------------------------------
36
+ # Global state (thread-safe)
37
+ # --------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  tts = None
39
+ tts_lock = threading.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ active_tasks = {}
42
+ tasks_lock = threading.Lock()
43
+
44
+ # --------------------------------------------------
45
+ # Lazy-load XTTS (ONLY ONCE)
46
+ # --------------------------------------------------
47
+ def get_tts():
48
+ global tts
49
+ with tts_lock:
50
+ if tts is None:
51
+ log.info("🔊 Loading XTTS model (this takes time)...")
52
+ tts = TTS(model_name=MODEL_NAME).to(device)
53
+ log.info("✅ XTTS loaded")
54
+ return tts
55
+
56
+ # --------------------------------------------------
57
+ # Routes
58
+ # --------------------------------------------------
59
  @app.route("/")
60
  def greet_html():
61
  return render_template("home.html")
62
 
 
63
  @app.route("/sign-in")
64
  def sign_in():
65
  return render_template("sign_in.html")
66
 
 
67
  @app.route("/user_dash")
68
  def user_dash():
69
  user_id = request.args.get("user_id")
70
+ if not user_id:
71
+ return jsonify({"error": "Missing user_id"}), 400
72
+ return render_template("u_dash.html", user_id=user_id)
 
73
 
74
+ # --------------------------------------------------
75
+ # Generate Voice (NON-BLOCKING)
76
+ # --------------------------------------------------
77
  @app.route("/generate_voice", methods=["POST"])
78
  def generate_voice():
79
+ data = request.get_json()
80
+ if not data:
81
+ return jsonify({"error": "No JSON body"}), 400
82
+
83
+ user_id = data.get("user_id")
84
+ text = data.get("text")
85
+ audio_base64 = data.get("audio")
86
+ video = data.get("video")
87
+ task_id = data.get("task_id")
88
+
89
+ if not user_id:
90
+ return jsonify({"error": "You must sign in"}), 401
91
+ if not text:
92
+ return jsonify({"error": "Text is required"}), 400
93
+ if not task_id:
94
+ return jsonify({"error": "task_id required"}), 400
95
+
96
+ with tasks_lock:
97
  if task_id in active_tasks:
98
+ return jsonify({"error": "Task already running"}), 409
99
 
100
  active_tasks[task_id] = {
101
+ "status": "processing",
 
102
  "created_at": datetime.now(),
103
  }
104
 
105
+ threading.Thread(
106
+ target=process_vox,
107
+ args=(user_id, text, video, audio_base64, task_id),
108
+ daemon=True
109
+ ).start()
 
 
110
 
111
+ return jsonify({"message": "Processing started", "task_id": task_id}), 202
112
 
113
+ # --------------------------------------------------
114
+ # Background Processor
115
+ # --------------------------------------------------
116
  def process_vox(user_id, text, video, audio_base64, task_id):
117
+ ref_audio = None
118
+ out_file = None
119
+
120
  try:
121
+ # 1️⃣ Prepare reference audio
122
  if audio_base64:
123
+ if audio_base64.startswith("data:audio"):
124
  audio_base64 = audio_base64.split(",", 1)[1]
125
+ ref_audio = f"/tmp/ref_{uuid4().hex}.wav"
126
+ with open(ref_audio, "wb") as f:
127
  f.write(base64.b64decode(audio_base64))
128
  elif video:
129
+ ref_audio = video_to_audio(video)
130
 
131
+ ref_audio = ensure_wav_format(ref_audio)
132
+ valid, msg = validate_audio_file(ref_audio, MAX_AUDIO_SIZE_MB)
 
133
  if not valid:
134
+ raise Exception(msg)
135
+
136
+ # 2️⃣ Generate TTS
137
+ out_file = f"/tmp/tts_{uuid4().hex}.wav"
138
+ tts = get_tts()
139
+ tts.tts_to_file(
140
+ text=text,
141
+ speaker_wav=ref_audio,
142
+ language="en",
143
+ file_path=out_file
144
+ )
145
+
146
+ # 3️⃣ Duration
147
+ with wave.open(out_file, "rb") as wf:
148
+ duration = wf.getnframes() / wf.getframerate()
149
+
150
+ # 4️⃣ Save & upload
151
+ os.makedirs("user_audios", exist_ok=True)
152
+ file_name = generate_random_filename("wav")
153
+ final_path = os.path.join("user_audios", file_name)
154
+ os.rename(out_file, final_path)
155
+
156
+ audio_url = save_to_dataset_repo(
157
+ final_path,
158
+ f"user/data/audios/{file_name}",
159
+ file_name
160
+ )
161
+
162
+ save_audio(
163
+ user_id,
164
+ audio_url,
165
+ text[:20],
166
+ text,
167
+ f"{duration:.2f}"
168
+ )
169
+
170
+ with tasks_lock:
171
+ active_tasks[task_id].update({
172
  "status": "completed",
173
  "audio_url": audio_url,
174
+ "completed_at": datetime.now()
175
+ })
176
+
177
+ remove_task_after_delay(task_id)
178
 
179
  except Exception as e:
180
+ log.exception("TTS failed")
181
+ with tasks_lock:
182
+ active_tasks[task_id] = {
183
+ "status": "failed",
184
+ "error": str(e)
185
+ }
186
 
187
  finally:
188
+ for f in (ref_audio, out_file):
189
+ if f and os.path.exists(f):
190
+ os.remove(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ # --------------------------------------------------
193
+ # Task Status
194
+ # --------------------------------------------------
195
  @app.route("/task_status")
196
  def task_status():
197
  task_id = request.args.get("task_id")
198
  if not task_id:
199
+ return jsonify({"error": "task_id required"}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ with tasks_lock:
202
+ task = active_tasks.get(task_id)
203
 
204
+ if not task:
205
+ return jsonify({"status": "not found"}), 404
 
 
 
 
 
206
 
207
+ return jsonify(task)
208
 
209
+ # --------------------------------------------------
210
+ # Auto-clean tasks
211
+ # --------------------------------------------------
212
+ def remove_task_after_delay(task_id, delay=300):
213
+ def cleanup():
214
+ with tasks_lock:
215
+ active_tasks.pop(task_id, None)
216
+ threading.Timer(delay, cleanup).start()