Spaces:
Running
Running
| import json | |
| import re | |
| from typing import Dict, List, Any, Optional | |
| from datetime import timedelta | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class SubtitleFormatter: | |
| """Format transcription results into various subtitle formats""" | |
| # Settings for grouping words into subtitle segments | |
| MAX_WORDS_PER_SEGMENT = 12 | |
| MAX_DURATION_SECONDS = 7.0 | |
| MIN_DURATION_SECONDS = 1.0 | |
| def group_words_into_segments(chunks: List[Dict]) -> List[Dict]: | |
| """ | |
| Group word-level chunks into proper subtitle segments. | |
| Groups by: | |
| - Sentence-ending punctuation (. ! ? etc.) | |
| - Maximum words per segment | |
| - Maximum duration per segment | |
| Args: | |
| chunks: List of word-level chunks with timestamps | |
| Returns: | |
| List of grouped segments with combined text and timestamps | |
| """ | |
| if not chunks: | |
| return [] | |
| segments = [] | |
| current_segment = { | |
| 'words': [], | |
| 'start': None, | |
| 'end': None, | |
| 'text': '' | |
| } | |
| sentence_endings = re.compile(r'[.!?;:]$') | |
| for chunk in chunks: | |
| text = chunk.get('text', '').strip() | |
| if not text: | |
| continue | |
| timestamp = chunk.get('timestamp', (None, None)) | |
| word_start = timestamp[0] if timestamp[0] is not None else 0.0 | |
| word_end = timestamp[1] if timestamp[1] is not None else word_start + 0.5 | |
| # Initialize segment start time | |
| if current_segment['start'] is None: | |
| current_segment['start'] = word_start | |
| current_segment['words'].append(text) | |
| current_segment['end'] = word_end | |
| # Calculate current segment duration | |
| duration = current_segment['end'] - current_segment['start'] | |
| word_count = len(current_segment['words']) | |
| # Check if we should end the current segment | |
| should_end_segment = ( | |
| sentence_endings.search(text) or # Sentence ending punctuation | |
| word_count >= SubtitleFormatter.MAX_WORDS_PER_SEGMENT or # Max words reached | |
| duration >= SubtitleFormatter.MAX_DURATION_SECONDS # Max duration reached | |
| ) | |
| if should_end_segment: | |
| # Finalize current segment | |
| current_segment['text'] = ' '.join(current_segment['words']) | |
| # Clean up double spaces | |
| current_segment['text'] = re.sub(r'\s+', ' ', current_segment['text']).strip() | |
| segments.append({ | |
| 'timestamp': (current_segment['start'], current_segment['end']), | |
| 'text': current_segment['text'] | |
| }) | |
| # Reset for next segment | |
| current_segment = { | |
| 'words': [], | |
| 'start': None, | |
| 'end': None, | |
| 'text': '' | |
| } | |
| # Don't forget the last segment if there are remaining words | |
| if current_segment['words']: | |
| current_segment['text'] = ' '.join(current_segment['words']) | |
| current_segment['text'] = re.sub(r'\s+', ' ', current_segment['text']).strip() | |
| segments.append({ | |
| 'timestamp': (current_segment['start'], current_segment['end']), | |
| 'text': current_segment['text'] | |
| }) | |
| return segments | |
| def format_timestamp_srt(seconds: float) -> str: | |
| """ | |
| Convert seconds to SRT timestamp format (HH:MM:SS,mmm) | |
| Args: | |
| seconds: Time in seconds | |
| Returns: | |
| Formatted timestamp string | |
| """ | |
| if seconds is None: | |
| seconds = 0.0 | |
| td = timedelta(seconds=seconds) | |
| hours = int(td.total_seconds() // 3600) | |
| minutes = int((td.total_seconds() % 3600) // 60) | |
| secs = int(td.total_seconds() % 60) | |
| millis = int((seconds - int(seconds)) * 1000) | |
| return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" | |
| def format_timestamp_vtt(seconds: float) -> str: | |
| """ | |
| Convert seconds to VTT timestamp format (HH:MM:SS.mmm) | |
| Args: | |
| seconds: Time in seconds | |
| Returns: | |
| Formatted timestamp string | |
| """ | |
| if seconds is None: | |
| seconds = 0.0 | |
| td = timedelta(seconds=seconds) | |
| hours = int(td.total_seconds() // 3600) | |
| minutes = int((td.total_seconds() % 3600) // 60) | |
| secs = int(td.total_seconds() % 60) | |
| millis = int((seconds - int(seconds)) * 1000) | |
| return f"{hours:02d}:{minutes:02d}:{secs:02d}.{millis:03d}" | |
| def to_srt(result: Dict[str, Any], speaker_labels: Optional[Dict] = None) -> str: | |
| """ | |
| Convert transcription result to SRT format | |
| Args: | |
| result: Transcription result dictionary | |
| speaker_labels: Optional speaker diarization labels | |
| Returns: | |
| SRT formatted string | |
| """ | |
| srt_content = [] | |
| chunks = result.get('chunks', []) | |
| # Group words into proper subtitle segments | |
| segments = SubtitleFormatter.group_words_into_segments(chunks) | |
| for idx, segment in enumerate(segments, 1): | |
| timestamp = segment.get('timestamp', (0, 0)) | |
| text = segment.get('text', '').strip() | |
| if not text: | |
| continue | |
| start_time = timestamp[0] if timestamp[0] is not None else 0.0 | |
| end_time = timestamp[1] if timestamp[1] is not None else start_time + 1.0 | |
| # Add speaker label if available (simplified for grouped segments) | |
| if speaker_labels: | |
| # Find the most common speaker in this time range | |
| speaker = SubtitleFormatter._get_speaker_for_segment( | |
| speaker_labels, chunks, start_time, end_time | |
| ) | |
| if speaker: | |
| text = f"[{speaker}]: {text}" | |
| # Format: index, timestamp, text | |
| srt_content.append(f"{idx}") | |
| srt_content.append( | |
| f"{SubtitleFormatter.format_timestamp_srt(start_time)} --> " | |
| f"{SubtitleFormatter.format_timestamp_srt(end_time)}" | |
| ) | |
| srt_content.append(text) | |
| srt_content.append("") # Blank line between entries | |
| return "\n".join(srt_content) | |
| def _get_speaker_for_segment( | |
| speaker_labels: Dict, | |
| chunks: List[Dict], | |
| start_time: float, | |
| end_time: float | |
| ) -> Optional[str]: | |
| """Get the most common speaker for a time segment""" | |
| speakers = [] | |
| for idx, chunk in enumerate(chunks): | |
| ts = chunk.get('timestamp', (None, None)) | |
| if ts[0] is not None and start_time <= ts[0] <= end_time: | |
| if idx in speaker_labels: | |
| speakers.append(speaker_labels[idx]) | |
| if speakers: | |
| # Return most common speaker | |
| return max(set(speakers), key=speakers.count) | |
| return None | |
| def to_vtt(result: Dict[str, Any], speaker_labels: Optional[Dict] = None) -> str: | |
| """ | |
| Convert transcription result to VTT (WebVTT) format | |
| Args: | |
| result: Transcription result dictionary | |
| speaker_labels: Optional speaker diarization labels | |
| Returns: | |
| VTT formatted string | |
| """ | |
| vtt_content = ["WEBVTT", ""] | |
| chunks = result.get('chunks', []) | |
| # Group words into proper subtitle segments | |
| segments = SubtitleFormatter.group_words_into_segments(chunks) | |
| for idx, segment in enumerate(segments): | |
| timestamp = segment.get('timestamp', (0, 0)) | |
| text = segment.get('text', '').strip() | |
| if not text: | |
| continue | |
| start_time = timestamp[0] if timestamp[0] is not None else 0.0 | |
| end_time = timestamp[1] if timestamp[1] is not None else start_time + 1.0 | |
| # Add speaker label if available | |
| if speaker_labels: | |
| speaker = SubtitleFormatter._get_speaker_for_segment( | |
| speaker_labels, chunks, start_time, end_time | |
| ) | |
| if speaker: | |
| text = f"<v {speaker}>{text}</v>" | |
| # Format: timestamp, text | |
| vtt_content.append( | |
| f"{SubtitleFormatter.format_timestamp_vtt(start_time)} --> " | |
| f"{SubtitleFormatter.format_timestamp_vtt(end_time)}" | |
| ) | |
| vtt_content.append(text) | |
| vtt_content.append("") # Blank line between entries | |
| return "\n".join(vtt_content) | |
| def to_txt(result: Dict[str, Any], speaker_labels: Optional[Dict] = None) -> str: | |
| """ | |
| Convert transcription result to plain text format | |
| Args: | |
| result: Transcription result dictionary | |
| speaker_labels: Optional speaker diarization labels | |
| Returns: | |
| Plain text string | |
| """ | |
| if speaker_labels: | |
| # Format with speaker labels | |
| txt_lines = [] | |
| chunks = result.get('chunks', []) | |
| current_speaker = None | |
| current_text = [] | |
| for idx, chunk in enumerate(chunks): | |
| text = chunk.get('text', '').strip() | |
| if not text: | |
| continue | |
| speaker = speaker_labels.get(idx, 'UNKNOWN') | |
| if speaker != current_speaker: | |
| # New speaker, write previous speaker's text | |
| if current_text: | |
| txt_lines.append(f"[{current_speaker}]: {' '.join(current_text)}") | |
| current_text = [] | |
| current_speaker = speaker | |
| current_text.append(text) | |
| # Add last speaker's text | |
| if current_text: | |
| txt_lines.append(f"[{current_speaker}]: {' '.join(current_text)}") | |
| return "\n\n".join(txt_lines) | |
| else: | |
| # Simple format without speakers | |
| return result.get('text', '') | |
| def to_json(result: Dict[str, Any], speaker_labels: Optional[Dict] = None) -> str: | |
| """ | |
| Convert transcription result to JSON format with word-level timestamps | |
| Args: | |
| result: Transcription result dictionary | |
| speaker_labels: Optional speaker diarization labels | |
| Returns: | |
| JSON formatted string | |
| """ | |
| output = { | |
| 'text': result.get('text', ''), | |
| 'language': result.get('language', 'unknown'), | |
| 'segments': [] | |
| } | |
| chunks = result.get('chunks', []) | |
| # Process chunks into segments with word-level details | |
| for idx, chunk in enumerate(chunks): | |
| timestamp = chunk.get('timestamp', (None, None)) | |
| text = chunk.get('text', '').strip() | |
| if not text: | |
| continue | |
| segment = { | |
| 'index': idx, | |
| 'start': timestamp[0], | |
| 'end': timestamp[1], | |
| 'text': text, | |
| } | |
| # Add speaker label if available | |
| if speaker_labels and idx in speaker_labels: | |
| segment['speaker'] = speaker_labels[idx] | |
| output['segments'].append(segment) | |
| return json.dumps(output, indent=2, ensure_ascii=False) | |
| def save_to_file(content: str, file_path: str): | |
| """ | |
| Save formatted content to file | |
| Args: | |
| content: Formatted subtitle content | |
| file_path: Path to save file | |
| """ | |
| try: | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| f.write(content) | |
| logger.info(f"Saved output to: {file_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to save file: {e}") | |
| raise Exception(f"Failed to save file: {str(e)}") | |
| def generate_all_formats( | |
| result: Dict[str, Any], | |
| output_prefix: str, | |
| speaker_labels: Optional[Dict] = None | |
| ) -> Dict[str, str]: | |
| """ | |
| Generate all output formats and save to files | |
| Args: | |
| result: Transcription result dictionary | |
| output_prefix: Prefix for output filenames | |
| speaker_labels: Optional speaker diarization labels | |
| Returns: | |
| Dictionary mapping format to file path | |
| """ | |
| outputs = {} | |
| # Generate SRT | |
| srt_content = SubtitleFormatter.to_srt(result, speaker_labels) | |
| srt_path = f"{output_prefix}.srt" | |
| SubtitleFormatter.save_to_file(srt_content, srt_path) | |
| outputs['srt'] = srt_path | |
| # Generate VTT | |
| vtt_content = SubtitleFormatter.to_vtt(result, speaker_labels) | |
| vtt_path = f"{output_prefix}.vtt" | |
| SubtitleFormatter.save_to_file(vtt_content, vtt_path) | |
| outputs['vtt'] = vtt_path | |
| # Generate TXT | |
| txt_content = SubtitleFormatter.to_txt(result, speaker_labels) | |
| txt_path = f"{output_prefix}.txt" | |
| SubtitleFormatter.save_to_file(txt_content, txt_path) | |
| outputs['txt'] = txt_path | |
| # Generate JSON | |
| json_content = SubtitleFormatter.to_json(result, speaker_labels) | |
| json_path = f"{output_prefix}.json" | |
| SubtitleFormatter.save_to_file(json_content, json_path) | |
| outputs['json'] = json_path | |
| logger.info(f"Generated all formats: {list(outputs.keys())}") | |
| return outputs | |