| from fastapi import FastAPI, UploadFile, File, Form, Request, BackgroundTasks |
| from transformers import pipeline |
| from concurrent.futures import ThreadPoolExecutor |
| import asyncio |
| import os |
| import time |
| import httpx |
| import torch |
| import torchaudio |
| from collections import defaultdict |
| import io |
|
|
| app = FastAPI() |
|
|
| |
| |
| |
| device = 0 if torch.cuda.is_available() else -1 |
| asr = pipeline( |
| "automatic-speech-recognition", |
| model="openai/whisper-small", |
| device=device) |
|
|
| executor = ThreadPoolExecutor(max_workers=2) |
| sessions = {} |
| USERNAME = "velosts" |
| BASE = "stt" |
| MAX_USERS = 2 |
| SESSION_TIMEOUT = 20 |
| BATCH_SIZE = 50 |
|
|
| |
| active_transcriptions = 0 |
| transcription_lock = asyncio.Lock() |
|
|
| |
| |
| |
| langs_request_times = [] |
| langs_lock = asyncio.Lock() |
| LANGS_RATE_LIMIT = 10000 |
|
|
| async def reset_langs_counter(): |
| """Reset the global counter every second""" |
| global langs_request_times |
| while True: |
| await asyncio.sleep(1) |
| async with langs_lock: |
| now = time.time() |
| langs_request_times = [t for t in langs_request_times if now - t < 1.0] |
|
|
| |
| |
| |
| async def cleanup_sessions(): |
| while True: |
| now = time.time() |
| for sid in list(sessions.keys()): |
| if now - sessions[sid]["last_seen"] > SESSION_TIMEOUT: |
| del sessions[sid] |
| await asyncio.sleep(5) |
|
|
| @app.on_event("startup") |
| async def startup(): |
| asyncio.create_task(cleanup_sessions()) |
| asyncio.create_task(reset_langs_counter()) |
|
|
| |
| |
| |
| def space_url(i: int): |
| name = BASE if i == 0 else f"{BASE}{i}" |
| return f"https://{USERNAME}-{name}.hf.space" |
|
|
| |
| |
| |
| async def check(client, url): |
| try: |
| r = await client.get(f"{url}/status", timeout=3) |
| if r.status_code == 404: |
| return "404", url |
| if r.status_code == 429: |
| return "429", url |
| if r.status_code != 200: |
| return "bad", url |
| return r.json().get("status"), url |
| except Exception: |
| return "error", url |
|
|
| |
| |
| |
| async def find_empty_space(): |
| async with httpx.AsyncClient() as client: |
| i = 0 |
| while True: |
| tasks = [] |
| for j in range(i, i + BATCH_SIZE): |
| tasks.append(check(client, space_url(j))) |
| results = await asyncio.gather(*tasks) |
| for status, url in results: |
| if status == "empty": |
| return {"url": url} |
| i += BATCH_SIZE |
|
|
| |
| |
| |
| @app.get("/status") |
| async def status(): |
| active = len(sessions) |
| if active < MAX_USERS: |
| return { |
| "status": "empty", |
| "active": active, |
| "max": MAX_USERS |
| } |
| result = await find_empty_space() |
| return { |
| "status": "full", |
| "redirect_to": result["url"] |
| } |
|
|
| |
| |
| |
| @app.get("/langs") |
| async def get_langs(request: Request): |
| |
| now = time.time() |
| |
| async with langs_lock: |
| |
| langs_request_times[:] = [t for t in langs_request_times if now - t < 1.0] |
| |
| |
| if len(langs_request_times) >= LANGS_RATE_LIMIT: |
| |
| result = await find_empty_space() |
| return { |
| "status": "full", |
| "redirect_to": result["url"] |
| } |
| |
| |
| langs_request_times.append(now) |
| |
| |
| |
| all_languages = [] |
| |
| |
| tokenizer = asr.tokenizer |
| |
| |
| if hasattr(tokenizer, 'all_language_codes') and hasattr(tokenizer, 'all_language_tokens'): |
| |
| language_codes = tokenizer.all_language_codes |
| language_names = [tokenizer.decode([tok]) for tok in tokenizer.all_language_tokens] |
| else: |
| |
| language_dict = {} |
| if hasattr(asr.model, 'config'): |
| if hasattr(asr.model.config, 'lang_to_id'): |
| language_dict = asr.model.config.lang_to_id |
| elif hasattr(asr.model.config, 'language_to_id'): |
| language_dict = asr.model.config.language_to_id |
| |
| if language_dict: |
| language_codes = list(language_dict.keys()) |
| language_names = language_codes |
| else: |
| |
| whisper_langs = { |
| "en": "English", "zh": "Chinese", "de": "German", "es": "Spanish", |
| "ru": "Russian", "ko": "Korean", "fr": "French", "ja": "Japanese", |
| "pt": "Portuguese", "tr": "Turkish", "pl": "Polish", "ca": "Catalan", |
| "nl": "Dutch", "ar": "Arabic", "sv": "Swedish", "it": "Italian", |
| "id": "Indonesian", "hi": "Hindi", "fi": "Finnish", "vi": "Vietnamese", |
| "he": "Hebrew", "uk": "Ukrainian", "el": "Greek", "ms": "Malay", |
| "cs": "Czech", "ro": "Romanian", "da": "Danish", "hu": "Hungarian", |
| "ta": "Tamil", "no": "Norwegian", "th": "Thai", "ur": "Urdu", |
| "hr": "Croatian", "bg": "Bulgarian", "lt": "Lithuanian", "la": "Latin", |
| "mi": "Maori", "ml": "Malayalam", "cy": "Welsh", "sk": "Slovak", |
| "te": "Telugu", "fa": "Persian", "lv": "Latvian", "bn": "Bengali", |
| "sr": "Serbian", "az": "Azerbaijani", "sl": "Slovenian", "kn": "Kannada", |
| "et": "Estonian", "mk": "Macedonian", "br": "Breton", "eu": "Basque", |
| "is": "Icelandic", "hy": "Armenian", "ne": "Nepali", "mn": "Mongolian", |
| "bs": "Bosnian", "kk": "Kazakh", "sq": "Albanian", "sw": "Swahili", |
| "gl": "Galician", "mr": "Marathi", "pa": "Punjabi", "si": "Sinhala", |
| "km": "Khmer", "sn": "Shona", "yo": "Yoruba", "so": "Somali", |
| "af": "Afrikaans", "oc": "Occitan", "ka": "Georgian", "be": "Belarusian", |
| "tg": "Tajik", "sd": "Sindhi", "gu": "Gujarati", "am": "Amharic", |
| "yi": "Yiddish", "lo": "Lao", "uz": "Uzbek", "fo": "Faroese", |
| "ht": "Haitian Creole", "ps": "Pashto", "tk": "Turkmen", "nn": "Nynorsk", |
| "mt": "Maltese", "sa": "Sanskrit", "lb": "Luxembourgish", "my": "Myanmar", |
| "bo": "Tibetan", "tl": "Tagalog", "mg": "Malagasy", "tt": "Tatar", |
| "haw": "Hawaiian", "ln": "Lingala", "ha": "Hausa", "ba": "Bashkir", |
| "jw": "Javanese", "su": "Sundanese" |
| } |
| language_codes = list(whisper_langs.keys()) |
| language_names = list(whisper_langs.values()) |
| |
| |
| languages = [] |
| for code, name in zip(language_codes, language_names): |
| languages.append({ |
| "code": code, |
| "name": name.strip().title() if name else code.upper() |
| }) |
| |
| |
| languages.sort(key=lambda x: x["code"]) |
| |
| return { |
| "status": "ok", |
| "count": len(languages), |
| "languages": languages |
| } |
|
|
| |
| |
| |
| def transcribe_file(audio_bytes, language): |
| """ |
| Transcribe audio entirely in memory using torchaudio. |
| No disk writes or reads. |
| """ |
| try: |
| |
| audio_buffer = io.BytesIO(audio_bytes) |
| |
| |
| |
| waveform, sample_rate = torchaudio.load(audio_buffer) |
| |
| |
| if waveform.shape[0] > 1: |
| waveform = torch.mean(waveform, dim=0, keepdim=True) |
| |
| |
| if sample_rate != 16000: |
| resampler = torchaudio.transforms.Resample( |
| orig_freq=sample_rate, |
| new_freq=16000 |
| ) |
| waveform = resampler(waveform) |
| |
| |
| audio_array = waveform.squeeze().numpy() |
| |
| |
| if language and language != "auto": |
| result = asr( |
| audio_array, |
| generate_kwargs={ |
| "task": "transcribe", |
| "language": language |
| } |
| ) |
| else: |
| result = asr( |
| audio_array, |
| generate_kwargs={"task": "transcribe"} |
| ) |
| |
| return result["text"] |
| except Exception as e: |
| |
| import tempfile |
| suffix = ".wav" |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| tmp.write(audio_bytes) |
| path = tmp.name |
| |
| try: |
| if language and language != "auto": |
| return asr( |
| path, |
| generate_kwargs={ |
| "task": "transcribe", |
| "language": language |
| } |
| )["text"] |
| return asr( |
| path, |
| generate_kwargs={"task": "transcribe"} |
| )["text"] |
| finally: |
| if os.path.exists(path): |
| os.remove(path) |
|
|
| async def process_transcription(session_id, chunk_index, audio_bytes, language): |
| """Background task to process transcription""" |
| global active_transcriptions |
| |
| try: |
| loop = asyncio.get_running_loop() |
| text = await loop.run_in_executor( |
| executor, |
| lambda: transcribe_file(audio_bytes, language) |
| ) |
| |
| if session_id in sessions: |
| sessions[session_id]["transcript"].append({ |
| "chunk": chunk_index, |
| "text": text |
| }) |
| sessions[session_id]["last_seen"] = time.time() |
| finally: |
| |
| async with transcription_lock: |
| active_transcriptions -= 1 |
|
|
| @app.post("/transcribe") |
| async def transcribe( |
| background_tasks: BackgroundTasks, |
| session_id: str = Form(...), |
| chunk_index: int = Form(...), |
| language: str = Form("auto"), |
| audio: UploadFile = File(...) |
| ): |
| global active_transcriptions |
| |
| |
| async with transcription_lock: |
| |
| if session_id not in sessions and len(sessions) >= MAX_USERS: |
| result = await find_empty_space() |
| return { |
| "status": "full", |
| "redirect_to": result["url"] |
| } |
| |
| |
| if active_transcriptions >= MAX_USERS: |
| result = await find_empty_space() |
| return { |
| "status": "full", |
| "redirect_to": result["url"] |
| } |
| |
| |
| active_transcriptions += 1 |
| |
| |
| if session_id not in sessions: |
| sessions[session_id] = { |
| "transcript": [], |
| "last_seen": time.time() |
| } |
| |
| sessions[session_id]["last_seen"] = time.time() |
| |
| |
| audio_bytes = await audio.read() |
| |
| |
| background_tasks.add_task( |
| process_transcription, |
| session_id, |
| chunk_index, |
| audio_bytes, |
| language |
| ) |
| |
| |
| return { |
| "status": "processing", |
| "message": "Transcription is in progress", |
| "chunk_index": chunk_index, |
| "session_id": session_id |
| } |
|
|
| |
| |
| |
| @app.get("/transcript/{session_id}") |
| def transcript(session_id: str): |
| if session_id not in sessions: |
| return {"error": "not found"} |
| |
| ordered = sorted( |
| sessions[session_id]["transcript"], |
| key=lambda x: x["chunk"] |
| ) |
| |
| return { |
| "text": " ".join(x["text"] for x in ordered), |
| "chunks": ordered |
| } |