Joyboy-dy commited on
Commit
7bfdd1b
·
1 Parent(s): 6b8c228

Update server.py and requirements.txt

Browse files
Files changed (2) hide show
  1. requirements.txt +0 -1
  2. server.py +102 -3
requirements.txt CHANGED
@@ -3,4 +3,3 @@ uvicorn[standard]
3
  python-multipart
4
  whisperx
5
  torch
6
- ffmpeg
 
3
  python-multipart
4
  whisperx
5
  torch
 
server.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
 
2
  import shutil
3
  import tempfile
4
  from contextlib import asynccontextmanager
5
  from pathlib import Path
6
 
7
  import whisperx
8
- from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import FileResponse
11
 
@@ -131,14 +132,106 @@ def _write_srt_file(segments: list[dict], file_obj) -> None:
131
  index += 1
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  @app.post("/align")
135
- async def align_audio(background_tasks: BackgroundTasks, audio_file: UploadFile = File(...)):
 
 
 
 
136
  if model is None or model_a is None or metadata is None:
137
  raise HTTPException(status_code=503, detail="WhisperX models are not ready")
138
 
139
  temp_dir = tempfile.mkdtemp(prefix="lyric-sync-")
140
 
141
  try:
 
 
 
142
  source_name = audio_file.filename or "audio"
143
  audio_path = os.path.join(temp_dir, source_name)
144
  with open(audio_path, "wb") as f:
@@ -147,9 +240,15 @@ async def align_audio(background_tasks: BackgroundTasks, audio_file: UploadFile
147
  result = _transcribe_with_compat(model, audio_path)
148
  result = _align_with_compat(result["segments"], audio_path)
149
 
 
 
 
 
 
 
150
  srt_path = os.path.join(temp_dir, f"{Path(source_name).stem}.srt")
151
  with open(srt_path, "w", encoding="utf-8") as srt_file:
152
- _write_srt_file(result["segments"], srt_file)
153
 
154
  background_tasks.add_task(_cleanup_temp_dir, temp_dir)
155
 
 
1
  import os
2
+ import re
3
  import shutil
4
  import tempfile
5
  from contextlib import asynccontextmanager
6
  from pathlib import Path
7
 
8
  import whisperx
9
+ from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.responses import FileResponse
12
 
 
132
  index += 1
133
 
134
 
135
+ _STRONG_PUNCT_RE = re.compile(r"[.!?]+$")
136
+ _SOFT_PUNCT_RE = re.compile(r"[,;:]+$")
137
+
138
+
139
+ def _cleanup_spacing(text: str) -> str:
140
+ text = re.sub(r"\s+([,.;:!?])", r"\1", text)
141
+ text = re.sub(r"\(\s+", "(", text)
142
+ text = re.sub(r"\s+\)", ")", text)
143
+ return text.strip()
144
+
145
+
146
+ def _extract_word_segments(aligned_segments: list[dict]) -> list[dict]:
147
+ words: list[dict] = []
148
+ for segment in aligned_segments:
149
+ for word in segment.get("words") or []:
150
+ token = (word.get("word") or word.get("text") or "").strip()
151
+ start = word.get("start")
152
+ end = word.get("end")
153
+ if not token or start is None or end is None:
154
+ continue
155
+ entry = {"word": token, "start": float(start), "end": float(end)}
156
+ score = word.get("score")
157
+ if score is None:
158
+ score = word.get("probability")
159
+ if score is not None:
160
+ entry["score"] = float(score)
161
+ words.append(entry)
162
+
163
+ words.sort(key=lambda w: (w["start"], w["end"]))
164
+ return words
165
+
166
+
167
+ def _paragraph_segments_from_aligned(aligned_segments: list[dict]) -> list[dict]:
168
+ segments: list[dict] = []
169
+ for seg in aligned_segments:
170
+ text = _cleanup_spacing((seg.get("text") or "").strip())
171
+ words = [w for w in (seg.get("words") or []) if w.get("start") is not None and w.get("end") is not None]
172
+ if not text or not words:
173
+ continue
174
+ start = float(words[0]["start"])
175
+ end = float(words[-1]["end"])
176
+ segments.append({"start": start, "end": end, "text": text})
177
+ return segments
178
+
179
+
180
+ def _sentence_segments_from_words(word_segments: list[dict], max_words: int = 8, gap_s: float = 0.4) -> list[dict]:
181
+ segments: list[dict] = []
182
+ current: list[dict] = []
183
+
184
+ def flush() -> None:
185
+ nonlocal current
186
+ if not current:
187
+ return
188
+ text = _cleanup_spacing(" ".join(w["word"] for w in current))
189
+ if text:
190
+ segments.append({"start": current[0]["start"], "end": current[-1]["end"], "text": text})
191
+ current = []
192
+
193
+ for word in word_segments:
194
+ if current:
195
+ gap = float(word["start"]) - float(current[-1]["end"])
196
+ if gap > gap_s:
197
+ flush()
198
+
199
+ current.append(word)
200
+
201
+ token = word["word"]
202
+ is_strong = bool(_STRONG_PUNCT_RE.search(token))
203
+ is_soft = bool(_SOFT_PUNCT_RE.search(token))
204
+
205
+ if is_strong:
206
+ flush()
207
+ continue
208
+
209
+ if len(current) >= max_words:
210
+ flush()
211
+ continue
212
+
213
+ if is_soft and len(current) >= 4:
214
+ flush()
215
+
216
+ flush()
217
+ return segments
218
+
219
+
220
  @app.post("/align")
221
+ async def align_audio(
222
+ background_tasks: BackgroundTasks,
223
+ audio_file: UploadFile = File(...),
224
+ srt_mode: str = Form("paragraph"),
225
+ ):
226
  if model is None or model_a is None or metadata is None:
227
  raise HTTPException(status_code=503, detail="WhisperX models are not ready")
228
 
229
  temp_dir = tempfile.mkdtemp(prefix="lyric-sync-")
230
 
231
  try:
232
+ if srt_mode not in ("paragraph", "sentence"):
233
+ raise HTTPException(status_code=400, detail="Invalid srt_mode (expected 'paragraph' or 'sentence')")
234
+
235
  source_name = audio_file.filename or "audio"
236
  audio_path = os.path.join(temp_dir, source_name)
237
  with open(audio_path, "wb") as f:
 
240
  result = _transcribe_with_compat(model, audio_path)
241
  result = _align_with_compat(result["segments"], audio_path)
242
 
243
+ word_segments = _extract_word_segments(result["segments"])
244
+ if srt_mode == "sentence":
245
+ srt_segments = _sentence_segments_from_words(word_segments)
246
+ else:
247
+ srt_segments = _paragraph_segments_from_aligned(result["segments"])
248
+
249
  srt_path = os.path.join(temp_dir, f"{Path(source_name).stem}.srt")
250
  with open(srt_path, "w", encoding="utf-8") as srt_file:
251
+ _write_srt_file(srt_segments, srt_file)
252
 
253
  background_tasks.add_task(_cleanup_temp_dir, temp_dir)
254