import os import json import tempfile import time import subprocess import ssl import threading import functools from pathlib import Path from flask import Flask, request, jsonify from flask_cors import CORS from werkzeug.utils import secure_filename from allosaurus.app import read_recognizer app = Flask(__name__) CORS(app) CACHE_DIR = "/tmp/cache" UPLOAD_FOLDER = 'uploads' ALLOWED_EXTENSIONS = {'wav', 'ogg', 'mp3', 'm4a'} os.makedirs("/tmp/uploads", exist_ok=True) os.makedirs("/tmp/cache", exist_ok=True) # Disable SSL verification for model download ssl._create_default_https_context = ssl._create_unverified_context os.environ['PYTHONHTTPSVERIFY'] = '0' import torch # Preload the model at server startup print("Preloading Allosaurus model...") device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") MODEL = read_recognizer(alt_model_path=Path("/tmp/allosaurus_models")) # Create a phoneme to viseme mapping dictionary for faster lookups PHONEME_MAP = {} vowels = ['a', 'e', 'i', 'o', 'u', 'æ', 'ɑ', 'ɒ', 'ɔ', 'ɛ', 'ɜ', 'ɪ', 'ʊ', 'ʌ', 'ə', 'ɐ'] bilabials = ['b', 'p', 'm'] labiodentals = ['f', 'v'] dentals = ['θ', 'ð'] alveolars = ['t', 'd', 'n', 's', 'z', 'l', 'r'] palatals = ['ʃ', 'ʒ', 'j', 'tʃ', 'dʒ'] velars = ['k', 'g', 'ŋ', 'x'] # Build the mapping dictionary for p in bilabials: PHONEME_MAP[p] = 'A' # MBP for p in labiodentals + dentals: PHONEME_MAP[p] = 'G' # FV for p in alveolars: if p == 'l': PHONEME_MAP[p] = 'H' # L else: PHONEME_MAP[p] = 'B' # etc for p in palatals + velars: PHONEME_MAP[p] = 'B' # etc for p in vowels: if p in ['a', 'æ', 'ɑ', 'ɒ']: PHONEME_MAP[p] = 'D' # AI elif p in ['e', 'ɛ', 'ɪ', 'i']: PHONEME_MAP[p] = 'C' # E elif p in ['o', 'ɔ', 'ʌ', 'ə', 'ɐ', 'ɜ']: PHONEME_MAP[p] = 'E' # O elif p in ['u', 'ʊ']: PHONEME_MAP[p] = 'F' # U else: PHONEME_MAP[p] = 'C' # Default vowel # Cache for processed results RESULT_CACHE = {} def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def convert_to_wav(input_file): output_file = os.path.splitext(input_file)[0] + '.wav' try: subprocess.run(['ffmpeg', '-i', input_file, '-acodec', 'pcm_s16le', '-ar', '16000', output_file], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return output_file except subprocess.CalledProcessError: return None def map_phoneme_to_viseme(phoneme): # Fast lookup using the precomputed dictionary return PHONEME_MAP.get(phoneme, 'X') # Default to 'X' if not found def get_file_hash(filepath): # Simple file hash based on file size and modification time stat = os.stat(filepath) return f"{stat.st_size}_{stat.st_mtime}" def process_audio_with_allosaurus(audio_file): # Optimized cache check with LRU eviction file_hash = get_file_hash(audio_file) cache_key = f"{os.path.basename(audio_file)}_{file_hash}" if cache_key in RESULT_CACHE: # Move to front of cache for LRU result = RESULT_CACHE.pop(cache_key) RESULT_CACHE[cache_key] = result return result start_time = time.time() # Convert to WAV if not already in WAV format if not audio_file.lower().endswith('.wav'): wav_file = convert_to_wav(audio_file) if not wav_file: return None audio_file = wav_file # Recognize phonemes using the preloaded model if device == 'cuda': with torch.no_grad(): phonemes = MODEL.recognize(audio_file, timestamp=True) else: phonemes = MODEL.recognize(audio_file, timestamp=True) # Process the phonemes into visemes mouth_cues = [] # Parse the phoneme output lines = phonemes.strip().split('\n') for line in lines: parts = line.split() if len(parts) >= 3: start_time_val = float(parts[0]) duration = float(parts[1]) phoneme = parts[2] # Map phoneme to viseme using the fast lookup viseme = map_phoneme_to_viseme(phoneme) # Calculate end time end_time_val = start_time_val + duration # Add to mouth cues mouth_cues.append({ "start": round(start_time_val, 2), "end": round(end_time_val, 2), "value": viseme }) # Add rest position at the beginning if needed if mouth_cues and mouth_cues[0]["start"] > 0: mouth_cues.insert(0, { "start": 0, "end": mouth_cues[0]["start"], "value": "X" }) # Get audio duration try: result = subprocess.run(['ffprobe', '-v', 'error', '-show_entries', 'format=duration', '-of', 'default=noprint_wrappers=1:nokey=1', audio_file], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) duration = float(result.stdout.strip()) except: # If ffprobe fails, estimate duration from the last phoneme duration = mouth_cues[-1]["end"] if mouth_cues else 0 # Add rest position at the end if needed if mouth_cues and mouth_cues[-1]["end"] < duration: mouth_cues.append({ "start": mouth_cues[-1]["end"], "end": duration, "value": "X" }) # Create result in the same format as Rhubarb for compatibility result = { "metadata": { "soundFile": audio_file, "duration": duration }, "mouthCues": mouth_cues } # Cache with size limit (100 items) if len(RESULT_CACHE) >= 100: RESULT_CACHE.pop(next(iter(RESULT_CACHE))) RESULT_CACHE[cache_key] = result processing_time = time.time() - start_time print(f"Processing completed in {processing_time:.2f} seconds") return result @app.route('/api/viseme', methods=['POST']) def generate_viseme(): if 'file' not in request.files: return jsonify({'error': 'No file part'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': 'No selected file'}), 400 if file and allowed_file(file.filename): filename = secure_filename(file.filename) filepath = os.path.join(UPLOAD_FOLDER, filename) file.save(filepath) result = process_audio_with_allosaurus(filepath) # Don't delete the file immediately to allow caching to work # We'll clean up old files periodically if result: return jsonify(result) else: return jsonify({'error': 'Failed to process audio file'}), 500 return jsonify({'error': 'File type not allowed'}), 400 @app.route('/api/status', methods=['GET']) def status(): return jsonify({ 'status': 'ok', 'model_loaded': MODEL is not None, 'cache_size': len(RESULT_CACHE), 'supported_formats': list(ALLOWED_EXTENSIONS) }) @app.route('/health', methods=['GET']) def health_check(): return jsonify({'status': 'ok'}) def cleanup_old_files(): # Clean up files older than 1 hour now = time.time() for filename in os.listdir(UPLOAD_FOLDER): filepath = os.path.join(UPLOAD_FOLDER, filename) if os.path.isfile(filepath) and now - os.path.getmtime(filepath) > 3600: os.unlink(filepath) if __name__ == '__main__': # Start a background thread to clean up old files cleanup_thread = threading.Thread(target=lambda: ( time.sleep(3600), # Run every hour cleanup_old_files() )) cleanup_thread.daemon = True cleanup_thread.start() # Configure hot reload with increased watcher sensitivity app.run(host='0.0.0.0', port=7860, debug=True, use_reloader=True, reloader_type='stat', extra_files=['./requirements.txt'], reloader_interval=1)