import os import sys import warnings import io import tempfile from pathlib import Path import subprocess # <-- Added warnings.filterwarnings('ignore') os.environ['PYTHONWARNINGS'] = 'ignore' class SuppressStderr: def __enter__(self): self.original_stderr = sys.stderr sys.stderr = io.StringIO() return self def __exit__(self, *args): sys.stderr = self.original_stderr with warnings.catch_warnings(): warnings.simplefilter("ignore") with SuppressStderr(): import torch import whisper import soundfile as sf from pyannote.audio import Pipeline from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', message='.*torchcodec.*') warnings.filterwarnings('ignore', message='.*FP16.*') warnings.filterwarnings('ignore', message='.*degrees of freedom.*') warnings.filterwarnings('ignore', module='pyannote.audio.core.io') warnings.filterwarnings('ignore', module='whisper.transcribe') warnings.filterwarnings('ignore', module='whisper') _original_torch_load = torch.load def _patched_torch_load(*args, **kwargs): kwargs['weights_only'] = False return _original_torch_load(*args, **kwargs) torch.load = _patched_torch_load # Get HF token from environment variable (set in HF Space settings) HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is required. Please set it in your Hugging Face Space settings.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize FastAPI app app = FastAPI( title="Speaker Diarization & Transcription API", description="API for speaker diarization and transcription using pyannote.audio and Whisper", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Convert ANY audio file to WAV using FFmpeg def convert_to_wav(input_path): output_path = input_path + "_converted.wav" command = [ "ffmpeg", "-y", "-i", input_path, "-ac", "1", "-ar", "16000", output_path ] subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return output_path # Global variables pipeline = None whisper_model = None @app.on_event("startup") async def load_models(): global pipeline, whisper_model print(f"Using device: {device}") print("Loading diarization model...") with SuppressStderr(): pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-community-1", token=HF_TOKEN, ) pipeline.to(device) print("Loading Whisper small model...") with SuppressStderr(): whisper_model = whisper.load_model("small", device=device) print("Models loaded successfully!\n") def process_audio(audio_path): if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") print(f"Processing: {audio_path}") print("Loading audio file...") waveform, sample_rate = sf.read(audio_path) waveform = torch.from_numpy(waveform).float() if waveform.ndim == 1: waveform = waveform.unsqueeze(0) elif waveform.shape[0] > waveform.shape[1]: waveform = waveform.T audio_dict = { 'waveform': waveform, 'sample_rate': sample_rate } print("Running speaker diarization...") diarization = pipeline(audio_dict) print("Running transcription...") transcription_result = whisper_model.transcribe(audio_path) results = [] for turn, speaker in diarization.speaker_diarization: text = "" for trans_seg in transcription_result["segments"]: if (trans_seg["start"] <= turn.end and trans_seg["end"] >= turn.start): overlap_start = max(turn.start, trans_seg["start"]) overlap_end = min(turn.end, trans_seg["end"]) if overlap_end > overlap_start: if (overlap_end - overlap_start) / (turn.end - turn.start) > 0.5: text = trans_seg["text"].strip() break results.append({ "start": round(turn.start, 2), "end": round(turn.end, 2), "speaker": speaker, "text": text }) return { "segments": results, "full_transcription": transcription_result["text"] } @app.get("/") async def root(): return { "message": "Speaker Diarization & Transcription API", "version": "1.0.0", "endpoints": { "/": "API information", "/health": "Health check", "/process": "Process audio file (POST)" } } @app.get("/health") async def health_check(): return { "status": "healthy", "device": str(device), "models_loaded": pipeline is not None and whisper_model is not None } @app.post("/process") async def process_audio_endpoint(file: UploadFile = File(...)): if pipeline is None or whisper_model is None: raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.") allowed_extensions = {'.wav', '.mp3', '.m4a', '.flac', '.ogg', '.webm'} file_ext = Path(file.filename).suffix.lower() if file_ext not in allowed_extensions: raise HTTPException( status_code=400, detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}" ) with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file: try: content = await file.read() tmp_file.write(content) tmp_file_path = tmp_file.name # Convert ANY format to WAV wav_path = convert_to_wav(tmp_file_path) # Process WAV only result = process_audio(wav_path) return JSONResponse(content=result) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}") finally: if os.path.exists(tmp_file_path): os.unlink(tmp_file_path) if os.path.exists(tmp_file_path + "_converted.wav"): os.unlink(tmp_file_path + "_converted.wav") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)