|
|
""" |
|
|
Subtitle Extractor Module |
|
|
Extracts subtitles from videos using OCR and generates SRT files |
|
|
""" |
|
|
|
|
|
import cv2 |
|
|
import sys |
|
|
import os |
|
|
from pathlib import Path |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
from backend.main import SubtitleDetect |
|
|
|
|
|
|
|
|
class SubtitleExtractor: |
|
|
"""Extract subtitles from video and generate SRT files""" |
|
|
|
|
|
def __init__(self, video_path, sub_area=None): |
|
|
""" |
|
|
Initialize subtitle extractor |
|
|
|
|
|
Args: |
|
|
video_path: Path to video file |
|
|
sub_area: Optional subtitle area (ymin, ymax, xmin, xmax) |
|
|
""" |
|
|
self.video_path = video_path |
|
|
self.sub_area = sub_area |
|
|
self.detector = SubtitleDetect(video_path, sub_area) |
|
|
|
|
|
|
|
|
self.video_cap = cv2.VideoCapture(video_path) |
|
|
self.fps = self.video_cap.get(cv2.CAP_PROP_FPS) |
|
|
self.frame_count = int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
@property |
|
|
def text_recognizer(self): |
|
|
"""Lazy load PaddleOCR text recognizer""" |
|
|
if not hasattr(self, '_text_recognizer'): |
|
|
import paddle |
|
|
paddle.disable_signal_handler() |
|
|
from paddleocr.tools.infer import utility |
|
|
from paddleocr.tools.infer.predict_rec import TextRecognizer |
|
|
import importlib |
|
|
import config |
|
|
importlib.reload(config) |
|
|
|
|
|
args = utility.parse_args() |
|
|
args.rec_algorithm = 'CRNN' |
|
|
args.rec_model_dir = config.REC_MODEL_PATH if hasattr(config, 'REC_MODEL_PATH') else os.path.join(config.DET_MODEL_BASE, config.MODEL_VERSION, 'ch_rec') |
|
|
args.use_onnx = len(config.ONNX_PROVIDERS) > 0 |
|
|
args.onnx_providers = config.ONNX_PROVIDERS |
|
|
|
|
|
self._text_recognizer = TextRecognizer(args) |
|
|
return self._text_recognizer |
|
|
|
|
|
def extract_text_from_frame(self, frame, boxes): |
|
|
""" |
|
|
Extract text from frame using OCR |
|
|
|
|
|
Args: |
|
|
frame: Video frame (numpy array) |
|
|
boxes: List of detected text boxes [(xmin, xmax, ymin, ymax), ...] |
|
|
|
|
|
Returns: |
|
|
List of extracted text strings |
|
|
""" |
|
|
texts = [] |
|
|
|
|
|
for box in boxes: |
|
|
xmin, xmax, ymin, ymax = box |
|
|
|
|
|
|
|
|
text_region = frame[ymin:ymax, xmin:xmax] |
|
|
|
|
|
if text_region.size == 0: |
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
rec_result, _ = self.text_recognizer([text_region]) |
|
|
if rec_result and len(rec_result) > 0: |
|
|
text, confidence = rec_result[0] |
|
|
if confidence > 0.5: |
|
|
texts.append(text) |
|
|
except Exception as e: |
|
|
print(f"Warning: OCR failed for box {box}: {e}") |
|
|
continue |
|
|
|
|
|
return texts |
|
|
|
|
|
def format_timestamp(self, seconds): |
|
|
""" |
|
|
Convert seconds to SRT timestamp format (HH:MM:SS,mmm) |
|
|
|
|
|
Args: |
|
|
seconds: Time in seconds (float) |
|
|
|
|
|
Returns: |
|
|
Formatted timestamp string |
|
|
""" |
|
|
hours = int(seconds // 3600) |
|
|
minutes = int((seconds % 3600) // 60) |
|
|
secs = int(seconds % 60) |
|
|
millis = int((seconds % 1) * 1000) |
|
|
|
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" |
|
|
|
|
|
def extract_subtitles(self, progress_callback=None): |
|
|
""" |
|
|
Extract subtitles with OCR and timestamps |
|
|
|
|
|
Args: |
|
|
progress_callback: Optional callback function for progress updates |
|
|
|
|
|
Returns: |
|
|
List of subtitle dictionaries with 'start', 'end', 'text' keys |
|
|
""" |
|
|
print("[Subtitle Extractor] Starting subtitle extraction...") |
|
|
|
|
|
|
|
|
subtitle_frame_dict = self.detector.find_subtitle_frame_no() |
|
|
|
|
|
if not subtitle_frame_dict: |
|
|
print("[Subtitle Extractor] No subtitles detected!") |
|
|
return [] |
|
|
|
|
|
print(f"[Subtitle Extractor] Found subtitles in {len(subtitle_frame_dict)} frames") |
|
|
|
|
|
|
|
|
subtitles = [] |
|
|
current_subtitle = None |
|
|
|
|
|
|
|
|
self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) |
|
|
current_frame_no = 0 |
|
|
|
|
|
|
|
|
continuous_ranges = self.detector.find_continuous_ranges_with_same_mask(subtitle_frame_dict) |
|
|
|
|
|
for start_frame, end_frame in continuous_ranges: |
|
|
|
|
|
self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1) |
|
|
ret, frame = self.video_cap.read() |
|
|
|
|
|
if not ret: |
|
|
continue |
|
|
|
|
|
|
|
|
boxes = subtitle_frame_dict.get(start_frame, []) |
|
|
|
|
|
|
|
|
texts = self.extract_text_from_frame(frame, boxes) |
|
|
combined_text = " ".join(texts).strip() |
|
|
|
|
|
if not combined_text: |
|
|
continue |
|
|
|
|
|
|
|
|
start_time = (start_frame - 1) / self.fps |
|
|
end_time = end_frame / self.fps |
|
|
|
|
|
|
|
|
if (current_subtitle and |
|
|
current_subtitle['text'] == combined_text and |
|
|
abs(start_time - current_subtitle['end']) < 1.0): |
|
|
|
|
|
current_subtitle['end'] = end_time |
|
|
else: |
|
|
|
|
|
if current_subtitle: |
|
|
subtitles.append(current_subtitle) |
|
|
|
|
|
|
|
|
current_subtitle = { |
|
|
'start': start_time, |
|
|
'end': end_time, |
|
|
'text': combined_text |
|
|
} |
|
|
|
|
|
if progress_callback: |
|
|
progress = end_frame / self.frame_count |
|
|
progress_callback(progress, f"Extracting subtitles: {len(subtitles)+1} found") |
|
|
|
|
|
|
|
|
if current_subtitle: |
|
|
subtitles.append(current_subtitle) |
|
|
|
|
|
print(f"[Subtitle Extractor] Extracted {len(subtitles)} subtitle segments") |
|
|
return subtitles |
|
|
|
|
|
def generate_srt(self, subtitles, output_path): |
|
|
""" |
|
|
Generate SRT file from subtitles |
|
|
|
|
|
Args: |
|
|
subtitles: List of subtitle dictionaries |
|
|
output_path: Path to save SRT file |
|
|
|
|
|
Returns: |
|
|
Path to generated SRT file |
|
|
""" |
|
|
print(f"[Subtitle Extractor] Generating SRT file: {output_path}") |
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
for i, sub in enumerate(subtitles, 1): |
|
|
|
|
|
f.write(f"{i}\n") |
|
|
|
|
|
|
|
|
start_ts = self.format_timestamp(sub['start']) |
|
|
end_ts = self.format_timestamp(sub['end']) |
|
|
f.write(f"{start_ts} --> {end_ts}\n") |
|
|
|
|
|
|
|
|
f.write(f"{sub['text']}\n") |
|
|
|
|
|
|
|
|
f.write("\n") |
|
|
|
|
|
print(f"[Subtitle Extractor] SRT file saved: {output_path}") |
|
|
return output_path |
|
|
|
|
|
def extract_to_srt(self, output_path=None, progress_callback=None): |
|
|
""" |
|
|
Complete extraction pipeline: detect -> OCR -> generate SRT |
|
|
|
|
|
Args: |
|
|
output_path: Optional custom output path for SRT file |
|
|
progress_callback: Optional callback for progress updates |
|
|
|
|
|
Returns: |
|
|
Path to generated SRT file |
|
|
""" |
|
|
|
|
|
if output_path is None: |
|
|
video_name = Path(self.video_path).stem |
|
|
output_dir = Path(self.video_path).parent |
|
|
output_path = output_dir / f"{video_name}_subtitles.srt" |
|
|
|
|
|
|
|
|
subtitles = self.extract_subtitles(progress_callback) |
|
|
|
|
|
if not subtitles: |
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
f.write("# No subtitles detected\n") |
|
|
return str(output_path) |
|
|
|
|
|
|
|
|
return self.generate_srt(subtitles, str(output_path)) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import sys |
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python subtitle_extractor.py <video_path>") |
|
|
sys.exit(1) |
|
|
|
|
|
video_path = sys.argv[1] |
|
|
extractor = SubtitleExtractor(video_path) |
|
|
srt_path = extractor.extract_to_srt() |
|
|
print(f"Subtitles extracted to: {srt_path}") |
|
|
|