lucamartinelli commited on
Commit
dd5bcef
·
1 Parent(s): a7c3098
.gitattributes CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.vtt
2
+ *.mp3
3
+ *.wav
4
+ .venv
5
+ .env
6
+ tmp
7
+ __pycache__
README.md CHANGED
@@ -7,6 +7,7 @@ sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.13
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
main.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Whisper + Pyannote Transcription & Diarization Web Interface."""
2
+ import logging
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+
8
+ from src.audio_processor import AudioProcessor
9
+ from src.speaker_manager import SpeakerManager
10
+ from src.vtt_utils import clean_vtt, validate_vtt
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ def process_audio(
16
+ audio_path: str,
17
+ openai_api_key: str,
18
+ hf_api_key: str,
19
+ transcription_model: str,
20
+ pyannote_model: str,
21
+ openai_whisper_prompt: str,
22
+ openai_whisper_language: str | None,
23
+ progress=gr.Progress()
24
+ ):
25
+ """
26
+ Process audio file with diarization and transcription.
27
+
28
+ Returns:
29
+ Tuple of (vtt_content, transcripts, audio_filename)
30
+ """
31
+ if not audio_path:
32
+ return "", [], ""
33
+
34
+ processor = AudioProcessor(
35
+ openai_api_key=openai_api_key,
36
+ hf_api_key=hf_api_key,
37
+ transcription_model=transcription_model,
38
+ pyannote_model=pyannote_model,
39
+ whisper_prompt=openai_whisper_prompt,
40
+ whisper_language=openai_whisper_language
41
+ )
42
+
43
+ return processor.process(
44
+ audio_path=audio_path,
45
+ progress_callback=lambda p, desc: progress(p, desc=desc)
46
+ )
47
+
48
+
49
+ def rename_speaker_in_vtt(vtt_content: str, transcripts_state, old_speaker: str, new_speaker: str):
50
+ """Rename speaker and regenerate VTT."""
51
+ if not vtt_content or not transcripts_state:
52
+ return vtt_content
53
+
54
+ return SpeakerManager.rename_speaker(transcripts_state, old_speaker, new_speaker)
55
+
56
+
57
+ def prepare_download(vtt_content: str, audio_filename: str) -> str | None:
58
+ """
59
+ Prepare VTT file for download.
60
+
61
+ Args:
62
+ vtt_content: VTT content as string
63
+ audio_filename: Base filename for the audio
64
+
65
+ Returns:
66
+ Path to temporary VTT file, or None if inputs are invalid
67
+ """
68
+ if not vtt_content or not audio_filename:
69
+ return None
70
+
71
+ download_path = Path(tempfile.gettempdir()) / f"{audio_filename}.vtt"
72
+
73
+ with open(download_path, 'w', encoding='utf-8') as f:
74
+ f.write(vtt_content)
75
+
76
+ return str(download_path)
77
+
78
+
79
+ with gr.Blocks(title="Transcription & Diarization") as app:
80
+
81
+ gr.Markdown("""
82
+ # 🎙️ Transcription & Diarization
83
+ Fill the required settings, upload an audio file, and start the transcription using Whisper and Pyannote!
84
+ """)
85
+
86
+ transcripts_state = gr.State([])
87
+ audio_filename_state = gr.State("")
88
+
89
+ with gr.Row():
90
+ with gr.Column():
91
+ with gr.Accordion("⚙️ Settings", open=True):
92
+ openapi_api_key = gr.Textbox(label="OpenAI API key", type="password")
93
+ hf_api_key = gr.Textbox(label="Hugging Face API key", type="password")
94
+
95
+ with gr.Accordion("⚙️ Additional settings", open=False):
96
+ transcription_model = gr.Dropdown(
97
+ label="Transcription model",
98
+ choices=[("Whisper", "whisper-1")],
99
+ value="whisper-1"
100
+ )
101
+ pyannote_model = gr.Dropdown(
102
+ label="Pyannote model",
103
+ choices=[("Speaker diarization community 1", "pyannote/speaker-diarization-community-1")],
104
+ value="pyannote/speaker-diarization-community-1"
105
+ )
106
+
107
+ openai_whisper_prompt = gr.Textbox(label="Additional whisper prompt", value="")
108
+ openai_whisper_language = gr.Dropdown(
109
+ label="Whisper language",
110
+ choices=[
111
+ ("Default (Auto-detect)", None),
112
+ ("🇮🇹 Italian", "it"),
113
+ ("🇩🇪 German", "de"),
114
+ ("🇬🇧 English", "en"),
115
+ ("🇪🇸 Spanish", "es"),
116
+ ("🇫🇷 French", "fr"),
117
+ ],
118
+ value=None
119
+ )
120
+
121
+ audio_input = gr.Audio(type="filepath", label="Upload audio")
122
+ submit_btn = gr.Button("Transcript", variant="primary", interactive=False)
123
+
124
+ with gr.Column():
125
+ with gr.Group():
126
+ output_vtt = gr.Textbox(
127
+ label="Transcription",
128
+ lines=20,
129
+ placeholder="Your transcription will appear here...",
130
+ buttons=["copy"],
131
+ container=False,
132
+ )
133
+
134
+ validation_status = gr.Markdown("⚪ No content", container=True)
135
+
136
+ with gr.Row():
137
+ clean_btn = gr.Button("Clean & improve VTT", variant="secondary", interactive=False)
138
+ download_file = gr.File(label="Download VTT", visible=False)
139
+ download_btn = gr.Button("Download VTT", variant="secondary", interactive=False)
140
+
141
+ with gr.Accordion("🎭 Rename speakers", open=False):
142
+ with gr.Row():
143
+ old_speaker_name = gr.Textbox(label="Current speaker name (e.g., SPEAKER_00)", placeholder="SPEAKER_00", value="SPEAKER_00")
144
+ new_speaker_name = gr.Textbox(label="New speaker name", placeholder="Davide")
145
+
146
+ rename_btn = gr.Button("Rename")
147
+
148
+ def check_inputs(openai_key: str, hf_key: str, audio) -> gr.Button:
149
+ """
150
+ Enable submit button only if both API keys and audio are provided.
151
+
152
+ Args:
153
+ openai_key: OpenAI API key
154
+ hf_key: Hugging Face API key
155
+ audio: Audio file path
156
+
157
+ Returns:
158
+ Button component with updated interactive state
159
+ """
160
+ is_ready = bool(openai_key and hf_key and audio)
161
+ return gr.Button(interactive=is_ready)
162
+
163
+ def update_validation(vtt_content: str):
164
+ """
165
+ Update validation status and button states when VTT content changes.
166
+
167
+ Args:
168
+ vtt_content: VTT content to validate
169
+
170
+ Returns:
171
+ Tuple of (status_message, clean_button, download_button)
172
+ """
173
+ status, status_type = validate_vtt(vtt_content)
174
+
175
+ # Enable buttons only if VTT is valid
176
+ is_valid = status_type == "success"
177
+
178
+ return (
179
+ status,
180
+ gr.Button(interactive=is_valid), # clean_btn
181
+ gr.Button(interactive=is_valid) # download_btn
182
+ )
183
+
184
+ # Enable/disable submit button based on API keys and audio input
185
+ openapi_api_key.change(
186
+ fn=check_inputs,
187
+ inputs=[openapi_api_key, hf_api_key, audio_input],
188
+ outputs=submit_btn
189
+ )
190
+ hf_api_key.change(
191
+ fn=check_inputs,
192
+ inputs=[openapi_api_key, hf_api_key, audio_input],
193
+ outputs=submit_btn
194
+ )
195
+ audio_input.change(
196
+ fn=check_inputs,
197
+ inputs=[openapi_api_key, hf_api_key, audio_input],
198
+ outputs=submit_btn
199
+ )
200
+
201
+ # Main transcription process
202
+ submit_btn.click(
203
+ fn=process_audio,
204
+ inputs=[
205
+ audio_input,
206
+ openapi_api_key,
207
+ hf_api_key,
208
+ transcription_model,
209
+ pyannote_model,
210
+ openai_whisper_prompt,
211
+ openai_whisper_language
212
+ ],
213
+ outputs=[output_vtt, transcripts_state, audio_filename_state],
214
+ )
215
+
216
+ # Real-time VTT validation and button state management
217
+ output_vtt.change(
218
+ fn=update_validation,
219
+ inputs=[output_vtt],
220
+ outputs=[validation_status, clean_btn, download_btn]
221
+ )
222
+
223
+ # VTT cleaning and improvement
224
+ clean_btn.click(
225
+ fn=clean_vtt,
226
+ inputs=[output_vtt],
227
+ outputs=[output_vtt]
228
+ )
229
+
230
+ # VTT file download
231
+ download_btn.click(
232
+ fn=prepare_download,
233
+ inputs=[output_vtt, audio_filename_state],
234
+ outputs=[download_file]
235
+ )
236
+
237
+ # Speaker renaming
238
+ rename_btn.click(
239
+ fn=rename_speaker_in_vtt,
240
+ inputs=[output_vtt, transcripts_state, old_speaker_name, new_speaker_name],
241
+ outputs=output_vtt
242
+ )
243
+
244
+ if __name__ == "__main__":
245
+ app.launch()
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "whisper-diarization"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = [
6
+ {name = "Luca Martinelli",email = "martinelliluca98@gmail.com"}
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">=3.11,<3.14"
10
+ dependencies = [
11
+ "openai (>=2.8.1,<3.0.0)",
12
+ "pydantic (>=2.12.4,<3.0.0)",
13
+ "pydub (>=0.25.1,<0.26.0)",
14
+ "pyannote-audio (>=4.0.2,<5.0.0)",
15
+ "audioop-lts (>=0.2.2,<0.3.0)",
16
+ "pydantic-settings (>=2.12.0,<3.0.0)",
17
+ "webvtt-py (>=0.5.1,<0.6.0)",
18
+ "numpy (>=2.2.2)",
19
+ "huggingface-hub (<1.0.0)",
20
+ "scipy (>=1.14.0)",
21
+ "gradio (>=6.0.0,<7.0.0)"
22
+ ]
23
+
24
+
25
+ [build-system]
26
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
27
+ build-backend = "poetry.core.masonry.api"
28
+
29
+ [tool.poetry]
30
+ package-mode = false
31
+
32
+ [tool.poetry.dependencies]
33
+ audioop-lts = { version=">=0.2.2,<0.3.0", python = ">=3.13" }
src/__init__.py ADDED
File without changes
src/audio_processor.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio processing and transcription logic."""
2
+ import logging
3
+ import shutil
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Callable, List, Tuple
7
+
8
+ from src.diarization import get_pipeline
9
+ from src.vtt import create_vtt
10
+ from src.whisper import TranscriptSegment, get_transcripts
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AudioProcessor:
16
+ """Handles audio processing, diarization, and transcription."""
17
+
18
+ def __init__(
19
+ self,
20
+ openai_api_key: str,
21
+ hf_api_key: str,
22
+ transcription_model: str,
23
+ pyannote_model: str,
24
+ whisper_prompt: str = "",
25
+ whisper_language: str | None = None
26
+ ):
27
+ """
28
+ Initialize AudioProcessor.
29
+
30
+ Args:
31
+ openai_api_key: OpenAI API key for Whisper
32
+ hf_api_key: Hugging Face API key for Pyannote
33
+ transcription_model: Model name for transcription
34
+ pyannote_model: Model name for diarization
35
+ whisper_prompt: Optional prompt for Whisper
36
+ whisper_language: Optional language code for Whisper
37
+ """
38
+ self.openai_api_key = openai_api_key
39
+ self.hf_api_key = hf_api_key
40
+ self.transcription_model = transcription_model
41
+ self.pyannote_model = pyannote_model
42
+ self.whisper_prompt = whisper_prompt
43
+ self.whisper_language = whisper_language
44
+
45
+ def process(
46
+ self,
47
+ audio_path: str | Path,
48
+ progress_callback: Callable[[float, str], None] | None = None
49
+ ) -> Tuple[str, List[TranscriptSegment], str]:
50
+ """
51
+ Process audio file: diarization + transcription.
52
+
53
+ Args:
54
+ audio_path: Path to audio file
55
+ progress_callback: Optional callback for progress updates (progress, description)
56
+
57
+ Returns:
58
+ Tuple of (vtt_content, transcripts, audio_filename)
59
+ """
60
+ if not audio_path:
61
+ return "", [], ""
62
+
63
+ audio_path = Path(audio_path).absolute()
64
+ tmp_dir = Path(tempfile.mkdtemp(prefix="whisper_diarization_"))
65
+ logger.info(f"📁 Created temporary directory: {tmp_dir}")
66
+
67
+ try:
68
+ # Step 1: Diarization
69
+ if progress_callback:
70
+ progress_callback(0, "Loading diarization model...")
71
+ logger.info("🔄 Starting diarization process")
72
+
73
+ audio_segment, diarization = get_pipeline(
74
+ audio_path,
75
+ self.hf_api_key,
76
+ self.pyannote_model,
77
+ tmp_dir
78
+ )
79
+
80
+ if progress_callback:
81
+ progress_callback(0.3, "Diarization complete. Starting transcription...")
82
+ logger.info("✅ Diarization complete")
83
+
84
+ # Step 2: Transcription
85
+ total_segments = sum(1 for _ in diarization.speaker_diarization.itertracks())
86
+ logger.info(f"📊 Found {total_segments} segments to transcribe")
87
+
88
+ def transcription_progress(i: int, total: int):
89
+ if progress_callback:
90
+ progress_callback(
91
+ 0.3 + (0.6 * i / total),
92
+ f"Transcribing segment {i}/{total}..."
93
+ )
94
+
95
+ transcripts = get_transcripts(
96
+ diarization,
97
+ audio_segment,
98
+ self.openai_api_key,
99
+ self.transcription_model,
100
+ self.whisper_prompt,
101
+ self.whisper_language,
102
+ tmp_dir,
103
+ progress_callback=transcription_progress
104
+ )
105
+
106
+ # Step 3: Create VTT
107
+ if progress_callback:
108
+ progress_callback(0.9, "Creating VTT file...")
109
+ logger.info("📝 Creating VTT file")
110
+
111
+ vtt = create_vtt(transcripts)
112
+
113
+ if progress_callback:
114
+ progress_callback(1.0, "Complete!")
115
+ logger.info("✅ Process complete")
116
+
117
+ audio_filename = audio_path.stem
118
+ return vtt.content, transcripts, audio_filename
119
+
120
+ finally:
121
+ # Cleanup
122
+ if progress_callback:
123
+ progress_callback(0.95, "Cleaning up temporary files...")
124
+ logger.info("🧹 Cleaning up")
125
+
126
+ if tmp_dir.exists():
127
+ shutil.rmtree(tmp_dir)
128
+ logger.info(f"🗑️ Removed temporary directory: {tmp_dir}")
src/diarization.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from pyannote.audio import Pipeline
6
+ from pydub import AudioSegment
7
+
8
+
9
+ def get_pipeline(filename: str | Path, hf_api_key: str, pyannote_model: str, tmp_dir: Path) -> Tuple[AudioSegment, Pipeline]:
10
+ pipeline = Pipeline.from_pretrained(
11
+ pyannote_model,
12
+ token=hf_api_key,
13
+ )
14
+ pipeline.to(torch.device("cuda"))
15
+
16
+ audio_segment = AudioSegment.from_mp3(filename)
17
+ wav_audio = tmp_dir.joinpath(Path(filename).name).with_suffix(".wav")
18
+
19
+ with open(wav_audio, "wb"):
20
+ audio_segment.export(wav_audio, format="wav")
21
+
22
+ return (audio_segment, pipeline(wav_audio))
src/speaker_manager.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speaker management utilities."""
2
+ from typing import List
3
+
4
+ from src.vtt import create_vtt
5
+ from src.whisper import TranscriptSegment
6
+
7
+
8
+ class SpeakerManager:
9
+ """Manages speaker renaming operations."""
10
+
11
+ @staticmethod
12
+ def rename_speaker(
13
+ transcripts: List[TranscriptSegment],
14
+ old_speaker: str,
15
+ new_speaker: str
16
+ ) -> str:
17
+ """
18
+ Rename a speaker in transcripts and return updated VTT.
19
+
20
+ Args:
21
+ transcripts: List of transcript segments
22
+ old_speaker: Current speaker name
23
+ new_speaker: New speaker name
24
+
25
+ Returns:
26
+ Updated VTT content as string
27
+ """
28
+ if not transcripts:
29
+ return ""
30
+
31
+ # Update speaker names in place
32
+ for transcript in transcripts:
33
+ if transcript.speaker == old_speaker:
34
+ transcript.speaker = new_speaker
35
+
36
+ # Regenerate VTT with updated speakers
37
+ vtt = create_vtt(transcripts)
38
+ return vtt.content
src/vtt.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from webvtt import Caption, WebVTT
4
+
5
+ from src.whisper import TranscriptSegment
6
+
7
+
8
+ def format_milliseconds(milliseconds):
9
+ seconds, milliseconds = divmod(milliseconds, 1000)
10
+ minutes, seconds = divmod(seconds, 60)
11
+ hours, minutes = divmod(minutes, 60)
12
+
13
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
14
+
15
+
16
+ def create_vtt(transcripts: List[TranscriptSegment]) -> WebVTT:
17
+ vtt = WebVTT()
18
+
19
+ for transcript in transcripts:
20
+ for x in transcript.transcript.segments:
21
+ start = transcript.start + x.start * 1000
22
+ end = transcript.start + x.end * 1000
23
+
24
+ caption = Caption(
25
+ format_milliseconds(start),
26
+ format_milliseconds(end),
27
+ f"<v {transcript.speaker}>" + x.text,
28
+ )
29
+
30
+ vtt.captions.append(caption)
31
+
32
+ return vtt
src/vtt_utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for VTT validation and cleaning."""
2
+ import re
3
+ from typing import Tuple
4
+
5
+
6
+ def parse_timestamp(timestamp_str: str) -> int | None:
7
+ """
8
+ Parse timestamp string to milliseconds.
9
+
10
+ Args:
11
+ timestamp_str: Timestamp in format HH:MM:SS.mmm
12
+
13
+ Returns:
14
+ Milliseconds as integer, or None if parsing fails
15
+ """
16
+ try:
17
+ parts = timestamp_str.strip().split(':')
18
+ hours = int(parts[0])
19
+ minutes = int(parts[1])
20
+ seconds_parts = parts[2].split('.')
21
+ seconds = int(seconds_parts[0])
22
+ milliseconds = int(seconds_parts[1])
23
+
24
+ total_ms = (hours * 3600 + minutes * 60 + seconds) * 1000 + milliseconds
25
+ return total_ms
26
+ except (ValueError, IndexError, AttributeError):
27
+ return None
28
+
29
+
30
+ def validate_vtt(vtt_content: str) -> Tuple[str, str]:
31
+ """
32
+ Validate VTT format and return status message.
33
+
34
+ Args:
35
+ vtt_content: VTT file content as string
36
+
37
+ Returns:
38
+ Tuple of (status_message, status_type) where status_type is "error", "warning", "success", or ""
39
+ """
40
+ if not vtt_content or vtt_content.strip() == "":
41
+ return "⚪ No content", ""
42
+
43
+ try:
44
+ # Check if starts with WEBVTT
45
+ if not vtt_content.strip().startswith("WEBVTT"):
46
+ return "❌ Invalid: Missing WEBVTT header", "error"
47
+
48
+ lines = vtt_content.split('\n')
49
+ has_timestamps = False
50
+ timestamps = []
51
+
52
+ for line in lines:
53
+ if '-->' not in line:
54
+ continue
55
+
56
+ has_timestamps = True
57
+
58
+ # Validate timestamp format
59
+ match = re.match(r'(\d{2}:\d{2}:\d{2}\.\d{3})\s*-->\s*(\d{2}:\d{2}:\d{2}\.\d{3})', line)
60
+ if not match:
61
+ return "⚠️ Warning: Malformed timestamp found", "warning"
62
+
63
+ # Parse and validate timestamps
64
+ start_str, end_str = match.groups()
65
+ start_ms = parse_timestamp(start_str)
66
+ end_ms = parse_timestamp(end_str)
67
+
68
+ if start_ms is None or end_ms is None:
69
+ return "⚠️ Warning: Invalid timestamp values", "warning"
70
+
71
+ if start_ms >= end_ms:
72
+ return "⚠️ Warning: Start timestamp >= end timestamp", "warning"
73
+
74
+ timestamps.append((start_ms, end_ms))
75
+
76
+ if not has_timestamps:
77
+ return "❌ Invalid: No timestamps found", "error"
78
+
79
+ # Check for overlapping timestamps
80
+ for i in range(len(timestamps) - 1):
81
+ current_end = timestamps[i][1]
82
+ next_start = timestamps[i + 1][0]
83
+ if current_end > next_start:
84
+ return "⚠️ Warning: Overlapping timestamps detected", "warning"
85
+
86
+ return "✅ Valid VTT format", "success"
87
+ except Exception as e:
88
+ return f"❌ Validation error: {str(e)}", "error"
89
+
90
+
91
+ def clean_vtt(vtt_content: str) -> str:
92
+ """
93
+ Clean and improve VTT content.
94
+
95
+ Improvements:
96
+ - Capitalizes after sentence-ending punctuation (. ! ?)
97
+ - Handles cross-segment capitalization intelligently
98
+ - Removes multiple spaces
99
+ - Preserves speaker tags
100
+
101
+ Args:
102
+ vtt_content: VTT file content as string
103
+
104
+ Returns:
105
+ Cleaned VTT content
106
+ """
107
+ if not vtt_content:
108
+ return vtt_content
109
+
110
+ lines = vtt_content.split('\n')
111
+ cleaned_lines = []
112
+ last_text_ended_with_sentence_end = False
113
+
114
+ for line in lines:
115
+ # Skip empty lines and WEBVTT header
116
+ if not line.strip() or line.startswith('WEBVTT'):
117
+ cleaned_lines.append(line)
118
+ continue
119
+
120
+ # Skip timestamp lines
121
+ if '-->' in line:
122
+ cleaned_lines.append(line)
123
+ continue
124
+
125
+ # Extract speaker tag if present
126
+ speaker_tag = ""
127
+ text_content = line
128
+ speaker_match = re.match(r'^(<v [^>]+>)\s*(.*)', line)
129
+ if speaker_match:
130
+ speaker_tag = speaker_match.group(1)
131
+ text_content = speaker_match.group(2)
132
+
133
+ # Capitalize first letter if previous segment ended with sentence-ending punctuation
134
+ if last_text_ended_with_sentence_end and text_content and text_content[0].islower():
135
+ text_content = text_content[0].upper() + text_content[1:]
136
+
137
+ # Fix capitalization after punctuation within the same line
138
+ text_content = re.sub(
139
+ r'([.!?])\s+([a-z])',
140
+ lambda m: m.group(1) + m.group(2).upper(),
141
+ text_content
142
+ )
143
+
144
+ # Remove multiple spaces
145
+ text_content = re.sub(r'\s{2,}', ' ', text_content)
146
+
147
+ # Trim leading/trailing spaces
148
+ text_content = text_content.strip()
149
+
150
+ # Rebuild line with speaker tag if it existed
151
+ cleaned_line = f"{speaker_tag} {text_content}" if speaker_tag else text_content
152
+
153
+ # Check if this line ends with sentence-ending punctuation
154
+ last_text_ended_with_sentence_end = bool(
155
+ text_content and re.search(r'[.!?]\s*$', text_content)
156
+ )
157
+
158
+ cleaned_lines.append(cleaned_line)
159
+
160
+ return '\n'.join(cleaned_lines)
src/whisper.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable, List
3
+
4
+ from openai import OpenAI
5
+ from openai.types.audio import TranscriptionVerbose
6
+ from pyannote.pipeline import Pipeline
7
+ from pydantic import BaseModel
8
+ from pydub import AudioSegment
9
+
10
+
11
+ class TranscriptSegment(BaseModel):
12
+ audio_file: str | Path
13
+ speaker: str
14
+ i: str
15
+ start: float
16
+ end: float
17
+ transcript: TranscriptionVerbose
18
+
19
+
20
+ def get_transcripts(
21
+ diarization: Pipeline,
22
+ audio_segment: AudioSegment,
23
+ openai_api_key: str,
24
+ whisper_model: str,
25
+ whisper_prompt: str,
26
+ whisper_language: str | None,
27
+ tmp_dir: Path,
28
+ progress_callback: Callable[[int, int], None] | None = None
29
+ ) -> List[TranscriptSegment]:
30
+ client = OpenAI(api_key=openai_api_key)
31
+
32
+ transcripts = []
33
+
34
+ # Count total segments
35
+ total_segments = sum(1 for _ in diarization.speaker_diarization.itertracks())
36
+ segment_index = 0
37
+
38
+ for turn, i, speaker in diarization.speaker_diarization.itertracks(yield_label=True):
39
+ segment_index += 1
40
+
41
+ if progress_callback:
42
+ progress_callback(segment_index, total_segments)
43
+
44
+ start = turn.start * 1000
45
+ end = turn.end * 1000
46
+
47
+ chunck = audio_segment[slice(start, end)]
48
+
49
+ chunk_filename = tmp_dir.joinpath(f"segment-{start}.mp3")
50
+
51
+ chunck.export(chunk_filename, format="mp3")
52
+
53
+ audio_chunk_segment = open(chunk_filename, "rb")
54
+
55
+ params = {
56
+ "file": audio_chunk_segment,
57
+ "model": whisper_model,
58
+ "response_format": "verbose_json",
59
+ "timestamp_granularities": ["segment"],
60
+ "prompt": whisper_prompt,
61
+ }
62
+
63
+ if whisper_language:
64
+ params["language"] = whisper_language
65
+
66
+ transcript = client.audio.transcriptions.create(**params)
67
+
68
+ transcripts.append(
69
+ TranscriptSegment(
70
+ audio_file=chunk_filename,
71
+ speaker=speaker,
72
+ i=i,
73
+ start=start,
74
+ end=end,
75
+ transcript=transcript,
76
+ )
77
+ )
78
+
79
+ return transcripts