Rivalcoder commited on
Commit
1600c41
·
1 Parent(s): 4d43a59
Files changed (3) hide show
  1. Dockerfile +22 -0
  2. app.py +216 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ ffmpeg \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements and install Python dependencies
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application code
15
+ COPY app.py .
16
+
17
+ # Expose port
18
+ EXPOSE 7860
19
+
20
+ # Run the application
21
+ CMD ["python", "app.py"]
22
+
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ import io
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ warnings.filterwarnings('ignore')
9
+ os.environ['PYTHONWARNINGS'] = 'ignore'
10
+
11
+ class SuppressStderr:
12
+ def __enter__(self):
13
+ self.original_stderr = sys.stderr
14
+ sys.stderr = io.StringIO()
15
+ return self
16
+
17
+ def __exit__(self, *args):
18
+ sys.stderr = self.original_stderr
19
+
20
+ with warnings.catch_warnings():
21
+ warnings.simplefilter("ignore")
22
+ with SuppressStderr():
23
+ import torch
24
+ import whisper
25
+ import soundfile as sf
26
+ from pyannote.audio import Pipeline
27
+ from fastapi import FastAPI, File, UploadFile, HTTPException
28
+ from fastapi.responses import JSONResponse
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+
31
+ warnings.filterwarnings('ignore', category=UserWarning)
32
+ warnings.filterwarnings('ignore', category=FutureWarning)
33
+ warnings.filterwarnings('ignore', message='.*torchcodec.*')
34
+ warnings.filterwarnings('ignore', message='.*FP16.*')
35
+ warnings.filterwarnings('ignore', message='.*degrees of freedom.*')
36
+ warnings.filterwarnings('ignore', module='pyannote.audio.core.io')
37
+ warnings.filterwarnings('ignore', module='whisper.transcribe')
38
+ warnings.filterwarnings('ignore', module='whisper')
39
+
40
+ _original_torch_load = torch.load
41
+
42
+ def _patched_torch_load(*args, **kwargs):
43
+ kwargs['weights_only'] = False
44
+ return _original_torch_load(*args, **kwargs)
45
+
46
+ torch.load = _patched_torch_load
47
+
48
+ # Get HF token from environment variable (set in HF Space settings)
49
+ HF_TOKEN = os.getenv("HF_TOKEN")
50
+ if not HF_TOKEN:
51
+ raise ValueError("HF_TOKEN environment variable is required. Please set it in your Hugging Face Space settings.")
52
+
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ # Initialize FastAPI app
56
+ app = FastAPI(
57
+ title="Speaker Diarization & Transcription API",
58
+ description="API for speaker diarization and transcription using pyannote.audio and Whisper",
59
+ version="1.0.0"
60
+ )
61
+
62
+ # Add CORS middleware
63
+ app.add_middleware(
64
+ CORSMiddleware,
65
+ allow_origins=["*"],
66
+ allow_credentials=True,
67
+ allow_methods=["*"],
68
+ allow_headers=["*"],
69
+ )
70
+
71
+ # Global variables for models
72
+ pipeline = None
73
+ whisper_model = None
74
+
75
+ @app.on_event("startup")
76
+ async def load_models():
77
+ """Load models on startup"""
78
+ global pipeline, whisper_model
79
+
80
+ print(f"Using device: {device}")
81
+
82
+ print("Loading diarization model...")
83
+ with SuppressStderr():
84
+ pipeline = Pipeline.from_pretrained(
85
+ "pyannote/speaker-diarization-community-1",
86
+ token=HF_TOKEN,
87
+ )
88
+ pipeline.to(device)
89
+
90
+ print("Loading Whisper small model...")
91
+ with SuppressStderr():
92
+ whisper_model = whisper.load_model("small", device=device)
93
+ print("Models loaded successfully!\n")
94
+
95
+ def process_audio(audio_path):
96
+ """Process audio file with diarization and transcription"""
97
+ if not os.path.exists(audio_path):
98
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
99
+
100
+ print(f"Processing: {audio_path}")
101
+
102
+ print("Loading audio file...")
103
+ waveform, sample_rate = sf.read(audio_path)
104
+ waveform = torch.from_numpy(waveform).float()
105
+
106
+ if waveform.ndim == 1:
107
+ waveform = waveform.unsqueeze(0)
108
+ elif waveform.shape[0] > waveform.shape[1]:
109
+ waveform = waveform.T
110
+
111
+ audio_dict = {
112
+ 'waveform': waveform,
113
+ 'sample_rate': sample_rate
114
+ }
115
+
116
+ print("Running speaker diarization...")
117
+ diarization = pipeline(audio_dict)
118
+
119
+ print("Running transcription...")
120
+ transcription_result = whisper_model.transcribe(audio_path)
121
+
122
+ results = []
123
+
124
+ for turn, speaker in diarization.speaker_diarization:
125
+ text = ""
126
+ for trans_seg in transcription_result["segments"]:
127
+ if (trans_seg["start"] <= turn.end and trans_seg["end"] >= turn.start):
128
+ overlap_start = max(turn.start, trans_seg["start"])
129
+ overlap_end = min(turn.end, trans_seg["end"])
130
+ if overlap_end > overlap_start:
131
+ if (overlap_end - overlap_start) / (turn.end - turn.start) > 0.5:
132
+ text = trans_seg["text"].strip()
133
+ break
134
+
135
+ results.append({
136
+ "start": round(turn.start, 2),
137
+ "end": round(turn.end, 2),
138
+ "speaker": speaker,
139
+ "text": text
140
+ })
141
+
142
+ return {
143
+ "segments": results,
144
+ "full_transcription": transcription_result["text"]
145
+ }
146
+
147
+ @app.get("/")
148
+ async def root():
149
+ """Root endpoint with API information"""
150
+ return {
151
+ "message": "Speaker Diarization & Transcription API",
152
+ "version": "1.0.0",
153
+ "endpoints": {
154
+ "/": "API information",
155
+ "/health": "Health check",
156
+ "/process": "Process audio file (POST)"
157
+ }
158
+ }
159
+
160
+ @app.get("/health")
161
+ async def health_check():
162
+ """Health check endpoint"""
163
+ return {
164
+ "status": "healthy",
165
+ "device": str(device),
166
+ "models_loaded": pipeline is not None and whisper_model is not None
167
+ }
168
+
169
+ @app.post("/process")
170
+ async def process_audio_endpoint(file: UploadFile = File(...)):
171
+ """
172
+ Process audio file for speaker diarization and transcription
173
+
174
+ Args:
175
+ file: Audio file (wav, mp3, etc.)
176
+
177
+ Returns:
178
+ JSON response with segments and full transcription
179
+ """
180
+ if pipeline is None or whisper_model is None:
181
+ raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.")
182
+
183
+ # Validate file type
184
+ allowed_extensions = {'.wav', '.mp3', '.m4a', '.flac', '.ogg', '.webm'}
185
+ file_ext = Path(file.filename).suffix.lower()
186
+
187
+ if file_ext not in allowed_extensions:
188
+ raise HTTPException(
189
+ status_code=400,
190
+ detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}"
191
+ )
192
+
193
+ # Save uploaded file temporarily
194
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
195
+ try:
196
+ content = await file.read()
197
+ tmp_file.write(content)
198
+ tmp_file_path = tmp_file.name
199
+
200
+ # Process audio
201
+ result = process_audio(tmp_file_path)
202
+
203
+ return JSONResponse(content=result)
204
+
205
+ except Exception as e:
206
+ raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
207
+
208
+ finally:
209
+ # Clean up temporary file
210
+ if os.path.exists(tmp_file_path):
211
+ os.unlink(tmp_file_path)
212
+
213
+ if __name__ == "__main__":
214
+ import uvicorn
215
+ uvicorn.run(app, host="0.0.0.0", port=7860)
216
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pyannote.audio
4
+ torch
5
+ openai-whisper
6
+ python-multipart
7
+ soundfile
8
+