Tts-api-new / app.py
yukee1992's picture
Update app.py
0e5714a verified
import gradio as gr
import asyncio
import edge_tts
import tempfile
import os
import json
from pathlib import Path
from huggingface_hub import HfApi, upload_file
import uuid
from datetime import datetime
import shutil
import re
import requests
import threading
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import subprocess
import mutagen.mp3
# Configuration
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_REPO = os.environ.get("DATASET_REPO", "yukee1992/video-project-images")
TRACKER_URL = os.environ.get("TRACKER_URL", "https://yukee1992-status-tracker.hf.space")
print("=" * 60)
print("🚀 STARTING TTS SERVICE WITH API AND SRT CAPTIONS")
print("=" * 60)
print(f"📦 HF Dataset: {DATASET_REPO}")
print(f"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing'}")
print(f"📡 Tracker URL: {TRACKER_URL}")
# Initialize Hugging Face API
hf_api = HfApi(token=HF_TOKEN)
# =============================================
# Helper function to get audio duration
# =============================================
def get_audio_duration(audio_path):
"""Get actual audio duration in seconds using mutagen"""
try:
audio = mutagen.mp3.MP3(audio_path)
return audio.info.length
except:
# Fallback: use ffprobe if available
try:
result = subprocess.run(
['ffprobe', '-v', 'error', '-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1', audio_path],
capture_output=True, text=True
)
return float(result.stdout.strip())
except:
return None
# =============================================
# SRT Generation Functions
# =============================================
def format_timestamp(seconds):
"""Convert seconds to SRT timestamp format (HH:MM:SS,mmm)"""
# Ensure seconds is within reasonable range
seconds = max(0, min(seconds, 3600)) # Max 1 hour
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
milliseconds = int((secs - int(secs)) * 1000)
return f"{hours:02d}:{minutes:02d}:{int(secs):02d},{milliseconds:03d}"
def create_short_srt_from_words(words_data, total_duration):
"""
Create SRT with short, readable subtitles (2-3 words per line)
Ensures timestamps are scaled to match actual audio duration
"""
if not words_data:
return ""
# Find the maximum timestamp from words data
max_word_time = 0
for word in words_data:
word_end = (word['offset'] + word['duration']) / 1e7
max_word_time = max(max_word_time, word_end)
# If the max word time is close to total_duration, use it
# Otherwise, we need to scale timestamps
if max_word_time > 0 and total_duration and abs(max_word_time - total_duration) > 0.5:
# Scale factor to match actual audio duration
scale_factor = total_duration / max_word_time
print(f"📊 Scaling timestamps: max_word_time={max_word_time:.2f}s, audio_duration={total_duration:.2f}s, scale={scale_factor:.3f}")
else:
scale_factor = 1.0
srt_entries = []
counter = 1
# Group words into small chunks (2 words per subtitle for readability)
words_per_subtitle = 2
for i in range(0, len(words_data), words_per_subtitle):
chunk = words_data[i:i + words_per_subtitle]
if not chunk:
continue
# Get start time of first word (scaled)
start_time = (chunk[0]['offset'] / 1e7) * scale_factor
# Get end time of last word (scaled)
end_time = ((chunk[-1]['offset'] + chunk[-1]['duration']) / 1e7) * scale_factor
# Ensure end_time doesn't exceed total_duration
end_time = min(end_time, total_duration)
# Combine text without spaces (Chinese)
text = ''.join([word['text'] for word in chunk])
# Format timestamps
start_str = format_timestamp(start_time)
end_str = format_timestamp(end_time)
# Add SRT entry
srt_entries.append(str(counter))
srt_entries.append(f"{start_str} --> {end_str}")
srt_entries.append(text)
srt_entries.append("")
counter += 1
print(f"✅ Created {counter-1} short subtitle entries")
print(f" First subtitle: {srt_entries[1]} -> {srt_entries[2]}")
print(f" Last subtitle ends at: {format_timestamp(total_duration)}")
return "\n".join(srt_entries)
def create_fallback_srt(text, total_duration):
"""
Fallback method: split text into smaller phrases based on character count
"""
# Split into smaller chunks (max 15 characters per subtitle for readability)
max_chars = 15
phrases = []
# First split by punctuation
temp_phrases = re.split(r'[,,。!?.!?]', text)
for phrase in temp_phrases:
phrase = phrase.strip()
if phrase:
# Further split long phrases by character count
if len(phrase) > max_chars:
for j in range(0, len(phrase), max_chars):
sub_phrase = phrase[j:j+max_chars]
if sub_phrase:
phrases.append(sub_phrase)
else:
phrases.append(phrase)
if not phrases:
phrases = [text[i:i+max_chars] for i in range(0, len(text), max_chars)]
# Calculate duration per phrase
duration_per_phrase = total_duration / len(phrases)
srt_entries = []
current_time = 0
for i, phrase in enumerate(phrases):
start_time = current_time
end_time = current_time + duration_per_phrase
srt_entries.append(str(i + 1))
srt_entries.append(f"{format_timestamp(start_time)} --> {format_timestamp(end_time)}")
srt_entries.append(phrase)
srt_entries.append("")
current_time = end_time
print(f"⚠️ Using fallback SRT: {len(phrases)} phrases over {total_duration:.1f}s")
return "\n".join(srt_entries)
# =============================================
# Status Reporting Function
# =============================================
def report_to_tracker(project_id, service_type, status, file_urls=None, error=None):
"""Report generation status to central tracker"""
if not project_id:
return
def _report():
try:
payload = {
"project_id": project_id,
"service_type": service_type,
"status": status,
"file_urls": file_urls or []
}
if error:
payload["error"] = error
response = requests.post(
f"{TRACKER_URL}/update",
json=payload,
timeout=5
)
print(f"📤 Reported to tracker: {response.status_code}")
except Exception as e:
print(f"⚠️ Failed to report to tracker: {e}")
thread = threading.Thread(target=_report, daemon=True)
thread.start()
# =============================================
# Chinese voice options (both female and male)
# =============================================
VOICE_MAPPING = {
# Female voices
0: "zh-CN-XiaoxiaoNeural",
1: "zh-CN-XiaoyiNeural",
2: "zh-CN-XiaomengNeural",
3: "zh-CN-XiaoxuanNeural",
4: "zh-CN-XiaohanNeural",
5: "zh-CN-XiaomoNeural",
6: "zh-CN-XiaoruiNeural",
# Male voices
7: "zh-CN-YunxiNeural",
8: "zh-CN-YunjianNeural",
9: "zh-CN-YunyangNeural",
10: "zh-CN-YunxiaNeural",
11: "zh-CN-YunhaoNeural",
12: "zh-CN-YunfengNeural",
# Regional/dialect voices
13: "zh-CN-liaoning-XiaobeiNeural",
14: "zh-CN-shaanxi-XiaoniNeural",
15: "zh-HK-HiuGaaiNeural",
16: "zh-HK-HiuMaanNeural",
17: "zh-HK-WanLungNeural",
18: "zh-TW-HsiaoChenNeural",
19: "zh-TW-HsiaoYuNeural",
20: "zh-TW-YunJheNeural",
}
VOICE_DESCRIPTIONS = {
0: "Xiaoxiao (Female) - Warm, caring sister",
1: "Xiaoyi (Female) - Lively, cute sweet voice",
2: "Xiaomeng (Female) - Childish, energetic loli voice",
3: "Xiaoxuan (Female) - Mature, professional",
4: "Xiaohan (Female) - Gentle, warm",
5: "Xiaomo (Female) - Youthful, friendly",
6: "Xiaorui (Female) - Kind, gentle senior",
7: "Yunxi (Male) - Clear, professional broadcast",
8: "Yunjian (Male) - Cool, calm, sports commentary",
9: "Yunyang (Male) - Authoritative news anchor",
10: "Yunxia (Male) - Lively, sunshine, anime style",
11: "Yunhao (Male) - Warm, friendly, optimistic",
12: "Yunfeng (Male) - Deep, mature, serious",
13: "Xiaobei (Female) - Cheerful Liaoning dialect",
14: "Xiaoni (Female) - Bright Shaanxi dialect",
15: "HiuGaai (Female Cantonese) - Hong Kong style",
16: "HiuMaan (Female Cantonese) - Hong Kong style",
17: "WanLung (Male Cantonese) - Hong Kong style",
18: "HsiaoChen (Female Taiwanese) - Taiwan Mandarin",
19: "HsiaoYu (Female Taiwanese) - Taiwan Mandarin",
20: "YunJhe (Male Taiwanese) - Taiwan Mandarin",
}
# Create FastAPI app
fastapi_app = FastAPI(title="TTS API")
# Add CORS middleware
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def sanitize_folder_name(title):
"""Convert video title to safe folder name"""
safe_name = re.sub(r'[^\w\s-]', '', title)
safe_name = re.sub(r'[-\s]+', '_', safe_name)
return safe_name.strip('_')
def get_emotion_params(emotion_id):
"""Convert emotion ID to speech parameters"""
emotions = {
0: {"rate": "+0%", "pitch": "+0Hz", "volume": "+0%"},
1: {"rate": "+15%", "pitch": "+30Hz", "volume": "+10%"},
2: {"rate": "-10%", "pitch": "-20Hz", "volume": "-10%"},
3: {"rate": "+25%", "pitch": "+50Hz", "volume": "+15%"},
4: {"rate": "+5%", "pitch": "+15Hz", "volume": "+5%"},
}
return emotions.get(emotion_id, emotions[0])
def upload_to_dataset(audio_path, srt_path, metadata, video_title, project_id=None):
"""Upload audio and SRT files to Hugging Face dataset"""
try:
folder_name = project_id if project_id else sanitize_folder_name(video_title)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_id = str(uuid.uuid4())[:8]
voice_name = VOICE_DESCRIPTIONS[metadata["voice_id"]].split(" ")[0]
emotion_names = ["neutral", "happy", "sad", "excited", "frustrated"]
emotion_name = emotion_names[metadata["emotion_id"]]
audio_filename = f"{timestamp}_{voice_name}_{emotion_name}_{file_id}.mp3"
srt_filename = f"{timestamp}_{voice_name}_{emotion_name}_{file_id}.srt"
audio_dataset_path = f"data/projects/{folder_name}/audio/{audio_filename}"
srt_dataset_path = f"data/projects/{folder_name}/subtitles/{srt_filename}"
# Upload files
upload_file(
path_or_fileobj=audio_path,
path_in_repo=audio_dataset_path,
repo_id=DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN
)
upload_file(
path_or_fileobj=srt_path,
path_in_repo=srt_dataset_path,
repo_id=DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN
)
audio_file_url = f"https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{audio_dataset_path}"
srt_file_url = f"https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{srt_dataset_path}"
metadata_entry = {
"file_id": file_id,
"type": "audio_with_subtitles",
"audio_filename": audio_filename,
"srt_filename": srt_filename,
"audio_dataset_path": audio_dataset_path,
"srt_dataset_path": srt_dataset_path,
"audio_file_url": audio_file_url,
"srt_file_url": srt_file_url,
"video_title": video_title,
"project_id": folder_name,
"timestamp": timestamp,
"text": metadata["text"],
"voice_id": metadata["voice_id"],
"voice_name": voice_name,
"emotion_id": metadata["emotion_id"],
"emotion_name": emotion_name,
"speed": metadata["speed"],
"duration_seconds": metadata.get("duration_seconds"),
"word_count": metadata.get("word_count"),
"parameters": metadata["parameters"]
}
audio_metadata_path = f"data/projects/{folder_name}/metadata/audio_{file_id}.json"
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(metadata_entry, f, indent=2)
temp_meta_path = f.name
upload_file(
path_or_fileobj=temp_meta_path,
path_in_repo=audio_metadata_path,
repo_id=DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN
)
os.unlink(temp_meta_path)
return {
"success": True,
"audio_file_url": audio_file_url,
"srt_file_url": srt_file_url,
"project_id": folder_name,
"metadata": metadata_entry
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
async def generate_speech(text, voice_id, emotion_id, speed, video_title, project_id=None, subtitle_style="short"):
"""Generate speech with accurate subtitles that match audio duration"""
temp_dir = None
try:
voice = VOICE_MAPPING.get(voice_id, "zh-CN-YunxiNeural")
emotion_params = get_emotion_params(emotion_id)
rate_percentage = int(emotion_params["rate"].replace("%", "").replace("+", ""))
adjusted_rate = rate_percentage + int((speed - 1.0) * 50)
rate = f"{adjusted_rate:+d}%"
temp_dir = tempfile.mkdtemp()
local_audio_path = os.path.join(temp_dir, "temp_audio.mp3")
local_srt_path = os.path.join(temp_dir, "temp_subtitles.srt")
# Initialize word boundary collection
words_data = []
# Create communicate instance
communicate = edge_tts.Communicate(
text,
voice,
rate=rate,
pitch=emotion_params["pitch"],
volume=emotion_params["volume"]
)
# Collect word boundaries and save audio
audio_data = bytearray()
print(f"🎤 Generating TTS and collecting word boundaries...")
async for chunk in communicate.stream():
if chunk["type"] == "audio":
audio_data.extend(chunk["data"])
elif chunk["type"] == "WordBoundary":
word_info = {
'text': chunk['text'],
'offset': chunk['offset'],
'duration': chunk['duration']
}
words_data.append(word_info)
if len(words_data) <= 5:
start_sec = word_info['offset'] / 1e7
print(f" Word {len(words_data)}: '{word_info['text']}' at {start_sec:.2f}s")
# Save audio file
with open(local_audio_path, "wb") as f:
f.write(audio_data)
print(f"✅ Audio saved. Total words collected: {len(words_data)}")
# Get actual audio duration using mutagen
total_duration = get_audio_duration(local_audio_path)
if total_duration is None:
# Fallback: estimate from last word
if words_data:
last_word = words_data[-1]
total_duration = (last_word['offset'] + last_word['duration']) / 1e7
print(f"📊 Estimated duration from words: {total_duration:.2f}s")
else:
total_duration = max(len(text) / 3.5, 5)
print(f"📊 Estimated duration from text length: {total_duration:.2f}s")
else:
print(f"🎵 Actual audio duration: {total_duration:.2f} seconds")
# Create SRT with proper timing
if words_data:
srt_content = create_short_srt_from_words(words_data, total_duration)
caption_count = srt_content.count('-->')
print(f"📝 Generated {caption_count} subtitle entries")
print(f" Last subtitle timestamp: {srt_content.split('-->')[-1].strip() if '-->' in srt_content else 'N/A'}")
else:
print("⚠️ No word boundaries collected - using fallback")
srt_content = create_fallback_srt(text, total_duration)
# Save SRT file
with open(local_srt_path, "w", encoding="utf-8") as f:
f.write(srt_content)
metadata = {
"text": text,
"voice_id": voice_id,
"voice_description": VOICE_DESCRIPTIONS[voice_id],
"emotion_id": emotion_id,
"speed": speed,
"duration_seconds": total_duration,
"word_count": len(words_data),
"parameters": {
"rate": rate,
"pitch": emotion_params["pitch"],
"volume": emotion_params["volume"]
}
}
# Upload both files
upload_result = upload_to_dataset(local_audio_path, local_srt_path, metadata, video_title, project_id)
# Clean up
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
if upload_result["success"]:
result = {
"success": True,
"message": f"Audio ({total_duration:.1f}s) and SRT captions generated",
"video_title": video_title,
"project_id": upload_result["project_id"],
"audio_url": upload_result["audio_file_url"],
"srt_url": upload_result["srt_file_url"],
"audio_duration": total_duration,
"word_count": len(words_data),
"subtitle_count": caption_count if 'caption_count' in locals() else 0,
"metadata": upload_result["metadata"]
}
if project_id:
report_to_tracker(
project_id=project_id,
service_type="tts",
status="completed",
file_urls=[upload_result["audio_file_url"], upload_result["srt_file_url"]]
)
return result
else:
if project_id:
report_to_tracker(
project_id=project_id,
service_type="tts",
status="failed",
error=upload_result["error"]
)
return {
"success": False,
"error": upload_result["error"]
}
except Exception as e:
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
print(f"❌ Error in generate_speech: {str(e)}")
if project_id:
report_to_tracker(
project_id=project_id,
service_type="tts",
status="failed",
error=str(e)
)
return {
"success": False,
"error": str(e)
}
# =============================================
# FASTAPI ENDPOINTS
# =============================================
@fastapi_app.get("/")
async def root():
return {
"name": "TTS API with SRT Captions",
"version": "2.1",
"endpoints": {
"generate": "POST /api/generate",
"health": "GET /api/health"
}
}
@fastapi_app.get("/api/health")
async def health():
return {"status": "healthy", "service": "tts"}
@fastapi_app.post("/api/generate")
async def generate_tts(request: dict):
"""API endpoint - returns permanent dataset URLs for audio and SRT"""
try:
text = request.get("text", "")
voice_id = int(request.get("voice_id", 7))
emotion_id = int(request.get("emotion_id", 0))
speed = float(request.get("speed", 1.0))
video_title = request.get("video_title", "Untitled Video")
project_id = request.get("project_id")
subtitle_style = request.get("subtitle_style", "short")
if voice_id not in VOICE_MAPPING:
return {"status": "error", "error": f"Invalid voice_id: {voice_id}"}
if not text:
return {"status": "error", "error": "No text provided"}
result = await generate_speech(text, voice_id, emotion_id, speed, video_title, project_id, subtitle_style)
return result
except Exception as e:
return {"status": "error", "error": str(e)}
# =============================================
# GRADIO INTERFACE
# =============================================
with gr.Blocks(title="TTS with SRT Captions") as demo:
gr.Markdown("# 🎙️ TTS API with Accurate SRT Captions")
gr.Markdown("Generates short, readable subtitles that exactly match the audio duration")
with gr.Row():
with gr.Column(scale=1):
video_title_input = gr.Textbox(
label="🎬 Video Title",
placeholder="Enter video title...",
value="My Video"
)
project_id_input = gr.Textbox(
label="📁 Project ID (optional)",
placeholder="Enter project ID if known..."
)
text_input = gr.Textbox(
label="📝 Text to synthesize",
placeholder="输入中文...",
lines=4,
value="乌鲁登嘉楼发生致命车祸,一辆休旅车疑未察觉前方路段因土崩已封闭,失控坠入约61公尺深山沟,导致一名男教师与其未婚妻被抛出车外,当场身亡。"
)
voice_dropdown = gr.Dropdown(
label="🎤 Voice Selection",
choices=[
("👩 Xiaoxiao - Warm Sister (Female)", 0),
("👨 Yunxi - Professional (Male)", 7),
("👨 Yunjian - Cool Commentary (Male)", 8),
],
value=7,
type="index"
)
subtitle_style = gr.Radio(
label="📝 Subtitle Style",
choices=[("Short (2 words per line)", "short")],
value="short",
type="value"
)
emotion_slider = gr.Slider(
minimum=0, maximum=4, step=1, value=0,
label="😊 Emotion",
info="0:Neutral 1:Happy 2:Sad 3:Excited 4:Frustrated"
)
speed_slider = gr.Slider(
minimum=0.5, maximum=2.0, step=0.1, value=1.0,
label="⚡ Speed"
)
generate_btn = gr.Button("🎵 Generate Audio + SRT", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(label="Generated Audio", type="filepath")
srt_output = gr.File(label="SRT Captions", file_types=[".srt"])
json_output = gr.JSON(label="Response Data")
generate_btn.click(
fn=lambda t, v, e, s, vt, p, st: asyncio.run(generate_speech(t, v, e, s, vt, p, st)),
inputs=[text_input, voice_dropdown, emotion_slider, speed_slider, video_title_input, project_id_input, subtitle_style],
outputs=[audio_output, srt_output, json_output]
)
# =============================================
# MAIN
# =============================================
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)