import os import re import torch import ffmpeg import yt_dlp import torchaudio import gradio as gr import shutil from torch.utils.data import Dataset, DataLoader from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript, VideoUnavailable from youtube_transcript_api.formatters import TextFormatter from transformers import ( pipeline, WhisperProcessor, WhisperForConditionalGeneration, ) from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse import uvicorn # === FASTAPI APP === app = FastAPI() # === UTILS === def is_youtube_url(url): return "youtube.com" in url or "youtu.be" in url def is_web_url(url): return url.startswith("http://") or url.startswith("https://") def get_video_id(url): match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11})', url) return match.group(1) if match else None def try_download_transcript(video_id): try: transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=["en"]) formatted = TextFormatter().format_transcript(transcript) return formatted except (TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript, VideoUnavailable): return None except Exception as e: print(f"Transcript error: {e}") return None def download_audio_youtube(url, output_path="audio.wav", cookies_path=None): import subprocess fallback_video_path = "fallback_video.mp4" video_id= get_video_id(url) ydl_opts = { "format": "best", "outtmpl": fallback_video_path, "user_agent": "com.google.android.youtube/17.31.35 (Linux; U; Android 11)", "compat_opts": ["allow_unplayable_formats"] } if cookies_path: ydl_opts["cookiefile"] = cookies_path try: with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([url]) except Exception as e: try: list_cmd = ["yt-dlp", "-F", url] if cookies_path: list_cmd += ["--cookies", cookies_path] result = subprocess.run(list_cmd, capture_output=True, text=True, timeout=15) formats = result.stdout or "No formats found." except Exception as format_err: formats = f"\u26a0\ufe0f Could not list formats due to: {format_err}" raise RuntimeError( "\u26a0\ufe0f Could not download this YouTube video due to restrictions. " "Please use this alternative tool to extract the transcript manually:\n\n" f"" ) return extract_audio_from_video(fallback_video_path, audio_path=output_path) def download_video_direct(url, output_path="video.mp4"): ydl_opts = { "format": "best", "outtmpl": output_path } with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([url]) return output_path def extract_audio_from_video(video_path, audio_path="audio.wav"): ffmpeg.input(video_path).output(audio_path, ac=1, ar=16000).run(overwrite_output=True) return audio_path def split_audio(input_path, chunk_length_sec=30, target_sr=16000): waveform, sr = torchaudio.load(input_path) if sr != target_sr: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) waveform = resampler(waveform) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) chunk_samples = target_sr * chunk_length_sec chunks = [waveform[:, i:i+chunk_samples] for i in range(0, waveform.shape[1], chunk_samples)] return chunks, target_sr class AudioChunksDataset(Dataset): def __init__(self, chunks): self.chunks = chunks def __len__(self): return len(self.chunks) def __getitem__(self, idx): return self.chunks[idx].squeeze(0) def collate_audio_batch(batch): max_len = max([b.shape[0] for b in batch]) padded_batch = [torch.nn.functional.pad(b, (0, max_len - b.shape[0])) for b in batch] return torch.stack(padded_batch) def transcribe_chunks_dataset(chunks, sr, model_name="openai/whisper-small", batch_size=4): device = "cuda" if torch.cuda.is_available() else "cpu" processor = WhisperProcessor.from_pretrained(model_name) model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device) model.eval() dataset = AudioChunksDataset(chunks) dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_audio_batch) full_transcript = [] for batch_waveforms in dataloader: wave_list = [waveform.numpy() for waveform in batch_waveforms] input_features = processor(wave_list, sampling_rate=sr, return_tensors="pt", padding="max_length").input_features.to(device) with torch.no_grad(): predicted_ids = model.generate(input_features, language="en") transcriptions = processor.batch_decode(predicted_ids, skip_special_tokens=True) full_transcript.extend(transcriptions) return " ".join(full_transcript) def summarize_with_bart(text, max_tokens=1024): summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1) sentences = text.split(". ") chunks = [] current_chunk = "" for sentence in sentences: if len(current_chunk + sentence) <= max_tokens: current_chunk += sentence + ". " else: chunks.append(current_chunk.strip()) current_chunk = sentence + ". " if current_chunk: chunks.append(current_chunk.strip()) summary = "" for chunk in chunks: out = summarizer(chunk, max_length=150, min_length=30, do_sample=False) summary += out[0]['summary_text'] + " " return summary.strip() def generate_questions_with_pipeline(text, num_questions=5): question_generator = pipeline("text2text-generation", model="valhalla/t5-base-qg-hl", device=0 if torch.cuda.is_available() else -1) sentences = text.split(". ") questions = [] for sentence in sentences[:num_questions * 2]: if not sentence.strip(): continue input_text = f"generate question: {sentence.strip()}" out = question_generator(input_text, max_length=50, do_sample=True, temperature=0.9) question = out[0]["generated_text"].strip() if question: questions.append(question) return questions[:num_questions] # === FASTAPI ROUTE FOR DIRECT FILE UPLOAD === @app.post("/upload") async def upload(file: UploadFile = File(...)): try: file_path = f"temp_{file.filename}" with open(file_path, "wb") as f: f.write(await file.read()) audio_path = extract_audio_from_video(file_path) chunks, sr = split_audio(audio_path, chunk_length_sec=15) transcript = transcribe_chunks_dataset(chunks, sr) summary = summarize_with_bart(transcript) questions = generate_questions_with_pipeline(summary) os.remove(file_path) return JSONResponse({"summary": summary, "questions": questions}) except Exception as e: return JSONResponse({"error": str(e)}) # === GRADIO UI === def process_input_gradio(url_input, file_input, text_input): try: transcript = "" if text_input: transcript = text_input.strip() elif file_input is not None: audio_path = extract_audio_from_video(file_input.name) chunks, sr = split_audio(audio_path, chunk_length_sec=15) transcript = transcribe_chunks_dataset(chunks, sr) elif url_input: if is_youtube_url(url_input): video_id = get_video_id(url_input) transcript = try_download_transcript(video_id) if not transcript: audio_path = download_audio_youtube(url_input) chunks, sr = split_audio(audio_path, chunk_length_sec=15) transcript = transcribe_chunks_dataset(chunks, sr) else: video_file = download_video_direct(url_input) audio_path = extract_audio_from_video(video_file) chunks, sr = split_audio(audio_path, chunk_length_sec=15) transcript = transcribe_chunks_dataset(chunks, sr) else: return "Please provide a URL, upload a video file, or paste text.", "" summary = summarize_with_bart(transcript) questions = generate_questions_with_pipeline(summary) return summary, "\n".join([f"{i+1}. {q}" for i, q in enumerate(questions)]) except Exception as e: return f"Error: {str(e)}", "" iface = gr.Interface( fn=process_input_gradio, inputs=[ gr.Textbox(label="YouTube or Direct Video URL", placeholder="https://..."), gr.File(label="Or Upload a Video File", file_types=[".mp4", ".mkv", ".webm"]), gr.Textbox(label="Or Paste Transcript/Text Directly", lines=10, placeholder="Paste transcript or text here...") ], outputs=[ gr.Textbox(label="Summary", lines=10), gr.Textbox(label="Generated Questions", lines=10), ], title="Lecture Summary & Question Generator", description="Provide a YouTube/Direct video URL, upload a video file, or paste text. If the video is restricted, upload the video file directly." ) app = gr.mount_gradio_app(app, iface, path="/") # === RUNNING BOTH FASTAPI + GRADIO === if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)