Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| import time | |
| import asyncio | |
| from typing import List, Dict, Any, Optional | |
| from concurrent.futures import ThreadPoolExecutor | |
| import torch | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
| import librosa | |
| import numpy as np | |
| from fastapi.responses import JSONResponse | |
| import gc | |
| # Initialize thread pool for background processing | |
| thread_pool = ThreadPoolExecutor(max_workers=2) | |
| # Environment and model configuration | |
| MODEL_NAME = "nyrahealth/CrisperWhisper" | |
| BATCH_SIZE = 8 | |
| FILE_LIMIT_MB = 30 | |
| FILE_EXTENSIONS = [".mp3", ".wav", ".m4a", ".ogg", ".flac"] | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Speech to Text API", | |
| description="API for transcribing audio files using the CrisperWhisper model", | |
| version="1.0.0" | |
| ) | |
| # Add CORS support | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Response models | |
| class TranscriptionChunk(BaseModel): | |
| timestamp: List[float] | |
| text: str | |
| class TranscriptionResponse(BaseModel): | |
| text: str | |
| chunks: List[TranscriptionChunk] | |
| # Setup device and load model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Load model and processor at startup | |
| async def load_model(): | |
| global processor, model | |
| print("Loading model and processor...") | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) | |
| model.to(device) | |
| print("Model loaded successfully!") | |
| def load_audio(file_path: str) -> tuple: | |
| """Load audio file efficiently""" | |
| try: | |
| # Use a faster sr=None first to get the original sampling rate, | |
| # then resample only if needed | |
| audio_array, orig_sr = librosa.load(file_path, sr=None, mono=True) | |
| # Resample only if needed | |
| if orig_sr != 16000: | |
| audio_array = librosa.resample(audio_array, orig_sr=orig_sr, target_sr=16000) | |
| sampling_rate = 16000 | |
| else: | |
| sampling_rate = orig_sr | |
| # Convert to float32 if needed | |
| if audio_array.dtype != np.float32: | |
| audio_array = audio_array.astype(np.float32) | |
| return audio_array, sampling_rate | |
| except Exception as e: | |
| print(f"Error loading audio: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error loading audio: {str(e)}") | |
| def process_audio_file(file_path: str) -> Dict: | |
| """Process audio file and return transcription with timestamps""" | |
| try: | |
| # Load audio file efficiently | |
| audio_array, sampling_rate = load_audio(file_path) | |
| # Process with model | |
| inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt") | |
| inputs = {key: value.to(device) for key, value in inputs.items()} | |
| # Generate transcription with word timestamps | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| return_timestamps=True, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| max_new_tokens=256 if len(audio_array) < 160000 else 512, # Adjust based on audio length | |
| num_beams=1, # Use greedy decoding for speed | |
| ) | |
| # Extract timestamps and words | |
| result = processor.decode(outputs.sequences[0], skip_special_tokens=False, output_word_offsets=True) | |
| words_with_timestamps = [] | |
| for word in result.word_offsets: | |
| words_with_timestamps.append({ | |
| "text": word["word"].strip(), | |
| "timestamp": [ | |
| round(word["start_offset"] / sampling_rate, 2), | |
| round(word["end_offset"] / sampling_rate, 2) | |
| ] | |
| }) | |
| # Create final response format | |
| response_data = { | |
| "text": processor.decode(outputs.sequences[0], skip_special_tokens=True), | |
| "chunks": words_with_timestamps | |
| } | |
| # Manual garbage collection to free memory | |
| del inputs, outputs, result | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return response_data | |
| except Exception as e: | |
| print(f"Error processing audio: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}") | |
| async def process_in_background(file_path: str): | |
| """Process audio file in a background thread to prevent blocking""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(thread_pool, process_audio_file, file_path) | |
| async def transcribe_audio(file: UploadFile = File(...)): | |
| """ | |
| Transcribe an audio file to text with timestamps for each word. | |
| Accepts .mp3, .wav, .m4a, .ogg or .flac files up to 30MB. | |
| """ | |
| start_time = time.time() | |
| # Validate file extension | |
| file_ext = os.path.splitext(file.filename)[1].lower() | |
| if file_ext not in FILE_EXTENSIONS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file format. Supported formats: {', '.join(FILE_EXTENSIONS)}" | |
| ) | |
| # Create temp file to store upload | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file: | |
| # Get file content | |
| content = await file.read() | |
| # Check file size | |
| if len(content) > FILE_LIMIT_MB * 1024 * 1024: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File too large. Maximum size: {FILE_LIMIT_MB}MB" | |
| ) | |
| # Write to temp file | |
| temp_file.write(content) | |
| temp_file_path = temp_file.name | |
| try: | |
| # Process the audio file in background to prevent blocking | |
| result = await process_in_background(temp_file_path) | |
| processing_time = time.time() - start_time | |
| print(f"Processing completed in {processing_time:.2f} seconds") | |
| return JSONResponse(content=result) | |
| finally: | |
| # Clean up the temp file | |
| if os.path.exists(temp_file_path): | |
| try: | |
| os.unlink(temp_file_path) | |
| except Exception as e: | |
| print(f"Error deleting temp file: {e}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy"} | |
| # Simple root endpoint that shows API is running | |
| async def root(): | |
| return { | |
| "message": "Speech-to-Text API is running", | |
| "endpoints": { | |
| "transcribe": "/transcribe (POST)", | |
| "health": "/health (GET)", | |
| "docs": "/docs (GET)" | |
| }, | |
| "model": MODEL_NAME, | |
| "device": device | |
| } | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port) |