xTHExBEASTx commited on
Commit
0268049
·
verified ·
1 Parent(s): e6b8cce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -129
app.py CHANGED
@@ -1,54 +1,43 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
3
  import srt
4
  import torch
5
  import os
6
- import math
7
  from datetime import timedelta
8
  import subprocess
9
  import re
10
 
11
  # --- Configuration ---
 
12
  TRANSLATION_MODEL = "facebook/nllb-200-distilled-1.3B"
13
- WHISPER_MODEL = "openai/whisper-small"
 
 
 
14
 
15
  print("Loading Models...")
16
 
17
- # --- Load Translation Model ---
18
  tokenizer_nllb = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)
19
  model_nllb = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL)
20
 
21
- # --- Load Audio Model ---
22
- whisper_pipe = pipeline(
23
- "automatic-speech-recognition",
24
- model=WHISPER_MODEL,
25
- torch_dtype=torch.float32,
26
- device="cpu",
27
- chunk_length_s=30,
28
- stride_length_s=5,
29
- )
30
 
31
  print("Models Loaded Successfully!")
32
 
33
  # ---------------------------------------------------------
34
- # Helper: Extract Audio & Duration
35
  # ---------------------------------------------------------
36
- def get_media_duration(filename):
37
- try:
38
- result = subprocess.run(
39
- ["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename],
40
- stdout=subprocess.PIPE,
41
- stderr=subprocess.STDOUT
42
- )
43
- return float(result.stdout)
44
- except:
45
- return 30.0
46
-
47
  def extract_audio(video_path):
48
  output_audio_path = "temp_audio.mp3"
49
  if os.path.exists(output_audio_path):
50
  os.remove(output_audio_path)
51
 
 
52
  command = [
53
  "ffmpeg", "-i", video_path,
54
  "-vn", "-acodec", "libmp3lame",
@@ -66,85 +55,83 @@ def srt_to_vtt(srt_path):
66
  with open(srt_path, 'r', encoding='utf-8') as f:
67
  content = f.read()
68
 
69
- # VTT Header
70
  vtt_content = "WEBVTT\n\n"
71
-
72
- # Replace comma timestamps (00:00:01,000) with dot (00:00:01.000)
73
  vtt_content += re.sub(r'(\d{2}:\d{2}:\d{2}),(\d{3})', r'\1.\2', content)
74
 
75
  with open(vtt_path, 'w', encoding='utf-8') as f:
76
  f.write(vtt_content)
77
-
78
  return vtt_path
79
 
80
  # ---------------------------------------------------------
81
- # Helper: Smart Splitter logic
82
  # ---------------------------------------------------------
83
- def split_text_into_lines(text, max_chars=80):
84
- words = text.split()
85
- lines = []
86
- current_line = []
87
- current_length = 0
88
-
89
- for word in words:
90
- if current_length + len(word) + 1 > max_chars:
91
- lines.append(" ".join(current_line))
92
- current_line = [word]
93
- current_length = len(word)
94
- else:
95
- current_line.append(word)
96
- current_length += len(word) + 1
97
-
98
- if current_line:
99
- lines.append(" ".join(current_line))
100
- return lines
101
-
102
- def create_srt_segments(chunks, total_video_duration):
103
  srt_subtitles = []
104
- index_counter = 1
105
-
106
- for chunk in chunks:
107
- text = chunk['text'].strip()
108
- timestamp = chunk['timestamp']
109
-
110
- if isinstance(timestamp, (list, tuple)):
111
- start_time, end_time = timestamp
112
- else:
113
- start_time, end_time = 0.0, None
114
-
115
- if end_time is None:
116
- end_time = total_video_duration
117
-
118
- lines = split_text_into_lines(text, max_chars=80)
119
- duration = end_time - start_time
120
- if duration <= 0: duration = 5.0
121
-
122
- step = duration / len(lines) if lines else 0
123
- current_start = start_time
124
 
125
- for line in lines:
126
- current_end = current_start + step
127
- srt_subtitles.append(
128
- srt.Subtitle(index=index_counter, start=timedelta(seconds=current_start), end=timedelta(seconds=current_end), content=line)
 
 
129
  )
130
- index_counter += 1
131
- current_start = current_end
132
-
133
- return srt_subtitles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # ---------------------------------------------------------
136
- # Logic 1: Translation (NLLB)
137
  # ---------------------------------------------------------
138
- def batch_translate(texts, src_lang, tgt_lang, batch_size=8, progress=gr.Progress()):
139
  results = []
140
  tokenizer_nllb.src_lang = src_lang
141
 
142
- for i, start_idx in enumerate(range(0, len(texts), batch_size)):
143
- batch = texts[start_idx : start_idx + batch_size]
144
  inputs = tokenizer_nllb(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
145
  forced_bos_token_id = tokenizer_nllb.convert_tokens_to_ids(tgt_lang)
 
146
  with torch.no_grad():
147
  generated_tokens = model_nllb.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512)
 
148
  results.extend(tokenizer_nllb.batch_decode(generated_tokens, skip_special_tokens=True))
149
  return results
150
 
@@ -165,58 +152,13 @@ def process_translation(filepath, src_lang_code, tgt_lang_code):
165
  out_path = "translated_subtitles.srt"
166
  with open(out_path, 'w', encoding='utf-8') as f:
167
  f.write(srt.compose(subtitles))
168
-
169
  return out_path
170
 
171
- # ---------------------------------------------------------
172
- # Logic 2: Video to SRT + Preview
173
- # ---------------------------------------------------------
174
- def video_to_srt(video_path, progress=gr.Progress()):
175
- if video_path is None: return None, None
176
-
177
- # 1. Audio & Duration
178
- progress(0.1, desc="Extracting Audio...")
179
- try:
180
- audio_path = extract_audio(video_path)
181
- duration = get_media_duration(audio_path)
182
- except Exception as e:
183
- return None, f"Error: {str(e)}"
184
-
185
- # 2. Transcribe
186
- progress(0.3, desc="Transcribing...")
187
- outputs = whisper_pipe(audio_path, return_timestamps=True, generate_kwargs={"language": "english"})
188
- chunks = outputs.get("chunks", [])
189
- if not chunks: chunks = [{"text": outputs.get("text", ""), "timestamp": (0.0, None)}]
190
-
191
- # 3. Format SRT
192
- progress(0.8, desc="Formatting...")
193
- srt_subtitles = create_srt_segments(chunks, duration)
194
-
195
- srt_path = "generated_captions.srt"
196
- with open(srt_path, 'w', encoding='utf-8') as f:
197
- f.write(srt.compose(srt_subtitles))
198
-
199
- # 4. Create Preview (HTML + VTT)
200
- vtt_path = srt_to_vtt(srt_path)
201
-
202
- # Create the HTML player
203
- html_preview = f"""
204
- <h3>Video Preview</h3>
205
- <video controls width="100%" height="400px" style="background:black">
206
- <source src="/file={video_path}" type="video/mp4">
207
- <track kind="captions" src="/file={vtt_path}" srclang="en" label="English" default>
208
- Your browser does not support the video tag.
209
- </video>
210
- <p style="margin-top:10px; color: #666;">Note: Subtitles are overlaid for preview only. They are not burned into the video.</p>
211
- """
212
-
213
- return srt_path, html_preview
214
-
215
  # ---------------------------------------------------------
216
  # Gradio Interface
217
  # ---------------------------------------------------------
218
  with gr.Blocks(title="SRT Master Tool") as demo:
219
- gr.Markdown("# 🎬 Auto Subtitle & Translator")
220
 
221
  with gr.Tabs():
222
  # --- TAB 1 ---
@@ -224,7 +166,6 @@ with gr.Blocks(title="SRT Master Tool") as demo:
224
  gr.Markdown("### 1. Upload Video -> 2. Check Preview -> 3. Download SRT")
225
  with gr.Row():
226
  video_input = gr.Video(label="Upload Video", sources=["upload"])
227
-
228
  with gr.Column():
229
  preview_output = gr.HTML(label="Preview Player")
230
  srt_output_gen = gr.File(label="Download Generated SRT")
 
1
  import gradio as gr
2
+ import whisper
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import srt
5
  import torch
6
  import os
 
7
  from datetime import timedelta
8
  import subprocess
9
  import re
10
 
11
  # --- Configuration ---
12
+ # Translation Model (NLLB)
13
  TRANSLATION_MODEL = "facebook/nllb-200-distilled-1.3B"
14
+
15
+ # Whisper Model Size: "medium" is the best balance for CPU.
16
+ # You can change to "large" or "large-v3" but it will be 2x slower.
17
+ WHISPER_MODEL_SIZE = "medium"
18
 
19
  print("Loading Models...")
20
 
21
+ # --- Load Translation Model (NLLB) ---
22
  tokenizer_nllb = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)
23
  model_nllb = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL)
24
 
25
+ # --- Load Audio Model (Official OpenAI Whisper) ---
26
+ # This downloads the model to the container
27
+ print(f"Loading Whisper '{WHISPER_MODEL_SIZE}' model...")
28
+ whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device="cpu")
 
 
 
 
 
29
 
30
  print("Models Loaded Successfully!")
31
 
32
  # ---------------------------------------------------------
33
+ # Helper: Extract Audio
34
  # ---------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
35
  def extract_audio(video_path):
36
  output_audio_path = "temp_audio.mp3"
37
  if os.path.exists(output_audio_path):
38
  os.remove(output_audio_path)
39
 
40
+ # Simple FFMPEG extraction
41
  command = [
42
  "ffmpeg", "-i", video_path,
43
  "-vn", "-acodec", "libmp3lame",
 
55
  with open(srt_path, 'r', encoding='utf-8') as f:
56
  content = f.read()
57
 
 
58
  vtt_content = "WEBVTT\n\n"
59
+ # Regex to convert SRT comma timestamps to VTT dot timestamps
 
60
  vtt_content += re.sub(r'(\d{2}:\d{2}:\d{2}),(\d{3})', r'\1.\2', content)
61
 
62
  with open(vtt_path, 'w', encoding='utf-8') as f:
63
  f.write(vtt_content)
 
64
  return vtt_path
65
 
66
  # ---------------------------------------------------------
67
+ # Logic 1: Video to SRT (Using Native Whisper)
68
  # ---------------------------------------------------------
69
+ def video_to_srt(video_path, progress=gr.Progress()):
70
+ if video_path is None: return None, None
71
+
72
+ # 1. Extract Audio
73
+ progress(0.1, desc="Extracting Audio...")
74
+ try:
75
+ audio_path = extract_audio(video_path)
76
+ except Exception as e:
77
+ return None, f"Error: {str(e)}"
78
+
79
+ # 2. Transcribe using Native Whisper
80
+ progress(0.3, desc=f"Transcribing with Whisper {WHISPER_MODEL_SIZE}...")
81
+
82
+ # The native transcribe function handles segmentation automatically!
83
+ result = whisper_model.transcribe(audio_path, language="en")
84
+
85
+ # 3. Format to SRT
86
+ progress(0.8, desc="Formatting SRT...")
 
 
87
  srt_subtitles = []
88
+
89
+ for i, segment in enumerate(result["segments"]):
90
+ start_seconds = segment["start"]
91
+ end_seconds = segment["end"]
92
+ text = segment["text"].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ srt_subtitles.append(
95
+ srt.Subtitle(
96
+ index=i+1,
97
+ start=timedelta(seconds=start_seconds),
98
+ end=timedelta(seconds=end_seconds),
99
+ content=text
100
  )
101
+ )
102
+
103
+ srt_path = "generated_captions.srt"
104
+ with open(srt_path, 'w', encoding='utf-8') as f:
105
+ f.write(srt.compose(srt_subtitles))
106
+
107
+ # 4. Create Preview
108
+ vtt_path = srt_to_vtt(srt_path)
109
+
110
+ html_preview = f"""
111
+ <h3>Video Preview</h3>
112
+ <video controls width="100%" height="400px" style="background:black">
113
+ <source src="/file={video_path}" type="video/mp4">
114
+ <track kind="captions" src="/file={vtt_path}" srclang="en" label="English" default>
115
+ Your browser does not support the video tag.
116
+ </video>
117
+ """
118
+ return srt_path, html_preview
119
 
120
  # ---------------------------------------------------------
121
+ # Logic 2: Translation (NLLB)
122
  # ---------------------------------------------------------
123
+ def batch_translate(texts, src_lang, tgt_lang, batch_size=8):
124
  results = []
125
  tokenizer_nllb.src_lang = src_lang
126
 
127
+ for i in range(0, len(texts), batch_size):
128
+ batch = texts[i : i + batch_size]
129
  inputs = tokenizer_nllb(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
130
  forced_bos_token_id = tokenizer_nllb.convert_tokens_to_ids(tgt_lang)
131
+
132
  with torch.no_grad():
133
  generated_tokens = model_nllb.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512)
134
+
135
  results.extend(tokenizer_nllb.batch_decode(generated_tokens, skip_special_tokens=True))
136
  return results
137
 
 
152
  out_path = "translated_subtitles.srt"
153
  with open(out_path, 'w', encoding='utf-8') as f:
154
  f.write(srt.compose(subtitles))
 
155
  return out_path
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  # ---------------------------------------------------------
158
  # Gradio Interface
159
  # ---------------------------------------------------------
160
  with gr.Blocks(title="SRT Master Tool") as demo:
161
+ gr.Markdown(f"# 🎬 Auto Subtitle (Whisper {WHISPER_MODEL_SIZE}) & Translator")
162
 
163
  with gr.Tabs():
164
  # --- TAB 1 ---
 
166
  gr.Markdown("### 1. Upload Video -> 2. Check Preview -> 3. Download SRT")
167
  with gr.Row():
168
  video_input = gr.Video(label="Upload Video", sources=["upload"])
 
169
  with gr.Column():
170
  preview_output = gr.HTML(label="Preview Player")
171
  srt_output_gen = gr.File(label="Download Generated SRT")