stt / app.py
blaze-aura69's picture
Update app.py
2160ae6 verified
from fastapi import FastAPI, UploadFile, File, Form, Request, BackgroundTasks
from transformers import pipeline
from concurrent.futures import ThreadPoolExecutor
import asyncio
import os
import time
import httpx
import torch
import torchaudio
from collections import defaultdict
import io
app = FastAPI()
# -------------------------
# DEVICE
# -------------------------
device = 0 if torch.cuda.is_available() else -1
asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=device)
executor = ThreadPoolExecutor(max_workers=2)
sessions = {}
USERNAME = "velost"
BASE = "stt"
MAX_USERS = 2
SESSION_TIMEOUT = 20
BATCH_SIZE = 50
# Track active transcriptions
active_transcriptions = 0
transcription_lock = asyncio.Lock()
# -------------------------
# RATE LIMITING FOR /langs (GLOBAL)
# -------------------------
langs_request_times = [] # List of timestamps for ALL requests
langs_lock = asyncio.Lock()
LANGS_RATE_LIMIT = 10000 # max total requests per second (global)
async def reset_langs_counter():
"""Reset the global counter every second"""
global langs_request_times
while True:
await asyncio.sleep(1)
async with langs_lock:
now = time.time()
langs_request_times = [t for t in langs_request_times if now - t < 1.0]
# -------------------------
# CLEANUP
# -------------------------
async def cleanup_sessions():
while True:
now = time.time()
for sid in list(sessions.keys()):
if now - sessions[sid]["last_seen"] > SESSION_TIMEOUT:
del sessions[sid]
await asyncio.sleep(5)
@app.on_event("startup")
async def startup():
asyncio.create_task(cleanup_sessions())
asyncio.create_task(reset_langs_counter())
# -------------------------
# SPACE URL
# -------------------------
def space_url(i: int):
name = BASE if i == 0 else f"{BASE}{i}"
return f"https://{USERNAME}-{name}.hf.space"
# -------------------------
# CHECK SPACE
# -------------------------
async def check(client, url):
try:
r = await client.get(f"{url}/status", timeout=3)
if r.status_code == 404:
return "404", url
if r.status_code == 429:
return "429", url
if r.status_code != 200:
return "bad", url
return r.json().get("status"), url
except Exception:
return "error", url
# -------------------------
# SPACE DISCOVERY
# -------------------------
async def find_empty_space():
async with httpx.AsyncClient() as client:
i = 0
while True:
tasks = []
for j in range(i, i + BATCH_SIZE):
tasks.append(check(client, space_url(j)))
results = await asyncio.gather(*tasks)
for status, url in results:
if status == "empty":
return {"url": url}
i += BATCH_SIZE
# -------------------------
# STATUS
# -------------------------
@app.get("/status")
async def status():
active = len(sessions)
if active < MAX_USERS:
return {
"status": "empty",
"active": active,
"max": MAX_USERS
}
result = await find_empty_space()
return {
"status": "full",
"redirect_to": result["url"]
}
# -------------------------
# LANGS ENDPOINT
# -------------------------
@app.get("/langs")
async def get_langs(request: Request):
# GLOBAL rate limiting logic
now = time.time()
async with langs_lock:
# Remove timestamps older than 1 second
langs_request_times[:] = [t for t in langs_request_times if now - t < 1.0]
# Check if global limit exceeded
if len(langs_request_times) >= LANGS_RATE_LIMIT:
# Rate limit exceeded: redirect to empty space
result = await find_empty_space()
return {
"status": "full",
"redirect_to": result["url"]
}
# Record this request
langs_request_times.append(now)
# Get all languages from Whisper pipeline
# The tokenizer has the language information
all_languages = []
# Access tokenizer's language mappings
tokenizer = asr.tokenizer
# Get all language codes and names from the tokenizer
if hasattr(tokenizer, 'all_language_codes') and hasattr(tokenizer, 'all_language_tokens'):
# Whisper tokenizer stores languages in these attributes
language_codes = tokenizer.all_language_codes
language_names = [tokenizer.decode([tok]) for tok in tokenizer.all_language_tokens]
else:
# Alternative: try to get languages from generation config or model config
language_dict = {}
if hasattr(asr.model, 'config'):
if hasattr(asr.model.config, 'lang_to_id'):
language_dict = asr.model.config.lang_to_id
elif hasattr(asr.model.config, 'language_to_id'):
language_dict = asr.model.config.language_to_id
if language_dict:
language_codes = list(language_dict.keys())
language_names = language_codes # fallback if names not available
else:
# Hardcoded full Whisper language list as ultimate fallback
whisper_langs = {
"en": "English", "zh": "Chinese", "de": "German", "es": "Spanish",
"ru": "Russian", "ko": "Korean", "fr": "French", "ja": "Japanese",
"pt": "Portuguese", "tr": "Turkish", "pl": "Polish", "ca": "Catalan",
"nl": "Dutch", "ar": "Arabic", "sv": "Swedish", "it": "Italian",
"id": "Indonesian", "hi": "Hindi", "fi": "Finnish", "vi": "Vietnamese",
"he": "Hebrew", "uk": "Ukrainian", "el": "Greek", "ms": "Malay",
"cs": "Czech", "ro": "Romanian", "da": "Danish", "hu": "Hungarian",
"ta": "Tamil", "no": "Norwegian", "th": "Thai", "ur": "Urdu",
"hr": "Croatian", "bg": "Bulgarian", "lt": "Lithuanian", "la": "Latin",
"mi": "Maori", "ml": "Malayalam", "cy": "Welsh", "sk": "Slovak",
"te": "Telugu", "fa": "Persian", "lv": "Latvian", "bn": "Bengali",
"sr": "Serbian", "az": "Azerbaijani", "sl": "Slovenian", "kn": "Kannada",
"et": "Estonian", "mk": "Macedonian", "br": "Breton", "eu": "Basque",
"is": "Icelandic", "hy": "Armenian", "ne": "Nepali", "mn": "Mongolian",
"bs": "Bosnian", "kk": "Kazakh", "sq": "Albanian", "sw": "Swahili",
"gl": "Galician", "mr": "Marathi", "pa": "Punjabi", "si": "Sinhala",
"km": "Khmer", "sn": "Shona", "yo": "Yoruba", "so": "Somali",
"af": "Afrikaans", "oc": "Occitan", "ka": "Georgian", "be": "Belarusian",
"tg": "Tajik", "sd": "Sindhi", "gu": "Gujarati", "am": "Amharic",
"yi": "Yiddish", "lo": "Lao", "uz": "Uzbek", "fo": "Faroese",
"ht": "Haitian Creole", "ps": "Pashto", "tk": "Turkmen", "nn": "Nynorsk",
"mt": "Maltese", "sa": "Sanskrit", "lb": "Luxembourgish", "my": "Myanmar",
"bo": "Tibetan", "tl": "Tagalog", "mg": "Malagasy", "tt": "Tatar",
"haw": "Hawaiian", "ln": "Lingala", "ha": "Hausa", "ba": "Bashkir",
"jw": "Javanese", "su": "Sundanese"
}
language_codes = list(whisper_langs.keys())
language_names = list(whisper_langs.values())
# Format the response
languages = []
for code, name in zip(language_codes, language_names):
languages.append({
"code": code,
"name": name.strip().title() if name else code.upper()
})
# Sort by code for consistency
languages.sort(key=lambda x: x["code"])
return {
"status": "ok",
"count": len(languages),
"languages": languages
}
# -------------------------
# TRANSCRIBE
# -------------------------
def transcribe_file(audio_bytes, language):
"""
Transcribe audio entirely in memory using torchaudio.
No disk writes or reads.
"""
try:
# Convert bytes to in-memory buffer
audio_buffer = io.BytesIO(audio_bytes)
# Load audio directly from memory buffer
# torchaudio.load returns (waveform, sample_rate)
waveform, sample_rate = torchaudio.load(audio_buffer)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz if needed (Whisper expects 16kHz)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=16000
)
waveform = resampler(waveform)
# Convert to numpy array (Whisper pipeline expects numpy array)
audio_array = waveform.squeeze().numpy()
# Transcribe using the pipeline directly with the array
if language and language != "auto":
result = asr(
audio_array,
generate_kwargs={
"task": "transcribe",
"language": language
}
)
else:
result = asr(
audio_array,
generate_kwargs={"task": "transcribe"}
)
return result["text"]
except Exception as e:
# Fallback: if torchaudio fails, try with temp file (backup solution)
import tempfile
suffix = ".wav" # Default to wav
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(audio_bytes)
path = tmp.name
try:
if language and language != "auto":
return asr(
path,
generate_kwargs={
"task": "transcribe",
"language": language
}
)["text"]
return asr(
path,
generate_kwargs={"task": "transcribe"}
)["text"]
finally:
if os.path.exists(path):
os.remove(path)
async def process_transcription(session_id, chunk_index, audio_bytes, language):
"""Background task to process transcription"""
global active_transcriptions
try:
loop = asyncio.get_running_loop()
text = await loop.run_in_executor(
executor,
lambda: transcribe_file(audio_bytes, language)
)
if session_id in sessions:
sessions[session_id]["transcript"].append({
"chunk": chunk_index,
"text": text
})
sessions[session_id]["last_seen"] = time.time()
finally:
# Decrease active transcriptions counter
async with transcription_lock:
active_transcriptions -= 1
@app.post("/transcribe")
async def transcribe(
background_tasks: BackgroundTasks,
session_id: str = Form(...),
chunk_index: int = Form(...),
language: str = Form("auto"),
audio: UploadFile = File(...)
):
global active_transcriptions
# Check if we can accept this transcription
async with transcription_lock:
# If session doesn't exist and we're at max users, reject
if session_id not in sessions and len(sessions) >= MAX_USERS:
result = await find_empty_space()
return {
"status": "full",
"redirect_to": result["url"]
}
# If active transcriptions are at max, reject
if active_transcriptions >= MAX_USERS:
result = await find_empty_space()
return {
"status": "full",
"redirect_to": result["url"]
}
# Accept the transcription
active_transcriptions += 1
# Create session if new
if session_id not in sessions:
sessions[session_id] = {
"transcript": [],
"last_seen": time.time()
}
sessions[session_id]["last_seen"] = time.time()
# Read audio into RAM
audio_bytes = await audio.read()
# Add transcription as background task
background_tasks.add_task(
process_transcription,
session_id,
chunk_index,
audio_bytes,
language
)
# Return IMMEDIATELY before transcription starts
return {
"status": "processing",
"message": "Transcription is in progress",
"chunk_index": chunk_index,
"session_id": session_id
}
# -------------------------
# TRANSCRIPT
# -------------------------
@app.get("/transcript/{session_id}")
def transcript(session_id: str):
if session_id not in sessions:
return {"error": "not found"}
ordered = sorted(
sessions[session_id]["transcript"],
key=lambda x: x["chunk"]
)
return {
"text": " ".join(x["text"] for x in ordered),
"chunks": ordered
}