Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import warnings | |
| import io | |
| import tempfile | |
| from pathlib import Path | |
| warnings.filterwarnings('ignore') | |
| os.environ['PYTHONWARNINGS'] = 'ignore' | |
| class SuppressStderr: | |
| def __enter__(self): | |
| self.original_stderr = sys.stderr | |
| sys.stderr = io.StringIO() | |
| return self | |
| def __exit__(self, *args): | |
| sys.stderr = self.original_stderr | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| with SuppressStderr(): | |
| import torch | |
| import whisper | |
| import soundfile as sf | |
| from pyannote.audio import Pipeline | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| warnings.filterwarnings('ignore', category=UserWarning) | |
| warnings.filterwarnings('ignore', category=FutureWarning) | |
| warnings.filterwarnings('ignore', message='.*torchcodec.*') | |
| warnings.filterwarnings('ignore', message='.*FP16.*') | |
| warnings.filterwarnings('ignore', message='.*degrees of freedom.*') | |
| warnings.filterwarnings('ignore', module='pyannote.audio.core.io') | |
| warnings.filterwarnings('ignore', module='whisper.transcribe') | |
| warnings.filterwarnings('ignore', module='whisper') | |
| _original_torch_load = torch.load | |
| def _patched_torch_load(*args, **kwargs): | |
| kwargs['weights_only'] = False | |
| return _original_torch_load(*args, **kwargs) | |
| torch.load = _patched_torch_load | |
| # Get HF token from environment variable (set in HF Space settings) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("HF_TOKEN environment variable is required. Please set it in your Hugging Face Space settings.") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Speaker Diarization & Transcription API", | |
| description="API for speaker diarization and transcription using pyannote.audio and Whisper", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for models | |
| pipeline = None | |
| whisper_model = None | |
| async def load_models(): | |
| """Load models on startup""" | |
| global pipeline, whisper_model | |
| print(f"Using device: {device}") | |
| print("Loading diarization model...") | |
| with SuppressStderr(): | |
| pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-community-1", | |
| token=HF_TOKEN, | |
| ) | |
| pipeline.to(device) | |
| print("Loading Whisper small model...") | |
| with SuppressStderr(): | |
| whisper_model = whisper.load_model("small", device=device) | |
| print("Models loaded successfully!\n") | |
| def process_audio(audio_path): | |
| """Process audio file with diarization and transcription""" | |
| if not os.path.exists(audio_path): | |
| raise FileNotFoundError(f"Audio file not found: {audio_path}") | |
| print(f"Processing: {audio_path}") | |
| print("Loading audio file...") | |
| waveform, sample_rate = sf.read(audio_path) | |
| waveform = torch.from_numpy(waveform).float() | |
| if waveform.ndim == 1: | |
| waveform = waveform.unsqueeze(0) | |
| elif waveform.shape[0] > waveform.shape[1]: | |
| waveform = waveform.T | |
| audio_dict = { | |
| 'waveform': waveform, | |
| 'sample_rate': sample_rate | |
| } | |
| print("Running speaker diarization...") | |
| diarization = pipeline(audio_dict) | |
| print("Running transcription...") | |
| transcription_result = whisper_model.transcribe(audio_path) | |
| results = [] | |
| for turn, speaker in diarization.speaker_diarization: | |
| text = "" | |
| for trans_seg in transcription_result["segments"]: | |
| if (trans_seg["start"] <= turn.end and trans_seg["end"] >= turn.start): | |
| overlap_start = max(turn.start, trans_seg["start"]) | |
| overlap_end = min(turn.end, trans_seg["end"]) | |
| if overlap_end > overlap_start: | |
| if (overlap_end - overlap_start) / (turn.end - turn.start) > 0.5: | |
| text = trans_seg["text"].strip() | |
| break | |
| results.append({ | |
| "start": round(turn.start, 2), | |
| "end": round(turn.end, 2), | |
| "speaker": speaker, | |
| "text": text | |
| }) | |
| return { | |
| "segments": results, | |
| "full_transcription": transcription_result["text"] | |
| } | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "Speaker Diarization & Transcription API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "/": "API information", | |
| "/health": "Health check", | |
| "/process": "Process audio file (POST)" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "device": str(device), | |
| "models_loaded": pipeline is not None and whisper_model is not None | |
| } | |
| async def process_audio_endpoint(file: UploadFile = File(...)): | |
| """ | |
| Process audio file for speaker diarization and transcription | |
| Args: | |
| file: Audio file (wav, mp3, etc.) | |
| Returns: | |
| JSON response with segments and full transcription | |
| """ | |
| if pipeline is None or whisper_model is None: | |
| raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.") | |
| # Validate file type | |
| allowed_extensions = {'.wav', '.mp3', '.m4a', '.flac', '.ogg', '.webm'} | |
| file_ext = Path(file.filename).suffix.lower() | |
| if file_ext not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}" | |
| ) | |
| # Save uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file: | |
| try: | |
| content = await file.read() | |
| tmp_file.write(content) | |
| tmp_file_path = tmp_file.name | |
| # Process audio | |
| result = process_audio(tmp_file_path) | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}") | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(tmp_file_path): | |
| os.unlink(tmp_file_path) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |