from fastapi import FastAPI, UploadFile, File, Form, Request, BackgroundTasks from transformers import pipeline from concurrent.futures import ThreadPoolExecutor import asyncio import os import time import httpx import torch import torchaudio from collections import defaultdict import io app = FastAPI() # ------------------------- # DEVICE # ------------------------- device = 0 if torch.cuda.is_available() else -1 asr = pipeline( "automatic-speech-recognition", model="openai/whisper-small", device=device) executor = ThreadPoolExecutor(max_workers=2) sessions = {} USERNAME = "velost" BASE = "stt" MAX_USERS = 2 SESSION_TIMEOUT = 20 BATCH_SIZE = 50 # Track active transcriptions active_transcriptions = 0 transcription_lock = asyncio.Lock() # ------------------------- # RATE LIMITING FOR /langs (GLOBAL) # ------------------------- langs_request_times = [] # List of timestamps for ALL requests langs_lock = asyncio.Lock() LANGS_RATE_LIMIT = 10000 # max total requests per second (global) async def reset_langs_counter(): """Reset the global counter every second""" global langs_request_times while True: await asyncio.sleep(1) async with langs_lock: now = time.time() langs_request_times = [t for t in langs_request_times if now - t < 1.0] # ------------------------- # CLEANUP # ------------------------- async def cleanup_sessions(): while True: now = time.time() for sid in list(sessions.keys()): if now - sessions[sid]["last_seen"] > SESSION_TIMEOUT: del sessions[sid] await asyncio.sleep(5) @app.on_event("startup") async def startup(): asyncio.create_task(cleanup_sessions()) asyncio.create_task(reset_langs_counter()) # ------------------------- # SPACE URL # ------------------------- def space_url(i: int): name = BASE if i == 0 else f"{BASE}{i}" return f"https://{USERNAME}-{name}.hf.space" # ------------------------- # CHECK SPACE # ------------------------- async def check(client, url): try: r = await client.get(f"{url}/status", timeout=3) if r.status_code == 404: return "404", url if r.status_code == 429: return "429", url if r.status_code != 200: return "bad", url return r.json().get("status"), url except Exception: return "error", url # ------------------------- # SPACE DISCOVERY # ------------------------- async def find_empty_space(): async with httpx.AsyncClient() as client: i = 0 while True: tasks = [] for j in range(i, i + BATCH_SIZE): tasks.append(check(client, space_url(j))) results = await asyncio.gather(*tasks) for status, url in results: if status == "empty": return {"url": url} i += BATCH_SIZE # ------------------------- # STATUS # ------------------------- @app.get("/status") async def status(): active = len(sessions) if active < MAX_USERS: return { "status": "empty", "active": active, "max": MAX_USERS } result = await find_empty_space() return { "status": "full", "redirect_to": result["url"] } # ------------------------- # LANGS ENDPOINT # ------------------------- @app.get("/langs") async def get_langs(request: Request): # GLOBAL rate limiting logic now = time.time() async with langs_lock: # Remove timestamps older than 1 second langs_request_times[:] = [t for t in langs_request_times if now - t < 1.0] # Check if global limit exceeded if len(langs_request_times) >= LANGS_RATE_LIMIT: # Rate limit exceeded: redirect to empty space result = await find_empty_space() return { "status": "full", "redirect_to": result["url"] } # Record this request langs_request_times.append(now) # Get all languages from Whisper pipeline # The tokenizer has the language information all_languages = [] # Access tokenizer's language mappings tokenizer = asr.tokenizer # Get all language codes and names from the tokenizer if hasattr(tokenizer, 'all_language_codes') and hasattr(tokenizer, 'all_language_tokens'): # Whisper tokenizer stores languages in these attributes language_codes = tokenizer.all_language_codes language_names = [tokenizer.decode([tok]) for tok in tokenizer.all_language_tokens] else: # Alternative: try to get languages from generation config or model config language_dict = {} if hasattr(asr.model, 'config'): if hasattr(asr.model.config, 'lang_to_id'): language_dict = asr.model.config.lang_to_id elif hasattr(asr.model.config, 'language_to_id'): language_dict = asr.model.config.language_to_id if language_dict: language_codes = list(language_dict.keys()) language_names = language_codes # fallback if names not available else: # Hardcoded full Whisper language list as ultimate fallback whisper_langs = { "en": "English", "zh": "Chinese", "de": "German", "es": "Spanish", "ru": "Russian", "ko": "Korean", "fr": "French", "ja": "Japanese", "pt": "Portuguese", "tr": "Turkish", "pl": "Polish", "ca": "Catalan", "nl": "Dutch", "ar": "Arabic", "sv": "Swedish", "it": "Italian", "id": "Indonesian", "hi": "Hindi", "fi": "Finnish", "vi": "Vietnamese", "he": "Hebrew", "uk": "Ukrainian", "el": "Greek", "ms": "Malay", "cs": "Czech", "ro": "Romanian", "da": "Danish", "hu": "Hungarian", "ta": "Tamil", "no": "Norwegian", "th": "Thai", "ur": "Urdu", "hr": "Croatian", "bg": "Bulgarian", "lt": "Lithuanian", "la": "Latin", "mi": "Maori", "ml": "Malayalam", "cy": "Welsh", "sk": "Slovak", "te": "Telugu", "fa": "Persian", "lv": "Latvian", "bn": "Bengali", "sr": "Serbian", "az": "Azerbaijani", "sl": "Slovenian", "kn": "Kannada", "et": "Estonian", "mk": "Macedonian", "br": "Breton", "eu": "Basque", "is": "Icelandic", "hy": "Armenian", "ne": "Nepali", "mn": "Mongolian", "bs": "Bosnian", "kk": "Kazakh", "sq": "Albanian", "sw": "Swahili", "gl": "Galician", "mr": "Marathi", "pa": "Punjabi", "si": "Sinhala", "km": "Khmer", "sn": "Shona", "yo": "Yoruba", "so": "Somali", "af": "Afrikaans", "oc": "Occitan", "ka": "Georgian", "be": "Belarusian", "tg": "Tajik", "sd": "Sindhi", "gu": "Gujarati", "am": "Amharic", "yi": "Yiddish", "lo": "Lao", "uz": "Uzbek", "fo": "Faroese", "ht": "Haitian Creole", "ps": "Pashto", "tk": "Turkmen", "nn": "Nynorsk", "mt": "Maltese", "sa": "Sanskrit", "lb": "Luxembourgish", "my": "Myanmar", "bo": "Tibetan", "tl": "Tagalog", "mg": "Malagasy", "tt": "Tatar", "haw": "Hawaiian", "ln": "Lingala", "ha": "Hausa", "ba": "Bashkir", "jw": "Javanese", "su": "Sundanese" } language_codes = list(whisper_langs.keys()) language_names = list(whisper_langs.values()) # Format the response languages = [] for code, name in zip(language_codes, language_names): languages.append({ "code": code, "name": name.strip().title() if name else code.upper() }) # Sort by code for consistency languages.sort(key=lambda x: x["code"]) return { "status": "ok", "count": len(languages), "languages": languages } # ------------------------- # TRANSCRIBE # ------------------------- def transcribe_file(audio_bytes, language): """ Transcribe audio entirely in memory using torchaudio. No disk writes or reads. """ try: # Convert bytes to in-memory buffer audio_buffer = io.BytesIO(audio_bytes) # Load audio directly from memory buffer # torchaudio.load returns (waveform, sample_rate) waveform, sample_rate = torchaudio.load(audio_buffer) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample to 16kHz if needed (Whisper expects 16kHz) if sample_rate != 16000: resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=16000 ) waveform = resampler(waveform) # Convert to numpy array (Whisper pipeline expects numpy array) audio_array = waveform.squeeze().numpy() # Transcribe using the pipeline directly with the array if language and language != "auto": result = asr( audio_array, generate_kwargs={ "task": "transcribe", "language": language } ) else: result = asr( audio_array, generate_kwargs={"task": "transcribe"} ) return result["text"] except Exception as e: # Fallback: if torchaudio fails, try with temp file (backup solution) import tempfile suffix = ".wav" # Default to wav with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(audio_bytes) path = tmp.name try: if language and language != "auto": return asr( path, generate_kwargs={ "task": "transcribe", "language": language } )["text"] return asr( path, generate_kwargs={"task": "transcribe"} )["text"] finally: if os.path.exists(path): os.remove(path) async def process_transcription(session_id, chunk_index, audio_bytes, language): """Background task to process transcription""" global active_transcriptions try: loop = asyncio.get_running_loop() text = await loop.run_in_executor( executor, lambda: transcribe_file(audio_bytes, language) ) if session_id in sessions: sessions[session_id]["transcript"].append({ "chunk": chunk_index, "text": text }) sessions[session_id]["last_seen"] = time.time() finally: # Decrease active transcriptions counter async with transcription_lock: active_transcriptions -= 1 @app.post("/transcribe") async def transcribe( background_tasks: BackgroundTasks, session_id: str = Form(...), chunk_index: int = Form(...), language: str = Form("auto"), audio: UploadFile = File(...) ): global active_transcriptions # Check if we can accept this transcription async with transcription_lock: # If session doesn't exist and we're at max users, reject if session_id not in sessions and len(sessions) >= MAX_USERS: result = await find_empty_space() return { "status": "full", "redirect_to": result["url"] } # If active transcriptions are at max, reject if active_transcriptions >= MAX_USERS: result = await find_empty_space() return { "status": "full", "redirect_to": result["url"] } # Accept the transcription active_transcriptions += 1 # Create session if new if session_id not in sessions: sessions[session_id] = { "transcript": [], "last_seen": time.time() } sessions[session_id]["last_seen"] = time.time() # Read audio into RAM audio_bytes = await audio.read() # Add transcription as background task background_tasks.add_task( process_transcription, session_id, chunk_index, audio_bytes, language ) # Return IMMEDIATELY before transcription starts return { "status": "processing", "message": "Transcription is in progress", "chunk_index": chunk_index, "session_id": session_id } # ------------------------- # TRANSCRIPT # ------------------------- @app.get("/transcript/{session_id}") def transcript(session_id: str): if session_id not in sessions: return {"error": "not found"} ordered = sorted( sessions[session_id]["transcript"], key=lambda x: x["chunk"] ) return { "text": " ".join(x["text"] for x in ordered), "chunks": ordered }