Whisper / app.py
Anchal23's picture
Update app.py
2f6fd0b verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import os
import tempfile
import shutil
from typing import Optional
import threading
import uvicorn
os.makedirs("/tmp/transformers_cache", exist_ok=True)
os.makedirs("/tmp/hf_home", exist_ok=True)
os.makedirs("/tmp/torch_home", exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
os.environ["HF_HOME"] = "/tmp/hf_home"
os.environ["TORCH_HOME"] = "/tmp/torch_home"
app = FastAPI(
title="Speech Transcription API",
description="API for transcribing speech using Whisper model",
version="1.0.0"
)
model_loaded = False
transcriber = None
def load_model():
"""Load the model in a background thread"""
global transcriber, model_loaded
try:
from transformers import pipeline
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-small")
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
threading.Thread(target=load_model, daemon=True).start()
@app.get("/")
def read_root():
return {"message": "Welcome to the Speech Transcription API. Use /transcribe endpoint to transcribe audio."}
@app.get("/health")
def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": model_loaded}
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...),
return_format: Optional[str] = "json"):
global model_loaded, transcriber
if not model_loaded:
raise HTTPException(
status_code=503,
detail="Model is still loading, please try again in a few minutes"
)
if not file.filename.lower().endswith(('.wav', '.mp3', '.m4a', '.ogg')):
raise HTTPException(
status_code=400,
detail="Unsupported file format. Please upload WAV, MP3, M4A, or OGG file."
)
with tempfile.NamedTemporaryFile(delete=False, dir="/tmp") as temp_file:
shutil.copyfileobj(file.file, temp_file)
temp_path = temp_file.name
try:
result = transcriber(temp_path)
transcript = result["text"]
os.unlink(temp_path)
if return_format and return_format.lower() == "text":
return transcript
else:
return JSONResponse(content={"transcript": transcript})
except Exception as e:
if os.path.exists(temp_path):
os.unlink(temp_path)
raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)