Spaces:
Sleeping
Sleeping
File size: 6,711 Bytes
1600c41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import os
import sys
import warnings
import io
import tempfile
from pathlib import Path
warnings.filterwarnings('ignore')
os.environ['PYTHONWARNINGS'] = 'ignore'
class SuppressStderr:
def __enter__(self):
self.original_stderr = sys.stderr
sys.stderr = io.StringIO()
return self
def __exit__(self, *args):
sys.stderr = self.original_stderr
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with SuppressStderr():
import torch
import whisper
import soundfile as sf
from pyannote.audio import Pipeline
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', message='.*torchcodec.*')
warnings.filterwarnings('ignore', message='.*FP16.*')
warnings.filterwarnings('ignore', message='.*degrees of freedom.*')
warnings.filterwarnings('ignore', module='pyannote.audio.core.io')
warnings.filterwarnings('ignore', module='whisper.transcribe')
warnings.filterwarnings('ignore', module='whisper')
_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
kwargs['weights_only'] = False
return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
# Get HF token from environment variable (set in HF Space settings)
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is required. Please set it in your Hugging Face Space settings.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize FastAPI app
app = FastAPI(
title="Speaker Diarization & Transcription API",
description="API for speaker diarization and transcription using pyannote.audio and Whisper",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for models
pipeline = None
whisper_model = None
@app.on_event("startup")
async def load_models():
"""Load models on startup"""
global pipeline, whisper_model
print(f"Using device: {device}")
print("Loading diarization model...")
with SuppressStderr():
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-community-1",
token=HF_TOKEN,
)
pipeline.to(device)
print("Loading Whisper small model...")
with SuppressStderr():
whisper_model = whisper.load_model("small", device=device)
print("Models loaded successfully!\n")
def process_audio(audio_path):
"""Process audio file with diarization and transcription"""
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
print(f"Processing: {audio_path}")
print("Loading audio file...")
waveform, sample_rate = sf.read(audio_path)
waveform = torch.from_numpy(waveform).float()
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
elif waveform.shape[0] > waveform.shape[1]:
waveform = waveform.T
audio_dict = {
'waveform': waveform,
'sample_rate': sample_rate
}
print("Running speaker diarization...")
diarization = pipeline(audio_dict)
print("Running transcription...")
transcription_result = whisper_model.transcribe(audio_path)
results = []
for turn, speaker in diarization.speaker_diarization:
text = ""
for trans_seg in transcription_result["segments"]:
if (trans_seg["start"] <= turn.end and trans_seg["end"] >= turn.start):
overlap_start = max(turn.start, trans_seg["start"])
overlap_end = min(turn.end, trans_seg["end"])
if overlap_end > overlap_start:
if (overlap_end - overlap_start) / (turn.end - turn.start) > 0.5:
text = trans_seg["text"].strip()
break
results.append({
"start": round(turn.start, 2),
"end": round(turn.end, 2),
"speaker": speaker,
"text": text
})
return {
"segments": results,
"full_transcription": transcription_result["text"]
}
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Speaker Diarization & Transcription API",
"version": "1.0.0",
"endpoints": {
"/": "API information",
"/health": "Health check",
"/process": "Process audio file (POST)"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"device": str(device),
"models_loaded": pipeline is not None and whisper_model is not None
}
@app.post("/process")
async def process_audio_endpoint(file: UploadFile = File(...)):
"""
Process audio file for speaker diarization and transcription
Args:
file: Audio file (wav, mp3, etc.)
Returns:
JSON response with segments and full transcription
"""
if pipeline is None or whisper_model is None:
raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.")
# Validate file type
allowed_extensions = {'.wav', '.mp3', '.m4a', '.flac', '.ogg', '.webm'}
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}"
)
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
try:
content = await file.read()
tmp_file.write(content)
tmp_file_path = tmp_file.name
# Process audio
result = process_audio(tmp_file_path)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
finally:
# Clean up temporary file
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|