kirbah commited on
Commit
4fa749b
·
verified ·
1 Parent(s): 2ca35b4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -0
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import logging
4
+ from typing import Any, Dict
5
+
6
+ import torch
7
+ import yt_dlp
8
+ import gradio as gr
9
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
10
+ from huggingface_hub import login, InferenceClient
11
+
12
+ # Set up basic logging.
13
+ logging.basicConfig(level=logging.INFO)
14
+
15
+ # -------------------------------
16
+ # Download Audio from Video URL
17
+ # -------------------------------
18
+ def download_audio(url: str) -> str:
19
+ """
20
+ Download audio from a video URL and convert it to MP3 format.
21
+ """
22
+ ydl_opts = {
23
+ 'format': 'bestaudio/best',
24
+ 'postprocessors': [{
25
+ 'key': 'FFmpegExtractAudio',
26
+ 'preferredcodec': 'mp3',
27
+ 'preferredquality': '192',
28
+ }],
29
+ }
30
+ try:
31
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
32
+ info = ydl.extract_info(url, download=True)
33
+ audio_file = ydl.prepare_filename(info)
34
+ if not audio_file.endswith('.mp3'):
35
+ audio_file = audio_file.rsplit('.', 1)[0] + '.mp3'
36
+ logging.info("Audio downloaded successfully: %s", audio_file)
37
+ return audio_file
38
+ except Exception as e:
39
+ logging.error("Error downloading audio: %s", e)
40
+ raise RuntimeError("Audio download failed") from e
41
+
42
+ # ---------------------------------------
43
+ # Set Up Speech Recognition Model & Pipe
44
+ # ---------------------------------------
45
+ if torch.cuda.is_available():
46
+ model_device = "cuda"
47
+ pipeline_device = 0 # GPU device index for Hugging Face pipeline.
48
+ torch_dtype = torch.float16
49
+ speech_model_id = "openai/whisper-large-v3-turbo"
50
+ batch_size = 16
51
+ else:
52
+ model_device = "cpu"
53
+ pipeline_device = -1 # CPU for pipeline.
54
+ torch_dtype = torch.float32
55
+ speech_model_id = "openai/whisper-tiny"
56
+ batch_size = 2
57
+
58
+ try:
59
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
60
+ speech_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
61
+ )
62
+ model.to(model_device)
63
+ processor = AutoProcessor.from_pretrained(speech_model_id)
64
+ except Exception as e:
65
+ logging.error("Error loading the speech model: %s", e)
66
+ raise
67
+
68
+ pipe = pipeline(
69
+ "automatic-speech-recognition",
70
+ model=model,
71
+ tokenizer=processor.tokenizer,
72
+ feature_extractor=processor.feature_extractor,
73
+ torch_dtype=torch_dtype,
74
+ device=pipeline_device,
75
+ )
76
+
77
+ # --------------------------------------
78
+ # Transcription and SRT Conversion
79
+ # --------------------------------------
80
+ def transcribe_audio(audio_path: str, batch_size: int) -> Dict[str, Any]:
81
+ """
82
+ Transcribe the audio file using the configured pipeline.
83
+ """
84
+ try:
85
+ result = pipe(
86
+ audio_path,
87
+ chunk_length_s=10,
88
+ stride_length_s=(4, 2),
89
+ batch_size=batch_size,
90
+ return_timestamps=True,
91
+ )
92
+ return result
93
+ except Exception as e:
94
+ logging.error("Error during transcription: %s", e)
95
+ raise
96
+
97
+ def seconds_to_srt_time(seconds: float) -> str:
98
+ """
99
+ Convert seconds to SRT time format (HH:MM:SS,mmm).
100
+ """
101
+ if seconds is None or not isinstance(seconds, (int, float)):
102
+ return "00:00:00,000"
103
+ hours = int(seconds // 3600)
104
+ minutes = int((seconds % 3600) // 60)
105
+ secs = int(seconds % 60)
106
+ millis = int((seconds - int(seconds)) * 1000)
107
+ return f"{hours:02}:{minutes:02}:{secs:02},{millis:03}"
108
+
109
+ def convert_to_srt(transcribed: Dict[str, Any]) -> str:
110
+ """
111
+ Convert transcription chunks into SRT format.
112
+ """
113
+ srt_output = []
114
+ if "chunks" in transcribed:
115
+ for i, chunk in enumerate(transcribed["chunks"], start=1):
116
+ if chunk.get("timestamp") is not None:
117
+ start_time = seconds_to_srt_time(chunk["timestamp"][0])
118
+ end_time = seconds_to_srt_time(chunk["timestamp"][1])
119
+ srt_output.append(f"{i}\n{start_time} --> {end_time}\n{chunk['text']}\n")
120
+ else:
121
+ srt_output.append(f"{i}\n{chunk['text']}\n")
122
+ return "\n".join(srt_output)
123
+ else:
124
+ logging.warning("No chunks found; returning plain text.")
125
+ return transcribed.get("text", "")
126
+
127
+ # ------------------------------
128
+ # Hugging Face Login Adjustment
129
+ # ------------------------------
130
+ def hf_login() -> None:
131
+ """
132
+ Log in to Hugging Face using the token from environment variables.
133
+ """
134
+ huggingface_api_token = os.environ.get('HF_TOKEN')
135
+ if not huggingface_api_token:
136
+ raise ValueError("HF_TOKEN not set in environment variables.")
137
+ login(token=huggingface_api_token)
138
+ logging.info("Logged in to Hugging Face successfully.")
139
+
140
+ # Log in once (this can be done at startup)
141
+ hf_login()
142
+
143
+ # -------------------------------------------
144
+ # Generate Video Chapters from the Transcript
145
+ # -------------------------------------------
146
+ def generate_chapters(srt_text: str) -> str:
147
+ """
148
+ Generate video chapters from the SRT transcript using a text generation model.
149
+ """
150
+ chapter_model_id = "Qwen/Qwen2.5-Coder-32B-Instruct" # or another model if desired
151
+ client = InferenceClient(model=chapter_model_id)
152
+
153
+ prompt = (
154
+ "Based on the following video transcript, generate a numbered list of concise, SEO-friendly video chapters with timestamps. "
155
+ "Keep related parts together to limit the number of chapters (up to 5-10 chapters). "
156
+ "Each chapter should be in the format '<timestamp> <chapter title>', where the first chapter starts at 0:00. "
157
+ "Timestamps should be in the format 'm:ss' as needed. For example:\n\n"
158
+ "0:00 Intro\n"
159
+ "1:34 Why the GPT wrapper is bad\n"
160
+ "2:14 Smart users workflow\n\n"
161
+ "Only output the chapters list in the provided format. Stop after one list.\n"
162
+ "Transcript:\n"
163
+ f"{srt_text}\n\n"
164
+ "Chapters:"
165
+ )
166
+
167
+ generation_parameters = {
168
+ "max_new_tokens": 300,
169
+ "temperature": 0.5,
170
+ "top_p": 0.95,
171
+ "do_sample": True,
172
+ }
173
+
174
+ try:
175
+ generated_text = client.text_generation(prompt, **generation_parameters)
176
+ return generated_text
177
+ except Exception as e:
178
+ logging.error("Error generating chapters: %s", e)
179
+ raise
180
+
181
+ # -------------------------------------------
182
+ # Main Processing Function for Gradio UI
183
+ # -------------------------------------------
184
+ def process_video(video_url: str):
185
+ # Download audio from the provided URL.
186
+ audio_file = download_audio(video_url)
187
+ logging.info("Audio file saved as: %s", audio_file)
188
+
189
+ # Transcribe the audio.
190
+ transcribed_text = transcribe_audio(audio_file, batch_size)
191
+
192
+ # Clean up memory.
193
+ gc.collect()
194
+ if torch.cuda.is_available():
195
+ torch.cuda.empty_cache()
196
+
197
+ # Convert transcription to SRT format.
198
+ srt_text = convert_to_srt(transcribed_text)
199
+
200
+ # Generate chapters from the SRT.
201
+ chapters = generate_chapters(srt_text)
202
+
203
+ return srt_text, chapters
204
+
205
+ # -------------------------------------------
206
+ # Gradio Interface Definition
207
+ # -------------------------------------------
208
+ with gr.Blocks() as demo:
209
+ gr.Markdown("# Video Chapter Generator")
210
+
211
+ with gr.Row():
212
+ video_url_input = gr.Textbox(label="Video URL", placeholder="Enter video URL here", lines=1)
213
+
214
+ with gr.Row():
215
+ process_button = gr.Button("Process Video")
216
+
217
+ with gr.Row():
218
+ srt_output = gr.Textbox(label="SRT Transcript", interactive=False, lines=15, show_copy_button=True)
219
+
220
+ with gr.Row():
221
+ chapters_output = gr.Textbox(label="Generated Chapters", interactive=False, lines=10, show_copy_button=True)
222
+
223
+ process_button.click(fn=process_video, inputs=video_url_input, outputs=[srt_output, chapters_output])
224
+
225
+ # Launch the Gradio app
226
+ demo.launch()