| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.responses import JSONResponse |
| import os |
| import tempfile |
| import shutil |
| from typing import Optional |
| import threading |
| import uvicorn |
|
|
| os.makedirs("/tmp/transformers_cache", exist_ok=True) |
| os.makedirs("/tmp/hf_home", exist_ok=True) |
| os.makedirs("/tmp/torch_home", exist_ok=True) |
|
|
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" |
| os.environ["HF_HOME"] = "/tmp/hf_home" |
| os.environ["TORCH_HOME"] = "/tmp/torch_home" |
|
|
| app = FastAPI( |
| title="Speech Transcription API", |
| description="API for transcribing speech using Whisper model", |
| version="1.0.0" |
| ) |
|
|
| model_loaded = False |
| transcriber = None |
|
|
| def load_model(): |
| """Load the model in a background thread""" |
| global transcriber, model_loaded |
| try: |
| from transformers import pipeline |
| transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-small") |
| model_loaded = True |
| print("Model loaded successfully!") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
|
|
| threading.Thread(target=load_model, daemon=True).start() |
|
|
| @app.get("/") |
| def read_root(): |
| return {"message": "Welcome to the Speech Transcription API. Use /transcribe endpoint to transcribe audio."} |
|
|
| @app.get("/health") |
| def health_check(): |
| """Health check endpoint""" |
| return {"status": "healthy", "model_loaded": model_loaded} |
|
|
| @app.post("/transcribe") |
| async def transcribe_audio(file: UploadFile = File(...), |
| return_format: Optional[str] = "json"): |
| global model_loaded, transcriber |
|
|
| if not model_loaded: |
| raise HTTPException( |
| status_code=503, |
| detail="Model is still loading, please try again in a few minutes" |
| ) |
|
|
| if not file.filename.lower().endswith(('.wav', '.mp3', '.m4a', '.ogg')): |
| raise HTTPException( |
| status_code=400, |
| detail="Unsupported file format. Please upload WAV, MP3, M4A, or OGG file." |
| ) |
| |
| with tempfile.NamedTemporaryFile(delete=False, dir="/tmp") as temp_file: |
| shutil.copyfileobj(file.file, temp_file) |
| temp_path = temp_file.name |
| |
| try: |
| result = transcriber(temp_path) |
| transcript = result["text"] |
| |
| os.unlink(temp_path) |
|
|
| if return_format and return_format.lower() == "text": |
| return transcript |
| else: |
| return JSONResponse(content={"transcript": transcript}) |
| |
| except Exception as e: |
| if os.path.exists(temp_path): |
| os.unlink(temp_path) |
| raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}") |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |