File size: 9,746 Bytes
b4f0965 7c7d4ab 9f4ec4d 57a4d8d edad376 1fe6185 57a4d8d 7c7d4ab 7ad0f4f 57a4d8d 1fe6185 7207789 9f4ec4d 57a4d8d 1fe6185 5768b07 8fa01de 57a4d8d 8fa01de 57a4d8d 8fa01de 5768b07 57a4d8d 5768b07 7c7d4ab 7207789 3f0e9cc b4f0965 7c7d4ab b9f7e7b 7c7d4ab b9f7e7b 7c7d4ab 9f4ec4d 7c7d4ab 57a4d8d 7c7d4ab 8fa01de 7c7d4ab 57a4d8d 7c7d4ab 9f4ec4d 7c7d4ab 8fa01de 7c7d4ab 1fe6185 8fa01de 1fe6185 8fa01de 57a4d8d 7c7d4ab 57a4d8d 7c7d4ab 57a4d8d 7c7d4ab 8fa01de 57a4d8d 7c7d4ab 8fa01de 7c7d4ab 57a4d8d 8fa01de 7c7d4ab 8fa01de 7c7d4ab 57a4d8d 7c7d4ab 8fa01de 7c7d4ab 9f4ec4d 57a4d8d 8fa01de 57a4d8d 1fe6185 8fa01de 1fe6185 8fa01de 1fe6185 8fa01de 1fe6185 8fa01de 1fe6185 8fa01de 1fe6185 8fa01de 1fe6185 8fa01de 1fe6185 57a4d8d 7c7d4ab 9f4ec4d 07ad070 7c7d4ab 07ad070 7c7d4ab 07ad070 9f4ec4d 7c7d4ab 9f4ec4d 07ad070 7c7d4ab 07ad070 57a4d8d edad376 8fa01de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 | from flask import Flask, request, jsonify, render_template
from datetime import datetime
from flask_cors import CORS
from TTS.api import TTS
import os
import base64
import logging
import threading
import tempfile
import shutil
import textwrap # For robust text chunking
import torch # For no_grad and empty_cache
from pydub import AudioSegment # For WAV concat
import psutil # For RAM check
import warnings # For suppressing warnings
from helper import (
save_audio,
generate_random_filename,
save_to_dataset_repo,
video_to_audio,
validate_audio_file,
ensure_wav_format,
)
# ---------- Basic config ----------
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("app")
# Suppress warnings and logs
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
logging.getLogger("transformers").setLevel(logging.ERROR)
app = Flask(__name__)
CORS(app)
os.environ["COQUI_TOS_AGREED"] = "1"
device = "cpu"
MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" # coqui model id
MAX_AUDIO_SIZE_MB = 15
MAX_TEXT_LEN = 150 # Aggressive chunk size for OOM safety
# Simplified TTS init: Direct from model name (handles download/config auto)
tts = None
try:
log.info(f"⬇️ Initializing XTTS from {MODEL_NAME}...")
tts = TTS(model_name=MODEL_NAME).to(device) # Uses model_name kwarg for HF-style load
log.info("✅ TTS ready (direct init).")
except Exception as exc:
log.exception("Fatal: TTS init failed: %s", exc)
raise
# ============================================================
# Application logic (routes & helpers)
# ============================================================
active_tasks = {}
@app.route("/")
def greet_html():
return render_template("home.html")
@app.route("/sign-in")
def sign_in():
return render_template("sign_in.html")
@app.route("/user_dash")
def user_dash():
user_id = request.args.get("user_id")
if user_id:
return render_template("u_dash.html", user_id=user_id)
return jsonify({"error": "Missing user_id"}), 400
@app.route("/generate_voice", methods=["POST"])
def generate_voice():
try:
data = request.get_json()
if not data:
return jsonify({"error": "No JSON body"}), 400
video = data.get("video")
text = data.get("text")
audio_base64 = data.get("audio")
task_id = data.get("task_id")
user_id = data.get("user_id")
if not user_id:
return jsonify({"error": "You must sign in before using this AI"}), 401
if not text:
return jsonify({"error": "Please input a prompt"}), 400
if not task_id:
return jsonify({"error": "task_id is required"}), 400
if task_id in active_tasks:
return jsonify({"error": f"There is already an active task for {task_id}"}), 409
active_tasks[task_id] = {
"user_id": user_id,
"status": "Processing",
"created_at": datetime.now(),
}
# Run processing (synchronous; consider Celery for prod scaling)
process_vox(user_id, text, video, audio_base64, task_id)
return jsonify({"message": "Processing started", "task_id": task_id}), 202
except Exception as e:
log.exception("generate_voice error: %s", e)
return jsonify({"error": str(e)}), 500
def process_vox(user_id, text, video, audio_base64, task_id):
temp_audio_path = None
temp_output_path = None
try:
# RAM check (OOM guard - tightened threshold)
ram_gb = psutil.virtual_memory().available / (1024 ** 3)
log.info(f"Available RAM: {ram_gb:.1f} GB")
if ram_gb < 1.5: # XTTS needs ~1.5GB free min
raise Exception("Low RAM: Please try a shorter text or later.")
# 1) Prepare input audio
if audio_base64:
if audio_base64.startswith("data:audio/"):
audio_base64 = audio_base64.split(",", 1)[1]
temp_audio_path = f"/tmp/temp_ref_{task_id}.wav"
with open(temp_audio_path, "wb") as f:
f.write(base64.b64decode(audio_base64))
elif video:
temp_audio_path = video_to_audio(video, output_path=None)
# 2) Ensure WAV and validate
temp_audio_path = ensure_wav_format(temp_audio_path)
valid, msg = validate_audio_file(temp_audio_path, MAX_AUDIO_SIZE_MB)
if not valid:
raise Exception(f"Invalid audio file: {msg}")
# 3) Generate TTS (clone) with chunking for long text
temp_output_path = clone(text, temp_audio_path) # now returns possibly concatenated path
# 4) Save output to user_audios
out_dir = "user_audios"
os.makedirs(out_dir, exist_ok=True)
file_name = generate_random_filename("mp3")
file_path = os.path.join(out_dir, file_name)
with open(temp_output_path, "rb") as src, open(file_path, "wb") as dst:
dst.write(src.read())
# 5) Gather metadata
import wave
with wave.open(file_path, "rb") as wf:
dura = wf.getnframes() / float(wf.getframerate())
duration = f"{dura:.2f}"
title = text[:20]
# 6) Upload and save (with DB retry in helper)
audio_url = save_to_dataset_repo(file_path, f"user/data/audios/{file_name}", file_name)
active_tasks[task_id].update(
{
"status": "completed",
"audio_url": audio_url,
"completion_time": datetime.now(),
}
)
save_audio(user_id, audio_url, title or "Audio", text, duration)
except Exception as e:
log.exception("process_vox failed: %s", e)
active_tasks[task_id] = {
"status": "failed",
"error": str(e),
"completion_time": datetime.now(),
}
finally:
# Better cleanup with tempfile
for path in [temp_audio_path, temp_output_path]:
if path and os.path.exists(path):
try:
os.remove(path)
except:
pass
task = active_tasks.get(task_id)
if task and task["status"] == "completed":
remove_task_after_delay(task_id, delay_seconds=300)
elif task and task["status"] == "failed":
# Keep failed for 60s then del
threading.Timer(60, lambda: active_tasks.pop(task_id, None)).start()
def clone(text, audio):
"""
Generate cloned audio; chunk long text to avoid OOM.
Returns path to (possibly concatenated) output WAV.
"""
# Improved lang detect (simple heuristics)
lang = "en"
if any(ord(c) in range(0x0900, 0x0980) for c in text): # Devanagari for Hindi
lang = "hi"
elif any(c in "äöüß" for c in text): # German chars
lang = "de"
log.info(f"Cloning with lang: {lang}, text len: {len(text)}")
out_path = tempfile.mktemp(suffix=".wav")
# Aggressive chunk: wrap to MAX_TEXT_LEN, split sentences where possible
wrapped = textwrap.wrap(text, width=MAX_TEXT_LEN, break_long_words=False)
chunks = wrapped if len(wrapped) > 1 else [text] # Fallback to full if short
log.info(f"Split into {len(chunks)} chunks")
chunk_files = []
for i, chunk in enumerate(chunks):
if not chunk.strip(): continue
chunk_out = tempfile.mktemp(suffix=f"_chunk{i}.wav")
with torch.no_grad(): # Mem save: no gradients
tts.tts_to_file(
text=chunk.strip(),
speaker_wav=audio,
language=lang,
file_path=chunk_out,
split_sentences=True # Let TTS handle intra-chunk splits
)
chunk_files.append(chunk_out)
# Concat if multi-chunk
if chunk_files:
combined = AudioSegment.empty()
for f in chunk_files:
combined += AudioSegment.from_wav(f)
combined.export(out_path, format="wav")
# Clean chunk temps
for f in chunk_files:
try:
os.remove(f)
except:
pass
else:
raise Exception("No chunks generated—check text input.")
# Clear cache (harmless on CPU)
if torch.cuda.is_available():
torch.cuda.empty_cache()
log.info("Clone complete.")
return out_path
@app.route("/task_status")
def task_status():
task_id = request.args.get("task_id")
if not task_id:
return jsonify({"error": "task_id parameter is required"}), 400
if task_id not in active_tasks:
return jsonify({"status": "not found"}), 404
task = active_tasks[task_id]
response_data = {
"status": task["status"],
"start_time": task.get("created_at").isoformat() if task.get("created_at") else None,
}
if task["status"] == "completed":
response_data["audio_url"] = task.get("audio_url")
response_data["completion_time"] = (
task.get("completion_time").isoformat() if task.get("completion_time") else None
)
elif task["status"] == "failed":
response_data["error"] = task.get("error")
response_data["completion_time"] = (
task.get("completion_time").isoformat() if task.get("completion_time") else None
)
return jsonify(response_data)
def remove_task_after_delay(task_id, delay_seconds=300):
def remove_task():
if task_id in active_tasks:
del active_tasks[task_id]
log.info(f"Task {task_id} auto-deleted after {delay_seconds} seconds.")
timer = threading.Timer(delay_seconds, remove_task)
timer.start()
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=7860) |