tantk commited on
Commit
6216b68
·
verified ·
1 Parent(s): 6804bc6

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. Dockerfile +17 -17
  2. README.md +84 -84
  3. app.py +273 -273
  4. requirements.txt +13 -12
Dockerfile CHANGED
@@ -1,17 +1,17 @@
1
- FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime
2
-
3
- RUN apt-get update && apt-get install -y --no-install-recommends \
4
- ffmpeg libsndfile1 git \
5
- && rm -rf /var/lib/apt/lists/*
6
-
7
- WORKDIR /app
8
-
9
- COPY requirements.txt .
10
- RUN pip install --no-cache-dir -r requirements.txt
11
-
12
- COPY app.py .
13
-
14
- # HF Inference Endpoints require port 80
15
- EXPOSE 80
16
-
17
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
 
1
+ FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime
2
+
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ ffmpeg libsndfile1 git \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ WORKDIR /app
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ COPY app.py .
13
+
14
+ # HF Inference Endpoints require port 80
15
+ EXPOSE 80
16
+
17
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
README.md CHANGED
@@ -1,84 +1,84 @@
1
- ---
2
- tags:
3
- - audio
4
- - speaker-diarization
5
- - speaker-embedding
6
- - pyannote
7
- - funasr
8
- - meetingmind
9
- library_name: custom
10
- pipeline_tag: audio-classification
11
- ---
12
-
13
- # MeetingMind GPU Service
14
-
15
- GPU-accelerated speaker diarization and embedding extraction for the MeetingMind pipeline. Runs as an HF Inference Endpoint on a T4 GPU with scale-to-zero.
16
-
17
- ## API
18
-
19
- ### `GET /health`
20
-
21
- Returns service status and GPU availability.
22
-
23
- ```bash
24
- curl -H "Authorization: Bearer $HF_TOKEN" $ENDPOINT_URL/health
25
- ```
26
-
27
- ```json
28
- {"status": "ok", "gpu_available": true}
29
- ```
30
-
31
- ### `POST /diarize`
32
-
33
- Speaker diarization using pyannote v4. Accepts any audio format (FLAC, WAV, MP3, etc.).
34
-
35
- ```bash
36
- curl -X POST \
37
- -H "Authorization: Bearer $HF_TOKEN" \
38
- -F audio=@meeting.flac \
39
- -F min_speakers=2 \
40
- -F max_speakers=6 \
41
- $ENDPOINT_URL/diarize
42
- ```
43
-
44
- ```json
45
- {
46
- "segments": [
47
- {"speaker": "SPEAKER_00", "start": 0.5, "end": 3.2, "duration": 2.7},
48
- {"speaker": "SPEAKER_01", "start": 3.4, "end": 7.1, "duration": 3.7}
49
- ]
50
- }
51
- ```
52
-
53
- ### `POST /embed`
54
-
55
- Speaker embedding extraction using FunASR CAM++. Returns L2-normalized 192-dim vectors for voiceprint matching.
56
-
57
- ```bash
58
- curl -X POST \
59
- -H "Authorization: Bearer $HF_TOKEN" \
60
- -F audio=@meeting.flac \
61
- -F start_time=1.0 \
62
- -F end_time=5.0 \
63
- $ENDPOINT_URL/embed
64
- ```
65
-
66
- ```json
67
- {"embedding": [0.012, -0.034, ...], "dim": 192}
68
- ```
69
-
70
- ## Environment Variables
71
-
72
- | Variable | Default | Description |
73
- |---|---|---|
74
- | `HF_TOKEN` | (required) | Hugging Face token for pyannote model access |
75
- | `PYANNOTE_MIN_SPEAKERS` | `1` | Minimum speakers for diarization |
76
- | `PYANNOTE_MAX_SPEAKERS` | `10` | Maximum speakers for diarization |
77
-
78
- ## Architecture
79
-
80
- - **Base image**: `pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime`
81
- - **Diarization**: pyannote/speaker-diarization-community-1 (~2GB VRAM)
82
- - **Embeddings**: FunASR CAM++ sv_zh-cn_16k-common (~200MB)
83
- - **Total VRAM**: ~3GB (fits T4 16GB with headroom)
84
- - **Scale-to-zero**: 15 min idle timeout (~$0.60/hr when active)
 
1
+ ---
2
+ tags:
3
+ - audio
4
+ - speaker-diarization
5
+ - speaker-embedding
6
+ - pyannote
7
+ - funasr
8
+ - meetingmind
9
+ library_name: custom
10
+ pipeline_tag: audio-classification
11
+ ---
12
+
13
+ # MeetingMind GPU Service
14
+
15
+ GPU-accelerated speaker diarization and embedding extraction for the MeetingMind pipeline. Runs as an HF Inference Endpoint on a T4 GPU with scale-to-zero.
16
+
17
+ ## API
18
+
19
+ ### `GET /health`
20
+
21
+ Returns service status and GPU availability.
22
+
23
+ ```bash
24
+ curl -H "Authorization: Bearer $HF_TOKEN" $ENDPOINT_URL/health
25
+ ```
26
+
27
+ ```json
28
+ {"status": "ok", "gpu_available": true}
29
+ ```
30
+
31
+ ### `POST /diarize`
32
+
33
+ Speaker diarization using pyannote v4. Accepts any audio format (FLAC, WAV, MP3, etc.).
34
+
35
+ ```bash
36
+ curl -X POST \
37
+ -H "Authorization: Bearer $HF_TOKEN" \
38
+ -F audio=@meeting.flac \
39
+ -F min_speakers=2 \
40
+ -F max_speakers=6 \
41
+ $ENDPOINT_URL/diarize
42
+ ```
43
+
44
+ ```json
45
+ {
46
+ "segments": [
47
+ {"speaker": "SPEAKER_00", "start": 0.5, "end": 3.2, "duration": 2.7},
48
+ {"speaker": "SPEAKER_01", "start": 3.4, "end": 7.1, "duration": 3.7}
49
+ ]
50
+ }
51
+ ```
52
+
53
+ ### `POST /embed`
54
+
55
+ Speaker embedding extraction using FunASR CAM++. Returns L2-normalized 192-dim vectors for voiceprint matching.
56
+
57
+ ```bash
58
+ curl -X POST \
59
+ -H "Authorization: Bearer $HF_TOKEN" \
60
+ -F audio=@meeting.flac \
61
+ -F start_time=1.0 \
62
+ -F end_time=5.0 \
63
+ $ENDPOINT_URL/embed
64
+ ```
65
+
66
+ ```json
67
+ {"embedding": [0.012, -0.034, ...], "dim": 192}
68
+ ```
69
+
70
+ ## Environment Variables
71
+
72
+ | Variable | Default | Description |
73
+ |---|---|---|
74
+ | `HF_TOKEN` | (required) | Hugging Face token for pyannote model access |
75
+ | `PYANNOTE_MIN_SPEAKERS` | `1` | Minimum speakers for diarization |
76
+ | `PYANNOTE_MAX_SPEAKERS` | `10` | Maximum speakers for diarization |
77
+
78
+ ## Architecture
79
+
80
+ - **Base image**: `pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime`
81
+ - **Diarization**: pyannote/speaker-diarization-community-1 (~2GB VRAM)
82
+ - **Embeddings**: FunASR CAM++ sv_zh-cn_16k-common (~200MB)
83
+ - **Total VRAM**: ~3GB (fits T4 16GB with headroom)
84
+ - **Scale-to-zero**: 15 min idle timeout (~$0.60/hr when active)
app.py CHANGED
@@ -1,273 +1,273 @@
1
- """
2
- Slim GPU service for HF Inference Endpoints.
3
- Exposes /diarize, /embed, /transcribe, and /transcribe/stream endpoints.
4
- """
5
-
6
- import io
7
- import json
8
- import logging
9
- import os
10
- import re
11
- import threading
12
- from contextlib import asynccontextmanager
13
-
14
- import numpy as np
15
- import soundfile as sf
16
- import librosa
17
- import torch
18
- from fastapi import FastAPI, File, Form, UploadFile
19
- from fastapi.responses import JSONResponse
20
- from pydub import AudioSegment
21
- from sse_starlette.sse import EventSourceResponse
22
-
23
- logger = logging.getLogger("gpu_service")
24
-
25
- # ---------------------------------------------------------------------------
26
- # Config
27
- # ---------------------------------------------------------------------------
28
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
29
- PYANNOTE_MODEL = "pyannote/speaker-diarization-community-1"
30
- FUNASR_MODEL = "iic/speech_campplus_sv_zh-cn_16k-common"
31
- PYANNOTE_MIN_SPEAKERS = int(os.environ.get("PYANNOTE_MIN_SPEAKERS", "1"))
32
- PYANNOTE_MAX_SPEAKERS = int(os.environ.get("PYANNOTE_MAX_SPEAKERS", "10"))
33
- TARGET_SR = 16000
34
-
35
- # ---------------------------------------------------------------------------
36
- # Singletons
37
- # ---------------------------------------------------------------------------
38
- _diarize_pipeline = None
39
- _embed_model = None
40
- _voxtral_model = None
41
- _voxtral_processor = None
42
-
43
- VOXTRAL_MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
44
-
45
- # Markers to strip from Voxtral output
46
- _MARKER_RE = re.compile(r"\[STREAMING_PAD\]|\[STREAMING_WORD\]")
47
-
48
-
49
- def _load_diarize_pipeline():
50
- global _diarize_pipeline
51
- if _diarize_pipeline is None:
52
- from pyannote.audio import Pipeline as PyannotePipeline
53
-
54
- _diarize_pipeline = PyannotePipeline.from_pretrained(
55
- PYANNOTE_MODEL, token=HF_TOKEN
56
- )
57
- _diarize_pipeline = _diarize_pipeline.to(torch.device("cuda"))
58
- return _diarize_pipeline
59
-
60
-
61
- def _load_embed_model():
62
- global _embed_model
63
- if _embed_model is None:
64
- from funasr import AutoModel
65
-
66
- _embed_model = AutoModel(model=FUNASR_MODEL)
67
- return _embed_model
68
-
69
-
70
- def _load_voxtral():
71
- """Lazy-load Voxtral model and processor (first call only)."""
72
- global _voxtral_model, _voxtral_processor
73
- if _voxtral_model is None:
74
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
75
-
76
- logger.info("Loading Voxtral model %s ...", VOXTRAL_MODEL_ID)
77
- _voxtral_processor = AutoProcessor.from_pretrained(
78
- VOXTRAL_MODEL_ID, trust_remote_code=True
79
- )
80
- _voxtral_model = AutoModelForSpeechSeq2Seq.from_pretrained(
81
- VOXTRAL_MODEL_ID, torch_dtype=torch.float16, trust_remote_code=True
82
- ).to("cuda")
83
- logger.info("Voxtral model loaded.")
84
- return _voxtral_model, _voxtral_processor
85
-
86
-
87
- def _clean_voxtral_text(text: str) -> str:
88
- """Strip Voxtral streaming markers and collapse whitespace."""
89
- text = _MARKER_RE.sub("", text)
90
- return " ".join(text.split()).strip()
91
-
92
-
93
- # ---------------------------------------------------------------------------
94
- # Audio helpers
95
- # ---------------------------------------------------------------------------
96
- def prepare_audio(raw_bytes: bytes) -> np.ndarray:
97
- """Read any audio format -> float32 mono @ 16 kHz."""
98
- audio, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32")
99
- if audio.ndim > 1:
100
- audio = audio.mean(axis=1)
101
- if sr != TARGET_SR:
102
- audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR)
103
- return audio
104
-
105
-
106
- def prepare_audio_slice(raw_bytes: bytes, start_time: float, end_time: float) -> np.ndarray:
107
- """Read audio, slice by time, return float32 mono @ 16 kHz."""
108
- seg = AudioSegment.from_file(io.BytesIO(raw_bytes))
109
- seg = seg[int(start_time * 1000):int(end_time * 1000)]
110
- seg = seg.set_frame_rate(TARGET_SR).set_channels(1).set_sample_width(2)
111
- return np.array(seg.get_array_of_samples(), dtype=np.float32) / 32768.0
112
-
113
-
114
- # ---------------------------------------------------------------------------
115
- # App
116
- # ---------------------------------------------------------------------------
117
- @asynccontextmanager
118
- async def lifespan(app: FastAPI):
119
- # Warm up diarization pipeline at startup (embedding model lazy-loads)
120
- _load_diarize_pipeline()
121
- yield
122
-
123
-
124
- app = FastAPI(title="GPU Service (HF Endpoint)", lifespan=lifespan)
125
-
126
-
127
- @app.get("/health")
128
- async def health():
129
- return {"status": "ok", "gpu_available": torch.cuda.is_available()}
130
-
131
-
132
- @app.post("/diarize")
133
- async def diarize(
134
- audio: UploadFile = File(...),
135
- min_speakers: int | None = Form(None),
136
- max_speakers: int | None = Form(None),
137
- ):
138
- try:
139
- raw = await audio.read()
140
- audio_16k = prepare_audio(raw)
141
-
142
- pipeline = _load_diarize_pipeline()
143
- waveform = torch.from_numpy(audio_16k).unsqueeze(0).float()
144
- input_data = {"waveform": waveform, "sample_rate": TARGET_SR}
145
-
146
- result = pipeline(
147
- input_data,
148
- min_speakers=min_speakers or PYANNOTE_MIN_SPEAKERS,
149
- max_speakers=max_speakers or PYANNOTE_MAX_SPEAKERS,
150
- )
151
- # pyannote v4 compat
152
- diarization = getattr(result, "speaker_diarization", result)
153
-
154
- segments = []
155
- for turn, _, speaker in diarization.itertracks(yield_label=True):
156
- segments.append(
157
- {
158
- "speaker": speaker,
159
- "start": round(turn.start, 3),
160
- "end": round(turn.end, 3),
161
- "duration": round(turn.end - turn.start, 3),
162
- }
163
- )
164
- segments.sort(key=lambda s: s["start"])
165
- return {"segments": segments}
166
- except Exception as e:
167
- return JSONResponse(status_code=500, content={"error": str(e)})
168
-
169
-
170
- @app.post("/embed")
171
- async def embed(
172
- audio: UploadFile = File(...),
173
- start_time: float | None = Form(None),
174
- end_time: float | None = Form(None),
175
- ):
176
- try:
177
- raw = await audio.read()
178
- if start_time is not None and end_time is not None:
179
- audio_16k = prepare_audio_slice(raw, start_time, end_time)
180
- else:
181
- audio_16k = prepare_audio(raw)
182
-
183
- model = _load_embed_model()
184
- result = model.generate(input=audio_16k, output_dir=None)
185
- raw_emb = result[0]["spk_embedding"]
186
- if hasattr(raw_emb, "cpu"):
187
- raw_emb = raw_emb.cpu().numpy()
188
- emb = np.array(raw_emb).flatten()
189
-
190
- # L2-normalize
191
- norm = np.linalg.norm(emb)
192
- if norm > 0:
193
- emb = emb / norm
194
-
195
- return {"embedding": emb.tolist(), "dim": len(emb)}
196
- except Exception as e:
197
- return JSONResponse(status_code=500, content={"error": str(e)})
198
-
199
-
200
- @app.post("/transcribe")
201
- async def transcribe(
202
- audio: UploadFile = File(...),
203
- prompt: str = Form("Transcribe this audio."),
204
- ):
205
- try:
206
- raw = await audio.read()
207
- audio_16k = prepare_audio(raw)
208
-
209
- model, processor = _load_voxtral()
210
- inputs = processor(
211
- audios=audio_16k,
212
- sampling_rate=TARGET_SR,
213
- text=prompt,
214
- return_tensors="pt",
215
- ).to("cuda")
216
-
217
- output_ids = model.generate(**inputs, max_new_tokens=1024)
218
- text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
219
- text = _clean_voxtral_text(text)
220
-
221
- return {"text": text}
222
- except Exception as e:
223
- logger.exception("Transcription failed")
224
- return JSONResponse(status_code=500, content={"error": str(e)})
225
-
226
-
227
- @app.post("/transcribe/stream")
228
- async def transcribe_stream(
229
- audio: UploadFile = File(...),
230
- prompt: str = Form("Transcribe this audio."),
231
- ):
232
- try:
233
- raw = await audio.read()
234
- audio_16k = prepare_audio(raw)
235
- except Exception as e:
236
- logger.exception("Audio preparation failed")
237
- return JSONResponse(status_code=500, content={"error": str(e)})
238
-
239
- async def event_generator():
240
- try:
241
- from transformers import TextIteratorStreamer
242
-
243
- model, processor = _load_voxtral()
244
- inputs = processor(
245
- audios=audio_16k,
246
- sampling_rate=TARGET_SR,
247
- text=prompt,
248
- return_tensors="pt",
249
- ).to("cuda")
250
-
251
- streamer = TextIteratorStreamer(
252
- processor.tokenizer, skip_prompt=True, skip_special_tokens=True
253
- )
254
- gen_kwargs = {**inputs, "max_new_tokens": 1024, "streamer": streamer}
255
-
256
- thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
257
- thread.start()
258
-
259
- full_text = ""
260
- for chunk in streamer:
261
- chunk = _MARKER_RE.sub("", chunk)
262
- if chunk:
263
- full_text += chunk
264
- yield {"event": "token", "data": json.dumps({"token": chunk})}
265
-
266
- thread.join()
267
- full_text = " ".join(full_text.split()).strip()
268
- yield {"event": "done", "data": json.dumps({"text": full_text})}
269
- except Exception as e:
270
- logger.exception("Streaming transcription failed")
271
- yield {"event": "error", "data": json.dumps({"error": str(e)})}
272
-
273
- return EventSourceResponse(event_generator())
 
1
+ """
2
+ Slim GPU service for HF Inference Endpoints.
3
+ Exposes /diarize, /embed, /transcribe, and /transcribe/stream endpoints.
4
+ """
5
+
6
+ import io
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ import threading
12
+ from contextlib import asynccontextmanager
13
+
14
+ import numpy as np
15
+ import soundfile as sf
16
+ import librosa
17
+ import torch
18
+ from fastapi import FastAPI, File, Form, UploadFile
19
+ from fastapi.responses import JSONResponse
20
+ from pydub import AudioSegment
21
+ from sse_starlette.sse import EventSourceResponse
22
+
23
+ logger = logging.getLogger("gpu_service")
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Config
27
+ # ---------------------------------------------------------------------------
28
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
29
+ PYANNOTE_MODEL = "pyannote/speaker-diarization-community-1"
30
+ FUNASR_MODEL = "iic/speech_campplus_sv_zh-cn_16k-common"
31
+ PYANNOTE_MIN_SPEAKERS = int(os.environ.get("PYANNOTE_MIN_SPEAKERS", "1"))
32
+ PYANNOTE_MAX_SPEAKERS = int(os.environ.get("PYANNOTE_MAX_SPEAKERS", "10"))
33
+ TARGET_SR = 16000
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Singletons
37
+ # ---------------------------------------------------------------------------
38
+ _diarize_pipeline = None
39
+ _embed_model = None
40
+ _voxtral_model = None
41
+ _voxtral_processor = None
42
+
43
+ VOXTRAL_MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
44
+
45
+ # Markers to strip from Voxtral output
46
+ _MARKER_RE = re.compile(r"\[STREAMING_PAD\]|\[STREAMING_WORD\]")
47
+
48
+
49
+ def _load_diarize_pipeline():
50
+ global _diarize_pipeline
51
+ if _diarize_pipeline is None:
52
+ from pyannote.audio import Pipeline as PyannotePipeline
53
+
54
+ _diarize_pipeline = PyannotePipeline.from_pretrained(
55
+ PYANNOTE_MODEL, token=HF_TOKEN
56
+ )
57
+ _diarize_pipeline = _diarize_pipeline.to(torch.device("cuda"))
58
+ return _diarize_pipeline
59
+
60
+
61
+ def _load_embed_model():
62
+ global _embed_model
63
+ if _embed_model is None:
64
+ from funasr import AutoModel
65
+
66
+ _embed_model = AutoModel(model=FUNASR_MODEL)
67
+ return _embed_model
68
+
69
+
70
+ def _load_voxtral():
71
+ """Lazy-load Voxtral model and processor (first call only)."""
72
+ global _voxtral_model, _voxtral_processor
73
+ if _voxtral_model is None:
74
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
75
+
76
+ logger.info("Loading Voxtral model %s ...", VOXTRAL_MODEL_ID)
77
+ _voxtral_processor = AutoProcessor.from_pretrained(
78
+ VOXTRAL_MODEL_ID, trust_remote_code=True
79
+ )
80
+ _voxtral_model = AutoModelForSpeechSeq2Seq.from_pretrained(
81
+ VOXTRAL_MODEL_ID, torch_dtype=torch.float16, trust_remote_code=True
82
+ ).to("cuda")
83
+ logger.info("Voxtral model loaded.")
84
+ return _voxtral_model, _voxtral_processor
85
+
86
+
87
+ def _clean_voxtral_text(text: str) -> str:
88
+ """Strip Voxtral streaming markers and collapse whitespace."""
89
+ text = _MARKER_RE.sub("", text)
90
+ return " ".join(text.split()).strip()
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Audio helpers
95
+ # ---------------------------------------------------------------------------
96
+ def prepare_audio(raw_bytes: bytes) -> np.ndarray:
97
+ """Read any audio format -> float32 mono @ 16 kHz."""
98
+ audio, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32")
99
+ if audio.ndim > 1:
100
+ audio = audio.mean(axis=1)
101
+ if sr != TARGET_SR:
102
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR)
103
+ return audio
104
+
105
+
106
+ def prepare_audio_slice(raw_bytes: bytes, start_time: float, end_time: float) -> np.ndarray:
107
+ """Read audio, slice by time, return float32 mono @ 16 kHz."""
108
+ seg = AudioSegment.from_file(io.BytesIO(raw_bytes))
109
+ seg = seg[int(start_time * 1000):int(end_time * 1000)]
110
+ seg = seg.set_frame_rate(TARGET_SR).set_channels(1).set_sample_width(2)
111
+ return np.array(seg.get_array_of_samples(), dtype=np.float32) / 32768.0
112
+
113
+
114
+ # ---------------------------------------------------------------------------
115
+ # App
116
+ # ---------------------------------------------------------------------------
117
+ @asynccontextmanager
118
+ async def lifespan(app: FastAPI):
119
+ # Warm up diarization pipeline at startup (embedding model lazy-loads)
120
+ _load_diarize_pipeline()
121
+ yield
122
+
123
+
124
+ app = FastAPI(title="GPU Service (HF Endpoint)", lifespan=lifespan)
125
+
126
+
127
+ @app.get("/health")
128
+ async def health():
129
+ return {"status": "ok", "gpu_available": torch.cuda.is_available()}
130
+
131
+
132
+ @app.post("/diarize")
133
+ async def diarize(
134
+ audio: UploadFile = File(...),
135
+ min_speakers: int | None = Form(None),
136
+ max_speakers: int | None = Form(None),
137
+ ):
138
+ try:
139
+ raw = await audio.read()
140
+ audio_16k = prepare_audio(raw)
141
+
142
+ pipeline = _load_diarize_pipeline()
143
+ waveform = torch.from_numpy(audio_16k).unsqueeze(0).float()
144
+ input_data = {"waveform": waveform, "sample_rate": TARGET_SR}
145
+
146
+ result = pipeline(
147
+ input_data,
148
+ min_speakers=min_speakers or PYANNOTE_MIN_SPEAKERS,
149
+ max_speakers=max_speakers or PYANNOTE_MAX_SPEAKERS,
150
+ )
151
+ # pyannote v4 compat
152
+ diarization = getattr(result, "speaker_diarization", result)
153
+
154
+ segments = []
155
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
156
+ segments.append(
157
+ {
158
+ "speaker": speaker,
159
+ "start": round(turn.start, 3),
160
+ "end": round(turn.end, 3),
161
+ "duration": round(turn.end - turn.start, 3),
162
+ }
163
+ )
164
+ segments.sort(key=lambda s: s["start"])
165
+ return {"segments": segments}
166
+ except Exception as e:
167
+ return JSONResponse(status_code=500, content={"error": str(e)})
168
+
169
+
170
+ @app.post("/embed")
171
+ async def embed(
172
+ audio: UploadFile = File(...),
173
+ start_time: float | None = Form(None),
174
+ end_time: float | None = Form(None),
175
+ ):
176
+ try:
177
+ raw = await audio.read()
178
+ if start_time is not None and end_time is not None:
179
+ audio_16k = prepare_audio_slice(raw, start_time, end_time)
180
+ else:
181
+ audio_16k = prepare_audio(raw)
182
+
183
+ model = _load_embed_model()
184
+ result = model.generate(input=audio_16k, output_dir=None)
185
+ raw_emb = result[0]["spk_embedding"]
186
+ if hasattr(raw_emb, "cpu"):
187
+ raw_emb = raw_emb.cpu().numpy()
188
+ emb = np.array(raw_emb).flatten()
189
+
190
+ # L2-normalize
191
+ norm = np.linalg.norm(emb)
192
+ if norm > 0:
193
+ emb = emb / norm
194
+
195
+ return {"embedding": emb.tolist(), "dim": len(emb)}
196
+ except Exception as e:
197
+ return JSONResponse(status_code=500, content={"error": str(e)})
198
+
199
+
200
+ @app.post("/transcribe")
201
+ async def transcribe(
202
+ audio: UploadFile = File(...),
203
+ prompt: str = Form("Transcribe this audio."),
204
+ ):
205
+ try:
206
+ raw = await audio.read()
207
+ audio_16k = prepare_audio(raw)
208
+
209
+ model, processor = _load_voxtral()
210
+ inputs = processor(
211
+ audios=audio_16k,
212
+ sampling_rate=TARGET_SR,
213
+ text=prompt,
214
+ return_tensors="pt",
215
+ ).to("cuda")
216
+
217
+ output_ids = model.generate(**inputs, max_new_tokens=1024)
218
+ text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
219
+ text = _clean_voxtral_text(text)
220
+
221
+ return {"text": text}
222
+ except Exception as e:
223
+ logger.exception("Transcription failed")
224
+ return JSONResponse(status_code=500, content={"error": str(e)})
225
+
226
+
227
+ @app.post("/transcribe/stream")
228
+ async def transcribe_stream(
229
+ audio: UploadFile = File(...),
230
+ prompt: str = Form("Transcribe this audio."),
231
+ ):
232
+ try:
233
+ raw = await audio.read()
234
+ audio_16k = prepare_audio(raw)
235
+ except Exception as e:
236
+ logger.exception("Audio preparation failed")
237
+ return JSONResponse(status_code=500, content={"error": str(e)})
238
+
239
+ async def event_generator():
240
+ try:
241
+ from transformers import TextIteratorStreamer
242
+
243
+ model, processor = _load_voxtral()
244
+ inputs = processor(
245
+ audios=audio_16k,
246
+ sampling_rate=TARGET_SR,
247
+ text=prompt,
248
+ return_tensors="pt",
249
+ ).to("cuda")
250
+
251
+ streamer = TextIteratorStreamer(
252
+ processor.tokenizer, skip_prompt=True, skip_special_tokens=True
253
+ )
254
+ gen_kwargs = {**inputs, "max_new_tokens": 1024, "streamer": streamer}
255
+
256
+ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
257
+ thread.start()
258
+
259
+ full_text = ""
260
+ for chunk in streamer:
261
+ chunk = _MARKER_RE.sub("", chunk)
262
+ if chunk:
263
+ full_text += chunk
264
+ yield {"event": "token", "data": json.dumps({"token": chunk})}
265
+
266
+ thread.join()
267
+ full_text = " ".join(full_text.split()).strip()
268
+ yield {"event": "done", "data": json.dumps({"text": full_text})}
269
+ except Exception as e:
270
+ logger.exception("Streaming transcription failed")
271
+ yield {"event": "error", "data": json.dumps({"error": str(e)})}
272
+
273
+ return EventSourceResponse(event_generator())
requirements.txt CHANGED
@@ -1,12 +1,13 @@
1
- fastapi>=0.115.0
2
- uvicorn[standard]>=0.30.0
3
- numpy>=1.26.0
4
- soundfile>=0.12.0
5
- librosa>=0.10.0
6
- pyannote.audio>=3.3.0
7
- funasr>=1.3.0
8
- python-multipart>=0.0.9
9
- pydub>=0.25.0
10
- transformers>=4.45.0
11
- accelerate>=0.34.0
12
- sse-starlette>=1.0.0
 
 
1
+ fastapi>=0.115.0
2
+ uvicorn[standard]>=0.30.0
3
+ numpy>=1.26.0
4
+ soundfile>=0.12.0
5
+ librosa>=0.10.0
6
+ pyannote.audio>=3.3.0
7
+ funasr>=1.3.0
8
+ python-multipart>=0.0.9
9
+ pydub>=0.25.0
10
+ transformers>=4.45.0
11
+ accelerate>=0.34.0
12
+ sse-starlette>=1.0.0
13
+ torchvision>=0.19.0