Spaces:
Running
Running
| import os | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| import gradio as gr | |
| import pysrt | |
| import requests | |
| import tempfile | |
| import time | |
| from faster_whisper import WhisperModel | |
| from datetime import timedelta | |
| from urllib.parse import urlparse | |
| # Maximum words per subtitle (set to None to disable) | |
| DEFAULT_MAX_WORDS = 18 | |
| # ----------------------------- | |
| # Core subtitle generator | |
| # ----------------------------- | |
| class LinearSubtitleGenerator: | |
| def __init__(self, model_size="base"): | |
| self.model = WhisperModel( | |
| model_size, | |
| device="cpu", | |
| compute_type="int8" | |
| ) | |
| def transcribe(self, audio_path): | |
| segments, _ = self.model.transcribe( | |
| audio_path, | |
| word_timestamps=True, | |
| vad_filter=True | |
| ) | |
| return segments | |
| def extract_words(self, segments): | |
| words = [] | |
| for segment in segments: | |
| if not segment.words: | |
| continue | |
| for w in segment.words: | |
| if w.start is None or w.end is None: | |
| continue | |
| words.append({ | |
| "word": w.word.strip(), | |
| "start": float(w.start), | |
| "end": float(w.end) | |
| }) | |
| return words | |
| def find_sentence_boundaries(self, words): | |
| """ | |
| Find first and last sentence boundaries based on periods. | |
| Returns: (first_period_idx, last_period_idx) | |
| """ | |
| first_period_idx = None | |
| last_period_idx = None | |
| for idx, word_data in enumerate(words): | |
| word = word_data["word"] | |
| # Check if word ends with period (and not abbreviation) | |
| if word.endswith('.') or word.endswith('!') or word.endswith('?'): | |
| if first_period_idx is None: | |
| first_period_idx = idx | |
| last_period_idx = idx | |
| return first_period_idx, last_period_idx | |
| def create_linear_subtitles(self, words, max_words=None): | |
| """ | |
| Create subtitles with: | |
| - First sentence as first subtitle | |
| - Middle content with linear pattern (1, 2, 3, 4... words) | |
| - Last sentence as last subtitle | |
| """ | |
| subs = pysrt.SubRipFile() | |
| if not words: | |
| return subs | |
| total_words = len(words) | |
| first_period_idx, last_period_idx = self.find_sentence_boundaries(words) | |
| # Edge case: No periods found - use original linear pattern | |
| if first_period_idx is None: | |
| return self._create_basic_linear_subtitles(words, max_words=max_words) | |
| # Edge case: Only one sentence (first = last) | |
| if first_period_idx == last_period_idx: | |
| # Single sentence becomes single subtitle | |
| self._add_subtitle(subs, 1, words, 0, total_words) | |
| return subs | |
| subtitle_index = 1 | |
| # 1. First sentence as first subtitle | |
| first_sentence_words = words[0:first_period_idx + 1] | |
| self._add_subtitle(subs, subtitle_index, first_sentence_words, 0, len(first_sentence_words)) | |
| subtitle_index += 1 | |
| # 2. Middle content with linear pattern | |
| middle_start = first_period_idx + 1 | |
| middle_end = last_period_idx | |
| if middle_start < middle_end: | |
| middle_words = words[middle_start:middle_end] | |
| subtitle_index = self._add_linear_pattern( | |
| subs, middle_words, subtitle_index, max_words=max_words | |
| ) | |
| # 3. Last sentence as last subtitle | |
| last_sentence_words = words[last_period_idx:total_words] | |
| if last_sentence_words: | |
| self._add_subtitle(subs, subtitle_index, last_sentence_words, 0, len(last_sentence_words)) | |
| return subs | |
| def _add_subtitle(self, subs, index, words, start_idx, end_idx): | |
| """Helper to add a single subtitle from word range""" | |
| if start_idx >= end_idx or start_idx >= len(words): | |
| return | |
| subtitle_words = [] | |
| start_time = None | |
| end_time = None | |
| for i in range(start_idx, min(end_idx, len(words))): | |
| w = words[i] | |
| subtitle_words.append(w["word"]) | |
| if start_time is None: | |
| start_time = w["start"] | |
| end_time = w["end"] | |
| if subtitle_words: | |
| subs.append( | |
| pysrt.SubRipItem( | |
| index=index, | |
| start=self._to_time(start_time), | |
| end=self._to_time(end_time), | |
| text=" ".join(subtitle_words) | |
| ) | |
| ) | |
| def _add_linear_pattern(self, subs, words, start_index, max_words=None): | |
| """Apply linear pattern (1, 2, 3, 4... words) to words list | |
| If `max_words` is provided, no subtitle will contain more than | |
| `max_words` words. Once the linear size reaches `max_words` it | |
| will remain at that size for subsequent subtitles. | |
| """ | |
| total_words = len(words) | |
| index = 0 | |
| subtitle_index = start_index | |
| current_size = 1 | |
| while index < total_words: | |
| planned_size = current_size | |
| if max_words is not None: | |
| planned_size = min(planned_size, max_words) | |
| remaining = total_words - (index + planned_size) | |
| next_size = current_size + 1 | |
| # Absorb leftovers to avoid tiny last subtitle | |
| if remaining > 0 and remaining < next_size: | |
| planned_size += remaining | |
| subtitle_words = [] | |
| start_time = None | |
| end_time = None | |
| for _ in range(planned_size): | |
| if index >= total_words: | |
| break | |
| w = words[index] | |
| subtitle_words.append(w["word"]) | |
| if start_time is None: | |
| start_time = w["start"] | |
| end_time = w["end"] | |
| index += 1 | |
| if subtitle_words: | |
| subs.append( | |
| pysrt.SubRipItem( | |
| index=subtitle_index, | |
| start=self._to_time(start_time), | |
| end=self._to_time(end_time), | |
| text=" ".join(subtitle_words) | |
| ) | |
| ) | |
| subtitle_index += 1 | |
| # Progress to next size only if we didn't absorb leftovers | |
| # and we're not already at the configured maximum. | |
| if planned_size == current_size: | |
| if max_words is None or current_size < max_words: | |
| current_size += 1 | |
| else: | |
| # stay at max_words for following subtitles | |
| current_size = max_words | |
| else: | |
| break | |
| return subtitle_index | |
| def _create_basic_linear_subtitles(self, words, max_words=None): | |
| """Fallback: Original linear pattern when no periods found | |
| Honors `max_words` similarly to the linear pattern above. | |
| """ | |
| subs = pysrt.SubRipFile() | |
| total_words = len(words) | |
| index = 0 | |
| subtitle_index = 1 | |
| current_size = 1 | |
| while index < total_words: | |
| planned_size = current_size | |
| if max_words is not None: | |
| planned_size = min(planned_size, max_words) | |
| remaining = total_words - (index + planned_size) | |
| next_size = current_size + 1 | |
| if remaining > 0 and remaining < next_size: | |
| planned_size += remaining | |
| subtitle_words = [] | |
| start_time = None | |
| end_time = None | |
| for _ in range(planned_size): | |
| if index >= total_words: | |
| break | |
| w = words[index] | |
| subtitle_words.append(w["word"]) | |
| if start_time is None: | |
| start_time = w["start"] | |
| end_time = w["end"] | |
| index += 1 | |
| subs.append( | |
| pysrt.SubRipItem( | |
| index=subtitle_index, | |
| start=self._to_time(start_time), | |
| end=self._to_time(end_time), | |
| text=" ".join(subtitle_words) | |
| ) | |
| ) | |
| subtitle_index += 1 | |
| if planned_size == current_size: | |
| if max_words is None or current_size < max_words: | |
| current_size += 1 | |
| else: | |
| current_size = max_words | |
| else: | |
| break | |
| return subs | |
| def _to_time(self, seconds): | |
| td = timedelta(seconds=seconds) | |
| return pysrt.SubRipTime( | |
| hours=td.seconds // 3600, | |
| minutes=(td.seconds % 3600) // 60, | |
| seconds=td.seconds % 60, | |
| milliseconds=td.microseconds // 1000 | |
| ) | |
| # ----------------------------- | |
| # Helper: download audio from URL | |
| # ----------------------------- | |
| def download_audio(url: str) -> str: | |
| parsed = urlparse(url) | |
| if parsed.scheme not in ("http", "https"): | |
| raise ValueError("Invalid URL scheme") | |
| response = requests.get(url, stream=True, timeout=30) | |
| response.raise_for_status() | |
| suffix = os.path.splitext(parsed.path)[1] or ".wav" | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| for chunk in response.iter_content(chunk_size=8192): | |
| tmp.write(chunk) | |
| tmp.close() | |
| return tmp.name | |
| # ----------------------------- | |
| # Helper: format elapsed time | |
| # ----------------------------- | |
| def format_time(seconds): | |
| """Format seconds into readable time string""" | |
| if seconds < 60: | |
| return f"{seconds:.1f}s" | |
| elif seconds < 3600: | |
| mins = int(seconds // 60) | |
| secs = int(seconds % 60) | |
| return f"{mins}m {secs}s" | |
| else: | |
| hours = int(seconds // 3600) | |
| mins = int((seconds % 3600) // 60) | |
| return f"{hours}h {mins}m" | |
| # ----------------------------- | |
| # Gradio callable function with status updates | |
| # ----------------------------- | |
| def generate_srt(audio_file, audio_url, model_size): | |
| start_time = time.time() | |
| status_messages = [] | |
| try: | |
| # Validation | |
| if bool(audio_file) == bool(audio_url): | |
| error_msg = "β Error: Please provide EITHER an audio file OR an audio URL (not both)." | |
| return None, error_msg | |
| status_messages.append("π Starting subtitle generation...") | |
| yield None, "\n".join(status_messages) | |
| # Step 1: Get audio file | |
| if audio_url: | |
| status_messages.append("π₯ Downloading audio from URL...") | |
| yield None, "\n".join(status_messages) | |
| download_start = time.time() | |
| audio_path = download_audio(audio_url) | |
| download_time = time.time() - download_start | |
| status_messages.append(f"β Download completed in {format_time(download_time)}") | |
| yield None, "\n".join(status_messages) | |
| else: | |
| audio_path = audio_file | |
| status_messages.append("β Audio file loaded") | |
| yield None, "\n".join(status_messages) | |
| # Step 2: Load model | |
| status_messages.append(f"π§ Loading Whisper model ({model_size})...") | |
| yield None, "\n".join(status_messages) | |
| model_start = time.time() | |
| generator = LinearSubtitleGenerator(model_size) | |
| model_time = time.time() - model_start | |
| status_messages.append(f"β Model loaded in {format_time(model_time)}") | |
| yield None, "\n".join(status_messages) | |
| # Step 3: Transcribe | |
| status_messages.append("π€ Transcribing audio (this may take a while)...") | |
| yield None, "\n".join(status_messages) | |
| transcribe_start = time.time() | |
| segments = generator.transcribe(audio_path) | |
| words = generator.extract_words(segments) | |
| transcribe_time = time.time() - transcribe_start | |
| status_messages.append(f"β Transcription completed in {format_time(transcribe_time)}") | |
| status_messages.append(f"π Extracted {len(words)} words") | |
| yield None, "\n".join(status_messages) | |
| # Step 4: Generate subtitles | |
| status_messages.append("π Generating SRT subtitles...") | |
| yield None, "\n".join(status_messages) | |
| srt_start = time.time() | |
| subs = generator.create_linear_subtitles(words, max_words=DEFAULT_MAX_WORDS) | |
| srt_time = time.time() - srt_start | |
| status_messages.append(f"β Created {len(subs)} subtitle segments in {format_time(srt_time)}") | |
| yield None, "\n".join(status_messages) | |
| # Step 5: Save file | |
| status_messages.append("πΎ Saving SRT file...") | |
| yield None, "\n".join(status_messages) | |
| out = tempfile.NamedTemporaryFile(delete=False, suffix=".srt") | |
| subs.save(out.name, encoding="utf-8") | |
| # Calculate total time | |
| total_time = time.time() - start_time | |
| # Final success message | |
| status_messages.append(f"β SUCCESS! Total time: {format_time(total_time)}") | |
| status_messages.append(f"π SRT file ready for download") | |
| yield out.name, "\n".join(status_messages) | |
| except requests.RequestException as e: | |
| error_msg = f"β Network Error: Failed to download audio\nDetails: {str(e)}" | |
| yield None, error_msg | |
| except ValueError as e: | |
| error_msg = f"β Validation Error: {str(e)}" | |
| yield None, error_msg | |
| except Exception as e: | |
| total_time = time.time() - start_time | |
| error_msg = f"β Error occurred after {format_time(total_time)}\nDetails: {str(e)}" | |
| yield None, error_msg | |
| # ----------------------------- | |
| # Gradio UI with Status Bar | |
| # ----------------------------- | |
| with gr.Blocks(title="Subtitle Generator") as demo: | |
| gr.Markdown( | |
| """ | |
| # SRT Generator with Smart Sentence Handling | |
| **Features:** | |
| - First sentence β First subtitle | |
| - Middle content β Linear pattern (1, 2, 3, 4... words) | |
| - Last sentence β Last subtitle | |
| """ | |
| ) | |
| with gr.Row(): | |
| audio_file = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath" | |
| ) | |
| audio_url = gr.Textbox( | |
| label="Audio URL (http/https)", | |
| placeholder="https://example.com/audio.wav" | |
| ) | |
| model_choice = gr.Dropdown( | |
| choices=["tiny", "base", "small", "medium"], | |
| value="base", | |
| label="Whisper Model" | |
| ) | |
| generate_btn = gr.Button("Generate SRT", variant="primary") | |
| # Status display | |
| status_box = gr.Textbox( | |
| label="Status", | |
| placeholder="Status updates will appear here...", | |
| lines=10, | |
| max_lines=15, | |
| interactive=False | |
| ) | |
| output_file = gr.File(label="Download SRT") | |
| # Event handler | |
| generate_btn.click( | |
| fn=generate_srt, | |
| inputs=[audio_file, audio_url, model_choice], | |
| outputs=[output_file, status_box] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Tips:** | |
| - Larger models (small/medium) are more accurate but slower | |
| - For best results, use clear audio with minimal background noise | |
| - Processing time depends on audio length and model size | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |