tantk commited on
Commit
5b2eaaf
·
verified ·
1 Parent(s): e472334

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +5 -112
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Slim GPU service for HF Inference Endpoints.
3
- Exposes /diarize, /embed, /transcribe, and /transcribe/stream endpoints.
4
  """
5
 
6
  import io
@@ -14,9 +14,8 @@ import numpy as np
14
  import soundfile as sf
15
  import librosa
16
  import torch
17
- from fastapi import FastAPI, File, Form, UploadFile
18
  from fastapi.responses import JSONResponse
19
- from pydub import AudioSegment
20
  from sse_starlette.sse import EventSourceResponse
21
 
22
  from voxtral_inference import VoxtralModel
@@ -26,44 +25,15 @@ logger = logging.getLogger("gpu_service")
26
  # ---------------------------------------------------------------------------
27
  # Config
28
  # ---------------------------------------------------------------------------
29
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
30
- PYANNOTE_MODEL = "pyannote/speaker-diarization-community-1"
31
- FUNASR_MODEL = "iic/speech_campplus_sv_zh-cn_16k-common"
32
- PYANNOTE_MIN_SPEAKERS = int(os.environ.get("PYANNOTE_MIN_SPEAKERS", "1"))
33
- PYANNOTE_MAX_SPEAKERS = int(os.environ.get("PYANNOTE_MAX_SPEAKERS", "10"))
34
  TARGET_SR = 16000
35
-
36
  MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model")
37
 
38
  # ---------------------------------------------------------------------------
39
- # Singletons
40
  # ---------------------------------------------------------------------------
41
- _diarize_pipeline = None
42
- _embed_model = None
43
  _voxtral: VoxtralModel | None = None
44
 
45
 
46
- def _load_diarize_pipeline():
47
- global _diarize_pipeline
48
- if _diarize_pipeline is None:
49
- from pyannote.audio import Pipeline as PyannotePipeline
50
-
51
- _diarize_pipeline = PyannotePipeline.from_pretrained(
52
- PYANNOTE_MODEL, token=HF_TOKEN
53
- )
54
- _diarize_pipeline = _diarize_pipeline.to(torch.device("cuda"))
55
- return _diarize_pipeline
56
-
57
-
58
- def _load_embed_model():
59
- global _embed_model
60
- if _embed_model is None:
61
- from funasr import AutoModel
62
-
63
- _embed_model = AutoModel(model=FUNASR_MODEL)
64
- return _embed_model
65
-
66
-
67
  def _load_voxtral():
68
  global _voxtral
69
  if _voxtral is None:
@@ -86,25 +56,16 @@ def prepare_audio(raw_bytes: bytes) -> np.ndarray:
86
  return audio
87
 
88
 
89
- def prepare_audio_slice(raw_bytes: bytes, start_time: float, end_time: float) -> np.ndarray:
90
- """Read audio, slice by time, return float32 mono @ 16 kHz."""
91
- seg = AudioSegment.from_file(io.BytesIO(raw_bytes))
92
- seg = seg[int(start_time * 1000):int(end_time * 1000)]
93
- seg = seg.set_frame_rate(TARGET_SR).set_channels(1).set_sample_width(2)
94
- return np.array(seg.get_array_of_samples(), dtype=np.float32) / 32768.0
95
-
96
-
97
  # ---------------------------------------------------------------------------
98
  # App
99
  # ---------------------------------------------------------------------------
100
  @asynccontextmanager
101
  async def lifespan(app: FastAPI):
102
- # Warm up diarization pipeline at startup (embedding model lazy-loads)
103
- _load_diarize_pipeline()
104
  yield
105
 
106
 
107
- app = FastAPI(title="GPU Service (HF Endpoint)", lifespan=lifespan)
108
 
109
 
110
  @app.get("/health")
@@ -112,74 +73,6 @@ async def health():
112
  return {"status": "ok", "gpu_available": torch.cuda.is_available()}
113
 
114
 
115
- @app.post("/diarize")
116
- async def diarize(
117
- audio: UploadFile = File(...),
118
- min_speakers: int | None = Form(None),
119
- max_speakers: int | None = Form(None),
120
- ):
121
- try:
122
- raw = await audio.read()
123
- audio_16k = prepare_audio(raw)
124
-
125
- pipeline = _load_diarize_pipeline()
126
- waveform = torch.from_numpy(audio_16k).unsqueeze(0).float()
127
- input_data = {"waveform": waveform, "sample_rate": TARGET_SR}
128
-
129
- result = pipeline(
130
- input_data,
131
- min_speakers=min_speakers or PYANNOTE_MIN_SPEAKERS,
132
- max_speakers=max_speakers or PYANNOTE_MAX_SPEAKERS,
133
- )
134
- # pyannote v4 compat
135
- diarization = getattr(result, "speaker_diarization", result)
136
-
137
- segments = []
138
- for turn, _, speaker in diarization.itertracks(yield_label=True):
139
- segments.append(
140
- {
141
- "speaker": speaker,
142
- "start": round(turn.start, 3),
143
- "end": round(turn.end, 3),
144
- "duration": round(turn.end - turn.start, 3),
145
- }
146
- )
147
- segments.sort(key=lambda s: s["start"])
148
- return {"segments": segments}
149
- except Exception as e:
150
- return JSONResponse(status_code=500, content={"error": str(e)})
151
-
152
-
153
- @app.post("/embed")
154
- async def embed(
155
- audio: UploadFile = File(...),
156
- start_time: float | None = Form(None),
157
- end_time: float | None = Form(None),
158
- ):
159
- try:
160
- raw = await audio.read()
161
- if start_time is not None and end_time is not None:
162
- audio_16k = prepare_audio_slice(raw, start_time, end_time)
163
- else:
164
- audio_16k = prepare_audio(raw)
165
-
166
- model = _load_embed_model()
167
- result = model.generate(input=audio_16k, output_dir=None)
168
- raw_emb = result[0]["spk_embedding"]
169
- if hasattr(raw_emb, "cpu"):
170
- raw_emb = raw_emb.cpu().numpy()
171
- emb = np.array(raw_emb).flatten()
172
-
173
- # L2-normalize
174
- norm = np.linalg.norm(emb)
175
- if norm > 0:
176
- emb = emb / norm
177
-
178
- return {"embedding": emb.tolist(), "dim": len(emb)}
179
- except Exception as e:
180
- return JSONResponse(status_code=500, content={"error": str(e)})
181
-
182
-
183
  @app.post("/transcribe")
184
  async def transcribe(audio: UploadFile = File(...)):
185
  try:
 
1
  """
2
  Slim GPU service for HF Inference Endpoints.
3
+ Exposes /transcribe and /transcribe/stream using Voxtral 4B.
4
  """
5
 
6
  import io
 
14
  import soundfile as sf
15
  import librosa
16
  import torch
17
+ from fastapi import FastAPI, File, UploadFile
18
  from fastapi.responses import JSONResponse
 
19
  from sse_starlette.sse import EventSourceResponse
20
 
21
  from voxtral_inference import VoxtralModel
 
25
  # ---------------------------------------------------------------------------
26
  # Config
27
  # ---------------------------------------------------------------------------
 
 
 
 
 
28
  TARGET_SR = 16000
 
29
  MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model")
30
 
31
  # ---------------------------------------------------------------------------
32
+ # Singleton
33
  # ---------------------------------------------------------------------------
 
 
34
  _voxtral: VoxtralModel | None = None
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def _load_voxtral():
38
  global _voxtral
39
  if _voxtral is None:
 
56
  return audio
57
 
58
 
 
 
 
 
 
 
 
 
59
  # ---------------------------------------------------------------------------
60
  # App
61
  # ---------------------------------------------------------------------------
62
  @asynccontextmanager
63
  async def lifespan(app: FastAPI):
64
+ _load_voxtral()
 
65
  yield
66
 
67
 
68
+ app = FastAPI(title="Voxtral Transcription Service (HF Endpoint)", lifespan=lifespan)
69
 
70
 
71
  @app.get("/health")
 
73
  return {"status": "ok", "gpu_available": torch.cuda.is_available()}
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  @app.post("/transcribe")
77
  async def transcribe(audio: UploadFile = File(...)):
78
  try: