|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import librosa |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, List |
|
|
from datetime import datetime |
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "small") |
|
|
WHISPER_PORT = int(os.getenv("WHISPER_PORT", 8000)) |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
|
|
|
_whisper_pipeline = None |
|
|
_model_info = { |
|
|
"model_name": WHISPER_MODEL, |
|
|
"device": DEVICE, |
|
|
"dtype": str(TORCH_DTYPE), |
|
|
"cuda_available": torch.cuda.is_available() |
|
|
} |
|
|
|
|
|
|
|
|
class TranscriptionResponse(BaseModel): |
|
|
text: str |
|
|
language: str = "en" |
|
|
confidence: Optional[float] = None |
|
|
duration: float = 0.0 |
|
|
timestamp: str = "" |
|
|
|
|
|
|
|
|
def get_whisper_pipeline(): |
|
|
"""Get or initialize the Whisper pipeline (cached).""" |
|
|
global _whisper_pipeline |
|
|
if _whisper_pipeline is not None: |
|
|
return _whisper_pipeline |
|
|
|
|
|
print(f"π Loading Whisper model: {WHISPER_MODEL} on {DEVICE} with dtype {TORCH_DTYPE}") |
|
|
|
|
|
_whisper_pipeline = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=f"openai/whisper-{WHISPER_MODEL}", |
|
|
device=DEVICE, |
|
|
torch_dtype=TORCH_DTYPE |
|
|
) |
|
|
|
|
|
print(f"β
Whisper model loaded successfully") |
|
|
return _whisper_pipeline |
|
|
|
|
|
def load_and_resample_audio(audio_path: str, target_sr: int = 16000) -> tuple: |
|
|
"""Load audio file and resample to 16kHz (required by Whisper).""" |
|
|
try: |
|
|
|
|
|
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) |
|
|
duration = librosa.get_duration(y=audio, sr=sr) |
|
|
print(f"π Loaded audio: {Path(audio_path).name} | Duration: {duration:.2f}s | SR: {sr}Hz") |
|
|
return audio, sr, duration |
|
|
except Exception as e: |
|
|
print(f"β Error loading audio: {e}") |
|
|
raise |
|
|
|
|
|
async def transcribe_audio(audio_path: str) -> Dict: |
|
|
"""Transcribe audio file using Whisper.""" |
|
|
try: |
|
|
|
|
|
audio, sr, duration = load_and_resample_audio(audio_path) |
|
|
|
|
|
|
|
|
pipeline_model = get_whisper_pipeline() |
|
|
|
|
|
print(f"π€ Transcribing {Path(audio_path).name}...") |
|
|
|
|
|
|
|
|
result = pipeline_model( |
|
|
audio, |
|
|
chunk_length_s=30, |
|
|
stride_length_s=(4, 2), |
|
|
batch_size=24 if torch.cuda.is_available() else 4 |
|
|
) |
|
|
|
|
|
print(f"β
Transcription complete") |
|
|
|
|
|
return { |
|
|
"text": result.get("text", "").strip(), |
|
|
"language": "en", |
|
|
"confidence": None, |
|
|
"duration": duration, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Transcription error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Whisper Transcription Server", |
|
|
description="FastAPI server for audio transcription using OpenAI Whisper", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup(): |
|
|
print(f"π Whisper Server starting on port {WHISPER_PORT}") |
|
|
print(f"π Configuration:") |
|
|
print(f" - Model: {WHISPER_MODEL}") |
|
|
print(f" - Device: {DEVICE}") |
|
|
print(f" - CUDA Available: {torch.cuda.is_available()}") |
|
|
print(f" - Torch Dtype: {TORCH_DTYPE}") |
|
|
|
|
|
|
|
|
get_whisper_pipeline() |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Check server health and model status.""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_info": _model_info, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"device": DEVICE |
|
|
} |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint with server info.""" |
|
|
return { |
|
|
"server": "Whisper Transcription Backend", |
|
|
"model": WHISPER_MODEL, |
|
|
"device": DEVICE, |
|
|
"endpoints": { |
|
|
"/health": "Server health check", |
|
|
"/transcribe": "POST - Transcribe audio file", |
|
|
"/transcribe_file": "POST - Alternative transcribe endpoint" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.post("/transcribe") |
|
|
async def transcribe(file: UploadFile = File(...)): |
|
|
""" |
|
|
Transcribe an uploaded audio file. |
|
|
Accepts: mp3, wav, m4a, flac, ogg, aac |
|
|
""" |
|
|
if not file.filename: |
|
|
raise HTTPException(status_code=400, detail="No file provided") |
|
|
|
|
|
|
|
|
allowed_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.ogg', '.aac'} |
|
|
file_ext = Path(file.filename).suffix.lower() |
|
|
|
|
|
if file_ext not in allowed_extensions: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Unsupported file format: {file_ext}. Allowed: {allowed_extensions}" |
|
|
) |
|
|
|
|
|
temp_file = None |
|
|
try: |
|
|
|
|
|
temp_path = Path(f"temp_{file.filename}") |
|
|
with open(temp_path, 'wb') as f: |
|
|
content = await file.read() |
|
|
f.write(content) |
|
|
|
|
|
temp_file = temp_path |
|
|
|
|
|
print(f"π€ Processing uploaded file: {file.filename} ({len(content)} bytes)") |
|
|
|
|
|
|
|
|
result = await transcribe_audio(str(temp_path)) |
|
|
|
|
|
return { |
|
|
"audio_file": file.filename, |
|
|
"text": result["text"], |
|
|
"language": result["language"], |
|
|
"duration": result["duration"], |
|
|
"timestamp": result["timestamp"] |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Transcription failed: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
finally: |
|
|
|
|
|
if temp_file and temp_file.exists(): |
|
|
temp_file.unlink() |
|
|
print(f"π§Ή Cleaned up temp file: {temp_file}") |
|
|
|
|
|
@app.post("/transcribe_file") |
|
|
async def transcribe_file(file: UploadFile = File(...)): |
|
|
"""Alternative endpoint name for transcription.""" |
|
|
return await transcribe(file) |
|
|
|
|
|
@app.post("/transcribe_batch") |
|
|
async def transcribe_batch(files: List[UploadFile] = File(...)): |
|
|
""" |
|
|
Transcribe multiple audio files in parallel. |
|
|
""" |
|
|
if not files: |
|
|
raise HTTPException(status_code=400, detail="No files provided") |
|
|
|
|
|
results = [] |
|
|
for file in files: |
|
|
try: |
|
|
result = await transcribe(file) |
|
|
results.append({ |
|
|
"status": "success", |
|
|
"data": result |
|
|
}) |
|
|
except Exception as e: |
|
|
results.append({ |
|
|
"status": "error", |
|
|
"filename": file.filename, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return { |
|
|
"total": len(files), |
|
|
"results": results |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=WHISPER_PORT) |
|
|
|