TTS / app.py
github-actions[bot]
Auto-deploy from GitHub: c1cbfa3a37f6853e24d067af55ebc1ab447d9fc0
68a99fc
raw
history blame
8.61 kB
from flask import Flask, request, jsonify, send_from_directory, send_file
from flask_cors import CORS
import sqlite3
import os
import uuid
from datetime import datetime
from werkzeug.utils import secure_filename
import threading
import subprocess
import time
import shutil
app = Flask(__name__)
CORS(app)
UPLOAD_FOLDER = 'uploads'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs('temp_dir', exist_ok=True)
# Worker state
worker_thread = None
worker_running = False
def init_db():
conn = sqlite3.connect('tts_tasks.db')
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS tasks
(id TEXT PRIMARY KEY,
text TEXT NOT NULL,
voice TEXT,
speed REAL,
status TEXT NOT NULL,
output_file TEXT,
created_at TEXT NOT NULL,
processed_at TEXT,
error TEXT)''')
conn.commit()
conn.close()
def start_worker():
"""Start the worker thread if not already running"""
global worker_thread, worker_running
if not worker_running:
worker_running = True
worker_thread = threading.Thread(target=worker_loop, daemon=True)
worker_thread.start()
print("βœ… Worker thread started")
def worker_loop():
"""Main worker loop that processes TTS tasks"""
print("πŸ€– TTS Worker started. Monitoring for new tasks...")
CWD = "./"
PYTHON_PATH = "python3" # Or just python
POLL_INTERVAL = 2 # seconds
while worker_running:
try:
# Get next unprocessed task
conn = sqlite3.connect('tts_tasks.db')
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute('''SELECT * FROM tasks
WHERE status = 'not_started'
ORDER BY created_at ASC
LIMIT 1''')
row = c.fetchone()
conn.close()
if row:
task_id = row['id']
text = row['text']
voice = row['voice'] or '8' # Default voice
speed = row['speed'] or 1.0
print(f"\n{'='*60}")
print(f"🎡 Processing Task: {task_id}")
print(f"πŸ“ Text: {text[:50]}...")
print(f"{'='*60}")
# Update status to processing
update_status(task_id, 'processing')
try:
# Write text to content.txt
with open('content.txt', 'w', encoding='utf-8') as f:
f.write(text)
# Run TTS command
# python3 -m tts_runner.runner --model kokoro --voice <voice> --speed <speed>
print(f"πŸ”„ Running TTS...")
command = [
PYTHON_PATH, "-m", "tts_runner.runner",
"--model", "kokoro",
"--voice", str(voice),
"--speed", str(speed)
]
subprocess.run(
command,
check=True,
cwd=CWD,
env={
**os.environ,
'PYTHONUNBUFFERED': '1',
'CUDA_LAUNCH_BLOCKING': '1'
}
)
# Check for output file
output_filename = "output_audio.wav"
if os.path.exists(output_filename):
# Move to uploads folder
target_filename = f"{task_id}.wav"
target_path = os.path.join(UPLOAD_FOLDER, target_filename)
shutil.move(output_filename, target_path)
print(f"βœ… Successfully processed: {target_filename}")
# Update database with success
update_status(task_id, 'completed', output_file=target_filename)
else:
raise Exception("Output audio file not found")
except Exception as e:
print(f"❌ Failed to process: {task_id}")
print(f"Error: {str(e)}")
update_status(task_id, 'failed', error=str(e))
else:
# No tasks to process, sleep for a bit
time.sleep(POLL_INTERVAL)
except Exception as e:
print(f"⚠️ Worker error: {str(e)}")
time.sleep(POLL_INTERVAL)
def update_status(task_id, status, output_file=None, error=None):
"""Update the status of a task in the database"""
conn = sqlite3.connect('tts_tasks.db')
c = conn.cursor()
if status == 'completed':
c.execute('''UPDATE tasks
SET status = ?, output_file = ?, processed_at = ?
WHERE id = ?''',
(status, output_file, datetime.now().isoformat(), task_id))
elif status == 'failed':
c.execute('''UPDATE tasks
SET status = ?, error = ?, processed_at = ?
WHERE id = ?''',
(status, str(error), datetime.now().isoformat(), task_id))
else:
c.execute('UPDATE tasks SET status = ? WHERE id = ?', (status, task_id))
conn.commit()
conn.close()
@app.route('/')
def index():
return send_from_directory('.', 'index.html')
@app.route('/api/generate', methods=['POST'])
def generate_audio():
data = request.json
if not data or 'text' not in data:
return jsonify({'error': 'No text provided'}), 400
text = data['text']
voice = data.get('voice', '8')
speed = data.get('speed', 1.0)
if not text.strip():
return jsonify({'error': 'Text cannot be empty'}), 400
task_id = str(uuid.uuid4())
conn = sqlite3.connect('tts_tasks.db')
c = conn.cursor()
c.execute('''INSERT INTO tasks
(id, text, voice, speed, status, created_at)
VALUES (?, ?, ?, ?, ?, ?)''',
(task_id, text, voice, speed, 'not_started', datetime.now().isoformat()))
conn.commit()
conn.close()
# Start worker on first request
start_worker()
return jsonify({
'id': task_id,
'status': 'not_started',
'message': 'Task queued successfully'
}), 201
@app.route('/api/files', methods=['GET'])
def get_files():
conn = sqlite3.connect('tts_tasks.db')
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute('SELECT * FROM tasks ORDER BY created_at DESC')
rows = c.fetchall()
conn.close()
files = []
for row in rows:
files.append({
'id': row['id'],
'text': row['text'],
'status': row['status'],
'output_file': row['output_file'],
'created_at': row['created_at'],
'processed_at': row['processed_at'],
'error': row['error']
})
return jsonify(files)
@app.route('/api/download/<task_id>', methods=['GET'])
def download_file(task_id):
conn = sqlite3.connect('tts_tasks.db')
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute('SELECT * FROM tasks WHERE id = ?', (task_id,))
row = c.fetchone()
conn.close()
if row is None or not row['output_file']:
return jsonify({'error': 'File not found'}), 404
file_path = os.path.join(UPLOAD_FOLDER, row['output_file'])
if not os.path.exists(file_path):
return jsonify({'error': 'File missing on server'}), 404
return send_file(file_path, as_attachment=True, download_name=f"tts_{task_id}.wav")
@app.route('/health', methods=['GET'])
def health():
return jsonify({
'status': 'healthy',
'service': 'tts-generator',
'worker_running': worker_running
})
if __name__ == '__main__':
init_db()
print("\n" + "="*60)
print("πŸš€ TTS Generator API Server")
print("="*60)
print("πŸ“Œ Worker will start automatically on first request")
print("="*60 + "\n")
# Use PORT environment variable for Hugging Face compatibility
port = int(os.environ.get('PORT', 7860))
app.run(debug=False, host='0.0.0.0', port=port)