Speaker / app.py
Rivalcoder
Add Files
1600c41
raw
history blame
6.71 kB
import os
import sys
import warnings
import io
import tempfile
from pathlib import Path
warnings.filterwarnings('ignore')
os.environ['PYTHONWARNINGS'] = 'ignore'
class SuppressStderr:
def __enter__(self):
self.original_stderr = sys.stderr
sys.stderr = io.StringIO()
return self
def __exit__(self, *args):
sys.stderr = self.original_stderr
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with SuppressStderr():
import torch
import whisper
import soundfile as sf
from pyannote.audio import Pipeline
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', message='.*torchcodec.*')
warnings.filterwarnings('ignore', message='.*FP16.*')
warnings.filterwarnings('ignore', message='.*degrees of freedom.*')
warnings.filterwarnings('ignore', module='pyannote.audio.core.io')
warnings.filterwarnings('ignore', module='whisper.transcribe')
warnings.filterwarnings('ignore', module='whisper')
_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
kwargs['weights_only'] = False
return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
# Get HF token from environment variable (set in HF Space settings)
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is required. Please set it in your Hugging Face Space settings.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize FastAPI app
app = FastAPI(
title="Speaker Diarization & Transcription API",
description="API for speaker diarization and transcription using pyannote.audio and Whisper",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for models
pipeline = None
whisper_model = None
@app.on_event("startup")
async def load_models():
"""Load models on startup"""
global pipeline, whisper_model
print(f"Using device: {device}")
print("Loading diarization model...")
with SuppressStderr():
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-community-1",
token=HF_TOKEN,
)
pipeline.to(device)
print("Loading Whisper small model...")
with SuppressStderr():
whisper_model = whisper.load_model("small", device=device)
print("Models loaded successfully!\n")
def process_audio(audio_path):
"""Process audio file with diarization and transcription"""
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
print(f"Processing: {audio_path}")
print("Loading audio file...")
waveform, sample_rate = sf.read(audio_path)
waveform = torch.from_numpy(waveform).float()
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
elif waveform.shape[0] > waveform.shape[1]:
waveform = waveform.T
audio_dict = {
'waveform': waveform,
'sample_rate': sample_rate
}
print("Running speaker diarization...")
diarization = pipeline(audio_dict)
print("Running transcription...")
transcription_result = whisper_model.transcribe(audio_path)
results = []
for turn, speaker in diarization.speaker_diarization:
text = ""
for trans_seg in transcription_result["segments"]:
if (trans_seg["start"] <= turn.end and trans_seg["end"] >= turn.start):
overlap_start = max(turn.start, trans_seg["start"])
overlap_end = min(turn.end, trans_seg["end"])
if overlap_end > overlap_start:
if (overlap_end - overlap_start) / (turn.end - turn.start) > 0.5:
text = trans_seg["text"].strip()
break
results.append({
"start": round(turn.start, 2),
"end": round(turn.end, 2),
"speaker": speaker,
"text": text
})
return {
"segments": results,
"full_transcription": transcription_result["text"]
}
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Speaker Diarization & Transcription API",
"version": "1.0.0",
"endpoints": {
"/": "API information",
"/health": "Health check",
"/process": "Process audio file (POST)"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"device": str(device),
"models_loaded": pipeline is not None and whisper_model is not None
}
@app.post("/process")
async def process_audio_endpoint(file: UploadFile = File(...)):
"""
Process audio file for speaker diarization and transcription
Args:
file: Audio file (wav, mp3, etc.)
Returns:
JSON response with segments and full transcription
"""
if pipeline is None or whisper_model is None:
raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.")
# Validate file type
allowed_extensions = {'.wav', '.mp3', '.m4a', '.flac', '.ogg', '.webm'}
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}"
)
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
try:
content = await file.read()
tmp_file.write(content)
tmp_file_path = tmp_file.name
# Process audio
result = process_audio(tmp_file_path)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
finally:
# Clean up temporary file
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)