from flask import Flask, request, jsonify, render_template from datetime import datetime from flask_cors import CORS from TTS.api import TTS import os import base64 import logging import threading import tempfile import shutil import textwrap # For robust text chunking import torch # For no_grad and empty_cache from pydub import AudioSegment # For WAV concat import psutil # For RAM check import warnings # For suppressing warnings from helper import ( save_audio, generate_random_filename, save_to_dataset_repo, video_to_audio, validate_audio_file, ensure_wav_format, ) # ---------- Basic config ---------- logging.basicConfig(level=logging.INFO) log = logging.getLogger("app") # Suppress warnings and logs warnings.filterwarnings("ignore", category=UserWarning, module="transformers") logging.getLogger("transformers").setLevel(logging.ERROR) app = Flask(__name__) CORS(app) os.environ["COQUI_TOS_AGREED"] = "1" device = "cpu" MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" # coqui model id MAX_AUDIO_SIZE_MB = 15 MAX_TEXT_LEN = 150 # Aggressive chunk size for OOM safety # Simplified TTS init: Direct from model name (handles download/config auto) tts = None try: log.info(f"⬇️ Initializing XTTS from {MODEL_NAME}...") tts = TTS(model_name=MODEL_NAME).to(device) # Uses model_name kwarg for HF-style load log.info("✅ TTS ready (direct init).") except Exception as exc: log.exception("Fatal: TTS init failed: %s", exc) raise # ============================================================ # Application logic (routes & helpers) # ============================================================ active_tasks = {} @app.route("/") def greet_html(): return render_template("home.html") @app.route("/sign-in") def sign_in(): return render_template("sign_in.html") @app.route("/user_dash") def user_dash(): user_id = request.args.get("user_id") if user_id: return render_template("u_dash.html", user_id=user_id) return jsonify({"error": "Missing user_id"}), 400 @app.route("/generate_voice", methods=["POST"]) def generate_voice(): try: data = request.get_json() if not data: return jsonify({"error": "No JSON body"}), 400 video = data.get("video") text = data.get("text") audio_base64 = data.get("audio") task_id = data.get("task_id") user_id = data.get("user_id") if not user_id: return jsonify({"error": "You must sign in before using this AI"}), 401 if not text: return jsonify({"error": "Please input a prompt"}), 400 if not task_id: return jsonify({"error": "task_id is required"}), 400 if task_id in active_tasks: return jsonify({"error": f"There is already an active task for {task_id}"}), 409 active_tasks[task_id] = { "user_id": user_id, "status": "Processing", "created_at": datetime.now(), } # Run processing (synchronous; consider Celery for prod scaling) process_vox(user_id, text, video, audio_base64, task_id) return jsonify({"message": "Processing started", "task_id": task_id}), 202 except Exception as e: log.exception("generate_voice error: %s", e) return jsonify({"error": str(e)}), 500 def process_vox(user_id, text, video, audio_base64, task_id): temp_audio_path = None temp_output_path = None try: # RAM check (OOM guard - tightened threshold) ram_gb = psutil.virtual_memory().available / (1024 ** 3) log.info(f"Available RAM: {ram_gb:.1f} GB") if ram_gb < 1.5: # XTTS needs ~1.5GB free min raise Exception("Low RAM: Please try a shorter text or later.") # 1) Prepare input audio if audio_base64: if audio_base64.startswith("data:audio/"): audio_base64 = audio_base64.split(",", 1)[1] temp_audio_path = f"/tmp/temp_ref_{task_id}.wav" with open(temp_audio_path, "wb") as f: f.write(base64.b64decode(audio_base64)) elif video: temp_audio_path = video_to_audio(video, output_path=None) # 2) Ensure WAV and validate temp_audio_path = ensure_wav_format(temp_audio_path) valid, msg = validate_audio_file(temp_audio_path, MAX_AUDIO_SIZE_MB) if not valid: raise Exception(f"Invalid audio file: {msg}") # 3) Generate TTS (clone) with chunking for long text temp_output_path = clone(text, temp_audio_path) # now returns possibly concatenated path # 4) Save output to user_audios out_dir = "user_audios" os.makedirs(out_dir, exist_ok=True) file_name = generate_random_filename("mp3") file_path = os.path.join(out_dir, file_name) with open(temp_output_path, "rb") as src, open(file_path, "wb") as dst: dst.write(src.read()) # 5) Gather metadata import wave with wave.open(file_path, "rb") as wf: dura = wf.getnframes() / float(wf.getframerate()) duration = f"{dura:.2f}" title = text[:20] # 6) Upload and save (with DB retry in helper) audio_url = save_to_dataset_repo(file_path, f"user/data/audios/{file_name}", file_name) active_tasks[task_id].update( { "status": "completed", "audio_url": audio_url, "completion_time": datetime.now(), } ) save_audio(user_id, audio_url, title or "Audio", text, duration) except Exception as e: log.exception("process_vox failed: %s", e) active_tasks[task_id] = { "status": "failed", "error": str(e), "completion_time": datetime.now(), } finally: # Better cleanup with tempfile for path in [temp_audio_path, temp_output_path]: if path and os.path.exists(path): try: os.remove(path) except: pass task = active_tasks.get(task_id) if task and task["status"] == "completed": remove_task_after_delay(task_id, delay_seconds=300) elif task and task["status"] == "failed": # Keep failed for 60s then del threading.Timer(60, lambda: active_tasks.pop(task_id, None)).start() def clone(text, audio): """ Generate cloned audio; chunk long text to avoid OOM. Returns path to (possibly concatenated) output WAV. """ # Improved lang detect (simple heuristics) lang = "en" if any(ord(c) in range(0x0900, 0x0980) for c in text): # Devanagari for Hindi lang = "hi" elif any(c in "äöüß" for c in text): # German chars lang = "de" log.info(f"Cloning with lang: {lang}, text len: {len(text)}") out_path = tempfile.mktemp(suffix=".wav") # Aggressive chunk: wrap to MAX_TEXT_LEN, split sentences where possible wrapped = textwrap.wrap(text, width=MAX_TEXT_LEN, break_long_words=False) chunks = wrapped if len(wrapped) > 1 else [text] # Fallback to full if short log.info(f"Split into {len(chunks)} chunks") chunk_files = [] for i, chunk in enumerate(chunks): if not chunk.strip(): continue chunk_out = tempfile.mktemp(suffix=f"_chunk{i}.wav") with torch.no_grad(): # Mem save: no gradients tts.tts_to_file( text=chunk.strip(), speaker_wav=audio, language=lang, file_path=chunk_out, split_sentences=True # Let TTS handle intra-chunk splits ) chunk_files.append(chunk_out) # Concat if multi-chunk if chunk_files: combined = AudioSegment.empty() for f in chunk_files: combined += AudioSegment.from_wav(f) combined.export(out_path, format="wav") # Clean chunk temps for f in chunk_files: try: os.remove(f) except: pass else: raise Exception("No chunks generated—check text input.") # Clear cache (harmless on CPU) if torch.cuda.is_available(): torch.cuda.empty_cache() log.info("Clone complete.") return out_path @app.route("/task_status") def task_status(): task_id = request.args.get("task_id") if not task_id: return jsonify({"error": "task_id parameter is required"}), 400 if task_id not in active_tasks: return jsonify({"status": "not found"}), 404 task = active_tasks[task_id] response_data = { "status": task["status"], "start_time": task.get("created_at").isoformat() if task.get("created_at") else None, } if task["status"] == "completed": response_data["audio_url"] = task.get("audio_url") response_data["completion_time"] = ( task.get("completion_time").isoformat() if task.get("completion_time") else None ) elif task["status"] == "failed": response_data["error"] = task.get("error") response_data["completion_time"] = ( task.get("completion_time").isoformat() if task.get("completion_time") else None ) return jsonify(response_data) def remove_task_after_delay(task_id, delay_seconds=300): def remove_task(): if task_id in active_tasks: del active_tasks[task_id] log.info(f"Task {task_id} auto-deleted after {delay_seconds} seconds.") timer = threading.Timer(delay_seconds, remove_task) timer.start() if __name__ == "__main__": app.run(debug=True, host="0.0.0.0", port=7860)