Whisper-Transcriber / utils /formatters.py
Whisper Transcriber Bot
Group words into proper subtitle segments
d33fc74
import json
import re
from typing import Dict, List, Any, Optional
from datetime import timedelta
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SubtitleFormatter:
"""Format transcription results into various subtitle formats"""
# Settings for grouping words into subtitle segments
MAX_WORDS_PER_SEGMENT = 12
MAX_DURATION_SECONDS = 7.0
MIN_DURATION_SECONDS = 1.0
@staticmethod
def group_words_into_segments(chunks: List[Dict]) -> List[Dict]:
"""
Group word-level chunks into proper subtitle segments.
Groups by:
- Sentence-ending punctuation (. ! ? etc.)
- Maximum words per segment
- Maximum duration per segment
Args:
chunks: List of word-level chunks with timestamps
Returns:
List of grouped segments with combined text and timestamps
"""
if not chunks:
return []
segments = []
current_segment = {
'words': [],
'start': None,
'end': None,
'text': ''
}
sentence_endings = re.compile(r'[.!?;:]$')
for chunk in chunks:
text = chunk.get('text', '').strip()
if not text:
continue
timestamp = chunk.get('timestamp', (None, None))
word_start = timestamp[0] if timestamp[0] is not None else 0.0
word_end = timestamp[1] if timestamp[1] is not None else word_start + 0.5
# Initialize segment start time
if current_segment['start'] is None:
current_segment['start'] = word_start
current_segment['words'].append(text)
current_segment['end'] = word_end
# Calculate current segment duration
duration = current_segment['end'] - current_segment['start']
word_count = len(current_segment['words'])
# Check if we should end the current segment
should_end_segment = (
sentence_endings.search(text) or # Sentence ending punctuation
word_count >= SubtitleFormatter.MAX_WORDS_PER_SEGMENT or # Max words reached
duration >= SubtitleFormatter.MAX_DURATION_SECONDS # Max duration reached
)
if should_end_segment:
# Finalize current segment
current_segment['text'] = ' '.join(current_segment['words'])
# Clean up double spaces
current_segment['text'] = re.sub(r'\s+', ' ', current_segment['text']).strip()
segments.append({
'timestamp': (current_segment['start'], current_segment['end']),
'text': current_segment['text']
})
# Reset for next segment
current_segment = {
'words': [],
'start': None,
'end': None,
'text': ''
}
# Don't forget the last segment if there are remaining words
if current_segment['words']:
current_segment['text'] = ' '.join(current_segment['words'])
current_segment['text'] = re.sub(r'\s+', ' ', current_segment['text']).strip()
segments.append({
'timestamp': (current_segment['start'], current_segment['end']),
'text': current_segment['text']
})
return segments
@staticmethod
def format_timestamp_srt(seconds: float) -> str:
"""
Convert seconds to SRT timestamp format (HH:MM:SS,mmm)
Args:
seconds: Time in seconds
Returns:
Formatted timestamp string
"""
if seconds is None:
seconds = 0.0
td = timedelta(seconds=seconds)
hours = int(td.total_seconds() // 3600)
minutes = int((td.total_seconds() % 3600) // 60)
secs = int(td.total_seconds() % 60)
millis = int((seconds - int(seconds)) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
@staticmethod
def format_timestamp_vtt(seconds: float) -> str:
"""
Convert seconds to VTT timestamp format (HH:MM:SS.mmm)
Args:
seconds: Time in seconds
Returns:
Formatted timestamp string
"""
if seconds is None:
seconds = 0.0
td = timedelta(seconds=seconds)
hours = int(td.total_seconds() // 3600)
minutes = int((td.total_seconds() % 3600) // 60)
secs = int(td.total_seconds() % 60)
millis = int((seconds - int(seconds)) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{millis:03d}"
@staticmethod
def to_srt(result: Dict[str, Any], speaker_labels: Optional[Dict] = None) -> str:
"""
Convert transcription result to SRT format
Args:
result: Transcription result dictionary
speaker_labels: Optional speaker diarization labels
Returns:
SRT formatted string
"""
srt_content = []
chunks = result.get('chunks', [])
# Group words into proper subtitle segments
segments = SubtitleFormatter.group_words_into_segments(chunks)
for idx, segment in enumerate(segments, 1):
timestamp = segment.get('timestamp', (0, 0))
text = segment.get('text', '').strip()
if not text:
continue
start_time = timestamp[0] if timestamp[0] is not None else 0.0
end_time = timestamp[1] if timestamp[1] is not None else start_time + 1.0
# Add speaker label if available (simplified for grouped segments)
if speaker_labels:
# Find the most common speaker in this time range
speaker = SubtitleFormatter._get_speaker_for_segment(
speaker_labels, chunks, start_time, end_time
)
if speaker:
text = f"[{speaker}]: {text}"
# Format: index, timestamp, text
srt_content.append(f"{idx}")
srt_content.append(
f"{SubtitleFormatter.format_timestamp_srt(start_time)} --> "
f"{SubtitleFormatter.format_timestamp_srt(end_time)}"
)
srt_content.append(text)
srt_content.append("") # Blank line between entries
return "\n".join(srt_content)
@staticmethod
def _get_speaker_for_segment(
speaker_labels: Dict,
chunks: List[Dict],
start_time: float,
end_time: float
) -> Optional[str]:
"""Get the most common speaker for a time segment"""
speakers = []
for idx, chunk in enumerate(chunks):
ts = chunk.get('timestamp', (None, None))
if ts[0] is not None and start_time <= ts[0] <= end_time:
if idx in speaker_labels:
speakers.append(speaker_labels[idx])
if speakers:
# Return most common speaker
return max(set(speakers), key=speakers.count)
return None
@staticmethod
def to_vtt(result: Dict[str, Any], speaker_labels: Optional[Dict] = None) -> str:
"""
Convert transcription result to VTT (WebVTT) format
Args:
result: Transcription result dictionary
speaker_labels: Optional speaker diarization labels
Returns:
VTT formatted string
"""
vtt_content = ["WEBVTT", ""]
chunks = result.get('chunks', [])
# Group words into proper subtitle segments
segments = SubtitleFormatter.group_words_into_segments(chunks)
for idx, segment in enumerate(segments):
timestamp = segment.get('timestamp', (0, 0))
text = segment.get('text', '').strip()
if not text:
continue
start_time = timestamp[0] if timestamp[0] is not None else 0.0
end_time = timestamp[1] if timestamp[1] is not None else start_time + 1.0
# Add speaker label if available
if speaker_labels:
speaker = SubtitleFormatter._get_speaker_for_segment(
speaker_labels, chunks, start_time, end_time
)
if speaker:
text = f"<v {speaker}>{text}</v>"
# Format: timestamp, text
vtt_content.append(
f"{SubtitleFormatter.format_timestamp_vtt(start_time)} --> "
f"{SubtitleFormatter.format_timestamp_vtt(end_time)}"
)
vtt_content.append(text)
vtt_content.append("") # Blank line between entries
return "\n".join(vtt_content)
@staticmethod
def to_txt(result: Dict[str, Any], speaker_labels: Optional[Dict] = None) -> str:
"""
Convert transcription result to plain text format
Args:
result: Transcription result dictionary
speaker_labels: Optional speaker diarization labels
Returns:
Plain text string
"""
if speaker_labels:
# Format with speaker labels
txt_lines = []
chunks = result.get('chunks', [])
current_speaker = None
current_text = []
for idx, chunk in enumerate(chunks):
text = chunk.get('text', '').strip()
if not text:
continue
speaker = speaker_labels.get(idx, 'UNKNOWN')
if speaker != current_speaker:
# New speaker, write previous speaker's text
if current_text:
txt_lines.append(f"[{current_speaker}]: {' '.join(current_text)}")
current_text = []
current_speaker = speaker
current_text.append(text)
# Add last speaker's text
if current_text:
txt_lines.append(f"[{current_speaker}]: {' '.join(current_text)}")
return "\n\n".join(txt_lines)
else:
# Simple format without speakers
return result.get('text', '')
@staticmethod
def to_json(result: Dict[str, Any], speaker_labels: Optional[Dict] = None) -> str:
"""
Convert transcription result to JSON format with word-level timestamps
Args:
result: Transcription result dictionary
speaker_labels: Optional speaker diarization labels
Returns:
JSON formatted string
"""
output = {
'text': result.get('text', ''),
'language': result.get('language', 'unknown'),
'segments': []
}
chunks = result.get('chunks', [])
# Process chunks into segments with word-level details
for idx, chunk in enumerate(chunks):
timestamp = chunk.get('timestamp', (None, None))
text = chunk.get('text', '').strip()
if not text:
continue
segment = {
'index': idx,
'start': timestamp[0],
'end': timestamp[1],
'text': text,
}
# Add speaker label if available
if speaker_labels and idx in speaker_labels:
segment['speaker'] = speaker_labels[idx]
output['segments'].append(segment)
return json.dumps(output, indent=2, ensure_ascii=False)
@staticmethod
def save_to_file(content: str, file_path: str):
"""
Save formatted content to file
Args:
content: Formatted subtitle content
file_path: Path to save file
"""
try:
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
logger.info(f"Saved output to: {file_path}")
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise Exception(f"Failed to save file: {str(e)}")
@staticmethod
def generate_all_formats(
result: Dict[str, Any],
output_prefix: str,
speaker_labels: Optional[Dict] = None
) -> Dict[str, str]:
"""
Generate all output formats and save to files
Args:
result: Transcription result dictionary
output_prefix: Prefix for output filenames
speaker_labels: Optional speaker diarization labels
Returns:
Dictionary mapping format to file path
"""
outputs = {}
# Generate SRT
srt_content = SubtitleFormatter.to_srt(result, speaker_labels)
srt_path = f"{output_prefix}.srt"
SubtitleFormatter.save_to_file(srt_content, srt_path)
outputs['srt'] = srt_path
# Generate VTT
vtt_content = SubtitleFormatter.to_vtt(result, speaker_labels)
vtt_path = f"{output_prefix}.vtt"
SubtitleFormatter.save_to_file(vtt_content, vtt_path)
outputs['vtt'] = vtt_path
# Generate TXT
txt_content = SubtitleFormatter.to_txt(result, speaker_labels)
txt_path = f"{output_prefix}.txt"
SubtitleFormatter.save_to_file(txt_content, txt_path)
outputs['txt'] = txt_path
# Generate JSON
json_content = SubtitleFormatter.to_json(result, speaker_labels)
json_path = f"{output_prefix}.json"
SubtitleFormatter.save_to_file(json_content, json_path)
outputs['json'] = json_path
logger.info(f"Generated all formats: {list(outputs.keys())}")
return outputs