import os from huggingface_hub import HfApi, hf_hub_download from apscheduler.schedulers.background import BackgroundScheduler from concurrent.futures import ThreadPoolExecutor from datetime import datetime import threading # Added for locking from sqlalchemy import or_ # Added for vote counting query year = datetime.now().year month = datetime.now().month # Check if running in a Hugging Face Space IS_SPACES = False PERSISTENT_DATA_DIR = None # Will be set if persistent storage is available DATABASE_REPO_ID = "channelcorp/ko-tts-arena-db" # Single source for DB if os.getenv("SPACE_REPO_NAME"): print("Running in a Hugging Face Space ๐Ÿค—") IS_SPACES = True # Check for persistent storage availability (/data directory) # HuggingFace Spaces provides /data as persistent storage if os.path.exists("/data") and os.access("/data", os.W_OK): PERSISTENT_DATA_DIR = "/data" print("Persistent storage available at /data โœ…") else: # Fallback to instance directory (non-persistent) PERSISTENT_DATA_DIR = "instance" print("โš ๏ธ Warning: Persistent storage (/data) not available. Using 'instance/' (data may be lost on restart)") # Define database path db_path = os.path.join(PERSISTENT_DATA_DIR, "tts_arena.db") # Setup database - download only if it doesn't exist in persistent storage if not os.path.exists(db_path): os.makedirs(PERSISTENT_DATA_DIR, exist_ok=True) try: print(f"Database not found at {db_path}, downloading from HF dataset ({DATABASE_REPO_ID})...") hf_hub_download( repo_id=DATABASE_REPO_ID, filename="tts_arena.db", repo_type="dataset", local_dir=PERSISTENT_DATA_DIR, token=os.getenv("HF_TOKEN"), ) print(f"Database downloaded successfully to {db_path} โœ…") except Exception as e: print(f"Error downloading database from HF dataset: {str(e)} โš ๏ธ") print("A new database will be created.") else: print(f"Database found at {db_path} (persistent storage) โœ…") from flask import ( Flask, render_template, g, request, jsonify, send_file, redirect, url_for, session, abort, ) from werkzeug.middleware.proxy_fix import ProxyFix from flask_login import LoginManager, current_user from models import * from auth import auth, init_oauth, is_admin from admin import admin from security import is_vote_allowed, check_user_security_score, detect_coordinated_voting import os from dotenv import load_dotenv from flask_limiter import Limiter from flask_limiter.util import get_remote_address import uuid import tempfile import shutil from tts import predict_tts import random import json from datetime import datetime, timedelta from flask_migrate import Migrate import requests import functools import time # Added for potential retries def is_korean_text(text: str, threshold: float = 0.3) -> bool: """ Check if text contains sufficient Korean characters. Returns True if Korean character ratio >= threshold (default 30%). Also returns True for empty text or text with mostly numbers/punctuation. """ if not text: return True korean_count = 0 letter_count = 0 for char in text: # Check if it's a letter (Korean, English, etc.) if char.isalpha(): letter_count += 1 # Korean Unicode ranges: Hangul Syllables, Hangul Jamo, Hangul Compatibility Jamo if '\uAC00' <= char <= '\uD7AF' or '\u1100' <= char <= '\u11FF' or '\u3130' <= char <= '\u318F': korean_count += 1 # If no letters, allow (might be numbers, punctuation, etc.) if letter_count == 0: return True korean_ratio = korean_count / letter_count return korean_ratio >= threshold def get_client_ip(): """Get the client's IP address, handling proxies and load balancers.""" # Check for forwarded headers first (common with reverse proxies) if request.headers.get('X-Forwarded-For'): # X-Forwarded-For can contain multiple IPs, take the first one return request.headers.get('X-Forwarded-For').split(',')[0].strip() elif request.headers.get('X-Real-IP'): return request.headers.get('X-Real-IP') elif request.headers.get('CF-Connecting-IP'): # Cloudflare return request.headers.get('CF-Connecting-IP') elif request.access_route: # Use access_route for HF Spaces (falls back to proxy chain) return request.access_route[0] else: return request.remote_addr # Load environment variables if not IS_SPACES: load_dotenv() # Only load .env if not running in a Hugging Face Space app = Flask(__name__) # Apply ProxyFix to handle X-Forwarded-For headers from HF Spaces proxy app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) app.config["SECRET_KEY"] = os.getenv("SECRET_KEY", os.urandom(24)) # Configure database path - use persistent storage in HF Spaces if IS_SPACES and PERSISTENT_DATA_DIR: # Use persistent storage path for HF Spaces app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{PERSISTENT_DATA_DIR}/tts_arena.db" print(f"Using database at: {PERSISTENT_DATA_DIR}/tts_arena.db") else: app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv( "DATABASE_URI", "sqlite:///tts_arena.db" ) app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False app.config["SESSION_COOKIE_SECURE"] = True app.config["SESSION_COOKIE_SAMESITE"] = ( "None" if IS_SPACES else "Lax" ) # HF Spaces uses iframes to load the app, so we need to set SAMESITE to None app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=30) # Set to desired duration # Force HTTPS when running in HuggingFace Spaces if IS_SPACES: app.config["PREFERRED_URL_SCHEME"] = "https" # Cloudflare Turnstile settings app.config["TURNSTILE_ENABLED"] = ( os.getenv("TURNSTILE_ENABLED", "False").lower() == "true" ) app.config["TURNSTILE_SITE_KEY"] = os.getenv("TURNSTILE_SITE_KEY", "") app.config["TURNSTILE_SECRET_KEY"] = os.getenv("TURNSTILE_SECRET_KEY", "") app.config["TURNSTILE_VERIFY_URL"] = ( "https://challenges.cloudflare.com/turnstile/v0/siteverify" ) migrate = Migrate(app, db) # Initialize extensions db.init_app(app) login_manager = LoginManager() login_manager.init_app(app) login_manager.login_view = "auth.login" # Initialize OAuth init_oauth(app) # Configure rate limits limiter = Limiter( app=app, key_func=get_remote_address, default_limits=["2000 per day", "50 per minute"], storage_uri="memory://", ) # TTS Cache Configuration - Read from environment TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10")) CACHE_AUDIO_SUBDIR = "cache" tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at} tts_cache_lock = threading.Lock() SMOOTHING_FACTOR_MODEL_SELECTION = 100 # For weighted random model selection (lower = more exposure for new models) # Increased max_workers to 8 for concurrent generation/refill cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer') all_harvard_sentences = [] # Keep the full list available # Create temp directories TEMP_AUDIO_DIR = os.path.join(tempfile.gettempdir(), "tts_arena_audio") CACHE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, CACHE_AUDIO_SUBDIR) os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists # Store active TTS sessions app.tts_sessions = {} tts_sessions = app.tts_sessions # Register blueprints app.register_blueprint(auth, url_prefix="/auth") app.register_blueprint(admin) @login_manager.user_loader def load_user(user_id): return User.query.get(int(user_id)) @app.before_request def before_request(): g.user = current_user g.is_admin = is_admin(current_user) # Ensure HTTPS for HuggingFace Spaces environment if IS_SPACES and request.headers.get("X-Forwarded-Proto") == "http": url = request.url.replace("http://", "https://", 1) return redirect(url, code=301) # Check if Turnstile verification is required if app.config["TURNSTILE_ENABLED"]: # Exclude verification routes excluded_routes = ["verify_turnstile", "turnstile_page", "static"] if request.endpoint not in excluded_routes: # Check if user is verified if not session.get("turnstile_verified"): # Save original URL for redirect after verification redirect_url = request.url # Force HTTPS in HuggingFace Spaces if IS_SPACES and redirect_url.startswith("http://"): redirect_url = redirect_url.replace("http://", "https://", 1) # If it's an API request, return a JSON response if request.path.startswith("/api/"): return jsonify({"error": "Turnstile verification required"}), 403 # For regular requests, redirect to verification page return redirect(url_for("turnstile_page", redirect_url=redirect_url)) else: # Check if verification has expired (default: 24 hours) verification_timeout = ( int(os.getenv("TURNSTILE_TIMEOUT_HOURS", "24")) * 3600 ) # Convert hours to seconds verified_at = session.get("turnstile_verified_at", 0) current_time = datetime.utcnow().timestamp() if current_time - verified_at > verification_timeout: # Verification expired, clear status and redirect to verification page session.pop("turnstile_verified", None) session.pop("turnstile_verified_at", None) redirect_url = request.url # Force HTTPS in HuggingFace Spaces if IS_SPACES and redirect_url.startswith("http://"): redirect_url = redirect_url.replace("http://", "https://", 1) if request.path.startswith("/api/"): return jsonify({"error": "Turnstile verification expired"}), 403 return redirect( url_for("turnstile_page", redirect_url=redirect_url) ) @app.route("/turnstile", methods=["GET"]) def turnstile_page(): """Display Cloudflare Turnstile verification page""" redirect_url = request.args.get("redirect_url", url_for("arena", _external=True)) # Force HTTPS in HuggingFace Spaces if IS_SPACES and redirect_url.startswith("http://"): redirect_url = redirect_url.replace("http://", "https://", 1) return render_template( "turnstile.html", turnstile_site_key=app.config["TURNSTILE_SITE_KEY"], redirect_url=redirect_url, ) @app.route("/verify-turnstile", methods=["POST"]) def verify_turnstile(): """Verify Cloudflare Turnstile token""" token = request.form.get("cf-turnstile-response") redirect_url = request.form.get("redirect_url", url_for("arena", _external=True)) # Force HTTPS in HuggingFace Spaces if IS_SPACES and redirect_url.startswith("http://"): redirect_url = redirect_url.replace("http://", "https://", 1) if not token: # If AJAX request, return JSON error if request.headers.get("X-Requested-With") == "XMLHttpRequest": return ( jsonify({"success": False, "error": "Missing verification token"}), 400, ) # Otherwise redirect back to turnstile page return redirect(url_for("turnstile_page", redirect_url=redirect_url)) # Verify token with Cloudflare data = { "secret": app.config["TURNSTILE_SECRET_KEY"], "response": token, "remoteip": request.remote_addr, } try: response = requests.post(app.config["TURNSTILE_VERIFY_URL"], data=data) result = response.json() if result.get("success"): # Set verification status in session session["turnstile_verified"] = True session["turnstile_verified_at"] = datetime.utcnow().timestamp() # Determine response type based on request is_xhr = request.headers.get("X-Requested-With") == "XMLHttpRequest" accepts_json = "application/json" in request.headers.get("Accept", "") # If AJAX or JSON request, return success JSON if is_xhr or accepts_json: return jsonify({"success": True, "redirect": redirect_url}) # For regular form submissions, redirect to the target URL return redirect(redirect_url) else: # Verification failed app.logger.warning(f"Turnstile verification failed: {result}") # If AJAX request, return JSON error if request.headers.get("X-Requested-With") == "XMLHttpRequest": return jsonify({"success": False, "error": "Verification failed"}), 403 # Otherwise redirect back to turnstile page return redirect(url_for("turnstile_page", redirect_url=redirect_url)) except Exception as e: app.logger.error(f"Turnstile verification error: {str(e)}") # If AJAX request, return JSON error if request.headers.get("X-Requested-With") == "XMLHttpRequest": return ( jsonify( {"success": False, "error": "Server error during verification"} ), 500, ) # Otherwise redirect back to turnstile page return redirect(url_for("turnstile_page", redirect_url=redirect_url)) # Load Korean prompts from local JSON file print("Loading Korean TTS prompts from ko_prompts.json...") _prompts_path = os.path.join(os.path.dirname(__file__), "ko_prompts.json") with open(_prompts_path, "r", encoding="utf-8") as f: _prompts_data = json.load(f) all_harvard_sentences = _prompts_data.get("prompts", []) print(f"Loaded {len(all_harvard_sentences)} Korean prompts") # Initialize initial_sentences as empty - will be populated with unconsumed sentences only initial_sentences = [] @app.route("/") def arena(): # Pass a subset of sentences for the random button fallback return render_template("arena.html", harvard_sentences=json.dumps(initial_sentences)) @app.route("/leaderboard") def leaderboard(): # ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ๋ฆฌ๋”๋ณด๋“œ ๋ฐ์ดํ„ฐ tts_leaderboard_all = get_leaderboard_data(ModelType.TTS, 'all') tts_leaderboard_mixed_lang = get_leaderboard_data(ModelType.TTS, 'mixed_lang') tts_leaderboard_with_numbers = get_leaderboard_data(ModelType.TTS, 'with_numbers') tts_leaderboard_long_text = get_leaderboard_data(ModelType.TTS, 'long_text') voting_stats = get_voting_statistics() # Get voting statistics # Initialize personal leaderboard data tts_personal_leaderboard = None user_leaderboard_visibility = None # If user is logged in, get their personal leaderboard and visibility setting if current_user.is_authenticated: tts_personal_leaderboard = get_user_leaderboard(current_user.id, ModelType.TTS) user_leaderboard_visibility = current_user.show_in_leaderboard # Get key dates for the timeline tts_key_dates = get_key_historical_dates(ModelType.TTS) # Format dates for display in the dropdown formatted_tts_dates = [date.strftime("%B %Y") for date in tts_key_dates] return render_template( "leaderboard.html", tts_leaderboard=tts_leaderboard_all, tts_leaderboard_all=tts_leaderboard_all, tts_leaderboard_mixed_lang=tts_leaderboard_mixed_lang, tts_leaderboard_with_numbers=tts_leaderboard_with_numbers, tts_leaderboard_long_text=tts_leaderboard_long_text, tts_personal_leaderboard=tts_personal_leaderboard, tts_key_dates=tts_key_dates, formatted_tts_dates=formatted_tts_dates, voting_stats=voting_stats, user_leaderboard_visibility=user_leaderboard_visibility ) @app.route("/api/historical-leaderboard/") def historical_leaderboard(model_type): """Get historical leaderboard data for a specific date""" if model_type != ModelType.TTS: return jsonify({"error": "Invalid model type"}), 400 # Get date from query parameter date_str = request.args.get("date") if not date_str: return jsonify({"error": "Date parameter is required"}), 400 try: # Parse date from URL parameter (format: YYYY-MM-DD) target_date = datetime.strptime(date_str, "%Y-%m-%d") # Get historical leaderboard data leaderboard_data = get_historical_leaderboard_data(model_type, target_date) return jsonify( {"date": target_date.strftime("%B %d, %Y"), "leaderboard": leaderboard_data} ) except ValueError: return jsonify({"error": "Invalid date format. Use YYYY-MM-DD"}), 400 @app.route("/about") def about(): return render_template("about.html") # --- TTS Caching Functions --- def generate_and_save_tts(text, model_id, output_dir): """Generates TTS and saves it to a specific directory, returning the full path.""" temp_audio_path = None # Initialize to None try: app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'") # If predict_tts saves file itself and returns path: temp_audio_path = predict_tts(text, model_id) app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}") if not temp_audio_path or not os.path.exists(temp_audio_path): app.logger.warning(f"[TTS Gen {model_id}] predict_tts failed or returned invalid path: {temp_audio_path}") raise ValueError("predict_tts did not return a valid path or file does not exist") file_uuid = str(uuid.uuid4()) dest_path = os.path.join(output_dir, f"{file_uuid}.wav") app.logger.debug(f"[TTS Gen {model_id}] Moving {temp_audio_path} to {dest_path}") # Move the file generated by predict_tts to the target cache directory shutil.move(temp_audio_path, dest_path) app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}") return dest_path except Exception as e: app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}") # Ensure temporary file from predict_tts (if any) is cleaned up if temp_audio_path and os.path.exists(temp_audio_path): try: app.logger.debug(f"[TTS Gen {model_id}] Cleaning up temporary file {temp_audio_path} after error.") os.remove(temp_audio_path) except OSError: pass # Ignore error if file couldn't be removed return None def _generate_cache_entry_task(sentence): """Task function to generate audio for a sentence and add to cache.""" # Wrap the entire task in an application context with app.app_context(): if not sentence: # Select a new sentence if not provided (for replacement) with tts_cache_lock: cached_keys = set(tts_cache.keys()) # Get sentences that are not already cached available_sentences = [s for s in all_harvard_sentences if s not in cached_keys] if not available_sentences: # All sentences are cached, pick any random one available_sentences = all_harvard_sentences sentence = random.choice(available_sentences) # app.logger.info removed duplicate log print(f"[Cache Task] Querying models for: '{sentence[:50]}...'") available_models = Model.query.filter_by( model_type=ModelType.TTS, is_active=True ).all() if len(available_models) < 2: app.logger.error("Not enough active TTS models to generate cache entry.") return try: models = get_weighted_random_models(available_models, 2, ModelType.TTS) model_a_id = models[0].id model_b_id = models[1].id # Generate audio concurrently using a local executor for clarity within the task with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor: future_a = audio_executor.submit(generate_and_save_tts, sentence, model_a_id, CACHE_AUDIO_DIR) future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR) timeout_seconds = 120 audio_a_path = future_a.result(timeout=timeout_seconds) audio_b_path = future_b.result(timeout=timeout_seconds) if audio_a_path and audio_b_path: with tts_cache_lock: # Only add if the sentence isn't already back in the cache # And ensure cache size doesn't exceed limit if sentence not in tts_cache and len(tts_cache) < TTS_CACHE_SIZE: tts_cache[sentence] = { "model_a": model_a_id, "model_b": model_b_id, "audio_a": audio_a_path, "audio_b": audio_b_path, "created_at": datetime.utcnow(), } app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'") elif sentence in tts_cache: app.logger.warning(f"Sentence '{sentence[:50]}...' already re-cached. Discarding new generation.") # Clean up the newly generated files if not added if os.path.exists(audio_a_path): os.remove(audio_a_path) if os.path.exists(audio_b_path): os.remove(audio_b_path) else: # Cache is full app.logger.warning(f"Cache is full ({len(tts_cache)} entries). Discarding new generation for '{sentence[:50]}...'.") # Clean up the newly generated files if not added if os.path.exists(audio_a_path): os.remove(audio_a_path) if os.path.exists(audio_b_path): os.remove(audio_b_path) else: app.logger.error(f"Failed to generate one or both audio files for cache: '{sentence[:50]}...'") # Clean up whichever file might have been created if audio_a_path and os.path.exists(audio_a_path): os.remove(audio_a_path) if audio_b_path and os.path.exists(audio_b_path): os.remove(audio_b_path) except Exception as e: # Log the exception within the app context app.logger.error(f"Exception in _generate_cache_entry_task for '{sentence[:50]}...': {str(e)}", exc_info=True) def update_initial_sentences(): """Update initial sentences for random selection.""" global initial_sentences try: if all_harvard_sentences: initial_sentences = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), 500)) print(f"Updated initial sentences with {len(initial_sentences)} sentences") else: print("Warning: No sentences available for initial selection") initial_sentences = [] except Exception as e: print(f"Error updating initial sentences: {e}") initial_sentences = [] def initialize_tts_cache(): print("Initializing TTS cache") """Selects initial sentences and starts generation tasks.""" with app.app_context(): # Ensure access to models if not all_harvard_sentences: app.logger.error("Harvard sentences not loaded. Cannot initialize cache.") return # Update initial sentences update_initial_sentences() # Select random sentences for initial cache population initial_selection = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), TTS_CACHE_SIZE)) app.logger.info(f"Initializing TTS cache with {len(initial_selection)} sentences...") for sentence in initial_selection: # Use the main cache_executor for initial population too cache_executor.submit(_generate_cache_entry_task, sentence) app.logger.info("Submitted initial cache generation tasks.") # --- End TTS Caching Functions --- @app.route("/api/tts/generate", methods=["POST"]) @limiter.limit("10 per minute") # Keep limit, cached responses are still requests def generate_tts(): # If verification not setup, handle it first if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): return jsonify({"error": "Turnstile verification required"}), 403 data = request.json text = data.get("text", "").strip() # Ensure text is stripped if not text or len(text) > 1000: return jsonify({"error": "Invalid or too long text"}), 400 # Check if text contains Korean (at least 30% Korean characters) if not is_korean_text(text): return jsonify({"error": "ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”. ์ตœ์†Œ 30% ์ด์ƒ์˜ ํ•œ๊ตญ์–ด๊ฐ€ ํฌํ•จ๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค."}), 400 # --- Cache Check --- cache_hit = False session_data_from_cache = None with tts_cache_lock: if text in tts_cache: cache_hit = True cached_entry = tts_cache.pop(text) # Remove from cache immediately app.logger.info(f"TTS Cache HIT for: '{text[:50]}...'") # Prepare session data using cached info session_id = str(uuid.uuid4()) session_data_from_cache = { "model_a": cached_entry["model_a"], "model_b": cached_entry["model_b"], "audio_a": cached_entry["audio_a"], # Paths are now from cache_dir "audio_b": cached_entry["audio_b"], "text": text, "created_at": datetime.utcnow(), "expires_at": datetime.utcnow() + timedelta(minutes=30), "voted": False, "cache_hit": True, } app.tts_sessions[session_id] = session_data_from_cache # Note: Sentence was already marked as consumed when it was cached # No need to mark it again here # --- Trigger background tasks to refill the cache --- # Calculate how many slots need refilling current_cache_size = len(tts_cache) # Size *before* adding potentially new items needed_refills = TTS_CACHE_SIZE - current_cache_size # Limit concurrent refills to 8 or the actual need refills_to_submit = min(needed_refills, 8) if refills_to_submit > 0: app.logger.info(f"Cache hit: Submitting {refills_to_submit} background task(s) to refill cache (current size: {current_cache_size}, target: {TTS_CACHE_SIZE}).") for _ in range(refills_to_submit): # Pass None to signal replacement selection within the task cache_executor.submit(_generate_cache_entry_task, None) else: app.logger.info(f"Cache hit: Cache is already full or at target size ({current_cache_size}/{TTS_CACHE_SIZE}). No refill tasks submitted.") # --- End Refill Trigger --- if cache_hit and session_data_from_cache: # Return response using cached data # Note: The files are now managed by the session lifecycle (cleanup_session) return jsonify( { "session_id": session_id, "audio_a": f"/api/tts/audio/{session_id}/a", "audio_b": f"/api/tts/audio/{session_id}/b", "expires_in": 1800, # 30 minutes in seconds "cache_hit": True, } ) # --- End Cache Check --- # --- Cache Miss: Generate on the fly --- app.logger.info(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.") available_models = Model.query.filter_by( model_type=ModelType.TTS, is_active=True ).all() if len(available_models) < 2: return jsonify({"error": "Not enough TTS models available"}), 500 selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS) # Track failed models and successful model IDs for fallback logic failed_model_ids = set() assigned_model_ids = set() # Track models currently assigned to any slot (including successful ones) max_retries_per_slot = 3 # Maximum retries for each model slot try: audio_files = [] model_ids = [] # Function to process a single model with fallback on error def process_model_with_fallback(initial_model, slot_index): current_model = initial_model # Check if initial model is already used by another slot, pick a different one if current_model.id in assigned_model_ids: remaining_models = [m for m in available_models if m.id not in failed_model_ids and m.id not in assigned_model_ids] if not remaining_models: raise ValueError(f"No available models for slot {slot_index}") current_model = get_weighted_random_models(remaining_models, 1, ModelType.TTS)[0] app.logger.info(f"Slot {slot_index}: Initial model already used, switched to {current_model.id}") # Mark this model as assigned to this slot assigned_model_ids.add(current_model.id) retries = 0 while retries < max_retries_per_slot: try: # Generate and save directly to the main temp dir temp_audio_path = predict_tts(text, current_model.id) if not temp_audio_path or not os.path.exists(temp_audio_path): raise ValueError(f"predict_tts failed for model {current_model.id}") # Create a unique name in the main TEMP_AUDIO_DIR for the session file_uuid = str(uuid.uuid4()) dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav") shutil.move(temp_audio_path, dest_path) return {"model_id": current_model.id, "audio_path": dest_path} except Exception as e: app.logger.warning(f"TTS generation failed for model {current_model.id}: {str(e)}") failed_model_ids.add(current_model.id) retries += 1 # Try to select a different model (exclude failed and already assigned models) remaining_models = [m for m in available_models if m.id not in failed_model_ids and m.id not in assigned_model_ids] if not remaining_models: app.logger.error(f"No more models available after {retries} retries for slot {slot_index}") raise ValueError(f"No more models available for fallback") # Select a new model weighted randomly current_model = get_weighted_random_models(remaining_models, 1, ModelType.TTS)[0] assigned_model_ids.add(current_model.id) # Mark new model as assigned app.logger.info(f"Retrying with fallback model: {current_model.id} (attempt {retries + 1})") raise ValueError(f"All {max_retries_per_slot} retries exhausted for slot {slot_index}") # Process each model slot sequentially to allow proper fallback # (Sequential processing needed to track failed/assigned models across slots) for i, model in enumerate(selected_models): result = process_model_with_fallback(model, i) model_ids.append(result["model_id"]) audio_files.append(result["audio_path"]) # Create session session_id = str(uuid.uuid4()) app.tts_sessions[session_id] = { "model_a": model_ids[0], "model_b": model_ids[1], "audio_a": audio_files[0], # Paths are now from TEMP_AUDIO_DIR directly "audio_b": audio_files[1], "text": text, "created_at": datetime.utcnow(), "expires_at": datetime.utcnow() + timedelta(minutes=30), "voted": False, "cache_hit": False, } # Don't mark as consumed yet - wait until vote is submitted to maintain security # while allowing legitimate votes to count for ELO # Return audio file paths and session return jsonify( { "session_id": session_id, "audio_a": f"/api/tts/audio/{session_id}/a", "audio_b": f"/api/tts/audio/{session_id}/b", "expires_in": 1800, "cache_hit": False, } ) except Exception as e: app.logger.error(f"TTS on-the-fly generation error: {str(e)}", exc_info=True) # Cleanup any files potentially created during the failed attempt if 'results' in locals(): for res in results: if 'audio_path' in res and os.path.exists(res['audio_path']): try: os.remove(res['audio_path']) except OSError: pass return jsonify({"error": "Failed to generate TTS"}), 500 # --- End Cache Miss --- @app.route("/api/tts/audio//") def get_audio(session_id, model_key): # If verification not setup, handle it first if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): return jsonify({"error": "Turnstile verification required"}), 403 if session_id not in app.tts_sessions: return jsonify({"error": "Invalid or expired session"}), 404 session_data = app.tts_sessions[session_id] # Check if session expired if datetime.utcnow() > session_data["expires_at"]: cleanup_session(session_id) return jsonify({"error": "Session expired"}), 410 if model_key == "a": audio_path = session_data["audio_a"] elif model_key == "b": audio_path = session_data["audio_b"] else: return jsonify({"error": "Invalid model key"}), 400 # Check if file exists if not os.path.exists(audio_path): return jsonify({"error": "Audio file not found"}), 404 return send_file(audio_path, mimetype="audio/wav") @app.route("/api/tts/vote", methods=["POST"]) @limiter.limit("30 per minute") def submit_vote(): # If verification not setup, handle it first if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): return jsonify({"error": "Turnstile verification required"}), 403 # Security checks for vote manipulation prevention client_ip = get_client_ip() user_id = current_user.id if current_user.is_authenticated else None vote_allowed, security_reason, security_score = is_vote_allowed(user_id, client_ip) if not vote_allowed: username = current_user.username if current_user.is_authenticated else "anonymous" app.logger.warning(f"Vote blocked for user {username} (ID: {user_id}): {security_reason} (Score: {security_score})") return jsonify({"error": f"Vote not allowed: {security_reason}"}), 403 data = request.json session_id = data.get("session_id") chosen_model_key = data.get("chosen_model") # "a" or "b" if not session_id or session_id not in app.tts_sessions: return jsonify({"error": "Invalid or expired session"}), 404 if not chosen_model_key or chosen_model_key not in ["a", "b"]: return jsonify({"error": "Invalid chosen model"}), 400 session_data = app.tts_sessions[session_id] # Check if session expired if datetime.utcnow() > session_data["expires_at"]: cleanup_session(session_id) return jsonify({"error": "Session expired"}), 410 # Check if already voted if session_data["voted"]: return jsonify({"error": "Vote already submitted for this session"}), 400 # Get model IDs and audio paths chosen_id = ( session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"] ) rejected_id = ( session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"] ) chosen_audio_path = ( session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"] ) rejected_audio_path = ( session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"] ) # Calculate session duration and gather analytics data vote_time = datetime.utcnow() session_duration = (vote_time - session_data["created_at"]).total_seconds() client_ip = get_client_ip() user_agent = request.headers.get('User-Agent') cache_hit = session_data.get("cache_hit", False) # Record vote in database with analytics data vote, error = record_vote( user_id, session_data["text"], chosen_id, rejected_id, ModelType.TTS, session_duration=session_duration, ip_address=client_ip, user_agent=user_agent, generation_date=session_data["created_at"], cache_hit=cache_hit, all_dataset_sentences=all_harvard_sentences ) if error: return jsonify({"error": error}), 500 # Sentence consumption is now handled within record_vote function # --- Save preference data --- try: vote_uuid = str(uuid.uuid4()) vote_dir = os.path.join("./votes", vote_uuid) os.makedirs(vote_dir, exist_ok=True) # Copy audio files shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav")) shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav")) # Create metadata chosen_model_obj = Model.query.get(chosen_id) rejected_model_obj = Model.query.get(rejected_id) metadata = { "text": session_data["text"], "chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown", "chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown", "rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown", "rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown", "session_id": session_id, "timestamp": datetime.utcnow().isoformat(), "username": current_user.username if current_user.is_authenticated else "anonymous", "model_type": "TTS" } with open(os.path.join(vote_dir, "metadata.json"), "w") as f: json.dump(metadata, f, indent=2) except Exception as e: app.logger.error(f"Error saving preference data for vote {session_id}: {str(e)}") # Continue even if saving preference data fails, vote is already recorded # Mark session as voted session_data["voted"] = True # Check for coordinated voting campaigns (async to not slow down response) try: from threading import Thread campaign_check_thread = Thread(target=check_for_coordinated_campaigns) campaign_check_thread.daemon = True campaign_check_thread.start() except Exception as e: app.logger.error(f"Error starting coordinated campaign check thread: {str(e)}") # Return updated models (use previously fetched objects) return jsonify( { "success": True, "chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"}, "rejected_model": { "id": rejected_id, "name": rejected_model_obj.name if rejected_model_obj else "Unknown", }, "names": { "a": ( chosen_model_obj.name if chosen_model_key == "a" else rejected_model_obj.name if chosen_model_obj and rejected_model_obj else "Unknown" ), "b": ( rejected_model_obj.name if chosen_model_key == "a" else chosen_model_obj.name if chosen_model_obj and rejected_model_obj else "Unknown" ), }, } ) def cleanup_session(session_id): """Remove session and its audio files""" if session_id in app.tts_sessions: session = app.tts_sessions[session_id] # Remove audio files for audio_file in [session["audio_a"], session["audio_b"]]: if os.path.exists(audio_file): try: os.remove(audio_file) except Exception as e: app.logger.error(f"Error removing audio file: {str(e)}") # Remove session del app.tts_sessions[session_id] # Schedule periodic cleanup def setup_cleanup(): def cleanup_expired_sessions(): with app.app_context(): # Ensure app context for logging current_time = datetime.utcnow() # Cleanup TTS sessions expired_tts_sessions = [ sid for sid, session_data in app.tts_sessions.items() if current_time > session_data["expires_at"] ] for sid in expired_tts_sessions: cleanup_session(sid) app.logger.info(f"Cleaned up {len(expired_tts_sessions)} TTS sessions.") # Also cleanup potentially expired cache entries (e.g., > 1 hour old) # This prevents stale cache entries if generation is slow or failing # cleanup_stale_cache_entries() # Run cleanup every 15 minutes scheduler = BackgroundScheduler(daemon=True) # Run scheduler as daemon thread scheduler.add_job(cleanup_expired_sessions, "interval", minutes=15) scheduler.start() print("Cleanup scheduler started") # Use print for startup messages # Schedule periodic tasks (database sync and preference upload) def setup_periodic_tasks(): """Setup periodic database synchronization and preference data upload for Spaces""" if not IS_SPACES: return # Get database path from config (handles both persistent storage and fallback) db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "") preferences_repo_id = "channelcorp/ko-tts-arena-db" # Use same repo as database # Use the same repo for download and upload (consistency) database_repo_id = DATABASE_REPO_ID votes_dir = "./votes" def sync_database(): """Uploads the database to HF dataset (backup to cloud)""" with app.app_context(): # Ensure app context for logging try: if not os.path.exists(db_path): app.logger.warning(f"Database file not found at {db_path}, skipping sync.") return api = HfApi(token=os.getenv("HF_TOKEN")) api.upload_file( path_or_fileobj=db_path, path_in_repo="tts_arena.db", repo_id=database_repo_id, repo_type="dataset", ) app.logger.info(f"Database backed up to {database_repo_id} at {datetime.utcnow()}") except Exception as e: app.logger.error(f"Error uploading database to {database_repo_id}: {str(e)}") def sync_preferences_data(): """Zips and uploads preference data folders in batches to HF dataset""" with app.app_context(): # Ensure app context for logging if not os.path.isdir(votes_dir): return # Don't log every 5 mins if dir doesn't exist yet temp_batch_dir = None # Initialize to manage cleanup temp_individual_zip_dir = None # Initialize for individual zips local_batch_zip_path = None # Initialize for batch zip path try: api = HfApi(token=os.getenv("HF_TOKEN")) vote_uuids = [d for d in os.listdir(votes_dir) if os.path.isdir(os.path.join(votes_dir, d))] if not vote_uuids: return # No data to process app.logger.info(f"Found {len(vote_uuids)} vote directories to process.") # Create temporary directories temp_batch_dir = tempfile.mkdtemp(prefix="hf_batch_") temp_individual_zip_dir = tempfile.mkdtemp(prefix="hf_indiv_zips_") app.logger.debug(f"Created temp directories: {temp_batch_dir}, {temp_individual_zip_dir}") processed_vote_dirs = [] individual_zips_in_batch = [] # 1. Create individual zips and move them to the batch directory for vote_uuid in vote_uuids: dir_path = os.path.join(votes_dir, vote_uuid) individual_zip_base_path = os.path.join(temp_individual_zip_dir, vote_uuid) individual_zip_path = f"{individual_zip_base_path}.zip" try: shutil.make_archive(individual_zip_base_path, 'zip', dir_path) app.logger.debug(f"Created individual zip: {individual_zip_path}") # Move the created zip into the batch directory final_individual_zip_path = os.path.join(temp_batch_dir, f"{vote_uuid}.zip") shutil.move(individual_zip_path, final_individual_zip_path) app.logger.debug(f"Moved individual zip to batch dir: {final_individual_zip_path}") processed_vote_dirs.append(dir_path) # Mark original dir for later cleanup individual_zips_in_batch.append(final_individual_zip_path) except Exception as zip_err: app.logger.error(f"Error creating or moving zip for {vote_uuid}: {str(zip_err)}") # Clean up partial zip if it exists if os.path.exists(individual_zip_path): try: os.remove(individual_zip_path) except OSError: pass # Continue processing other votes # Clean up the temporary dir used for creating individual zips shutil.rmtree(temp_individual_zip_dir) temp_individual_zip_dir = None # Mark as cleaned app.logger.debug("Cleaned up temporary individual zip directory.") if not individual_zips_in_batch: app.logger.warning("No individual zips were successfully created for batching.") # Clean up batch dir if it's empty or only contains failed attempts if temp_batch_dir and os.path.exists(temp_batch_dir): shutil.rmtree(temp_batch_dir) temp_batch_dir = None return # 2. Create the batch zip file batch_timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") batch_uuid_short = str(uuid.uuid4())[:8] batch_zip_filename = f"{batch_timestamp}_batch_{batch_uuid_short}.zip" # Create batch zip in a standard temp location first local_batch_zip_base = os.path.join(tempfile.gettempdir(), batch_zip_filename.replace('.zip', '')) local_batch_zip_path = f"{local_batch_zip_base}.zip" app.logger.info(f"Creating batch zip: {local_batch_zip_path} with {len(individual_zips_in_batch)} individual zips.") shutil.make_archive(local_batch_zip_base, 'zip', temp_batch_dir) app.logger.info(f"Batch zip created successfully: {local_batch_zip_path}") # 3. Upload the batch zip file hf_repo_path = f"votes/{year}/{month}/{batch_zip_filename}" app.logger.info(f"Uploading batch zip to HF Hub: {preferences_repo_id}/{hf_repo_path}") api.upload_file( path_or_fileobj=local_batch_zip_path, path_in_repo=hf_repo_path, repo_id=preferences_repo_id, repo_type="dataset", commit_message=f"Add batch preference data {batch_zip_filename} ({len(individual_zips_in_batch)} votes)" ) app.logger.info(f"Successfully uploaded batch {batch_zip_filename} to {preferences_repo_id}") # 4. Cleanup after successful upload app.logger.info("Cleaning up local files after successful upload.") # Remove original vote directories that were successfully zipped and uploaded for dir_path in processed_vote_dirs: try: shutil.rmtree(dir_path) app.logger.debug(f"Removed original vote directory: {dir_path}") except OSError as e: app.logger.error(f"Error removing processed vote directory {dir_path}: {str(e)}") # Remove the temporary batch directory (containing the individual zips) shutil.rmtree(temp_batch_dir) temp_batch_dir = None app.logger.debug("Removed temporary batch directory.") # Remove the local batch zip file os.remove(local_batch_zip_path) local_batch_zip_path = None app.logger.debug("Removed local batch zip file.") app.logger.info(f"Finished preference data sync. Uploaded batch {batch_zip_filename}.") except Exception as e: app.logger.error(f"Error during preference data batch sync: {str(e)}", exc_info=True) # If upload failed, the local batch zip might exist, clean it up. if local_batch_zip_path and os.path.exists(local_batch_zip_path): try: os.remove(local_batch_zip_path) app.logger.debug("Cleaned up local batch zip after failed upload.") except OSError as clean_err: app.logger.error(f"Error cleaning up batch zip after failed upload: {clean_err}") # Do NOT remove temp_batch_dir if it exists; its contents will be retried next time. # Do NOT remove original vote directories if upload failed. finally: # Final cleanup for temporary directories in case of unexpected exits if temp_individual_zip_dir and os.path.exists(temp_individual_zip_dir): try: shutil.rmtree(temp_individual_zip_dir) except Exception as final_clean_err: app.logger.error(f"Error in final cleanup (indiv zips): {final_clean_err}") # Only clean up batch dir in finally block if it *wasn't* kept intentionally after upload failure if temp_batch_dir and os.path.exists(temp_batch_dir): # Check if an upload attempt happened and failed upload_failed = 'e' in locals() and isinstance(e, Exception) # Crude check if exception occurred if not upload_failed: # If no upload error or upload succeeded, clean up try: shutil.rmtree(temp_batch_dir) except Exception as final_clean_err: app.logger.error(f"Error in final cleanup (batch dir): {final_clean_err}") else: app.logger.warning("Keeping temporary batch directory due to upload failure for next attempt.") # Schedule periodic tasks scheduler = BackgroundScheduler() # Sync database less frequently if needed, e.g., every 15 minutes scheduler.add_job(sync_database, "interval", minutes=15, id="sync_db_job") # Sync preferences more frequently scheduler.add_job(sync_preferences_data, "interval", minutes=5, id="sync_pref_job") scheduler.start() print("Periodic tasks scheduler started (DB sync and Preferences upload)") # Use print for startup @app.cli.command("init-db") def init_db(): """Initialize the database.""" with app.app_context(): db.create_all() print("Database initialized!") @app.route("/api/toggle-leaderboard-visibility", methods=["POST"]) def toggle_leaderboard_visibility(): """Toggle whether the current user appears in the top voters leaderboard""" if not current_user.is_authenticated: return jsonify({"error": "You must be logged in to change this setting"}), 401 new_status = toggle_user_leaderboard_visibility(current_user.id) if new_status is None: return jsonify({"error": "User not found"}), 404 return jsonify({ "success": True, "visible": new_status, "message": "You are now visible in the voters leaderboard" if new_status else "You are now hidden from the voters leaderboard" }) @app.route("/api/tts/cached-sentences") def get_cached_sentences(): """Returns a list of sentences available for random selection.""" sentences = all_harvard_sentences.copy() # Limit the response size to avoid overwhelming the frontend max_sentences = 1000 if len(sentences) > max_sentences: sentences = random.sample(sentences, max_sentences) return jsonify(sentences) @app.route("/api/tts/sentence-stats") def get_sentence_stats(): """Returns statistics about available sentences.""" total_sentences = len(all_harvard_sentences) return jsonify({ "total_sentences": total_sentences, "available_sentences": total_sentences }) @app.route("/api/tts/random-sentence") def get_random_sentence(): """Returns a random sentence.""" if all_harvard_sentences: return jsonify({"sentence": random.choice(all_harvard_sentences)}) else: return jsonify({"error": "No sentences available"}), 404 def get_weighted_random_models( applicable_models: list[Model], num_to_select: int, model_type: ModelType ) -> list[Model]: """ Selects a specified number of models randomly from a list of applicable_models, weighting models with fewer votes higher. A smoothing factor is used to ensure the preference is slight and to prevent models with zero votes from being overwhelmingly favored. Models are selected without replacement. This ensures new models and models with fewer votes get more exposure, while still allowing matchups between models with different vote counts for better evaluation of new models against established ones. Assumes len(applicable_models) >= num_to_select, which should be checked by the caller. """ model_votes_counts = {} for model in applicable_models: votes = ( Vote.query.filter(Vote.model_type == model_type) .filter(or_(Vote.model_chosen == model.id, Vote.model_rejected == model.id)) .count() ) model_votes_counts[model.id] = votes weights = [ 1.0 / (model_votes_counts[model.id] + SMOOTHING_FACTOR_MODEL_SELECTION) for model in applicable_models ] selected_models_list = [] # Create copies to modify during selection process current_candidates = list(applicable_models) current_weights = list(weights) # Assumes num_to_select is positive and less than or equal to len(current_candidates) # Callers should ensure this (e.g., len(available_models) >= 2). for i in range(num_to_select): if not current_candidates: # Safety break app.logger.warning("Not enough candidates left for weighted selection.") break # Select first model with weighted random chosen_model = random.choices(current_candidates, weights=current_weights, k=1)[0] selected_models_list.append(chosen_model) try: idx_to_remove = current_candidates.index(chosen_model) current_candidates.pop(idx_to_remove) current_weights.pop(idx_to_remove) except ValueError: # This should ideally not happen if chosen_model came from current_candidates. app.logger.error(f"Error removing model {chosen_model.id} from weighted selection candidates.") break # Avoid potential issues return selected_models_list def check_for_coordinated_campaigns(): """Check all active models for potential coordinated voting campaigns""" try: from security import detect_coordinated_voting from models import Model, ModelType # Check TTS models tts_models = Model.query.filter_by(model_type=ModelType.TTS, is_active=True).all() for model in tts_models: try: detect_coordinated_voting(model.id) except Exception as e: app.logger.error(f"Error checking coordinated voting for TTS model {model.id}: {str(e)}") except Exception as e: app.logger.error(f"Error in coordinated campaign check: {str(e)}") if __name__ == "__main__": with app.app_context(): # Ensure directories exist if IS_SPACES and PERSISTENT_DATA_DIR: os.makedirs(PERSISTENT_DATA_DIR, exist_ok=True) else: os.makedirs("instance", exist_ok=True) os.makedirs("./votes", exist_ok=True) # Create votes directory if it doesn't exist os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache audio dir exists # Clean up old cache audio files on startup try: app.logger.info(f"Clearing old cache audio files from {CACHE_AUDIO_DIR}") for filename in os.listdir(CACHE_AUDIO_DIR): file_path = os.path.join(CACHE_AUDIO_DIR, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: app.logger.error(f'Failed to delete {file_path}. Reason: {e}') except Exception as e: app.logger.error(f"Error clearing cache directory {CACHE_AUDIO_DIR}: {e}") # Note: Database download is handled at module load time for HF Spaces # This ensures DB is ready before app initialization db.create_all() # Create tables if they don't exist insert_initial_models() # Setup background tasks initialize_tts_cache() # Start populating the cache setup_cleanup() setup_periodic_tasks() # Renamed function call # Configure Flask to recognize HTTPS when behind a reverse proxy from werkzeug.middleware.proxy_fix import ProxyFix # Apply ProxyFix middleware to handle reverse proxy headers # This ensures Flask generates correct URLs with https scheme # X-Forwarded-Proto header will be used to detect the original protocol app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) # Force Flask to prefer HTTPS for generated URLs app.config["PREFERRED_URL_SCHEME"] = "https" from waitress import serve # Configuration for 2 vCPUs: # - threads: typically 4-8 threads per CPU core is a good balance # - connection_limit: maximum concurrent connections # - channel_timeout: prevent hanging connections threads = 12 # 6 threads per vCPU is a good balance for mixed IO/CPU workloads if IS_SPACES: serve( app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), threads=threads, connection_limit=100, channel_timeout=30, url_scheme='https' ) else: port = int(os.environ.get("PORT", 5001)) print(f"Starting Waitress server with {threads} threads on port {port}") serve( app, host="0.0.0.0", port=port, threads=threads, connection_limit=100, channel_timeout=30, url_scheme='http' # Local dev uses http )