DARKWICK commited on
Commit
5d13512
·
verified ·
1 Parent(s): d3c713e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -210
app.py CHANGED
@@ -1,210 +1,140 @@
1
- import os
2
- import io
3
- import re
4
- import math
5
- import tempfile
6
- from typing import List, Tuple
7
-
8
- import gradio as gr
9
- import numpy as np
10
- from moviepy.editor import VideoFileClip
11
- from transformers import (
12
- pipeline,
13
- AutoModelForSeq2SeqLM,
14
- AutoTokenizer,
15
- M2M100ForConditionalGeneration,
16
- M2M100Tokenizer,
17
- )
18
-
19
- # ---------------------------
20
- # Model choices (balanced for CPU Spaces)
21
- # ---------------------------
22
- ASR_MODEL_ID = "openai/whisper-small" # good balance of quality/speed
23
- SUMM_MODEL_ID = "sshleifer/distilbart-cnn-12-6" # light summarizer
24
- PARA_MODEL_ID = "google/flan-t5-base" # for “modernization” rewrite
25
- TRANS_MODEL_ID = "facebook/m2m100_418M" # many-to-many language translation
26
-
27
- # Preload pipelines/models once (Space warm-up)
28
- asr_pipe = pipeline(
29
- "automatic-speech-recognition",
30
- model=ASR_MODEL_ID,
31
- chunk_length_s=30, # chunking helps long audio on CPU
32
- )
33
-
34
- summ_pipe = pipeline(
35
- "summarization",
36
- model=SUMM_MODEL_ID,
37
- )
38
-
39
- para_tok = AutoTokenizer.from_pretrained(PARA_MODEL_ID)
40
- para_model = AutoModelForSeq2SeqLM.from_pretrained(PARA_MODEL_ID)
41
-
42
- m2m_tok = M2M100Tokenizer.from_pretrained(TRANS_MODEL_ID)
43
- m2m_model = M2M100ForConditionalGeneration.from_pretrained(TRANS_MODEL_ID)
44
-
45
- # Maps for M2M100 language codes (expand as needed)
46
- M2M_LANGS = {
47
- "English": "en",
48
- "Arabic": "ar",
49
- "French": "fr",
50
- "German": "de",
51
- "Hindi": "hi",
52
- "Italian": "it",
53
- "Japanese": "ja",
54
- "Korean": "ko",
55
- "Portuguese": "pt",
56
- "Russian": "ru",
57
- "Spanish": "es",
58
- "Turkish": "tr",
59
- "Urdu": "ur",
60
- "Chinese (simplified)": "zh",
61
- }
62
-
63
- def _extract_audio_to_wav(video_path: str) -> Tuple[str, float]:
64
- """
65
- Extract audio to a temp .wav (mono 16k) using moviepy.
66
- Returns (wav_path, duration_seconds)
67
- """
68
- clip = VideoFileClip(video_path)
69
- duration = clip.duration
70
- # Write temp wav
71
- tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
72
- tmp_wav.close()
73
- clip.audio.write_audiofile(tmp_wav.name, fps=16000, nbytes=2, codec="pcm_s16le", ffmpeg_params=["-ac", "1"])
74
- clip.close()
75
- return tmp_wav.name, float(duration)
76
-
77
- def _clean_text(s: str) -> str:
78
- return re.sub(r"\s+", " ", s).strip()
79
-
80
- def transcribe_video(video_file) -> Tuple[str, float]:
81
- """
82
- Transcribe the uploaded video to text using Whisper pipeline.
83
- """
84
- if video_file is None:
85
- return "", 0.0
86
- wav_path, duration = _extract_audio_to_wav(video_file)
87
- try:
88
- result = asr_pipe(wav_path)
89
- text = result["text"] if isinstance(result, dict) else str(result)
90
- return _clean_text(text), duration
91
- finally:
92
- if os.path.exists(wav_path):
93
- os.remove(wav_path)
94
-
95
- def translate_text_m2m(txt: str, src_code: str, tgt_code: str, max_len=1024) -> str:
96
- """
97
- Translate using M2M100. Handles long text by chunking on sentence boundaries.
98
- """
99
- if not txt.strip():
100
- return ""
101
- chunks = smart_sentence_chunks(txt, max_len=max_len//2) # conservative
102
- outputs = []
103
- for ch in chunks:
104
- m2m_tok.src_lang = src_code
105
- encoded = m2m_tok(ch, return_tensors="pt", truncation=True, max_length=max_len)
106
- generated_tokens = m2m_model.generate(
107
- **encoded,
108
- forced_bos_token_id=m2m_tok.get_lang_id(tgt_code),
109
- max_length=max_len,
110
- num_beams=4,
111
- )
112
- outputs.append(m2m_tok.batch_decode(generated_tokens, skip_special_tokens=True)[0])
113
- return _clean_text(" ".join(outputs))
114
-
115
- def summarize_text(txt: str) -> str:
116
- if not txt.strip():
117
- return ""
118
- # chunk long text to stay under model limits
119
- chunks = smart_sentence_chunks(txt, max_len=900)
120
- out = []
121
- for ch in chunks:
122
- s = summ_pipe(ch, max_length=180, min_length=60, do_sample=False)[0]["summary_text"]
123
- out.append(s)
124
- # if multiple chunks, do a final squeeze
125
- joined = " ".join(out)
126
- if len(joined.split()) > 300:
127
- joined = summ_pipe(joined, max_length=220, min_length=80, do_sample=False)[0]["summary_text"]
128
- return _clean_text(joined)
129
-
130
- def modernize_text(txt: str, style: str = "concise") -> str:
131
- """
132
- Paraphrase / modernize via FLAN-T5 instruction.
133
- """
134
- if not txt.strip():
135
- return ""
136
- prompt = (
137
- "Rewrite the text into modern, clear, and natural language. "
138
- "Preserve meaning and important details. Style: " + style + ".\n\nText:\n" + txt
139
- )
140
- inputs = para_tok(prompt, return_tensors="pt", truncation=True, max_length=2048)
141
- outputs = para_model.generate(**inputs, max_length=512, num_beams=4)
142
- return _clean_text(para_tok.decode(outputs[0], skip_special_tokens=True))
143
-
144
- # ------------ SRT Helpers ------------
145
- def smart_sentence_chunks(text: str, max_len: int = 800) -> List[str]:
146
- """
147
- Split text by sentences with a soft max token length (approx by chars).
148
- """
149
- # crude sentence split
150
- sents = re.split(r'(?<=[.!?])\s+', _clean_text(text))
151
- chunks, cur = [], ""
152
- for s in sents:
153
- if len(cur) + len(s) + 1 <= max_len:
154
- cur = (cur + " " + s).strip()
155
- else:
156
- if cur:
157
- chunks.append(cur)
158
- cur = s
159
- if cur:
160
- chunks.append(cur)
161
- return chunks
162
-
163
- def make_naive_srt(transcript: str, total_seconds: float) -> str:
164
- """
165
- Make a "good enough" SRT by assigning equal time slices per sentence.
166
- Not perfect, but usable when we don't have per-token timestamps.
167
- """
168
- sents = [s for s in re.split(r'(?<=[.!?])\s+', _clean_text(transcript)) if s]
169
- n = max(1, len(sents))
170
- # Avoid too-short windows: min 1.5s per sentence
171
- avg = max(1.5, total_seconds / n) if total_seconds > 0 else 3.0
172
- lines = []
173
- t = 0.0
174
- for i, s in enumerate(sents, start=1):
175
- start = t
176
- end = t + avg
177
- t = end
178
- lines.append(str(i))
179
- lines.append(f"{_fmt_srt_time(start)} --> {_fmt_srt_time(end)}")
180
- lines.append(s)
181
- lines.append("") # blank
182
- return "\n".join(lines).strip()
183
-
184
- def _fmt_srt_time(sec: float) -> str:
185
- sec = max(0.0, sec)
186
- h = int(sec // 3600)
187
- m = int((sec % 3600) // 60)
188
- s = int(sec % 60)
189
- ms = int((sec - math.floor(sec)) * 1000)
190
- return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
191
-
192
- # ------------ Gradio Handlers ------------
193
- def ui_video_translate(video, src_lang, tgt_lang):
194
- if video is None:
195
- return gr.update(value=""), gr.update(value=""), gr.update(value=b"", visible=False), gr.update(value="")
196
- src = M2M_LANGS[src_lang]
197
- tgt = M2M_LANGS[tgt_lang]
198
- transcript, duration = transcribe_video(video)
199
- translated = translate_text_m2m(transcript, src, tgt)
200
- srt_text = make_naive_srt(translated, duration)
201
- # Prepare SRT file for download
202
- srt_bytes = srt_text.encode("utf-8")
203
- srt_file = io.BytesIO(srt_bytes)
204
- srt_file.name = "subtitles_translated.srt"
205
- return transcript, translated, srt_file, srt_text
206
-
207
- def ui_video_summarize(video, lang_hint):
208
- if video is None:
209
- return "", ""
210
- # Transcribe then summarize (lang-hint doesn’t constrain
 
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+ import gradio as gr
5
+ import yt_dlp
6
+ from faster_whisper import WhisperModel
7
+
8
+ # -------- Settings you can tweak --------
9
+ DEFAULT_MODEL = os.getenv("WHISPER_MODEL", "small") # small | medium | large-v3 (requires more RAM)
10
+ COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "int8") # int8 | int8_float16 | float16 | float32
11
+ MAX_DURATION_SEC = int(os.getenv("MAX_DURATION_SEC", "1800")) # 30 min cap to keep things predictable
12
+ # ---------------------------------------
13
+
14
+ # Lazy-load model once per container
15
+ _model = None
16
+ def get_model():
17
+ global _model
18
+ if _model is None:
19
+ _model = WhisperModel(DEFAULT_MODEL, compute_type=COMPUTE_TYPE)
20
+ return _model
21
+
22
+ def _download_youtube_audio(url: str, workdir: str) -> str:
23
+ """
24
+ Download YouTube audio and convert to WAV mono 16 kHz using FFmpegExtractAudio.
25
+ Returns path to the WAV file.
26
+ """
27
+ outtmpl = str(Path(workdir) / "%(id)s.%(ext)s")
28
+ ydl_opts = {
29
+ "format": "bestaudio/best",
30
+ "outtmpl": outtmpl,
31
+ "noplaylist": True,
32
+ "quiet": True,
33
+ "no_warnings": True,
34
+ "postprocessors": [
35
+ {
36
+ "key": "FFmpegExtractAudio",
37
+ "preferredcodec": "wav",
38
+ "preferredquality": "5",
39
+ }
40
+ ],
41
+ # ensure mono @ 16 kHz
42
+ "postprocessor_args": ["-ac", "1", "-ar", "16000"],
43
+ }
44
+
45
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
46
+ info = ydl.extract_info(url, download=True)
47
+ duration = info.get("duration") or 0
48
+ if duration and duration > MAX_DURATION_SEC:
49
+ raise gr.Error(f"Video too long ({duration//60} min). Max allowed is {MAX_DURATION_SEC//60} min.")
50
+
51
+ # Find the produced .wav in the temp dir (name can vary)
52
+ wavs = list(Path(workdir).glob("*.wav"))
53
+ if not wavs:
54
+ raise gr.Error("Audio extraction failed. Try a different video.")
55
+ return str(wavs[0])
56
+
57
+
58
+ def _write_srt(segments, path: str):
59
+ def srt_timestamp(t):
60
+ # t in seconds -> "HH:MM:SS,mmm"
61
+ h = int(t // 3600)
62
+ m = int((t % 3600) // 60)
63
+ s = int(t % 60)
64
+ ms = int((t - int(t)) * 1000)
65
+ return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
66
+
67
+ with open(path, "w", encoding="utf-8") as f:
68
+ for i, seg in enumerate(segments, start=1):
69
+ f.write(f"{i}\n")
70
+ f.write(f"{srt_timestamp(seg.start)} --> {srt_timestamp(seg.end)}\n")
71
+ f.write(seg.text.strip() + "\n\n")
72
+
73
+ def transcribe(youtube_url, upload_file, model_size, language, translate_to_english):
74
+ if not youtube_url and not upload_file:
75
+ raise gr.Error("Provide a YouTube URL or upload a file.")
76
+
77
+ # Update model on-the-fly if user changes it
78
+ global _model
79
+ if _model is None or getattr(_model, "_model_size", None) != model_size:
80
+ _model = WhisperModel(model_size, compute_type=COMPUTE_TYPE)
81
+ _model._model_size = model_size # tag for reuse
82
+
83
+ with tempfile.TemporaryDirectory() as td:
84
+ if youtube_url:
85
+ audio_path = _download_youtube_audio(youtube_url.strip(), td)
86
+ else:
87
+ # Save uploaded file and (optionally) convert via ffmpeg if needed
88
+ src = Path(td) / Path(upload_file.name).name
89
+ with open(src, "wb") as w:
90
+ w.write(upload_file.read())
91
+ # Let faster-whisper/ffmpeg handle decoding directly
92
+ audio_path = str(src)
93
+
94
+ # Transcribe
95
+ segments, info = _model.transcribe(
96
+ audio_path,
97
+ language=None if language == "auto" else language,
98
+ task="translate" if translate_to_english else "transcribe",
99
+ vad_filter=True
100
+ )
101
+
102
+ # Collect text and also write SRT
103
+ segs = list(segments)
104
+ full_text = "".join(s.text for s in segs).strip()
105
+ srt_path = Path(td) / "subtitles.srt"
106
+ _write_srt(segs, srt_path)
107
+ return full_text, str(srt_path)
108
+
109
+ # ---- Gradio UI ----
110
+ with gr.Blocks(title="YouTube → Text (Whisper)") as demo:
111
+ gr.Markdown("## 🎬 YouTube → 📝 Text\nPaste a YouTube link **or** upload a media file to get a transcript.")
112
+ with gr.Row():
113
+ youtube_url = gr.Textbox(label="YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
114
+ with gr.Row():
115
+ upload_file = gr.File(label="Or upload a video/audio file", file_count="single")
116
+ with gr.Row():
117
+ model_size = gr.Dropdown(
118
+ ["small", "medium", "large-v3"],
119
+ value=DEFAULT_MODEL,
120
+ label="Model size (larger = more accurate, slower)"
121
+ )
122
+ language = gr.Dropdown(
123
+ ["auto","en","ar","fr","de","es","hi","ur","fa","ru","zh"],
124
+ value="auto",
125
+ label="Language (auto-detect or force)"
126
+ )
127
+ translate_to_english = gr.Checkbox(value=False, label="Translate to English")
128
+
129
+ run_btn = gr.Button("Transcribe", variant="primary")
130
+ transcript = gr.Textbox(label="Transcript", lines=12)
131
+ srt_file = gr.File(label="Download SRT (subtitles)")
132
+
133
+ run_btn.click(
134
+ transcribe,
135
+ inputs=[youtube_url, upload_file, model_size, language, translate_to_english],
136
+ outputs=[transcript, srt_file]
137
+ )
138
+
139
+ if __name__ == "__main__":
140
+ demo.launch(server_name="0.0.0.0", server_port=7860)