|
|
import os
|
|
|
import json
|
|
|
import torch
|
|
|
import librosa
|
|
|
import numpy as np
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Optional, List
|
|
|
from datetime import datetime
|
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException
|
|
|
from pydantic import BaseModel
|
|
|
from transformers import pipeline
|
|
|
|
|
|
|
|
|
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "small")
|
|
|
WHISPER_PORT = int(os.getenv("WHISPER_PORT", 8000))
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
|
|
|
|
|
|
_whisper_pipeline = None
|
|
|
_model_info = {
|
|
|
"model_name": WHISPER_MODEL,
|
|
|
"device": DEVICE,
|
|
|
"dtype": str(TORCH_DTYPE),
|
|
|
"cuda_available": torch.cuda.is_available()
|
|
|
}
|
|
|
|
|
|
|
|
|
class TranscriptionResponse(BaseModel):
|
|
|
text: str
|
|
|
language: str = "en"
|
|
|
confidence: Optional[float] = None
|
|
|
duration: float = 0.0
|
|
|
timestamp: str = ""
|
|
|
|
|
|
|
|
|
def get_whisper_pipeline():
|
|
|
"""Get or initialize the Whisper pipeline (cached)."""
|
|
|
global _whisper_pipeline
|
|
|
if _whisper_pipeline is not None:
|
|
|
return _whisper_pipeline
|
|
|
|
|
|
print(f"π Loading Whisper model: {WHISPER_MODEL} on {DEVICE} with dtype {TORCH_DTYPE}")
|
|
|
|
|
|
_whisper_pipeline = pipeline(
|
|
|
"automatic-speech-recognition",
|
|
|
model=f"openai/whisper-{WHISPER_MODEL}",
|
|
|
device=DEVICE,
|
|
|
torch_dtype=TORCH_DTYPE
|
|
|
)
|
|
|
|
|
|
print(f"β
Whisper model loaded successfully")
|
|
|
return _whisper_pipeline
|
|
|
|
|
|
def load_and_resample_audio(audio_path: str, target_sr: int = 16000) -> tuple:
|
|
|
"""Load audio file and resample to 16kHz (required by Whisper)."""
|
|
|
try:
|
|
|
|
|
|
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
|
|
|
duration = librosa.get_duration(y=audio, sr=sr)
|
|
|
print(f"π Loaded audio: {Path(audio_path).name} | Duration: {duration:.2f}s | SR: {sr}Hz")
|
|
|
return audio, sr, duration
|
|
|
except Exception as e:
|
|
|
print(f"β Error loading audio: {e}")
|
|
|
raise
|
|
|
|
|
|
async def transcribe_audio(audio_path: str) -> Dict:
|
|
|
"""Transcribe audio file using Whisper."""
|
|
|
try:
|
|
|
|
|
|
audio, sr, duration = load_and_resample_audio(audio_path)
|
|
|
|
|
|
|
|
|
pipeline_model = get_whisper_pipeline()
|
|
|
|
|
|
print(f"π€ Transcribing {Path(audio_path).name}...")
|
|
|
|
|
|
|
|
|
result = pipeline_model(
|
|
|
audio,
|
|
|
chunk_length_s=30,
|
|
|
stride_length_s=(4, 2),
|
|
|
batch_size=24 if torch.cuda.is_available() else 4
|
|
|
)
|
|
|
|
|
|
print(f"β
Transcription complete")
|
|
|
|
|
|
return {
|
|
|
"text": result.get("text", "").strip(),
|
|
|
"language": "en",
|
|
|
"confidence": None,
|
|
|
"duration": duration,
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"β Transcription error: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
title="Whisper Transcription Server",
|
|
|
description="FastAPI server for audio transcription using OpenAI Whisper",
|
|
|
version="1.0.0"
|
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup():
|
|
|
print(f"π Whisper Server starting on port {WHISPER_PORT}")
|
|
|
print(f"π Configuration:")
|
|
|
print(f" - Model: {WHISPER_MODEL}")
|
|
|
print(f" - Device: {DEVICE}")
|
|
|
print(f" - CUDA Available: {torch.cuda.is_available()}")
|
|
|
print(f" - Torch Dtype: {TORCH_DTYPE}")
|
|
|
|
|
|
|
|
|
get_whisper_pipeline()
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
|
"""Check server health and model status."""
|
|
|
return {
|
|
|
"status": "healthy",
|
|
|
"model_info": _model_info,
|
|
|
"cuda_available": torch.cuda.is_available(),
|
|
|
"device": DEVICE
|
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
"""Root endpoint with server info."""
|
|
|
return {
|
|
|
"server": "Whisper Transcription Backend",
|
|
|
"model": WHISPER_MODEL,
|
|
|
"device": DEVICE,
|
|
|
"endpoints": {
|
|
|
"/health": "Server health check",
|
|
|
"/transcribe": "POST - Transcribe audio file",
|
|
|
"/transcribe_file": "POST - Alternative transcribe endpoint"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@app.post("/transcribe")
|
|
|
async def transcribe(file: UploadFile = File(...)):
|
|
|
"""
|
|
|
Transcribe an uploaded audio file.
|
|
|
Accepts: mp3, wav, m4a, flac, ogg, aac
|
|
|
"""
|
|
|
if not file.filename:
|
|
|
raise HTTPException(status_code=400, detail="No file provided")
|
|
|
|
|
|
|
|
|
allowed_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.ogg', '.aac'}
|
|
|
file_ext = Path(file.filename).suffix.lower()
|
|
|
|
|
|
if file_ext not in allowed_extensions:
|
|
|
raise HTTPException(
|
|
|
status_code=400,
|
|
|
detail=f"Unsupported file format: {file_ext}. Allowed: {allowed_extensions}"
|
|
|
)
|
|
|
|
|
|
temp_file = None
|
|
|
try:
|
|
|
|
|
|
temp_path = Path(f"temp_{file.filename}")
|
|
|
with open(temp_path, 'wb') as f:
|
|
|
content = await file.read()
|
|
|
f.write(content)
|
|
|
|
|
|
temp_file = temp_path
|
|
|
|
|
|
print(f"π€ Processing uploaded file: {file.filename} ({len(content)} bytes)")
|
|
|
|
|
|
|
|
|
result = await transcribe_audio(str(temp_path))
|
|
|
|
|
|
return {
|
|
|
"audio_file": file.filename,
|
|
|
"text": result["text"],
|
|
|
"language": result["language"],
|
|
|
"duration": result["duration"],
|
|
|
"timestamp": result["timestamp"]
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"β Transcription failed: {e}")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if temp_file and temp_file.exists():
|
|
|
temp_file.unlink()
|
|
|
print(f"π§Ή Cleaned up temp file: {temp_file}")
|
|
|
|
|
|
@app.post("/transcribe_file")
|
|
|
async def transcribe_file(file: UploadFile = File(...)):
|
|
|
"""Alternative endpoint name for transcription."""
|
|
|
return await transcribe(file)
|
|
|
|
|
|
@app.post("/transcribe_batch")
|
|
|
async def transcribe_batch(files: List[UploadFile] = File(...)):
|
|
|
"""
|
|
|
Transcribe multiple audio files in parallel.
|
|
|
"""
|
|
|
if not files:
|
|
|
raise HTTPException(status_code=400, detail="No files provided")
|
|
|
|
|
|
results = []
|
|
|
for file in files:
|
|
|
try:
|
|
|
result = await transcribe(file)
|
|
|
results.append({
|
|
|
"status": "success",
|
|
|
"data": result
|
|
|
})
|
|
|
except Exception as e:
|
|
|
results.append({
|
|
|
"status": "error",
|
|
|
"filename": file.filename,
|
|
|
"error": str(e)
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
"total": len(files),
|
|
|
"results": results
|
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
uvicorn.run(app, host="0.0.0.0", port=WHISPER_PORT)
|
|
|
|