ConvxO2 commited on
Commit
d7a2919
·
0 Parent(s):

Initial commit: Speaker Diarization System

Browse files
.gitattributes ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ *.py text eol=lf
2
+ *.txt text eol=lf
3
+ *.md text eol=lf
4
+ *.yaml text eol=lf
5
+ *.yml text eol=lf
6
+ *.json text eol=lf
7
+ *.html text eol=lf
8
+ Dockerfile text eol=lf
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ model_cache/
3
+ wires/
4
+ .wires/
5
+ __pycache__/
6
+ *.pyc
7
+ *.wav
8
+ *.mp3
9
+ *.flac
10
+ .venv/
11
+ venv/
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Install system dependencies including ffmpeg
4
+ RUN apt-get update && apt-get install -y \
5
+ ffmpeg \
6
+ git \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ WORKDIR /app
10
+
11
+ # Install Python dependencies
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy project files
16
+ COPY . .
17
+
18
+ # HuggingFace Spaces uses port 7860
19
+ EXPOSE 7860
20
+
21
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎙 Speaker Diarization System
2
+ ### *Who Spoke When — Multi-Speaker Audio Segmentation*
3
+
4
+ > **Tech Stack:** Python · PyTorch · SpeechBrain · Pyannote.audio · Transformers · FastAPI
5
+
6
+ ---
7
+
8
+ ## Architecture
9
+
10
+ ```
11
+ Audio Input
12
+
13
+
14
+ ┌─────────────────────────────┐
15
+ │ Voice Activity Detection │ ← pyannote/voice-activity-detection
16
+ │ (VAD) │ fallback: energy-based VAD
17
+ └────────────┬────────────────┘
18
+ │ speech regions (start, end)
19
+
20
+ ┌─────────────────────────────┐
21
+ │ Sliding Window Segmentation│ ← 1.5s windows, 50% overlap
22
+ │ │
23
+ └────────────┬────────────────┘
24
+ │ segment list
25
+
26
+ ┌─────────────────────────────┐
27
+ │ ECAPA-TDNN Embedding │ ← speechbrain/spkrec-ecapa-voxceleb
28
+ │ Extraction │ 192-dim L2-normalized vectors
29
+ └────────────┬────────────────┘
30
+ │ embeddings (N × 192)
31
+
32
+ ┌─────────────────────────────┐
33
+ │ Agglomerative Hierarchical │ ← cosine distance metric
34
+ │ Clustering (AHC) │ silhouette-based auto k-selection
35
+ └────────────┬────────────────┘
36
+ │ speaker labels
37
+
38
+ ┌─────────────────────────────┐
39
+ │ Post-processing │ ← merge consecutive same-speaker segs
40
+ │ & Output Formatting │ timestamped JSON / RTTM / SRT
41
+ └─────────────────────────────┘
42
+ ```
43
+
44
+ ---
45
+
46
+ ## Project Structure
47
+
48
+ ```
49
+ speaker-diarization/
50
+ ├── app/
51
+ │ ├── main.py # FastAPI app — REST + WebSocket endpoints
52
+ │ └── pipeline.py # Core end-to-end diarization pipeline
53
+ ├── models/
54
+ │ ├── embedder.py # ECAPA-TDNN speaker embedding extractor
55
+ │ └── clusterer.py # Agglomerative Hierarchical Clustering (AHC)
56
+ ├── utils/
57
+ │ └── audio.py # Audio loading, chunking, RTTM/SRT export
58
+ ├── tests/
59
+ │ └── test_diarization.py # Unit + integration tests
60
+ ├── static/
61
+ │ └── index.html # Web demo UI
62
+ ├── demo.py # CLI interface
63
+ └── requirements.txt
64
+ ```
65
+
66
+ ---
67
+
68
+ ## Installation
69
+
70
+ ```bash
71
+ # 1. Clone / navigate to project
72
+ cd speaker-diarization
73
+
74
+ # 2. Create virtual environment
75
+ python -m venv .venv
76
+ source .venv/bin/activate # Windows: .venv\Scripts\activate
77
+
78
+ # 3. Install dependencies
79
+ pip install -r requirements.txt
80
+
81
+ # 4. (Optional) Set HuggingFace token for pyannote VAD
82
+ # Accept terms at: https://huggingface.co/pyannote/voice-activity-detection
83
+ export HF_TOKEN=your_token_here
84
+ ```
85
+
86
+ ---
87
+
88
+ ## Usage
89
+
90
+ ### CLI Demo
91
+
92
+ ```bash
93
+ # Basic usage (auto-detect speaker count)
94
+ python demo.py --audio meeting.wav
95
+
96
+ # Specify 3 speakers
97
+ python demo.py --audio call.wav --speakers 3
98
+
99
+ # Export all formats
100
+ python demo.py --audio audio.mp3 \
101
+ --output result.json \
102
+ --rttm output.rttm \
103
+ --srt subtitles.srt
104
+ ```
105
+
106
+ **Example output:**
107
+ ```
108
+ ✅ Done in 4.83s
109
+ Speakers found : 3
110
+ Audio duration : 120.50s
111
+ Segments : 42
112
+
113
+ START END DUR SPEAKER
114
+ ────────────────────────────────────
115
+ 0.000 3.250 3.250 SPEAKER_00
116
+ 3.500 8.120 4.620 SPEAKER_01
117
+ 8.200 11.800 3.600 SPEAKER_00
118
+ 12.000 17.340 5.340 SPEAKER_02
119
+ ...
120
+ ```
121
+
122
+ ### FastAPI Server
123
+
124
+ ```bash
125
+ # Start the API server
126
+ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
127
+
128
+ # Open the web UI
129
+ open http://localhost:8000
130
+
131
+ # Swagger documentation
132
+ open http://localhost:8000/docs
133
+ ```
134
+
135
+ ### REST API
136
+
137
+ **POST /diarize** — Upload audio file
138
+ ```bash
139
+ curl -X POST http://localhost:8000/diarize \
140
+ -F "file=@meeting.wav" \
141
+ -F "num_speakers=3"
142
+ ```
143
+
144
+ **Response:**
145
+ ```json
146
+ {
147
+ "status": "success",
148
+ "num_speakers": 3,
149
+ "audio_duration": 120.5,
150
+ "processing_time": 4.83,
151
+ "sample_rate": 16000,
152
+ "speakers": ["SPEAKER_00", "SPEAKER_01", "SPEAKER_02"],
153
+ "segments": [
154
+ { "start": 0.000, "end": 3.250, "duration": 3.250, "speaker": "SPEAKER_00" },
155
+ { "start": 3.500, "end": 8.120, "duration": 4.620, "speaker": "SPEAKER_01" }
156
+ ]
157
+ }
158
+ ```
159
+
160
+ **GET /health** — Service health
161
+ ```bash
162
+ curl http://localhost:8000/health
163
+ # {"status":"healthy","device":"cuda","version":"1.0.0"}
164
+ ```
165
+
166
+ ### WebSocket Streaming
167
+
168
+ ```python
169
+ import asyncio, websockets, json, numpy as np
170
+
171
+ async def stream_audio():
172
+ async with websockets.connect("ws://localhost:8000/ws/stream") as ws:
173
+ # Send config
174
+ await ws.send(json.dumps({"sample_rate": 16000, "num_speakers": 2}))
175
+
176
+ # Send audio chunks (raw float32 PCM)
177
+ with open("audio.raw", "rb") as f:
178
+ while chunk := f.read(4096):
179
+ await ws.send(chunk)
180
+
181
+ # Signal end
182
+ await ws.send(json.dumps({"type": "eof"}))
183
+
184
+ # Receive results
185
+ async for msg in ws:
186
+ data = json.loads(msg)
187
+ if data["type"] == "segment":
188
+ print(f"[{data['data']['speaker']}] {data['data']['start']:.2f}s – {data['data']['end']:.2f}s")
189
+ elif data["type"] == "done":
190
+ break
191
+
192
+ asyncio.run(stream_audio())
193
+ ```
194
+
195
+ ---
196
+
197
+ ## Key Design Decisions
198
+
199
+ | Component | Choice | Rationale |
200
+ |-----------|--------|-----------|
201
+ | Speaker Embeddings | ECAPA-TDNN (SpeechBrain) | State-of-the-art speaker verification accuracy on VoxCeleb |
202
+ | Clustering | AHC + cosine distance | No predefined k required; works well with L2-normalized embeddings |
203
+ | k-selection | Silhouette analysis | Unsupervised, parameter-free speaker count estimation |
204
+ | VAD | pyannote (energy fallback) | pyannote VAD reduces false embeddings on silence/noise |
205
+ | Embedding window | 1.5s, 50% overlap | Balances temporal resolution vs. embedding stability |
206
+ | Post-processing | Merge consecutive same-speaker | Reduces over-segmentation artifact |
207
+
208
+ ---
209
+
210
+ ## Evaluation Metrics
211
+
212
+ Standard speaker diarization evaluation uses **Diarization Error Rate (DER)**:
213
+
214
+ ```
215
+ DER = (Miss + False Alarm + Speaker Error) / Total Speech Duration
216
+ ```
217
+
218
+ Export RTTM files for evaluation with `md-eval` or `dscore`:
219
+ ```bash
220
+ python demo.py --audio test.wav --rttm hypothesis.rttm
221
+ dscore -r reference.rttm -s hypothesis.rttm
222
+ ```
223
+
224
+ ---
225
+
226
+ ## Running Tests
227
+
228
+ ```bash
229
+ pytest tests/ -v
230
+ pytest tests/ -v -k "clusterer" # run specific test class
231
+ ```
232
+
233
+ ---
234
+
235
+ ## Limitations & Future Work
236
+
237
+ - Long audio (>1hr) should use chunked processing (`utils.audio.chunk_audio`)
238
+ - Real-time streaming requires low-latency VAD (not yet implemented in WS endpoint)
239
+ - Speaker overlap (cross-talk) is assigned to a single speaker
240
+ - Consider fine-tuning ECAPA-TDNN on domain-specific data for call analytics
app/__init__.py ADDED
File without changes
app/main.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speaker Diarization API - FastAPI Application
3
+ """
4
+
5
+ import io
6
+ import time
7
+ import asyncio
8
+ import tempfile
9
+ import traceback
10
+ from pathlib import Path
11
+ from typing import Optional, List
12
+ import os
13
+
14
+ import torch
15
+ from fastapi import (
16
+ FastAPI, File, UploadFile, Form, WebSocket,
17
+ WebSocketDisconnect, HTTPException, Query,
18
+ )
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.staticfiles import StaticFiles
21
+ from fastapi.responses import HTMLResponse
22
+ from pydantic import BaseModel, Field
23
+ from loguru import logger
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Schemas
27
+ # ---------------------------------------------------------------------------
28
+
29
+ class SegmentOut(BaseModel):
30
+ start: float
31
+ end: float
32
+ duration: float
33
+ speaker: str
34
+
35
+
36
+ class DiarizationResponse(BaseModel):
37
+ status: str = "success"
38
+ num_speakers: int
39
+ audio_duration: float
40
+ processing_time: float
41
+ sample_rate: int
42
+ speakers: List[str]
43
+ segments: List[SegmentOut]
44
+
45
+
46
+ class HealthResponse(BaseModel):
47
+ status: str
48
+ device: str
49
+ version: str = "1.0.0"
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # App
54
+ # ---------------------------------------------------------------------------
55
+
56
+ app = FastAPI(
57
+ title="Speaker Diarization API",
58
+ description="Who Spoke When — Speaker diarization using ECAPA-TDNN + AHC Clustering",
59
+ version="1.0.0",
60
+ )
61
+
62
+ app.add_middleware(
63
+ CORSMiddleware,
64
+ allow_origins=["*"],
65
+ allow_credentials=True,
66
+ allow_methods=["*"],
67
+ allow_headers=["*"],
68
+ )
69
+
70
+ _pipeline = None
71
+
72
+ def get_pipeline():
73
+ global _pipeline
74
+ if _pipeline is None:
75
+ from app.pipeline import DiarizationPipeline
76
+ _pipeline = DiarizationPipeline(
77
+ device="auto",
78
+ use_pyannote_vad=True,
79
+ hf_token=os.getenv("HF_TOKEN"),
80
+ max_speakers=10,
81
+ cache_dir="./model_cache",
82
+ )
83
+ return _pipeline
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Endpoints
88
+ # ---------------------------------------------------------------------------
89
+
90
+ @app.get("/health", response_model=HealthResponse, tags=["System"])
91
+ async def health_check():
92
+ device = "cuda" if torch.cuda.is_available() else "cpu"
93
+ return HealthResponse(status="healthy", device=device)
94
+
95
+
96
+ @app.post("/diarize", response_model=DiarizationResponse, tags=["Diarization"])
97
+ async def diarize_audio(
98
+ file: UploadFile = File(...),
99
+ num_speakers: Optional[int] = Form(None, ge=1, le=20),
100
+ ):
101
+ """Diarize an uploaded audio file. Returns timestamped speaker labels."""
102
+ allowed = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"}
103
+ suffix = Path(file.filename).suffix.lower()
104
+ if suffix not in allowed:
105
+ raise HTTPException(status_code=415, detail=f"Unsupported format '{suffix}'")
106
+
107
+ audio_bytes = await file.read()
108
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
109
+ tmp.write(audio_bytes)
110
+ tmp_path = tmp.name
111
+
112
+ try:
113
+ pipeline = get_pipeline()
114
+ result = pipeline.process(tmp_path, num_speakers=num_speakers)
115
+ return DiarizationResponse(
116
+ num_speakers=result.num_speakers,
117
+ audio_duration=result.audio_duration,
118
+ processing_time=result.processing_time,
119
+ sample_rate=result.sample_rate,
120
+ speakers=sorted(set(s.speaker for s in result.segments)),
121
+ segments=[SegmentOut(**s.to_dict()) for s in result.segments],
122
+ )
123
+ except Exception as e:
124
+ logger.error(f"Diarization failed: {traceback.format_exc()}")
125
+ raise HTTPException(status_code=500, detail=str(e))
126
+ finally:
127
+ Path(tmp_path).unlink(missing_ok=True)
128
+
129
+
130
+ @app.post("/diarize/url", response_model=DiarizationResponse, tags=["Diarization"])
131
+ async def diarize_from_url(
132
+ audio_url: str = Query(...),
133
+ num_speakers: Optional[int] = Query(None, ge=1, le=20),
134
+ ):
135
+ """Diarize audio from a URL."""
136
+ import httpx
137
+ try:
138
+ async with httpx.AsyncClient(timeout=60.0) as client:
139
+ resp = await client.get(audio_url)
140
+ resp.raise_for_status()
141
+ except Exception as e:
142
+ raise HTTPException(status_code=400, detail=f"Failed to fetch audio: {e}")
143
+
144
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
145
+ tmp.write(resp.content)
146
+ tmp_path = tmp.name
147
+
148
+ try:
149
+ pipeline = get_pipeline()
150
+ result = pipeline.process(tmp_path, num_speakers=num_speakers)
151
+ return DiarizationResponse(
152
+ num_speakers=result.num_speakers,
153
+ audio_duration=result.audio_duration,
154
+ processing_time=result.processing_time,
155
+ sample_rate=result.sample_rate,
156
+ speakers=sorted(set(s.speaker for s in result.segments)),
157
+ segments=[SegmentOut(**s.to_dict()) for s in result.segments],
158
+ )
159
+ except Exception as e:
160
+ raise HTTPException(status_code=500, detail=str(e))
161
+ finally:
162
+ Path(tmp_path).unlink(missing_ok=True)
163
+
164
+
165
+ @app.websocket("/ws/stream")
166
+ async def stream_diarization(websocket: WebSocket):
167
+ """Real-time streaming diarization via WebSocket."""
168
+ await websocket.accept()
169
+ import numpy as np
170
+ audio_buffer = bytearray()
171
+ sample_rate = 16000
172
+ num_speakers = None
173
+ chunk_count = 0
174
+
175
+ try:
176
+ config_msg = await websocket.receive_json()
177
+ sample_rate = config_msg.get("sample_rate", 16000)
178
+ num_speakers = config_msg.get("num_speakers", None)
179
+
180
+ await websocket.send_json({
181
+ "type": "progress",
182
+ "data": {"message": "Config received. Send audio chunks.", "chunks_received": 0},
183
+ })
184
+
185
+ while True:
186
+ try:
187
+ msg = await asyncio.wait_for(websocket.receive(), timeout=30.0)
188
+ except asyncio.TimeoutError:
189
+ await websocket.send_json({"type": "error", "data": {"message": "Timeout"}})
190
+ break
191
+
192
+ if "bytes" in msg:
193
+ audio_buffer.extend(msg["bytes"])
194
+ chunk_count += 1
195
+ await websocket.send_json({
196
+ "type": "progress",
197
+ "data": {"message": f"Received chunk {chunk_count}", "chunks_received": chunk_count},
198
+ })
199
+ elif "text" in msg:
200
+ import json
201
+ data = json.loads(msg["text"])
202
+ if data.get("type") == "eof":
203
+ break
204
+
205
+ if not audio_buffer:
206
+ await websocket.send_json({"type": "error", "data": {"message": "No audio received"}})
207
+ return
208
+
209
+ import torch
210
+ audio_np = np.frombuffer(audio_buffer, dtype=np.float32).copy()
211
+ audio_tensor = torch.from_numpy(audio_np)
212
+
213
+ await websocket.send_json({
214
+ "type": "progress",
215
+ "data": {"message": "Running diarization pipeline..."},
216
+ })
217
+
218
+ loop = asyncio.get_event_loop()
219
+ pipeline = get_pipeline()
220
+ result = await loop.run_in_executor(
221
+ None,
222
+ lambda: pipeline.process(audio_tensor, sample_rate=sample_rate, num_speakers=num_speakers),
223
+ )
224
+
225
+ for seg in result.segments:
226
+ await websocket.send_json({"type": "segment", "data": seg.to_dict()})
227
+
228
+ await websocket.send_json({
229
+ "type": "done",
230
+ "data": {
231
+ "num_speakers": result.num_speakers,
232
+ "total_segments": len(result.segments),
233
+ "audio_duration": result.audio_duration,
234
+ "processing_time": result.processing_time,
235
+ },
236
+ })
237
+
238
+ except WebSocketDisconnect:
239
+ logger.info("WebSocket client disconnected")
240
+ except Exception as e:
241
+ logger.error(f"WebSocket error: {traceback.format_exc()}")
242
+ try:
243
+ await websocket.send_json({"type": "error", "data": {"message": str(e)}})
244
+ except Exception:
245
+ pass
246
+
247
+
248
+ @app.get("/", response_class=HTMLResponse, include_in_schema=False)
249
+ async def serve_ui():
250
+ ui_path = Path("static/index.html")
251
+ if ui_path.exists():
252
+ return HTMLResponse(ui_path.read_text())
253
+ return HTMLResponse("<h1>Speaker Diarization API</h1><p><a href='/docs'>API Docs →</a></p>")
254
+
255
+
256
+ static_dir = Path("static")
257
+ if static_dir.exists():
258
+ app.mount("/static", StaticFiles(directory="static"), name="static")
app/pipeline.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speaker Diarization Pipeline
3
+ Combines: Voice Activity Detection → Segmentation → ECAPA-TDNN Embeddings → AHC Clustering
4
+ """
5
+
6
+ import torch
7
+ import torchaudio
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from typing import Optional, List, Union, BinaryIO
11
+ from dataclasses import dataclass, field
12
+ from loguru import logger
13
+
14
+ from models.embedder import EcapaTDNNEmbedder
15
+ from models.clusterer import SpeakerClusterer
16
+
17
+
18
+ @dataclass
19
+ class DiarizationSegment:
20
+ start: float
21
+ end: float
22
+ speaker: str
23
+ duration: float = field(init=False)
24
+
25
+ def __post_init__(self):
26
+ self.duration = round(self.end - self.start, 3)
27
+
28
+ def to_dict(self) -> dict:
29
+ return {
30
+ "start": round(self.start, 3),
31
+ "end": round(self.end, 3),
32
+ "duration": self.duration,
33
+ "speaker": self.speaker,
34
+ }
35
+
36
+
37
+ @dataclass
38
+ class DiarizationResult:
39
+ segments: List[DiarizationSegment]
40
+ num_speakers: int
41
+ audio_duration: float
42
+ processing_time: float
43
+ sample_rate: int
44
+
45
+ def to_dict(self) -> dict:
46
+ speakers = sorted(set(s.speaker for s in self.segments))
47
+ return {
48
+ "num_speakers": self.num_speakers,
49
+ "audio_duration": round(self.audio_duration, 3),
50
+ "processing_time": round(self.processing_time, 3),
51
+ "sample_rate": self.sample_rate,
52
+ "speakers": speakers,
53
+ "segments": [s.to_dict() for s in self.segments],
54
+ }
55
+
56
+
57
+ class DiarizationPipeline:
58
+ """
59
+ End-to-end speaker diarization pipeline.
60
+ 1. Audio loading & preprocessing
61
+ 2. Voice Activity Detection (VAD) via pyannote or energy-based fallback
62
+ 3. Sliding-window segmentation of speech regions
63
+ 4. ECAPA-TDNN speaker embedding extraction per segment
64
+ 5. Agglomerative Hierarchical Clustering
65
+ 6. Post-processing: merge consecutive same-speaker segments
66
+ """
67
+
68
+ SAMPLE_RATE = 16000
69
+ WINDOW_DURATION = 1.5
70
+ WINDOW_STEP = 0.75
71
+ MIN_SEGMENT_DURATION = 0.5
72
+
73
+ def __init__(
74
+ self,
75
+ device: str = "auto",
76
+ use_pyannote_vad: bool = True,
77
+ hf_token: Optional[str] = None,
78
+ num_speakers: Optional[int] = None,
79
+ max_speakers: int = 10,
80
+ cache_dir: str = "./model_cache",
81
+ ):
82
+ self.device = self._resolve_device(device)
83
+ self.use_pyannote_vad = use_pyannote_vad
84
+ self.hf_token = hf_token
85
+ self.num_speakers = num_speakers
86
+ self.max_speakers = max_speakers
87
+ self.cache_dir = Path(cache_dir)
88
+
89
+ self.embedder = EcapaTDNNEmbedder(device=self.device, cache_dir=str(cache_dir))
90
+ self.clusterer = SpeakerClusterer(max_speakers=max_speakers)
91
+
92
+ self._vad_pipeline = None
93
+ logger.info(f"DiarizationPipeline ready | device={self.device}")
94
+
95
+ def _resolve_device(self, device: str) -> str:
96
+ if device == "auto":
97
+ return "cuda" if torch.cuda.is_available() else "cpu"
98
+ return device
99
+
100
+ def _load_vad(self):
101
+ if self._vad_pipeline is not None:
102
+ return
103
+ try:
104
+ from pyannote.audio import Pipeline
105
+ logger.info("Loading pyannote VAD pipeline...")
106
+ self._vad_pipeline = Pipeline.from_pretrained(
107
+ "pyannote/voice-activity-detection",
108
+ use_auth_token=self.hf_token,
109
+ )
110
+ self._vad_pipeline.to(torch.device(self.device))
111
+ logger.success("Pyannote VAD loaded.")
112
+ except Exception as e:
113
+ logger.warning(f"Could not load pyannote VAD: {e}. Falling back to energy-based VAD.")
114
+ self._vad_pipeline = "energy"
115
+
116
+ def _energy_vad(
117
+ self, audio: torch.Tensor, frame_duration: float = 0.02, threshold_db: float = -40.0
118
+ ) -> List[tuple]:
119
+ """Simple energy-based VAD as fallback."""
120
+ frame_samples = int(frame_duration * self.SAMPLE_RATE)
121
+ audio_np = audio.numpy()
122
+ frames = [
123
+ audio_np[i : i + frame_samples]
124
+ for i in range(0, len(audio_np) - frame_samples, frame_samples)
125
+ ]
126
+
127
+ energies_db = []
128
+ for f in frames:
129
+ rms = np.sqrt(np.mean(f ** 2) + 1e-10)
130
+ energies_db.append(20 * np.log10(rms))
131
+
132
+ is_speech = np.array(energies_db) > threshold_db
133
+
134
+ speech_regions = []
135
+ in_speech = False
136
+ start = 0.0
137
+
138
+ for i, active in enumerate(is_speech):
139
+ t = i * frame_duration
140
+ if active and not in_speech:
141
+ start = t
142
+ in_speech = True
143
+ elif not active and in_speech:
144
+ speech_regions.append((start, t))
145
+ in_speech = False
146
+
147
+ if in_speech:
148
+ speech_regions.append((start, len(audio_np) / self.SAMPLE_RATE))
149
+
150
+ return speech_regions
151
+
152
+ def _get_speech_regions(self, audio: torch.Tensor) -> List[tuple]:
153
+ if self.use_pyannote_vad:
154
+ self._load_vad()
155
+
156
+ if self._vad_pipeline == "energy" or not self.use_pyannote_vad:
157
+ return self._energy_vad(audio)
158
+
159
+ try:
160
+ audio_dict = {
161
+ "waveform": audio.unsqueeze(0).to(self.device),
162
+ "sample_rate": self.SAMPLE_RATE,
163
+ }
164
+ vad_output = self._vad_pipeline(audio_dict)
165
+ regions = [(seg.start, seg.end) for seg in vad_output.get_timeline().support()]
166
+ logger.info(f"Pyannote VAD: {len(regions)} speech regions found")
167
+ return regions
168
+ except Exception as e:
169
+ logger.warning(f"Pyannote VAD failed: {e}. Using energy VAD.")
170
+ return self._energy_vad(audio)
171
+
172
+ def _sliding_window_segments(self, speech_regions: List[tuple]) -> List[tuple]:
173
+ segments = []
174
+ for region_start, region_end in speech_regions:
175
+ duration = region_end - region_start
176
+ if duration < self.MIN_SEGMENT_DURATION:
177
+ continue
178
+
179
+ t = region_start
180
+ while t + self.WINDOW_DURATION <= region_end:
181
+ segments.append((t, t + self.WINDOW_DURATION))
182
+ t += self.WINDOW_STEP
183
+
184
+ if region_end - t >= self.MIN_SEGMENT_DURATION:
185
+ segments.append((t, region_end))
186
+
187
+ return segments
188
+
189
+ def load_audio(self, path: Union[str, Path, BinaryIO]) -> tuple:
190
+ waveform, sample_rate = torchaudio.load(path)
191
+ return waveform, sample_rate
192
+
193
+ def process(
194
+ self,
195
+ audio: Union[str, Path, torch.Tensor],
196
+ sample_rate: int = None,
197
+ num_speakers: Optional[int] = None,
198
+ ) -> DiarizationResult:
199
+ """Run full diarization pipeline on audio."""
200
+ import time
201
+ t_start = time.time()
202
+
203
+ if isinstance(audio, (str, Path)):
204
+ waveform, sample_rate = self.load_audio(audio)
205
+ audio_tensor = waveform.squeeze(0)
206
+ else:
207
+ assert sample_rate is not None, "sample_rate required when passing tensor"
208
+ audio_tensor = audio.squeeze(0) if audio.dim() > 1 else audio
209
+
210
+ audio_duration = len(audio_tensor) / sample_rate
211
+ logger.info(f"Processing {audio_duration:.1f}s audio at {sample_rate}Hz")
212
+
213
+ processed = self.embedder.preprocess_audio(audio_tensor, sample_rate)
214
+
215
+ speech_regions = self._get_speech_regions(processed)
216
+ if not speech_regions:
217
+ logger.warning("No speech detected in audio.")
218
+ return DiarizationResult(
219
+ segments=[], num_speakers=0,
220
+ audio_duration=audio_duration,
221
+ processing_time=time.time() - t_start,
222
+ sample_rate=sample_rate,
223
+ )
224
+
225
+ windows = self._sliding_window_segments(speech_regions)
226
+ logger.info(f"Generated {len(windows)} embedding windows")
227
+
228
+ embeddings, valid_windows = self.embedder.extract_embeddings_from_segments(
229
+ processed, self.SAMPLE_RATE, windows
230
+ )
231
+
232
+ if len(embeddings) == 0:
233
+ logger.warning("No valid embeddings extracted.")
234
+ return DiarizationResult(
235
+ segments=[], num_speakers=0,
236
+ audio_duration=audio_duration,
237
+ processing_time=time.time() - t_start,
238
+ sample_rate=sample_rate,
239
+ )
240
+
241
+ k = num_speakers or self.num_speakers
242
+ labels = self.clusterer.cluster(embeddings, num_speakers=k)
243
+
244
+ merged = self.clusterer.merge_consecutive_same_speaker(valid_windows, labels)
245
+
246
+ speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)}
247
+ segments = [
248
+ DiarizationSegment(
249
+ start=start,
250
+ end=end,
251
+ speaker=speaker_names[spk_id],
252
+ )
253
+ for start, end, spk_id in merged
254
+ ]
255
+
256
+ num_unique = len(set(labels))
257
+ processing_time = time.time() - t_start
258
+
259
+ logger.success(
260
+ f"Diarization complete: {num_unique} speakers, "
261
+ f"{len(segments)} segments, {processing_time:.2f}s"
262
+ )
263
+
264
+ return DiarizationResult(
265
+ segments=segments,
266
+ num_speakers=num_unique,
267
+ audio_duration=audio_duration,
268
+ processing_time=processing_time,
269
+ sample_rate=sample_rate,
270
+ )
demo.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI Demo: Run speaker diarization on a local audio file.
3
+ Usage:
4
+ python demo.py --audio path/to/audio.wav
5
+ python demo.py --audio path/to/audio.wav --speakers 3
6
+ python demo.py --audio path/to/audio.wav --output result.json
7
+ """
8
+
9
+ import argparse
10
+ import json
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ sys.path.insert(0, str(Path(__file__).parent))
15
+
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser(description="Speaker Diarization CLI")
19
+ parser.add_argument("--audio", required=True, help="Path to audio file")
20
+ parser.add_argument("--speakers", type=int, default=None)
21
+ parser.add_argument("--output", default=None, help="Save JSON result")
22
+ parser.add_argument("--rttm", default=None, help="Save RTTM output")
23
+ parser.add_argument("--srt", default=None, help="Save SRT subtitle file")
24
+ parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
25
+ args = parser.parse_args()
26
+
27
+ audio_path = Path(args.audio)
28
+ if not audio_path.exists():
29
+ print(f"❌ Audio file not found: {audio_path}")
30
+ sys.exit(1)
31
+
32
+ print(f"🎙 Speaker Diarization Pipeline")
33
+ print(f" Audio : {audio_path}")
34
+ print(f" Speakers: {'auto-detect' if args.speakers is None else args.speakers}")
35
+ print()
36
+
37
+ from app.pipeline import DiarizationPipeline
38
+ from utils.audio import segments_to_rttm, segments_to_srt
39
+
40
+ pipeline = DiarizationPipeline(device=args.device, num_speakers=args.speakers)
41
+
42
+ print("⏳ Running diarization...")
43
+ result = pipeline.process(audio_path, num_speakers=args.speakers)
44
+
45
+ print(f"\n✅ Done in {result.processing_time:.2f}s")
46
+ print(f" Speakers : {result.num_speakers}")
47
+ print(f" Duration : {result.audio_duration:.2f}s")
48
+ print(f" Segments : {len(result.segments)}")
49
+ print()
50
+ print(f"{'START':>8} {'END':>8} {'DUR':>6} SPEAKER")
51
+ print("─" * 42)
52
+ for seg in result.segments:
53
+ print(f"{seg.start:8.3f} {seg.end:8.3f} {seg.duration:6.3f} {seg.speaker}")
54
+
55
+ if args.output:
56
+ Path(args.output).write_text(json.dumps(result.to_dict(), indent=2))
57
+ print(f"\n💾 JSON saved to: {args.output}")
58
+ if args.rttm:
59
+ Path(args.rttm).write_text(segments_to_rttm(result.segments, audio_path.stem))
60
+ print(f"💾 RTTM saved to: {args.rttm}")
61
+ if args.srt:
62
+ Path(args.srt).write_text(segments_to_srt(result.segments))
63
+ print(f"💾 SRT saved to: {args.srt}")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
models/__init__.py ADDED
File without changes
models/clusterer.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agglomerative Hierarchical Clustering (AHC) for speaker identity assignment.
3
+ Uses cosine similarity on ECAPA-TDNN embeddings to cluster segments by speaker.
4
+ """
5
+
6
+ import numpy as np
7
+ from typing import List, Tuple, Optional
8
+ from scipy.cluster.hierarchy import linkage, fcluster
9
+ from scipy.spatial.distance import squareform
10
+ from sklearn.metrics import silhouette_score
11
+ from loguru import logger
12
+
13
+
14
+ class SpeakerClusterer:
15
+ """
16
+ Agglomerative Hierarchical Clustering for speaker diarization.
17
+ Supports automatic speaker count estimation via silhouette analysis.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ linkage_method: str = "average",
23
+ distance_threshold: float = 0.7,
24
+ min_speakers: int = 1,
25
+ max_speakers: int = 10,
26
+ ):
27
+ self.linkage_method = linkage_method
28
+ self.distance_threshold = distance_threshold
29
+ self.min_speakers = min_speakers
30
+ self.max_speakers = max_speakers
31
+
32
+ def _cosine_distance_matrix(self, embeddings: np.ndarray) -> np.ndarray:
33
+ similarity = embeddings @ embeddings.T
34
+ distance = np.clip(1.0 - similarity, 0.0, 2.0)
35
+ return distance
36
+
37
+ def _estimate_num_speakers(self, embeddings: np.ndarray, linkage_matrix: np.ndarray) -> int:
38
+ n = len(embeddings)
39
+ if n <= 2:
40
+ return n
41
+
42
+ best_k = self.min_speakers
43
+ best_score = -1.0
44
+ upper_k = min(self.max_speakers, n - 1)
45
+
46
+ for k in range(max(2, self.min_speakers), upper_k + 1):
47
+ labels = fcluster(linkage_matrix, k, criterion="maxclust")
48
+ if len(np.unique(labels)) < 2:
49
+ continue
50
+ try:
51
+ score = silhouette_score(embeddings, labels, metric="cosine")
52
+ if score > best_score:
53
+ best_score = score
54
+ best_k = k
55
+ except Exception:
56
+ continue
57
+
58
+ logger.info(f"Optimal speaker count: {best_k} (silhouette={best_score:.4f})")
59
+ return best_k
60
+
61
+ def cluster(
62
+ self,
63
+ embeddings: np.ndarray,
64
+ num_speakers: Optional[int] = None,
65
+ ) -> np.ndarray:
66
+ """Cluster embeddings into speaker identities."""
67
+ n = len(embeddings)
68
+
69
+ if n == 0:
70
+ return np.array([], dtype=int)
71
+ if n == 1:
72
+ return np.array([0], dtype=int)
73
+
74
+ dist_matrix = self._cosine_distance_matrix(embeddings)
75
+ condensed = squareform(dist_matrix, checks=False)
76
+ Z = linkage(condensed, method=self.linkage_method)
77
+
78
+ if num_speakers is not None:
79
+ k = max(1, min(num_speakers, n))
80
+ else:
81
+ k = self._estimate_num_speakers(embeddings, Z)
82
+
83
+ labels = fcluster(Z, k, criterion="maxclust") - 1
84
+ return labels.astype(int)
85
+
86
+ def merge_consecutive_same_speaker(
87
+ self,
88
+ segments: List[Tuple[float, float]],
89
+ labels: np.ndarray,
90
+ gap_tolerance: float = 0.3,
91
+ ) -> List[Tuple[float, float, int]]:
92
+ """Merge consecutive segments assigned to the same speaker."""
93
+ if not segments:
94
+ return []
95
+
96
+ merged = []
97
+ current_start, current_end = segments[0]
98
+ current_label = labels[0]
99
+
100
+ for i in range(1, len(segments)):
101
+ seg_start, seg_end = segments[i]
102
+ seg_label = labels[i]
103
+ gap = seg_start - current_end
104
+
105
+ if seg_label == current_label and gap <= gap_tolerance:
106
+ current_end = seg_end
107
+ else:
108
+ merged.append((current_start, current_end, int(current_label)))
109
+ current_start, current_end = seg_start, seg_end
110
+ current_label = seg_label
111
+
112
+ merged.append((current_start, current_end, int(current_label)))
113
+ return merged
models/embedder.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
3
+ Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torchaudio
9
+ import numpy as np
10
+ from pathlib import Path
11
+ from typing import Union, List, Tuple
12
+ from loguru import logger
13
+
14
+
15
+ class EcapaTDNNEmbedder:
16
+ """
17
+ Speaker embedding extractor using ECAPA-TDNN architecture.
18
+ Produces 192-dim L2-normalized speaker embeddings per audio segment.
19
+ """
20
+
21
+ MODEL_SOURCE = "speechbrain/spkrec-ecapa-voxceleb"
22
+ SAMPLE_RATE = 16000
23
+ EMBEDDING_DIM = 192
24
+
25
+ def __init__(self, device: str = "auto", cache_dir: str = "./model_cache"):
26
+ self.device = self._resolve_device(device)
27
+ self.cache_dir = Path(cache_dir)
28
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
29
+ self._model = None
30
+ logger.info(f"EcapaTDNNEmbedder initialized on device: {self.device}")
31
+
32
+ def _resolve_device(self, device: str) -> str:
33
+ if device == "auto":
34
+ return "cuda" if torch.cuda.is_available() else "cpu"
35
+ return device
36
+
37
+ def _load_model(self):
38
+ if self._model is not None:
39
+ return
40
+
41
+ try:
42
+ import speechbrain.utils.fetching as _fetching
43
+ import shutil as _shutil
44
+ from pathlib import Path as _Path
45
+
46
+ def _patched_link(src, dst, local_strategy):
47
+ dst = _Path(dst)
48
+ src = _Path(src)
49
+ dst.parent.mkdir(parents=True, exist_ok=True)
50
+ if dst.exists() or dst.is_symlink():
51
+ dst.unlink()
52
+ _shutil.copy2(str(src), str(dst))
53
+
54
+ _fetching.link_with_strategy = _patched_link
55
+
56
+ from speechbrain.inference.classifiers import EncoderClassifier
57
+ logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
58
+
59
+ savedir = str(self.cache_dir / "ecapa_tdnn")
60
+ import os
61
+ os.makedirs(savedir, exist_ok=True)
62
+
63
+ self._model = EncoderClassifier.from_hparams(
64
+ source=self.MODEL_SOURCE,
65
+ savedir=savedir,
66
+ run_opts={"device": self.device},
67
+ )
68
+ self._model.eval()
69
+ logger.success("ECAPA-TDNN model loaded successfully.")
70
+ except ImportError:
71
+ raise ImportError("SpeechBrain not installed. Run: pip install speechbrain")
72
+
73
+ def preprocess_audio(
74
+ self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
75
+ ) -> torch.Tensor:
76
+ """Resample and normalize audio to 16kHz mono float32 tensor."""
77
+ if isinstance(audio, np.ndarray):
78
+ audio = torch.from_numpy(audio).float()
79
+
80
+ if audio.dim() == 1:
81
+ audio = audio.unsqueeze(0)
82
+
83
+ if audio.shape[0] > 1:
84
+ audio = audio.mean(dim=0, keepdim=True)
85
+
86
+ if sample_rate != self.SAMPLE_RATE:
87
+ resampler = torchaudio.transforms.Resample(
88
+ orig_freq=sample_rate, new_freq=self.SAMPLE_RATE
89
+ )
90
+ audio = resampler(audio)
91
+
92
+ max_val = audio.abs().max()
93
+ if max_val > 0:
94
+ audio = audio / max_val
95
+
96
+ return audio.squeeze(0)
97
+
98
+ def extract_embedding(self, audio: torch.Tensor) -> np.ndarray:
99
+ """
100
+ Extract L2-normalized ECAPA-TDNN embedding from a preprocessed audio tensor.
101
+ Returns L2-normalized embedding of shape (192,)
102
+ """
103
+ self._load_model()
104
+
105
+ with torch.no_grad():
106
+ audio_batch = audio.unsqueeze(0).to(self.device)
107
+ lengths = torch.tensor([1.0]).to(self.device)
108
+ embedding = self._model.encode_batch(audio_batch, lengths)
109
+ embedding = embedding.squeeze().cpu().numpy()
110
+
111
+ norm = np.linalg.norm(embedding)
112
+ if norm > 0:
113
+ embedding = embedding / norm
114
+
115
+ return embedding
116
+
117
+ def extract_embeddings_from_segments(
118
+ self,
119
+ audio: torch.Tensor,
120
+ sample_rate: int,
121
+ segments: List[Tuple[float, float]],
122
+ min_duration: float = 0.5,
123
+ ) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
124
+ """Extract embeddings for a list of (start, end) time segments."""
125
+ processed = self.preprocess_audio(audio, sample_rate)
126
+ embeddings = []
127
+ valid_segments = []
128
+
129
+ for start, end in segments:
130
+ duration = end - start
131
+ if duration < min_duration:
132
+ continue
133
+
134
+ start_sample = int(start * self.SAMPLE_RATE)
135
+ end_sample = int(end * self.SAMPLE_RATE)
136
+ segment_audio = processed[start_sample:end_sample]
137
+
138
+ if segment_audio.shape[0] == 0:
139
+ continue
140
+
141
+ emb = self.extract_embedding(segment_audio)
142
+ embeddings.append(emb)
143
+ valid_segments.append((start, end))
144
+
145
+ if not embeddings:
146
+ return np.empty((0, self.EMBEDDING_DIM)), []
147
+
148
+ return np.stack(embeddings), valid_segments
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch==2.4.0+cpu
3
+ torchaudio==2.4.0+cpu
4
+ speechbrain>=1.0.0
5
+ pyannote.audio>=3.1.0
6
+ transformers>=4.35.0
7
+ fastapi>=0.104.0
8
+ uvicorn[standard]>=0.24.0
9
+ python-multipart>=0.0.6
10
+ websockets>=12.0
11
+ numpy==1.26.4
12
+ scipy>=1.11.0
13
+ scikit-learn>=1.3.0
14
+ soundfile>=0.12.1
15
+ httpx>=0.25.0
16
+ python-dotenv>=1.0.0
17
+ loguru>=0.7.0
18
+ huggingface_hub==0.23.0
static/index.html ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
6
+ <title>Speaker Diarization System</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;500;700&family=Space+Grotesk:wght@300;400;600;700&display=swap');
9
+
10
+ :root {
11
+ --bg: #090c10;
12
+ --surface: #0f1318;
13
+ --surface2: #151b23;
14
+ --border: #1e2730;
15
+ --accent: #00d4ff;
16
+ --accent2: #7c3aed;
17
+ --green: #22d3a0;
18
+ --yellow: #f59e0b;
19
+ --red: #ef4444;
20
+ --text: #e2e8f0;
21
+ --muted: #64748b;
22
+ --font-mono: 'JetBrains Mono', monospace;
23
+ --font-sans: 'Space Grotesk', sans-serif;
24
+ }
25
+
26
+ * { margin: 0; padding: 0; box-sizing: border-box; }
27
+
28
+ body {
29
+ font-family: var(--font-sans);
30
+ background: var(--bg);
31
+ color: var(--text);
32
+ min-height: 100vh;
33
+ overflow-x: hidden;
34
+ }
35
+
36
+ /* Grid bg */
37
+ body::before {
38
+ content: '';
39
+ position: fixed;
40
+ inset: 0;
41
+ background-image:
42
+ linear-gradient(rgba(0, 212, 255, 0.03) 1px, transparent 1px),
43
+ linear-gradient(90deg, rgba(0, 212, 255, 0.03) 1px, transparent 1px);
44
+ background-size: 40px 40px;
45
+ pointer-events: none;
46
+ z-index: 0;
47
+ }
48
+
49
+ .container {
50
+ position: relative;
51
+ z-index: 1;
52
+ max-width: 1100px;
53
+ margin: 0 auto;
54
+ padding: 2rem 1.5rem;
55
+ }
56
+
57
+ header {
58
+ text-align: center;
59
+ margin-bottom: 3rem;
60
+ }
61
+
62
+ .badge {
63
+ display: inline-block;
64
+ background: rgba(0, 212, 255, 0.1);
65
+ border: 1px solid rgba(0, 212, 255, 0.3);
66
+ color: var(--accent);
67
+ font-family: var(--font-mono);
68
+ font-size: 0.72rem;
69
+ letter-spacing: 0.15em;
70
+ padding: 4px 12px;
71
+ border-radius: 100px;
72
+ margin-bottom: 1rem;
73
+ }
74
+
75
+ h1 {
76
+ font-size: clamp(2rem, 5vw, 3.2rem);
77
+ font-weight: 700;
78
+ letter-spacing: -0.02em;
79
+ background: linear-gradient(135deg, #fff 30%, var(--accent));
80
+ -webkit-background-clip: text;
81
+ -webkit-text-fill-color: transparent;
82
+ line-height: 1.15;
83
+ }
84
+
85
+ .subtitle {
86
+ color: var(--muted);
87
+ font-size: 1rem;
88
+ margin-top: 0.75rem;
89
+ font-weight: 300;
90
+ }
91
+
92
+ /* Cards */
93
+ .card {
94
+ background: var(--surface);
95
+ border: 1px solid var(--border);
96
+ border-radius: 12px;
97
+ padding: 1.5rem;
98
+ margin-bottom: 1.5rem;
99
+ }
100
+
101
+ .card-title {
102
+ font-size: 0.8rem;
103
+ font-family: var(--font-mono);
104
+ letter-spacing: 0.12em;
105
+ color: var(--accent);
106
+ text-transform: uppercase;
107
+ margin-bottom: 1.2rem;
108
+ display: flex;
109
+ align-items: center;
110
+ gap: 8px;
111
+ }
112
+
113
+ .card-title::before {
114
+ content: '▸';
115
+ font-size: 0.9rem;
116
+ }
117
+
118
+ /* Upload zone */
119
+ .upload-zone {
120
+ border: 2px dashed var(--border);
121
+ border-radius: 10px;
122
+ padding: 2.5rem;
123
+ text-align: center;
124
+ cursor: pointer;
125
+ transition: all 0.25s;
126
+ position: relative;
127
+ }
128
+
129
+ .upload-zone:hover, .upload-zone.drag-over {
130
+ border-color: var(--accent);
131
+ background: rgba(0, 212, 255, 0.04);
132
+ }
133
+
134
+ .upload-zone input[type="file"] {
135
+ position: absolute;
136
+ inset: 0;
137
+ opacity: 0;
138
+ cursor: pointer;
139
+ }
140
+
141
+ .upload-icon {
142
+ font-size: 2.5rem;
143
+ margin-bottom: 0.75rem;
144
+ opacity: 0.6;
145
+ }
146
+
147
+ .upload-text {
148
+ color: var(--muted);
149
+ font-size: 0.9rem;
150
+ }
151
+
152
+ .upload-text strong {
153
+ color: var(--accent);
154
+ }
155
+
156
+ /* Controls */
157
+ .controls {
158
+ display: grid;
159
+ grid-template-columns: 1fr 1fr auto;
160
+ gap: 1rem;
161
+ margin-top: 1rem;
162
+ align-items: end;
163
+ }
164
+
165
+ .field label {
166
+ display: block;
167
+ font-size: 0.75rem;
168
+ font-family: var(--font-mono);
169
+ color: var(--muted);
170
+ margin-bottom: 6px;
171
+ letter-spacing: 0.08em;
172
+ }
173
+
174
+ .field input, .field select {
175
+ width: 100%;
176
+ background: var(--surface2);
177
+ border: 1px solid var(--border);
178
+ color: var(--text);
179
+ font-family: var(--font-mono);
180
+ font-size: 0.9rem;
181
+ padding: 10px 12px;
182
+ border-radius: 8px;
183
+ outline: none;
184
+ transition: border-color 0.2s;
185
+ }
186
+
187
+ .field input:focus, .field select:focus {
188
+ border-color: var(--accent);
189
+ }
190
+
191
+ .btn-primary {
192
+ background: var(--accent);
193
+ color: #000;
194
+ font-family: var(--font-sans);
195
+ font-weight: 700;
196
+ font-size: 0.9rem;
197
+ border: none;
198
+ padding: 10px 24px;
199
+ border-radius: 8px;
200
+ cursor: pointer;
201
+ transition: all 0.2s;
202
+ white-space: nowrap;
203
+ }
204
+
205
+ .btn-primary:hover { filter: brightness(1.1); transform: translateY(-1px); }
206
+ .btn-primary:disabled { opacity: 0.4; cursor: not-allowed; transform: none; }
207
+
208
+ /* Progress */
209
+ .progress-bar {
210
+ height: 4px;
211
+ background: var(--border);
212
+ border-radius: 99px;
213
+ overflow: hidden;
214
+ margin-top: 1rem;
215
+ display: none;
216
+ }
217
+
218
+ .progress-fill {
219
+ height: 100%;
220
+ background: linear-gradient(90deg, var(--accent), var(--accent2));
221
+ width: 0%;
222
+ transition: width 0.4s;
223
+ animation: progress-pulse 1.5s ease-in-out infinite;
224
+ }
225
+
226
+ @keyframes progress-pulse {
227
+ 0%, 100% { opacity: 1; }
228
+ 50% { opacity: 0.6; }
229
+ }
230
+
231
+ /* Stats row */
232
+ .stats-row {
233
+ display: grid;
234
+ grid-template-columns: repeat(4, 1fr);
235
+ gap: 1rem;
236
+ margin-bottom: 1.5rem;
237
+ }
238
+
239
+ .stat {
240
+ background: var(--surface);
241
+ border: 1px solid var(--border);
242
+ border-radius: 10px;
243
+ padding: 1rem 1.2rem;
244
+ }
245
+
246
+ .stat-val {
247
+ font-family: var(--font-mono);
248
+ font-size: 1.8rem;
249
+ font-weight: 700;
250
+ color: var(--accent);
251
+ }
252
+
253
+ .stat-label {
254
+ font-size: 0.73rem;
255
+ color: var(--muted);
256
+ margin-top: 4px;
257
+ letter-spacing: 0.06em;
258
+ }
259
+
260
+ /* Timeline */
261
+ #timeline-container {
262
+ margin-bottom: 1rem;
263
+ }
264
+
265
+ .timeline-ruler {
266
+ display: flex;
267
+ justify-content: space-between;
268
+ font-family: var(--font-mono);
269
+ font-size: 0.68rem;
270
+ color: var(--muted);
271
+ margin-bottom: 6px;
272
+ padding: 0 2px;
273
+ }
274
+
275
+ .timeline-track {
276
+ height: 48px;
277
+ background: var(--surface2);
278
+ border-radius: 8px;
279
+ position: relative;
280
+ overflow: hidden;
281
+ border: 1px solid var(--border);
282
+ margin-bottom: 8px;
283
+ }
284
+
285
+ .track-label {
286
+ font-family: var(--font-mono);
287
+ font-size: 0.68rem;
288
+ color: var(--muted);
289
+ position: absolute;
290
+ left: 8px;
291
+ top: 50%;
292
+ transform: translateY(-50%);
293
+ z-index: 2;
294
+ text-shadow: 0 0 8px var(--bg);
295
+ }
296
+
297
+ .timeline-segment {
298
+ position: absolute;
299
+ height: 100%;
300
+ border-radius: 4px;
301
+ opacity: 0.9;
302
+ cursor: pointer;
303
+ transition: opacity 0.15s, filter 0.15s;
304
+ display: flex;
305
+ align-items: center;
306
+ justify-content: center;
307
+ font-family: var(--font-mono);
308
+ font-size: 0.65rem;
309
+ color: rgba(0,0,0,0.85);
310
+ font-weight: 700;
311
+ overflow: hidden;
312
+ white-space: nowrap;
313
+ }
314
+
315
+ .timeline-segment:hover {
316
+ opacity: 1;
317
+ filter: brightness(1.15);
318
+ z-index: 5;
319
+ }
320
+
321
+ /* Segment table */
322
+ .seg-table {
323
+ width: 100%;
324
+ border-collapse: collapse;
325
+ font-family: var(--font-mono);
326
+ font-size: 0.82rem;
327
+ }
328
+
329
+ .seg-table th {
330
+ text-align: left;
331
+ padding: 8px 12px;
332
+ font-size: 0.7rem;
333
+ letter-spacing: 0.1em;
334
+ color: var(--muted);
335
+ border-bottom: 1px solid var(--border);
336
+ }
337
+
338
+ .seg-table td {
339
+ padding: 9px 12px;
340
+ border-bottom: 1px solid rgba(255,255,255,0.04);
341
+ vertical-align: middle;
342
+ }
343
+
344
+ .seg-table tr:last-child td { border-bottom: none; }
345
+ .seg-table tr:hover td { background: rgba(255,255,255,0.02); }
346
+
347
+ .speaker-dot {
348
+ display: inline-block;
349
+ width: 8px;
350
+ height: 8px;
351
+ border-radius: 50%;
352
+ margin-right: 8px;
353
+ }
354
+
355
+ /* Log */
356
+ #log {
357
+ font-family: var(--font-mono);
358
+ font-size: 0.78rem;
359
+ color: var(--muted);
360
+ background: var(--surface2);
361
+ border-radius: 8px;
362
+ padding: 1rem;
363
+ max-height: 160px;
364
+ overflow-y: auto;
365
+ line-height: 1.7;
366
+ }
367
+
368
+ .log-info { color: var(--accent); }
369
+ .log-success{ color: var(--green); }
370
+ .log-error { color: var(--red); }
371
+ .log-warn { color: var(--yellow); }
372
+
373
+ .hidden { display: none !important; }
374
+
375
+ @media (max-width: 640px) {
376
+ .controls { grid-template-columns: 1fr; }
377
+ .stats-row { grid-template-columns: 1fr 1fr; }
378
+ }
379
+ </style>
380
+ </head>
381
+ <body>
382
+ <div class="container">
383
+
384
+ <header>
385
+ <div class="badge">ECAPA-TDNN + AHC · FASTAPI</div>
386
+ <h1>Speaker Diarization System</h1>
387
+ <p class="subtitle">Who spoke when — multi-speaker audio segmentation & labeling</p>
388
+ </header>
389
+
390
+ <!-- Upload Card -->
391
+ <div class="card">
392
+ <div class="card-title">Audio Input</div>
393
+ <div class="upload-zone" id="dropzone">
394
+ <input type="file" id="audioFile" accept=".wav,.mp3,.flac,.ogg,.m4a,.webm" />
395
+ <div class="upload-icon">🎙</div>
396
+ <div class="upload-text"><strong>Drop audio file</strong> or click to browse</div>
397
+ <div class="upload-text" style="margin-top:4px;font-size:0.78rem;" id="filename-display">WAV · MP3 · FLAC · OGG · M4A</div>
398
+ </div>
399
+
400
+ <div class="controls">
401
+ <div class="field">
402
+ <label>API ENDPOINT</label>
403
+ <input type="text" id="apiUrl" value="http://localhost:8000/diarize" />
404
+ </div>
405
+ <div class="field">
406
+ <label>SPEAKERS (blank = auto)</label>
407
+ <input type="number" id="numSpeakers" min="1" max="20" placeholder="auto-detect" />
408
+ </div>
409
+ <button class="btn-primary" id="runBtn" onclick="runDiarization()" disabled>
410
+ ▶ Run
411
+ </button>
412
+ </div>
413
+
414
+ <div class="progress-bar" id="progressBar">
415
+ <div class="progress-fill" id="progressFill"></div>
416
+ </div>
417
+ </div>
418
+
419
+ <!-- Results (hidden until run) -->
420
+ <div id="results" class="hidden">
421
+
422
+ <div class="stats-row" id="statsRow"></div>
423
+
424
+ <!-- Timeline -->
425
+ <div class="card">
426
+ <div class="card-title">Speaker Timeline</div>
427
+ <div class="timeline-ruler" id="timelineRuler"></div>
428
+ <div id="timelineTracks"></div>
429
+ </div>
430
+
431
+ <!-- Segment Table -->
432
+ <div class="card">
433
+ <div class="card-title">Segments</div>
434
+ <div style="overflow-x:auto;">
435
+ <table class="seg-table">
436
+ <thead>
437
+ <tr>
438
+ <th>#</th><th>SPEAKER</th><th>START</th><th>END</th><th>DURATION</th>
439
+ </tr>
440
+ </thead>
441
+ <tbody id="segTableBody"></tbody>
442
+ </table>
443
+ </div>
444
+ </div>
445
+
446
+ </div>
447
+
448
+ <!-- Log -->
449
+ <div class="card">
450
+ <div class="card-title">Log</div>
451
+ <div id="log"><span class="log-info">// Ready. Upload an audio file to begin.</span></div>
452
+ </div>
453
+
454
+ </div>
455
+
456
+ <script>
457
+ const SPEAKER_COLORS = [
458
+ '#00d4ff','#7c3aed','#22d3a0','#f59e0b',
459
+ '#ec4899','#3b82f6','#84cc16','#f97316',
460
+ '#06b6d4','#a855f7',
461
+ ];
462
+
463
+ let selectedFile = null;
464
+
465
+ // ── File Handling ──────────────────────────────────────────────────────────
466
+ document.getElementById('audioFile').addEventListener('change', (e) => {
467
+ const file = e.target.files[0];
468
+ if (!file) return;
469
+ selectedFile = file;
470
+ document.getElementById('filename-display').textContent = `📁 ${file.name} (${(file.size/1024/1024).toFixed(2)} MB)`;
471
+ document.getElementById('runBtn').disabled = false;
472
+ log(`File selected: ${file.name}`, 'info');
473
+ });
474
+
475
+ // Drag & Drop
476
+ const dz = document.getElementById('dropzone');
477
+ dz.addEventListener('dragover', e => { e.preventDefault(); dz.classList.add('drag-over'); });
478
+ dz.addEventListener('dragleave', () => dz.classList.remove('drag-over'));
479
+ dz.addEventListener('drop', e => {
480
+ e.preventDefault();
481
+ dz.classList.remove('drag-over');
482
+ const file = e.dataTransfer.files[0];
483
+ if (file) {
484
+ document.getElementById('audioFile').files = e.dataTransfer.files;
485
+ document.getElementById('audioFile').dispatchEvent(new Event('change'));
486
+ }
487
+ });
488
+
489
+ // ── Log ────────────────────────────────────────────────────────────────────
490
+ function log(msg, type = '') {
491
+ const el = document.getElementById('log');
492
+ const cls = type ? `log-${type}` : '';
493
+ const ts = new Date().toLocaleTimeString('en', { hour12: false });
494
+ el.innerHTML += `<br><span class="${cls}">[${ts}] ${msg}</span>`;
495
+ el.scrollTop = el.scrollHeight;
496
+ }
497
+
498
+ // ── Run Diarization ────────────────────────────────────────────────────────
499
+ async function runDiarization() {
500
+ if (!selectedFile) return;
501
+
502
+ const btn = document.getElementById('runBtn');
503
+ const pb = document.getElementById('progressBar');
504
+ const pf = document.getElementById('progressFill');
505
+
506
+ btn.disabled = true;
507
+ pb.style.display = 'block';
508
+ pf.style.width = '20%';
509
+ document.getElementById('results').classList.add('hidden');
510
+ log('Uploading audio and running diarization...', 'info');
511
+
512
+ const formData = new FormData();
513
+ formData.append('file', selectedFile);
514
+ const ns = document.getElementById('numSpeakers').value;
515
+ if (ns) formData.append('num_speakers', ns);
516
+
517
+ const url = document.getElementById('apiUrl').value;
518
+
519
+ try {
520
+ pf.style.width = '50%';
521
+ const resp = await fetch(url, { method: 'POST', body: formData });
522
+
523
+ pf.style.width = '90%';
524
+
525
+ if (!resp.ok) {
526
+ const err = await resp.json().catch(() => ({ detail: resp.statusText }));
527
+ throw new Error(err.detail || `HTTP ${resp.status}`);
528
+ }
529
+
530
+ const data = await resp.json();
531
+ pf.style.width = '100%';
532
+
533
+ log(`Done — ${data.num_speakers} speaker(s), ${data.segments.length} segments, ${data.processing_time.toFixed(2)}s`, 'success');
534
+ renderResults(data);
535
+
536
+ } catch (e) {
537
+ log(`Error: ${e.message}`, 'error');
538
+ } finally {
539
+ setTimeout(() => { pb.style.display = 'none'; pf.style.width = '0%'; }, 800);
540
+ btn.disabled = false;
541
+ }
542
+ }
543
+
544
+ // ── Render Results ─────────────────────────────────────────────────────────
545
+ function renderResults(data) {
546
+ document.getElementById('results').classList.remove('hidden');
547
+
548
+ // Stats
549
+ const stats = [
550
+ { val: data.num_speakers, label: 'SPEAKERS' },
551
+ { val: data.segments.length, label: 'SEGMENTS' },
552
+ { val: data.audio_duration.toFixed(1) + 's', label: 'DURATION' },
553
+ { val: data.processing_time.toFixed(2) + 's', label: 'PROC TIME' },
554
+ ];
555
+ document.getElementById('statsRow').innerHTML = stats.map(s =>
556
+ `<div class="stat">
557
+ <div class="stat-val">${s.val}</div>
558
+ <div class="stat-label">${s.label}</div>
559
+ </div>`
560
+ ).join('');
561
+
562
+ // Build speaker→color map
563
+ const colorMap = {};
564
+ data.speakers.forEach((sp, i) => {
565
+ colorMap[sp] = SPEAKER_COLORS[i % SPEAKER_COLORS.length];
566
+ });
567
+
568
+ // Timeline
569
+ const duration = data.audio_duration;
570
+ const ruler = document.getElementById('timelineRuler');
571
+ const ticks = 8;
572
+ ruler.innerHTML = Array.from({ length: ticks + 1 }, (_, i) =>
573
+ `<span>${fmtTime(duration * i / ticks)}</span>`
574
+ ).join('');
575
+
576
+ // One track per speaker
577
+ const tracksEl = document.getElementById('timelineTracks');
578
+ tracksEl.innerHTML = '';
579
+
580
+ data.speakers.forEach(sp => {
581
+ const track = document.createElement('div');
582
+ track.className = 'timeline-track';
583
+ track.innerHTML = `<span class="track-label">${sp}</span>`;
584
+
585
+ const spSegs = data.segments.filter(s => s.speaker === sp);
586
+ spSegs.forEach(seg => {
587
+ const left = (seg.start / duration) * 100;
588
+ const width = (seg.duration / duration) * 100;
589
+ const seg_el = document.createElement('div');
590
+ seg_el.className = 'timeline-segment';
591
+ seg_el.style.cssText = `left:${left}%;width:${Math.max(width, 0.3)}%;background:${colorMap[sp]};`;
592
+ seg_el.title = `${sp}: ${seg.start.toFixed(2)}s – ${seg.end.toFixed(2)}s`;
593
+ if (width > 3) seg_el.textContent = fmtTime(seg.duration);
594
+ track.appendChild(seg_el);
595
+ });
596
+
597
+ tracksEl.appendChild(track);
598
+ });
599
+
600
+ // Segment table
601
+ const tbody = document.getElementById('segTableBody');
602
+ tbody.innerHTML = data.segments.map((seg, i) =>
603
+ `<tr>
604
+ <td style="color:var(--muted)">${i + 1}</td>
605
+ <td>
606
+ <span class="speaker-dot" style="background:${colorMap[seg.speaker]}"></span>
607
+ ${seg.speaker}
608
+ </td>
609
+ <td>${seg.start.toFixed(3)}</td>
610
+ <td>${seg.end.toFixed(3)}</td>
611
+ <td>${seg.duration.toFixed(3)}s</td>
612
+ </tr>`
613
+ ).join('');
614
+ }
615
+
616
+ function fmtTime(sec) {
617
+ const m = Math.floor(sec / 60);
618
+ const s = (sec % 60).toFixed(1).padStart(4, '0');
619
+ return `${m}:${s}`;
620
+ }
621
+ </script>
622
+ </body>
623
+ </html>
utils/__init__.py ADDED
File without changes
utils/audio.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio utility functions for the diarization pipeline."""
2
+
3
+ import io
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+ from pathlib import Path
8
+ from typing import Union, Tuple, Iterator
9
+ from loguru import logger
10
+
11
+ SUPPORTED_FORMATS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"}
12
+ TARGET_SAMPLE_RATE = 16000
13
+
14
+
15
+ def load_audio(source, target_sr: int = TARGET_SAMPLE_RATE) -> Tuple[torch.Tensor, int]:
16
+ if isinstance(source, bytes):
17
+ source = io.BytesIO(source)
18
+ waveform, sr = torchaudio.load(source)
19
+ if waveform.shape[0] > 1:
20
+ waveform = waveform.mean(dim=0, keepdim=True)
21
+ if sr != target_sr:
22
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
23
+ waveform = resampler(waveform)
24
+ sr = target_sr
25
+ return waveform.squeeze(0), sr
26
+
27
+
28
+ def pcm_bytes_to_tensor(data: bytes, dtype=np.float32) -> torch.Tensor:
29
+ arr = np.frombuffer(data, dtype=dtype).copy()
30
+ return torch.from_numpy(arr)
31
+
32
+
33
+ def chunk_audio(audio, sample_rate, chunk_duration=30.0, overlap=1.0):
34
+ chunk_samples = int(chunk_duration * sample_rate)
35
+ step_samples = int((chunk_duration - overlap) * sample_rate)
36
+ n = len(audio)
37
+ for start in range(0, n, step_samples):
38
+ end = min(start + chunk_samples, n)
39
+ yield audio[start:end], start / sample_rate
40
+ if end == n:
41
+ break
42
+
43
+
44
+ def format_timestamp(seconds: float) -> str:
45
+ hours = int(seconds // 3600)
46
+ minutes = int((seconds % 3600) // 60)
47
+ secs = seconds % 60
48
+ return f"{hours:02d}:{minutes:02d}:{secs:06.3f}"
49
+
50
+
51
+ def segments_to_rttm(segments, audio_name: str = "audio") -> str:
52
+ lines = []
53
+ for seg in segments:
54
+ duration = seg.end - seg.start
55
+ lines.append(
56
+ f"SPEAKER {audio_name} 1 {seg.start:.3f} {duration:.3f} "
57
+ f"<NA> <NA> {seg.speaker} <NA> <NA>"
58
+ )
59
+ return "\n".join(lines)
60
+
61
+
62
+ def segments_to_srt(segments) -> str:
63
+ lines = []
64
+ for i, seg in enumerate(segments, 1):
65
+ start = format_timestamp(seg.start).replace(".", ",")
66
+ end = format_timestamp(seg.end).replace(".", ",")
67
+ lines.append(f"{i}\n{start} --> {end}\n[{seg.speaker}]\n")
68
+ return "\n".join(lines)