vox-beta / app.py
EllenBeta's picture
Update app.py
1fe6185 verified
from flask import Flask, request, jsonify, render_template
from datetime import datetime
from flask_cors import CORS
from TTS.api import TTS
import os
import base64
import logging
import threading
import tempfile
import shutil
import textwrap # For robust text chunking
import torch # For no_grad and empty_cache
from pydub import AudioSegment # For WAV concat
import psutil # For RAM check
import warnings # For suppressing warnings
from helper import (
save_audio,
generate_random_filename,
save_to_dataset_repo,
video_to_audio,
validate_audio_file,
ensure_wav_format,
)
# ---------- Basic config ----------
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("app")
# Suppress warnings and logs
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
logging.getLogger("transformers").setLevel(logging.ERROR)
app = Flask(__name__)
CORS(app)
os.environ["COQUI_TOS_AGREED"] = "1"
device = "cpu"
MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" # coqui model id
MAX_AUDIO_SIZE_MB = 15
MAX_TEXT_LEN = 150 # Aggressive chunk size for OOM safety
# Simplified TTS init: Direct from model name (handles download/config auto)
tts = None
try:
log.info(f"⬇️ Initializing XTTS from {MODEL_NAME}...")
tts = TTS(model_name=MODEL_NAME).to(device) # Uses model_name kwarg for HF-style load
log.info("✅ TTS ready (direct init).")
except Exception as exc:
log.exception("Fatal: TTS init failed: %s", exc)
raise
# ============================================================
# Application logic (routes & helpers)
# ============================================================
active_tasks = {}
@app.route("/")
def greet_html():
return render_template("home.html")
@app.route("/sign-in")
def sign_in():
return render_template("sign_in.html")
@app.route("/user_dash")
def user_dash():
user_id = request.args.get("user_id")
if user_id:
return render_template("u_dash.html", user_id=user_id)
return jsonify({"error": "Missing user_id"}), 400
@app.route("/generate_voice", methods=["POST"])
def generate_voice():
try:
data = request.get_json()
if not data:
return jsonify({"error": "No JSON body"}), 400
video = data.get("video")
text = data.get("text")
audio_base64 = data.get("audio")
task_id = data.get("task_id")
user_id = data.get("user_id")
if not user_id:
return jsonify({"error": "You must sign in before using this AI"}), 401
if not text:
return jsonify({"error": "Please input a prompt"}), 400
if not task_id:
return jsonify({"error": "task_id is required"}), 400
if task_id in active_tasks:
return jsonify({"error": f"There is already an active task for {task_id}"}), 409
active_tasks[task_id] = {
"user_id": user_id,
"status": "Processing",
"created_at": datetime.now(),
}
# Run processing (synchronous; consider Celery for prod scaling)
process_vox(user_id, text, video, audio_base64, task_id)
return jsonify({"message": "Processing started", "task_id": task_id}), 202
except Exception as e:
log.exception("generate_voice error: %s", e)
return jsonify({"error": str(e)}), 500
def process_vox(user_id, text, video, audio_base64, task_id):
temp_audio_path = None
temp_output_path = None
try:
# RAM check (OOM guard - tightened threshold)
ram_gb = psutil.virtual_memory().available / (1024 ** 3)
log.info(f"Available RAM: {ram_gb:.1f} GB")
if ram_gb < 1.5: # XTTS needs ~1.5GB free min
raise Exception("Low RAM: Please try a shorter text or later.")
# 1) Prepare input audio
if audio_base64:
if audio_base64.startswith("data:audio/"):
audio_base64 = audio_base64.split(",", 1)[1]
temp_audio_path = f"/tmp/temp_ref_{task_id}.wav"
with open(temp_audio_path, "wb") as f:
f.write(base64.b64decode(audio_base64))
elif video:
temp_audio_path = video_to_audio(video, output_path=None)
# 2) Ensure WAV and validate
temp_audio_path = ensure_wav_format(temp_audio_path)
valid, msg = validate_audio_file(temp_audio_path, MAX_AUDIO_SIZE_MB)
if not valid:
raise Exception(f"Invalid audio file: {msg}")
# 3) Generate TTS (clone) with chunking for long text
temp_output_path = clone(text, temp_audio_path) # now returns possibly concatenated path
# 4) Save output to user_audios
out_dir = "user_audios"
os.makedirs(out_dir, exist_ok=True)
file_name = generate_random_filename("mp3")
file_path = os.path.join(out_dir, file_name)
with open(temp_output_path, "rb") as src, open(file_path, "wb") as dst:
dst.write(src.read())
# 5) Gather metadata
import wave
with wave.open(file_path, "rb") as wf:
dura = wf.getnframes() / float(wf.getframerate())
duration = f"{dura:.2f}"
title = text[:20]
# 6) Upload and save (with DB retry in helper)
audio_url = save_to_dataset_repo(file_path, f"user/data/audios/{file_name}", file_name)
active_tasks[task_id].update(
{
"status": "completed",
"audio_url": audio_url,
"completion_time": datetime.now(),
}
)
save_audio(user_id, audio_url, title or "Audio", text, duration)
except Exception as e:
log.exception("process_vox failed: %s", e)
active_tasks[task_id] = {
"status": "failed",
"error": str(e),
"completion_time": datetime.now(),
}
finally:
# Better cleanup with tempfile
for path in [temp_audio_path, temp_output_path]:
if path and os.path.exists(path):
try:
os.remove(path)
except:
pass
task = active_tasks.get(task_id)
if task and task["status"] == "completed":
remove_task_after_delay(task_id, delay_seconds=300)
elif task and task["status"] == "failed":
# Keep failed for 60s then del
threading.Timer(60, lambda: active_tasks.pop(task_id, None)).start()
def clone(text, audio):
"""
Generate cloned audio; chunk long text to avoid OOM.
Returns path to (possibly concatenated) output WAV.
"""
# Improved lang detect (simple heuristics)
lang = "en"
if any(ord(c) in range(0x0900, 0x0980) for c in text): # Devanagari for Hindi
lang = "hi"
elif any(c in "äöüß" for c in text): # German chars
lang = "de"
log.info(f"Cloning with lang: {lang}, text len: {len(text)}")
out_path = tempfile.mktemp(suffix=".wav")
# Aggressive chunk: wrap to MAX_TEXT_LEN, split sentences where possible
wrapped = textwrap.wrap(text, width=MAX_TEXT_LEN, break_long_words=False)
chunks = wrapped if len(wrapped) > 1 else [text] # Fallback to full if short
log.info(f"Split into {len(chunks)} chunks")
chunk_files = []
for i, chunk in enumerate(chunks):
if not chunk.strip(): continue
chunk_out = tempfile.mktemp(suffix=f"_chunk{i}.wav")
with torch.no_grad(): # Mem save: no gradients
tts.tts_to_file(
text=chunk.strip(),
speaker_wav=audio,
language=lang,
file_path=chunk_out,
split_sentences=True # Let TTS handle intra-chunk splits
)
chunk_files.append(chunk_out)
# Concat if multi-chunk
if chunk_files:
combined = AudioSegment.empty()
for f in chunk_files:
combined += AudioSegment.from_wav(f)
combined.export(out_path, format="wav")
# Clean chunk temps
for f in chunk_files:
try:
os.remove(f)
except:
pass
else:
raise Exception("No chunks generated—check text input.")
# Clear cache (harmless on CPU)
if torch.cuda.is_available():
torch.cuda.empty_cache()
log.info("Clone complete.")
return out_path
@app.route("/task_status")
def task_status():
task_id = request.args.get("task_id")
if not task_id:
return jsonify({"error": "task_id parameter is required"}), 400
if task_id not in active_tasks:
return jsonify({"status": "not found"}), 404
task = active_tasks[task_id]
response_data = {
"status": task["status"],
"start_time": task.get("created_at").isoformat() if task.get("created_at") else None,
}
if task["status"] == "completed":
response_data["audio_url"] = task.get("audio_url")
response_data["completion_time"] = (
task.get("completion_time").isoformat() if task.get("completion_time") else None
)
elif task["status"] == "failed":
response_data["error"] = task.get("error")
response_data["completion_time"] = (
task.get("completion_time").isoformat() if task.get("completion_time") else None
)
return jsonify(response_data)
def remove_task_after_delay(task_id, delay_seconds=300):
def remove_task():
if task_id in active_tasks:
del active_tasks[task_id]
log.info(f"Task {task_id} auto-deleted after {delay_seconds} seconds.")
timer = threading.Timer(delay_seconds, remove_task)
timer.start()
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=7860)