import gradio as gr import asyncio import edge_tts import tempfile import os import json from pathlib import Path from huggingface_hub import HfApi, upload_file import uuid from datetime import datetime import shutil import re import requests import threading from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware import uvicorn import subprocess import mutagen.mp3 # Configuration HF_TOKEN = os.environ.get("HF_TOKEN") DATASET_REPO = os.environ.get("DATASET_REPO", "yukee1992/video-project-images") TRACKER_URL = os.environ.get("TRACKER_URL", "https://yukee1992-status-tracker.hf.space") print("=" * 60) print("🚀 STARTING TTS SERVICE WITH API AND SRT CAPTIONS") print("=" * 60) print(f"📦 HF Dataset: {DATASET_REPO}") print(f"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing'}") print(f"📡 Tracker URL: {TRACKER_URL}") # Initialize Hugging Face API hf_api = HfApi(token=HF_TOKEN) # ============================================= # Helper function to get audio duration # ============================================= def get_audio_duration(audio_path): """Get actual audio duration in seconds using mutagen""" try: audio = mutagen.mp3.MP3(audio_path) return audio.info.length except: # Fallback: use ffprobe if available try: result = subprocess.run( ['ffprobe', '-v', 'error', '-show_entries', 'format=duration', '-of', 'default=noprint_wrappers=1:nokey=1', audio_path], capture_output=True, text=True ) return float(result.stdout.strip()) except: return None # ============================================= # SRT Generation Functions # ============================================= def format_timestamp(seconds): """Convert seconds to SRT timestamp format (HH:MM:SS,mmm)""" # Ensure seconds is within reasonable range seconds = max(0, min(seconds, 3600)) # Max 1 hour hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = seconds % 60 milliseconds = int((secs - int(secs)) * 1000) return f"{hours:02d}:{minutes:02d}:{int(secs):02d},{milliseconds:03d}" def create_short_srt_from_words(words_data, total_duration): """ Create SRT with short, readable subtitles (2-3 words per line) Ensures timestamps are scaled to match actual audio duration """ if not words_data: return "" # Find the maximum timestamp from words data max_word_time = 0 for word in words_data: word_end = (word['offset'] + word['duration']) / 1e7 max_word_time = max(max_word_time, word_end) # If the max word time is close to total_duration, use it # Otherwise, we need to scale timestamps if max_word_time > 0 and total_duration and abs(max_word_time - total_duration) > 0.5: # Scale factor to match actual audio duration scale_factor = total_duration / max_word_time print(f"📊 Scaling timestamps: max_word_time={max_word_time:.2f}s, audio_duration={total_duration:.2f}s, scale={scale_factor:.3f}") else: scale_factor = 1.0 srt_entries = [] counter = 1 # Group words into small chunks (2 words per subtitle for readability) words_per_subtitle = 2 for i in range(0, len(words_data), words_per_subtitle): chunk = words_data[i:i + words_per_subtitle] if not chunk: continue # Get start time of first word (scaled) start_time = (chunk[0]['offset'] / 1e7) * scale_factor # Get end time of last word (scaled) end_time = ((chunk[-1]['offset'] + chunk[-1]['duration']) / 1e7) * scale_factor # Ensure end_time doesn't exceed total_duration end_time = min(end_time, total_duration) # Combine text without spaces (Chinese) text = ''.join([word['text'] for word in chunk]) # Format timestamps start_str = format_timestamp(start_time) end_str = format_timestamp(end_time) # Add SRT entry srt_entries.append(str(counter)) srt_entries.append(f"{start_str} --> {end_str}") srt_entries.append(text) srt_entries.append("") counter += 1 print(f"✅ Created {counter-1} short subtitle entries") print(f" First subtitle: {srt_entries[1]} -> {srt_entries[2]}") print(f" Last subtitle ends at: {format_timestamp(total_duration)}") return "\n".join(srt_entries) def create_fallback_srt(text, total_duration): """ Fallback method: split text into smaller phrases based on character count """ # Split into smaller chunks (max 15 characters per subtitle for readability) max_chars = 15 phrases = [] # First split by punctuation temp_phrases = re.split(r'[,,。!?.!?]', text) for phrase in temp_phrases: phrase = phrase.strip() if phrase: # Further split long phrases by character count if len(phrase) > max_chars: for j in range(0, len(phrase), max_chars): sub_phrase = phrase[j:j+max_chars] if sub_phrase: phrases.append(sub_phrase) else: phrases.append(phrase) if not phrases: phrases = [text[i:i+max_chars] for i in range(0, len(text), max_chars)] # Calculate duration per phrase duration_per_phrase = total_duration / len(phrases) srt_entries = [] current_time = 0 for i, phrase in enumerate(phrases): start_time = current_time end_time = current_time + duration_per_phrase srt_entries.append(str(i + 1)) srt_entries.append(f"{format_timestamp(start_time)} --> {format_timestamp(end_time)}") srt_entries.append(phrase) srt_entries.append("") current_time = end_time print(f"⚠️ Using fallback SRT: {len(phrases)} phrases over {total_duration:.1f}s") return "\n".join(srt_entries) # ============================================= # Status Reporting Function # ============================================= def report_to_tracker(project_id, service_type, status, file_urls=None, error=None): """Report generation status to central tracker""" if not project_id: return def _report(): try: payload = { "project_id": project_id, "service_type": service_type, "status": status, "file_urls": file_urls or [] } if error: payload["error"] = error response = requests.post( f"{TRACKER_URL}/update", json=payload, timeout=5 ) print(f"📤 Reported to tracker: {response.status_code}") except Exception as e: print(f"⚠️ Failed to report to tracker: {e}") thread = threading.Thread(target=_report, daemon=True) thread.start() # ============================================= # Chinese voice options (both female and male) # ============================================= VOICE_MAPPING = { # Female voices 0: "zh-CN-XiaoxiaoNeural", 1: "zh-CN-XiaoyiNeural", 2: "zh-CN-XiaomengNeural", 3: "zh-CN-XiaoxuanNeural", 4: "zh-CN-XiaohanNeural", 5: "zh-CN-XiaomoNeural", 6: "zh-CN-XiaoruiNeural", # Male voices 7: "zh-CN-YunxiNeural", 8: "zh-CN-YunjianNeural", 9: "zh-CN-YunyangNeural", 10: "zh-CN-YunxiaNeural", 11: "zh-CN-YunhaoNeural", 12: "zh-CN-YunfengNeural", # Regional/dialect voices 13: "zh-CN-liaoning-XiaobeiNeural", 14: "zh-CN-shaanxi-XiaoniNeural", 15: "zh-HK-HiuGaaiNeural", 16: "zh-HK-HiuMaanNeural", 17: "zh-HK-WanLungNeural", 18: "zh-TW-HsiaoChenNeural", 19: "zh-TW-HsiaoYuNeural", 20: "zh-TW-YunJheNeural", } VOICE_DESCRIPTIONS = { 0: "Xiaoxiao (Female) - Warm, caring sister", 1: "Xiaoyi (Female) - Lively, cute sweet voice", 2: "Xiaomeng (Female) - Childish, energetic loli voice", 3: "Xiaoxuan (Female) - Mature, professional", 4: "Xiaohan (Female) - Gentle, warm", 5: "Xiaomo (Female) - Youthful, friendly", 6: "Xiaorui (Female) - Kind, gentle senior", 7: "Yunxi (Male) - Clear, professional broadcast", 8: "Yunjian (Male) - Cool, calm, sports commentary", 9: "Yunyang (Male) - Authoritative news anchor", 10: "Yunxia (Male) - Lively, sunshine, anime style", 11: "Yunhao (Male) - Warm, friendly, optimistic", 12: "Yunfeng (Male) - Deep, mature, serious", 13: "Xiaobei (Female) - Cheerful Liaoning dialect", 14: "Xiaoni (Female) - Bright Shaanxi dialect", 15: "HiuGaai (Female Cantonese) - Hong Kong style", 16: "HiuMaan (Female Cantonese) - Hong Kong style", 17: "WanLung (Male Cantonese) - Hong Kong style", 18: "HsiaoChen (Female Taiwanese) - Taiwan Mandarin", 19: "HsiaoYu (Female Taiwanese) - Taiwan Mandarin", 20: "YunJhe (Male Taiwanese) - Taiwan Mandarin", } # Create FastAPI app fastapi_app = FastAPI(title="TTS API") # Add CORS middleware fastapi_app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def sanitize_folder_name(title): """Convert video title to safe folder name""" safe_name = re.sub(r'[^\w\s-]', '', title) safe_name = re.sub(r'[-\s]+', '_', safe_name) return safe_name.strip('_') def get_emotion_params(emotion_id): """Convert emotion ID to speech parameters""" emotions = { 0: {"rate": "+0%", "pitch": "+0Hz", "volume": "+0%"}, 1: {"rate": "+15%", "pitch": "+30Hz", "volume": "+10%"}, 2: {"rate": "-10%", "pitch": "-20Hz", "volume": "-10%"}, 3: {"rate": "+25%", "pitch": "+50Hz", "volume": "+15%"}, 4: {"rate": "+5%", "pitch": "+15Hz", "volume": "+5%"}, } return emotions.get(emotion_id, emotions[0]) def upload_to_dataset(audio_path, srt_path, metadata, video_title, project_id=None): """Upload audio and SRT files to Hugging Face dataset""" try: folder_name = project_id if project_id else sanitize_folder_name(video_title) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") file_id = str(uuid.uuid4())[:8] voice_name = VOICE_DESCRIPTIONS[metadata["voice_id"]].split(" ")[0] emotion_names = ["neutral", "happy", "sad", "excited", "frustrated"] emotion_name = emotion_names[metadata["emotion_id"]] audio_filename = f"{timestamp}_{voice_name}_{emotion_name}_{file_id}.mp3" srt_filename = f"{timestamp}_{voice_name}_{emotion_name}_{file_id}.srt" audio_dataset_path = f"data/projects/{folder_name}/audio/{audio_filename}" srt_dataset_path = f"data/projects/{folder_name}/subtitles/{srt_filename}" # Upload files upload_file( path_or_fileobj=audio_path, path_in_repo=audio_dataset_path, repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN ) upload_file( path_or_fileobj=srt_path, path_in_repo=srt_dataset_path, repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN ) audio_file_url = f"https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{audio_dataset_path}" srt_file_url = f"https://huggingface.co/datasets/{DATASET_REPO}/blob/main/{srt_dataset_path}" metadata_entry = { "file_id": file_id, "type": "audio_with_subtitles", "audio_filename": audio_filename, "srt_filename": srt_filename, "audio_dataset_path": audio_dataset_path, "srt_dataset_path": srt_dataset_path, "audio_file_url": audio_file_url, "srt_file_url": srt_file_url, "video_title": video_title, "project_id": folder_name, "timestamp": timestamp, "text": metadata["text"], "voice_id": metadata["voice_id"], "voice_name": voice_name, "emotion_id": metadata["emotion_id"], "emotion_name": emotion_name, "speed": metadata["speed"], "duration_seconds": metadata.get("duration_seconds"), "word_count": metadata.get("word_count"), "parameters": metadata["parameters"] } audio_metadata_path = f"data/projects/{folder_name}/metadata/audio_{file_id}.json" with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: json.dump(metadata_entry, f, indent=2) temp_meta_path = f.name upload_file( path_or_fileobj=temp_meta_path, path_in_repo=audio_metadata_path, repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN ) os.unlink(temp_meta_path) return { "success": True, "audio_file_url": audio_file_url, "srt_file_url": srt_file_url, "project_id": folder_name, "metadata": metadata_entry } except Exception as e: return { "success": False, "error": str(e) } async def generate_speech(text, voice_id, emotion_id, speed, video_title, project_id=None, subtitle_style="short"): """Generate speech with accurate subtitles that match audio duration""" temp_dir = None try: voice = VOICE_MAPPING.get(voice_id, "zh-CN-YunxiNeural") emotion_params = get_emotion_params(emotion_id) rate_percentage = int(emotion_params["rate"].replace("%", "").replace("+", "")) adjusted_rate = rate_percentage + int((speed - 1.0) * 50) rate = f"{adjusted_rate:+d}%" temp_dir = tempfile.mkdtemp() local_audio_path = os.path.join(temp_dir, "temp_audio.mp3") local_srt_path = os.path.join(temp_dir, "temp_subtitles.srt") # Initialize word boundary collection words_data = [] # Create communicate instance communicate = edge_tts.Communicate( text, voice, rate=rate, pitch=emotion_params["pitch"], volume=emotion_params["volume"] ) # Collect word boundaries and save audio audio_data = bytearray() print(f"🎤 Generating TTS and collecting word boundaries...") async for chunk in communicate.stream(): if chunk["type"] == "audio": audio_data.extend(chunk["data"]) elif chunk["type"] == "WordBoundary": word_info = { 'text': chunk['text'], 'offset': chunk['offset'], 'duration': chunk['duration'] } words_data.append(word_info) if len(words_data) <= 5: start_sec = word_info['offset'] / 1e7 print(f" Word {len(words_data)}: '{word_info['text']}' at {start_sec:.2f}s") # Save audio file with open(local_audio_path, "wb") as f: f.write(audio_data) print(f"✅ Audio saved. Total words collected: {len(words_data)}") # Get actual audio duration using mutagen total_duration = get_audio_duration(local_audio_path) if total_duration is None: # Fallback: estimate from last word if words_data: last_word = words_data[-1] total_duration = (last_word['offset'] + last_word['duration']) / 1e7 print(f"📊 Estimated duration from words: {total_duration:.2f}s") else: total_duration = max(len(text) / 3.5, 5) print(f"📊 Estimated duration from text length: {total_duration:.2f}s") else: print(f"🎵 Actual audio duration: {total_duration:.2f} seconds") # Create SRT with proper timing if words_data: srt_content = create_short_srt_from_words(words_data, total_duration) caption_count = srt_content.count('-->') print(f"📝 Generated {caption_count} subtitle entries") print(f" Last subtitle timestamp: {srt_content.split('-->')[-1].strip() if '-->' in srt_content else 'N/A'}") else: print("⚠️ No word boundaries collected - using fallback") srt_content = create_fallback_srt(text, total_duration) # Save SRT file with open(local_srt_path, "w", encoding="utf-8") as f: f.write(srt_content) metadata = { "text": text, "voice_id": voice_id, "voice_description": VOICE_DESCRIPTIONS[voice_id], "emotion_id": emotion_id, "speed": speed, "duration_seconds": total_duration, "word_count": len(words_data), "parameters": { "rate": rate, "pitch": emotion_params["pitch"], "volume": emotion_params["volume"] } } # Upload both files upload_result = upload_to_dataset(local_audio_path, local_srt_path, metadata, video_title, project_id) # Clean up if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir) if upload_result["success"]: result = { "success": True, "message": f"Audio ({total_duration:.1f}s) and SRT captions generated", "video_title": video_title, "project_id": upload_result["project_id"], "audio_url": upload_result["audio_file_url"], "srt_url": upload_result["srt_file_url"], "audio_duration": total_duration, "word_count": len(words_data), "subtitle_count": caption_count if 'caption_count' in locals() else 0, "metadata": upload_result["metadata"] } if project_id: report_to_tracker( project_id=project_id, service_type="tts", status="completed", file_urls=[upload_result["audio_file_url"], upload_result["srt_file_url"]] ) return result else: if project_id: report_to_tracker( project_id=project_id, service_type="tts", status="failed", error=upload_result["error"] ) return { "success": False, "error": upload_result["error"] } except Exception as e: if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir) print(f"❌ Error in generate_speech: {str(e)}") if project_id: report_to_tracker( project_id=project_id, service_type="tts", status="failed", error=str(e) ) return { "success": False, "error": str(e) } # ============================================= # FASTAPI ENDPOINTS # ============================================= @fastapi_app.get("/") async def root(): return { "name": "TTS API with SRT Captions", "version": "2.1", "endpoints": { "generate": "POST /api/generate", "health": "GET /api/health" } } @fastapi_app.get("/api/health") async def health(): return {"status": "healthy", "service": "tts"} @fastapi_app.post("/api/generate") async def generate_tts(request: dict): """API endpoint - returns permanent dataset URLs for audio and SRT""" try: text = request.get("text", "") voice_id = int(request.get("voice_id", 7)) emotion_id = int(request.get("emotion_id", 0)) speed = float(request.get("speed", 1.0)) video_title = request.get("video_title", "Untitled Video") project_id = request.get("project_id") subtitle_style = request.get("subtitle_style", "short") if voice_id not in VOICE_MAPPING: return {"status": "error", "error": f"Invalid voice_id: {voice_id}"} if not text: return {"status": "error", "error": "No text provided"} result = await generate_speech(text, voice_id, emotion_id, speed, video_title, project_id, subtitle_style) return result except Exception as e: return {"status": "error", "error": str(e)} # ============================================= # GRADIO INTERFACE # ============================================= with gr.Blocks(title="TTS with SRT Captions") as demo: gr.Markdown("# 🎙️ TTS API with Accurate SRT Captions") gr.Markdown("Generates short, readable subtitles that exactly match the audio duration") with gr.Row(): with gr.Column(scale=1): video_title_input = gr.Textbox( label="🎬 Video Title", placeholder="Enter video title...", value="My Video" ) project_id_input = gr.Textbox( label="📁 Project ID (optional)", placeholder="Enter project ID if known..." ) text_input = gr.Textbox( label="📝 Text to synthesize", placeholder="输入中文...", lines=4, value="乌鲁登嘉楼发生致命车祸,一辆休旅车疑未察觉前方路段因土崩已封闭,失控坠入约61公尺深山沟,导致一名男教师与其未婚妻被抛出车外,当场身亡。" ) voice_dropdown = gr.Dropdown( label="🎤 Voice Selection", choices=[ ("👩 Xiaoxiao - Warm Sister (Female)", 0), ("👨 Yunxi - Professional (Male)", 7), ("👨 Yunjian - Cool Commentary (Male)", 8), ], value=7, type="index" ) subtitle_style = gr.Radio( label="📝 Subtitle Style", choices=[("Short (2 words per line)", "short")], value="short", type="value" ) emotion_slider = gr.Slider( minimum=0, maximum=4, step=1, value=0, label="😊 Emotion", info="0:Neutral 1:Happy 2:Sad 3:Excited 4:Frustrated" ) speed_slider = gr.Slider( minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="⚡ Speed" ) generate_btn = gr.Button("🎵 Generate Audio + SRT", variant="primary") with gr.Column(scale=1): audio_output = gr.Audio(label="Generated Audio", type="filepath") srt_output = gr.File(label="SRT Captions", file_types=[".srt"]) json_output = gr.JSON(label="Response Data") generate_btn.click( fn=lambda t, v, e, s, vt, p, st: asyncio.run(generate_speech(t, v, e, s, vt, p, st)), inputs=[text_input, voice_dropdown, emotion_slider, speed_slider, video_title_input, project_id_input, subtitle_style], outputs=[audio_output, srt_output, json_output] ) # ============================================= # MAIN # ============================================= app = gr.mount_gradio_app(fastapi_app, demo, path="/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)