moFouad1 commited on
Commit
24c3abb
·
verified ·
1 Parent(s): 45c5e8e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -0
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import ffmpeg
5
+ import yt_dlp
6
+ import torchaudio
7
+ import gradio as gr
8
+ import shutil
9
+
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript, VideoUnavailable
12
+ from youtube_transcript_api.formatters import TextFormatter
13
+ from transformers import (
14
+ pipeline,
15
+ WhisperProcessor,
16
+ WhisperForConditionalGeneration,
17
+ )
18
+
19
+ from fastapi import FastAPI, UploadFile, File
20
+ from fastapi.responses import JSONResponse
21
+
22
+ import uvicorn
23
+
24
+ # === FASTAPI APP ===
25
+ app = FastAPI()
26
+
27
+ # === UTILS ===
28
+
29
+ def is_youtube_url(url):
30
+ return "youtube.com" in url or "youtu.be" in url
31
+
32
+ def is_web_url(url):
33
+ return url.startswith("http://") or url.startswith("https://")
34
+
35
+ def get_video_id(url):
36
+ match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11})', url)
37
+ return match.group(1) if match else None
38
+
39
+ def try_download_transcript(video_id):
40
+ try:
41
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=["en"])
42
+ formatted = TextFormatter().format_transcript(transcript)
43
+ return formatted
44
+ except (TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript, VideoUnavailable):
45
+ return None
46
+ except Exception as e:
47
+ print(f"Transcript error: {e}")
48
+ return None
49
+
50
+ def download_audio_youtube(url, output_path="audio.wav", cookies_path=None):
51
+ import subprocess
52
+
53
+ fallback_video_path = "fallback_video.mp4"
54
+ video_id= get_video_id(url)
55
+
56
+ ydl_opts = {
57
+ "format": "best",
58
+ "outtmpl": fallback_video_path,
59
+ "user_agent": "com.google.android.youtube/17.31.35 (Linux; U; Android 11)",
60
+ "compat_opts": ["allow_unplayable_formats"]
61
+ }
62
+
63
+ if cookies_path:
64
+ ydl_opts["cookiefile"] = cookies_path
65
+
66
+ try:
67
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
68
+ ydl.download([url])
69
+ except Exception as e:
70
+ try:
71
+ list_cmd = ["yt-dlp", "-F", url]
72
+ if cookies_path:
73
+ list_cmd += ["--cookies", cookies_path]
74
+ result = subprocess.run(list_cmd, capture_output=True, text=True, timeout=15)
75
+ formats = result.stdout or "No formats found."
76
+ except Exception as format_err:
77
+ formats = f"\u26a0\ufe0f Could not list formats due to: {format_err}"
78
+
79
+ raise RuntimeError(
80
+ "\u26a0\ufe0f Could not download this YouTube video due to restrictions. "
81
+ "Please use this alternative tool to extract the transcript manually:\n\n"
82
+ f"<https://youtubetotranscript.com/transcript?v={video_id}&current_language_code=en>"
83
+ )
84
+
85
+ return extract_audio_from_video(fallback_video_path, audio_path=output_path)
86
+
87
+ def download_video_direct(url, output_path="video.mp4"):
88
+ ydl_opts = {
89
+ "format": "best",
90
+ "outtmpl": output_path
91
+ }
92
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
93
+ ydl.download([url])
94
+ return output_path
95
+
96
+ def extract_audio_from_video(video_path, audio_path="audio.wav"):
97
+ ffmpeg.input(video_path).output(audio_path, ac=1, ar=16000).run(overwrite_output=True)
98
+ return audio_path
99
+
100
+ def split_audio(input_path, chunk_length_sec=30, target_sr=16000):
101
+ waveform, sr = torchaudio.load(input_path)
102
+ if sr != target_sr:
103
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
104
+ waveform = resampler(waveform)
105
+ if waveform.shape[0] > 1:
106
+ waveform = waveform.mean(dim=0, keepdim=True)
107
+ chunk_samples = target_sr * chunk_length_sec
108
+ chunks = [waveform[:, i:i+chunk_samples] for i in range(0, waveform.shape[1], chunk_samples)]
109
+ return chunks, target_sr
110
+
111
+ class AudioChunksDataset(Dataset):
112
+ def __init__(self, chunks):
113
+ self.chunks = chunks
114
+
115
+ def __len__(self):
116
+ return len(self.chunks)
117
+
118
+ def __getitem__(self, idx):
119
+ return self.chunks[idx].squeeze(0)
120
+
121
+ def collate_audio_batch(batch):
122
+ max_len = max([b.shape[0] for b in batch])
123
+ padded_batch = [torch.nn.functional.pad(b, (0, max_len - b.shape[0])) for b in batch]
124
+ return torch.stack(padded_batch)
125
+
126
+ def transcribe_chunks_dataset(chunks, sr, model_name="openai/whisper-small", batch_size=4):
127
+ device = "cuda" if torch.cuda.is_available() else "cpu"
128
+ processor = WhisperProcessor.from_pretrained(model_name)
129
+ model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
130
+ model.eval()
131
+
132
+ dataset = AudioChunksDataset(chunks)
133
+ dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_audio_batch)
134
+
135
+ full_transcript = []
136
+ for batch_waveforms in dataloader:
137
+ wave_list = [waveform.numpy() for waveform in batch_waveforms]
138
+ input_features = processor(wave_list, sampling_rate=sr, return_tensors="pt", padding="max_length").input_features.to(device)
139
+ with torch.no_grad():
140
+ predicted_ids = model.generate(input_features, language="en")
141
+ transcriptions = processor.batch_decode(predicted_ids, skip_special_tokens=True)
142
+ full_transcript.extend(transcriptions)
143
+
144
+ return " ".join(full_transcript)
145
+
146
+ def summarize_with_bart(text, max_tokens=1024):
147
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1)
148
+ sentences = text.split(". ")
149
+ chunks = []
150
+ current_chunk = ""
151
+
152
+ for sentence in sentences:
153
+ if len(current_chunk + sentence) <= max_tokens:
154
+ current_chunk += sentence + ". "
155
+ else:
156
+ chunks.append(current_chunk.strip())
157
+ current_chunk = sentence + ". "
158
+ if current_chunk:
159
+ chunks.append(current_chunk.strip())
160
+
161
+ summary = ""
162
+ for chunk in chunks:
163
+ out = summarizer(chunk, max_length=150, min_length=30, do_sample=False)
164
+ summary += out[0]['summary_text'] + " "
165
+
166
+ return summary.strip()
167
+
168
+ def generate_questions_with_pipeline(text, num_questions=5):
169
+ question_generator = pipeline("text2text-generation", model="valhalla/t5-base-qg-hl", device=0 if torch.cuda.is_available() else -1)
170
+ sentences = text.split(". ")
171
+ questions = []
172
+
173
+ for sentence in sentences[:num_questions * 2]:
174
+ if not sentence.strip():
175
+ continue
176
+ input_text = f"generate question: {sentence.strip()}"
177
+ out = question_generator(input_text, max_length=50, do_sample=True, temperature=0.9)
178
+ question = out[0]["generated_text"].strip()
179
+ if question:
180
+ questions.append(question)
181
+
182
+ return questions[:num_questions]
183
+
184
+ # === FASTAPI ROUTE FOR DIRECT FILE UPLOAD ===
185
+
186
+ @app.post("/upload")
187
+ async def upload(file: UploadFile = File(...)):
188
+ try:
189
+ file_path = f"temp_{file.filename}"
190
+ with open(file_path, "wb") as f:
191
+ f.write(await file.read())
192
+
193
+ audio_path = extract_audio_from_video(file_path)
194
+ chunks, sr = split_audio(audio_path, chunk_length_sec=15)
195
+ transcript = transcribe_chunks_dataset(chunks, sr)
196
+ summary = summarize_with_bart(transcript)
197
+ questions = generate_questions_with_pipeline(summary)
198
+ os.remove(file_path)
199
+ return JSONResponse({"summary": summary, "questions": questions})
200
+ except Exception as e:
201
+ return JSONResponse({"error": str(e)})
202
+
203
+ # === GRADIO UI ===
204
+
205
+ def process_input_gradio(url_input, file_input, text_input):
206
+ try:
207
+ transcript = ""
208
+
209
+ if text_input:
210
+ transcript = text_input.strip()
211
+
212
+ elif file_input is not None:
213
+ audio_path = extract_audio_from_video(file_input.name)
214
+ chunks, sr = split_audio(audio_path, chunk_length_sec=15)
215
+ transcript = transcribe_chunks_dataset(chunks, sr)
216
+
217
+ elif url_input:
218
+ if is_youtube_url(url_input):
219
+ video_id = get_video_id(url_input)
220
+ transcript = try_download_transcript(video_id)
221
+ if not transcript:
222
+ audio_path = download_audio_youtube(url_input)
223
+ chunks, sr = split_audio(audio_path, chunk_length_sec=15)
224
+ transcript = transcribe_chunks_dataset(chunks, sr)
225
+ else:
226
+ video_file = download_video_direct(url_input)
227
+ audio_path = extract_audio_from_video(video_file)
228
+ chunks, sr = split_audio(audio_path, chunk_length_sec=15)
229
+ transcript = transcribe_chunks_dataset(chunks, sr)
230
+ else:
231
+ return "Please provide a URL, upload a video file, or paste text.", ""
232
+
233
+ summary = summarize_with_bart(transcript)
234
+ questions = generate_questions_with_pipeline(summary)
235
+ return summary, "\n".join([f"{i+1}. {q}" for i, q in enumerate(questions)])
236
+ except Exception as e:
237
+ return f"Error: {str(e)}", ""
238
+
239
+ iface = gr.Interface(
240
+ fn=process_input_gradio,
241
+ inputs=[
242
+ gr.Textbox(label="YouTube or Direct Video URL", placeholder="https://..."),
243
+ gr.File(label="Or Upload a Video File", file_types=[".mp4", ".mkv", ".webm"]),
244
+ gr.Textbox(label="Or Paste Transcript/Text Directly", lines=10, placeholder="Paste transcript or text here...")
245
+ ],
246
+ outputs=[
247
+ gr.Textbox(label="Summary", lines=10),
248
+ gr.Textbox(label="Generated Questions", lines=10),
249
+ ],
250
+ title="Lecture Summary & Question Generator",
251
+ description="Provide a YouTube/Direct video URL, upload a video file, or paste text. If the video is restricted, upload the video file directly."
252
+ )
253
+
254
+ app = gr.mount_gradio_app(app, iface, path="/")
255
+
256
+ # === RUNNING BOTH FASTAPI + GRADIO ===
257
+
258
+ if __name__ == "__main__":
259
+ uvicorn.run(app, host="0.0.0.0", port=7860)