xTHExBEASTx commited on
Commit
0a1e2fb
·
verified ·
1 Parent(s): 82d594e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -78
app.py CHANGED
@@ -5,11 +5,13 @@ import torch
5
  import os
6
  import math
7
  from datetime import timedelta
8
- import subprocess # Needed to run FFMPEG commands
9
 
10
  # --- Configuration ---
11
  TRANSLATION_MODEL = "facebook/nllb-200-distilled-1.3B"
12
- WHISPER_MODEL = "distil-whisper/distil-large-v3"
 
 
13
 
14
  print("Loading Models...")
15
 
@@ -30,30 +32,89 @@ whisper_pipe = pipeline(
30
  print("Models Loaded Successfully!")
31
 
32
  # ---------------------------------------------------------
33
- # Helper: Extract Audio from Video
34
  # ---------------------------------------------------------
35
  def extract_audio(video_path):
36
- """
37
- Converts video to mp3 using ffmpeg.
38
- Returns the path to the generated audio file.
39
- """
40
  output_audio_path = "temp_audio.mp3"
41
-
42
- # Check if previous temp file exists and remove it
43
  if os.path.exists(output_audio_path):
44
  os.remove(output_audio_path)
45
-
46
- # Run ffmpeg command to extract audio
47
- # -i: input, -vn: no video, -acodec: audio codec, -y: overwrite
48
  command = [
49
  "ffmpeg", "-i", video_path,
50
  "-vn", "-acodec", "libmp3lame",
51
  "-y", output_audio_path
52
  ]
53
-
54
  subprocess.run(command, check=True)
55
  return output_audio_path
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # ---------------------------------------------------------
58
  # Logic 1: Translation (NLLB)
59
  # ---------------------------------------------------------
@@ -61,31 +122,22 @@ def batch_translate(texts, src_lang, tgt_lang, batch_size=8, progress=gr.Progres
61
  results = []
62
  tokenizer_nllb.src_lang = src_lang
63
 
64
- total_batches = (len(texts) + batch_size - 1) // batch_size
65
-
66
  for i, start_idx in enumerate(range(0, len(texts), batch_size)):
67
  batch = texts[start_idx : start_idx + batch_size]
68
-
69
  inputs = tokenizer_nllb(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
70
  forced_bos_token_id = tokenizer_nllb.convert_tokens_to_ids(tgt_lang)
71
 
72
  with torch.no_grad():
73
- generated_tokens = model_nllb.generate(
74
- **inputs,
75
- forced_bos_token_id=forced_bos_token_id,
76
- max_length=512
77
- )
78
 
79
- batch_results = tokenizer_nllb.batch_decode(generated_tokens, skip_special_tokens=True)
80
- results.extend(batch_results)
81
  return results
82
 
83
  def process_translation(filepath, src_lang_code, tgt_lang_code):
84
  if filepath is None: return None
85
  try:
86
  with open(filepath, 'r', encoding='utf-8') as f:
87
- content = f.read()
88
- subtitles = list(srt.parse(content))
89
  except Exception as e:
90
  return f"Error: {str(e)}"
91
 
@@ -106,15 +158,14 @@ def process_translation(filepath, src_lang_code, tgt_lang_code):
106
  def video_to_srt(video_path, progress=gr.Progress()):
107
  if video_path is None: return None
108
 
109
- # 1. Extract Audio First (Fixes the ValueError)
110
- progress(0.1, desc="Extracting Audio from Video...")
111
  try:
112
  audio_path = extract_audio(video_path)
113
  except Exception as e:
114
  return f"Error extracting audio: {str(e)}"
115
 
116
- # 2. Run Transcription on the Audio file
117
- progress(0.3, desc="Transcribing Audio (this may take a while)...")
118
  outputs = whisper_pipe(audio_path, return_timestamps=True, generate_kwargs={"language": "english"})
119
 
120
  chunks = outputs.get("chunks", [])
@@ -123,24 +174,8 @@ def video_to_srt(video_path, progress=gr.Progress()):
123
 
124
  progress(0.8, desc="Formatting SRT...")
125
 
126
- srt_subtitles = []
127
- for i, chunk in enumerate(chunks):
128
- text = chunk['text'].strip()
129
- timestamp = chunk['timestamp']
130
-
131
- # Handle cases where timestamp might be None or single value
132
- if isinstance(timestamp, (list, tuple)):
133
- start, end = timestamp
134
- else:
135
- start = 0.0
136
- end = 5.0
137
-
138
- if end is None:
139
- end = start + 5.0
140
-
141
- srt_subtitles.append(
142
- srt.Subtitle(index=i+1, start=timedelta(seconds=start), end=timedelta(seconds=end), content=text)
143
- )
144
 
145
  out_path = "generated_captions.srt"
146
  with open(out_path, 'w', encoding='utf-8') as f:
@@ -151,45 +186,28 @@ def video_to_srt(video_path, progress=gr.Progress()):
151
  # ---------------------------------------------------------
152
  # Gradio Interface
153
  # ---------------------------------------------------------
154
- with gr.Blocks(title="The Ultimate Subtitler") as demo:
155
- gr.Markdown("# 🎥 The Ultimate Subtitle Tool")
156
 
157
  with gr.Tabs():
158
- # Tab 1: Video to SRT
159
- with gr.TabItem("Step 1: Video to SRT (Whisper)"):
160
- gr.Markdown("### Upload a video to generate English captions")
161
  with gr.Row():
162
  video_input = gr.Video(label="Upload Video")
163
- srt_output_gen = gr.File(label="Generated English SRT")
164
-
165
- gen_btn = gr.Button("Generate SRT", variant="primary")
166
- gen_btn.click(video_to_srt, inputs=video_input, outputs=srt_output_gen)
167
 
168
- # Tab 2: Translate SRT
169
- with gr.TabItem("Step 2: Translate SRT (NLLB)"):
170
- gr.Markdown("### Translate any SRT file")
171
-
172
  with gr.Row():
173
- srt_input = gr.File(label="Upload SRT File")
174
-
175
  with gr.Column():
176
- src_lang = gr.Dropdown(
177
- ["eng_Latn", "spa_Latn", "fra_Latn", "deu_Latn"],
178
- label="Source Language", value="eng_Latn"
179
- )
180
- tgt_lang = gr.Dropdown(
181
- ["arb_Arab", "arz_Arab", "eng_Latn", "fra_Latn"],
182
- label="Target Language", value="arb_Arab"
183
- )
184
-
185
  srt_output_trans = gr.File(label="Translated SRT")
186
-
187
- trans_btn = gr.Button("Translate", variant="primary")
188
- trans_btn.click(
189
- process_translation,
190
- inputs=[srt_input, src_lang, tgt_lang],
191
- outputs=srt_output_trans
192
- )
193
 
194
  if __name__ == "__main__":
195
  demo.launch()
 
5
  import os
6
  import math
7
  from datetime import timedelta
8
+ import subprocess
9
 
10
  # --- Configuration ---
11
  TRANSLATION_MODEL = "facebook/nllb-200-distilled-1.3B"
12
+ # We use OpenAI's original small model for better segmentation on CPU
13
+ # It is often better at splitting sentences than Distil-Large for subtitles
14
+ WHISPER_MODEL = "openai/whisper-small"
15
 
16
  print("Loading Models...")
17
 
 
32
  print("Models Loaded Successfully!")
33
 
34
  # ---------------------------------------------------------
35
+ # Helper: Extract Audio
36
  # ---------------------------------------------------------
37
  def extract_audio(video_path):
 
 
 
 
38
  output_audio_path = "temp_audio.mp3"
 
 
39
  if os.path.exists(output_audio_path):
40
  os.remove(output_audio_path)
41
+
 
 
42
  command = [
43
  "ffmpeg", "-i", video_path,
44
  "-vn", "-acodec", "libmp3lame",
45
  "-y", output_audio_path
46
  ]
 
47
  subprocess.run(command, check=True)
48
  return output_audio_path
49
 
50
+ # ---------------------------------------------------------
51
+ # Helper: Smart SRT Splitter (The Fix!)
52
+ # ---------------------------------------------------------
53
+ def split_text_into_lines(text, max_chars=80):
54
+ """Breaks long text into smaller lines based on character limit."""
55
+ words = text.split()
56
+ lines = []
57
+ current_line = []
58
+ current_length = 0
59
+
60
+ for word in words:
61
+ if current_length + len(word) + 1 > max_chars:
62
+ lines.append(" ".join(current_line))
63
+ current_line = [word]
64
+ current_length = len(word)
65
+ else:
66
+ current_line.append(word)
67
+ current_length += len(word) + 1
68
+
69
+ if current_line:
70
+ lines.append(" ".join(current_line))
71
+ return lines
72
+
73
+ def create_srt_segments(chunks):
74
+ """
75
+ Takes raw Whisper chunks and breaks them down into clean SRT subtitles.
76
+ Distributes time proportionally if a chunk is split into multiple lines.
77
+ """
78
+ srt_subtitles = []
79
+ index_counter = 1
80
+
81
+ for chunk in chunks:
82
+ text = chunk['text'].strip()
83
+ timestamp = chunk['timestamp']
84
+
85
+ # Safe unpacking of timestamps
86
+ if isinstance(timestamp, (list, tuple)):
87
+ start_time, end_time = timestamp
88
+ else:
89
+ continue # Skip bad chunks
90
+
91
+ if end_time is None: end_time = start_time + 5.0
92
+
93
+ # Smart Split: If text is too long (>80 chars), split it
94
+ lines = split_text_into_lines(text, max_chars=80)
95
+
96
+ # Calculate duration per line (Proportional split)
97
+ total_duration = end_time - start_time
98
+ duration_per_line = total_duration / len(lines) if lines else 0
99
+
100
+ current_start = start_time
101
+
102
+ for line in lines:
103
+ current_end = current_start + duration_per_line
104
+
105
+ srt_subtitles.append(
106
+ srt.Subtitle(
107
+ index=index_counter,
108
+ start=timedelta(seconds=current_start),
109
+ end=timedelta(seconds=current_end),
110
+ content=line
111
+ )
112
+ )
113
+ index_counter += 1
114
+ current_start = current_end # Next line starts where this one ended
115
+
116
+ return srt_subtitles
117
+
118
  # ---------------------------------------------------------
119
  # Logic 1: Translation (NLLB)
120
  # ---------------------------------------------------------
 
122
  results = []
123
  tokenizer_nllb.src_lang = src_lang
124
 
 
 
125
  for i, start_idx in enumerate(range(0, len(texts), batch_size)):
126
  batch = texts[start_idx : start_idx + batch_size]
 
127
  inputs = tokenizer_nllb(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
128
  forced_bos_token_id = tokenizer_nllb.convert_tokens_to_ids(tgt_lang)
129
 
130
  with torch.no_grad():
131
+ generated_tokens = model_nllb.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512)
 
 
 
 
132
 
133
+ results.extend(tokenizer_nllb.batch_decode(generated_tokens, skip_special_tokens=True))
 
134
  return results
135
 
136
  def process_translation(filepath, src_lang_code, tgt_lang_code):
137
  if filepath is None: return None
138
  try:
139
  with open(filepath, 'r', encoding='utf-8') as f:
140
+ subtitles = list(srt.parse(f.read()))
 
141
  except Exception as e:
142
  return f"Error: {str(e)}"
143
 
 
158
  def video_to_srt(video_path, progress=gr.Progress()):
159
  if video_path is None: return None
160
 
161
+ progress(0.1, desc="Extracting Audio...")
 
162
  try:
163
  audio_path = extract_audio(video_path)
164
  except Exception as e:
165
  return f"Error extracting audio: {str(e)}"
166
 
167
+ progress(0.3, desc="Transcribing...")
168
+ # We enable return_timestamps=True to get segment-level timing
169
  outputs = whisper_pipe(audio_path, return_timestamps=True, generate_kwargs={"language": "english"})
170
 
171
  chunks = outputs.get("chunks", [])
 
174
 
175
  progress(0.8, desc="Formatting SRT...")
176
 
177
+ # Use the new Smart Splitter function
178
+ srt_subtitles = create_srt_segments(chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  out_path = "generated_captions.srt"
181
  with open(out_path, 'w', encoding='utf-8') as f:
 
186
  # ---------------------------------------------------------
187
  # Gradio Interface
188
  # ---------------------------------------------------------
189
+ with gr.Blocks(title="SRT Master Tool") as demo:
190
+ gr.Markdown("# 🎬 Auto Subtitle & Translator")
191
 
192
  with gr.Tabs():
193
+ with gr.TabItem("Step 1: Video to SRT"):
194
+ gr.Markdown("### Convert Video to English Subtitles")
 
195
  with gr.Row():
196
  video_input = gr.Video(label="Upload Video")
197
+ srt_output_gen = gr.File(label="Generated SRT")
198
+ btn1 = gr.Button("Generate SRT", variant="primary")
199
+ btn1.click(video_to_srt, inputs=video_input, outputs=srt_output_gen)
 
200
 
201
+ with gr.TabItem("Step 2: Translate SRT"):
202
+ gr.Markdown("### Translate Subtitles to Arabic")
 
 
203
  with gr.Row():
204
+ srt_input = gr.File(label="Upload SRT")
 
205
  with gr.Column():
206
+ src_l = gr.Dropdown(["eng_Latn", "fra_Latn"], label="From", value="eng_Latn")
207
+ tgt_l = gr.Dropdown(["arb_Arab", "arz_Arab"], label="To", value="arb_Arab")
 
 
 
 
 
 
 
208
  srt_output_trans = gr.File(label="Translated SRT")
209
+ btn2 = gr.Button("Translate", variant="primary")
210
+ btn2.click(process_translation, inputs=[srt_input, src_l, tgt_l], outputs=srt_output_trans)
 
 
 
 
 
211
 
212
  if __name__ == "__main__":
213
  demo.launch()