Spaces:
Running
Running
| """ | |
| IndicConformer STT API for Hugging Face Spaces | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from transformers import AutoModel | |
| import torch | |
| import librosa | |
| import io | |
| import time | |
| import numpy as np | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| import os | |
| from huggingface_hub import login | |
| # Authenticate with Hugging Face | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| print("✓ Authenticated with Hugging Face") | |
| else: | |
| print("⚠ Warning: HF_TOKEN not found. Model loading may fail for gated repos.") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="IndicConformer STT API", | |
| description="Speech-to-Text API for 22 Indian languages", | |
| version="1.0" | |
| ) | |
| # Global variables | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL = None | |
| INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=10) | |
| # Audio chunking configuration | |
| CHUNK_DURATION = 30 | |
| OVERLAP_DURATION = 2 | |
| # Supported languages | |
| SUPPORTED_LANGUAGES = [ | |
| "as", "bn", "brx", "doi", "gu", "hi", "kn", "kok", | |
| "ks", "mai", "ml", "mni", "mr", "ne", "or", "pa", | |
| "sa", "sat", "sd", "ta", "te", "ur" | |
| ] | |
| async def load_model(): | |
| """Load model on startup""" | |
| global MODEL | |
| print("Loading IndicConformer model...") | |
| MODEL = AutoModel.from_pretrained( | |
| "ai4bharat/indic-conformer-600m-multilingual", | |
| trust_remote_code=True | |
| ) | |
| MODEL = MODEL.to(DEVICE) | |
| # Warm-up the model | |
| print("Warming up model...") | |
| dummy_audio = torch.randn(1, 16000).to(DEVICE) | |
| _ = MODEL(dummy_audio, "hi", "rnnt") | |
| print(f"Model loaded successfully on {DEVICE}") | |
| def split_audio_into_chunks(wav_np, sample_rate=16000, chunk_duration=30, overlap_duration=2): | |
| """Split audio into overlapping chunks""" | |
| chunk_samples = int(chunk_duration * sample_rate) | |
| overlap_samples = int(overlap_duration * sample_rate) | |
| step_samples = chunk_samples - overlap_samples | |
| chunks = [] | |
| total_samples = len(wav_np) | |
| start = 0 | |
| while start < total_samples: | |
| end = min(start + chunk_samples, total_samples) | |
| chunk = wav_np[start:end] | |
| chunks.append({ | |
| 'audio': chunk, | |
| 'start_time': start / sample_rate, | |
| 'end_time': end / sample_rate | |
| }) | |
| if end >= total_samples: | |
| break | |
| start += step_samples | |
| return chunks | |
| def merge_transcriptions_smart(transcriptions, max_overlap_words=10): | |
| """Merge chunk transcriptions with smart overlap removal""" | |
| if not transcriptions: | |
| return "" | |
| if len(transcriptions) == 1: | |
| return transcriptions[0].strip() | |
| result = transcriptions[0].strip() | |
| for i in range(1, len(transcriptions)): | |
| current = transcriptions[i].strip() | |
| if not current: | |
| continue | |
| result_words = result.split() | |
| current_words = current.split() | |
| max_check = min(len(result_words), len(current_words), max_overlap_words) | |
| best_overlap = 0 | |
| for overlap_size in range(max_check, 0, -1): | |
| if result_words[-overlap_size:] == current_words[:overlap_size]: | |
| best_overlap = overlap_size | |
| break | |
| if best_overlap > 0: | |
| result += " " + " ".join(current_words[best_overlap:]) | |
| else: | |
| result += " " + current | |
| return result | |
| def run_inference(wav, language): | |
| """Run model inference""" | |
| if DEVICE == 'cuda': | |
| torch.cuda.synchronize() | |
| transcription = MODEL(wav, language, "rnnt") | |
| if DEVICE == 'cuda': | |
| torch.cuda.synchronize() | |
| return transcription | |
| async def process_chunk(chunk_data, language, loop): | |
| """Process a single audio chunk""" | |
| wav_chunk = torch.tensor(chunk_data['audio']).unsqueeze(0) | |
| if DEVICE == 'cuda': | |
| wav_chunk = wav_chunk.to(DEVICE) | |
| transcription = await loop.run_in_executor( | |
| INFERENCE_EXECUTOR, | |
| run_inference, | |
| wav_chunk, | |
| language | |
| ) | |
| return transcription | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "IndicConformer STT API", | |
| "version": "1.0", | |
| "model": "ai4bharat/indic-conformer-600m-multilingual", | |
| "decoder": "RNNT", | |
| "parallel_workers": 10, | |
| "chunk_processing": True, | |
| "chunk_duration": CHUNK_DURATION, | |
| "overlap_duration": OVERLAP_DURATION, | |
| "max_audio_duration": "30 minutes", | |
| "supported_languages": SUPPORTED_LANGUAGES, | |
| "device": DEVICE, | |
| "endpoints": { | |
| "transcribe": "/transcribe", | |
| "health": "/health", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": MODEL is not None, | |
| "device": DEVICE, | |
| "parallel_enabled": True, | |
| "max_workers": 10 | |
| } | |
| async def transcribe_audio( | |
| file: UploadFile = File(...), | |
| language: str = Form(default="hi") | |
| ): | |
| """ | |
| Transcribe audio file (supports up to 30 minutes) | |
| Parameters: | |
| - file: Audio file (WAV, MP3, FLAC, M4A) | |
| - language: Language code (hi=Hindi, te=Telugu, bn=Bengali, etc.) | |
| Returns: | |
| - transcription: Transcribed text | |
| - metadata: Processing information | |
| """ | |
| try: | |
| # Validate file format | |
| if not file.filename.endswith(('.wav', '.mp3', '.flac', '.m4a')): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid file format. Supported: WAV, MP3, FLAC, M4A" | |
| ) | |
| # Validate language | |
| if language not in SUPPORTED_LANGUAGES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported language: {language}. Supported: {', '.join(SUPPORTED_LANGUAGES)}" | |
| ) | |
| # Read and process audio | |
| audio_bytes = await file.read() | |
| wav_np, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True) | |
| audio_duration = len(wav_np) / 16000 | |
| print(f"Processing audio: {audio_duration:.2f}s ({audio_duration/60:.1f} minutes)") | |
| # Check duration limit | |
| if audio_duration > 1800: # 30 minutes | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Audio too long: {audio_duration/60:.1f} minutes. Maximum: 30 minutes" | |
| ) | |
| # Split audio into chunks | |
| chunks = split_audio_into_chunks( | |
| wav_np, | |
| sample_rate=16000, | |
| chunk_duration=CHUNK_DURATION, | |
| overlap_duration=OVERLAP_DURATION | |
| ) | |
| print(f"Split into {len(chunks)} chunks") | |
| # Process chunks in parallel | |
| start_time = time.time() | |
| loop = asyncio.get_event_loop() | |
| tasks = [process_chunk(chunk, language, loop) for chunk in chunks] | |
| chunk_transcriptions = await asyncio.gather(*tasks) | |
| inference_time = time.time() - start_time | |
| rtf = inference_time / audio_duration | |
| # Merge transcriptions | |
| full_transcription = merge_transcriptions_smart(chunk_transcriptions) | |
| print(f"Completed in {inference_time:.2f}s (RTF: {rtf:.4f})") | |
| return JSONResponse({ | |
| "success": True, | |
| "transcription": full_transcription, | |
| "metadata": { | |
| "audio_duration": round(audio_duration, 2), | |
| "audio_duration_minutes": round(audio_duration / 60, 2), | |
| "inference_time": round(inference_time, 4), | |
| "rtf": round(rtf, 4), | |
| "language": language, | |
| "decoder": "rnnt", | |
| "num_chunks": len(chunks) | |
| } | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Transcription failed: {str(e)}" | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |