| |
|
|
| import os |
| import re |
| import gradio as gr |
| import numpy as np |
| import faiss |
|
|
| from youtube_transcript_api import ( |
| YouTubeTranscriptApi, |
| TranscriptsDisabled, |
| NoTranscriptFound, |
| VideoUnavailable, |
| ) |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from sentence_transformers import SentenceTransformer |
| from huggingface_hub import InferenceClient |
|
|
| |
| |
| |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
| faiss_index = None |
| chunk_store = [] |
| full_transcript = "" |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" |
| inference_client = InferenceClient(model=LLM_MODEL, token=HF_TOKEN or None) |
|
|
| |
| |
| |
| def _extract_video_id(url: str) -> str: |
| patterns = [ |
| r"(?:v=)([A-Za-z0-9_-]{11})", |
| r"(?:youtu\.be/)([A-Za-z0-9_-]{11})", |
| r"(?:embed/)([A-Za-z0-9_-]{11})", |
| r"(?:shorts/)([A-Za-z0-9_-]{11})", |
| ] |
| for pattern in patterns: |
| match = re.search(pattern, url) |
| if match: |
| return match.group(1) |
| raise ValueError(f"Could not extract a valid video ID from: {url}") |
|
|
| |
| |
| |
| |
| |
| |
| def get_transcript(url: str) -> str: |
| video_id = _extract_video_id(url) |
|
|
| |
| try: |
| snippets = YouTubeTranscriptApi.get_transcript( |
| video_id, languages=["en", "en-US", "en-GB"] |
| ) |
| return " ".join(s["text"] for s in snippets) |
| except (NoTranscriptFound, TranscriptsDisabled): |
| pass |
| except VideoUnavailable: |
| raise ValueError("This video is unavailable or private.") |
| except Exception: |
| pass |
|
|
| |
| try: |
| transcript_list = YouTubeTranscriptApi.list_transcripts(video_id) |
| transcript = None |
| |
| for t in transcript_list: |
| if t.language_code.startswith("en"): |
| transcript = t |
| break |
| |
| if transcript is None: |
| for t in transcript_list: |
| transcript = t |
| break |
| if transcript is None: |
| raise ValueError("No transcripts are available for this video.") |
| |
| snippets = transcript.fetch() |
| return " ".join(s["text"] for s in snippets) |
| except ValueError: |
| raise |
| except TranscriptsDisabled: |
| raise ValueError("Transcripts are disabled for this video.") |
| except Exception as exc: |
| raise ValueError(f"Could not retrieve transcript: {exc}") |
|
|
| |
| |
| |
| def process_video(url: str): |
| global faiss_index, chunk_store, full_transcript |
|
|
| faiss_index = None |
| chunk_store = [] |
| full_transcript = "" |
|
|
| if not url.strip(): |
| return "⚠️ Please enter a YouTube URL.", "" |
|
|
| try: |
| transcript = get_transcript(url) |
| except ValueError as exc: |
| return f"❌ {exc}", "" |
| except Exception as exc: |
| return f"❌ Unexpected error: {exc}", "" |
|
|
| if not transcript.strip(): |
| return "❌ Transcript is empty for this video.", "" |
|
|
| full_transcript = transcript |
|
|
| splitter = RecursiveCharacterTextSplitter( |
| chunk_size=500, |
| chunk_overlap=50, |
| length_function=len, |
| ) |
| chunks = splitter.split_text(transcript) |
| if not chunks: |
| return "❌ Could not split transcript into chunks.", transcript |
|
|
| chunk_store = chunks |
|
|
| embeddings = embedding_model.encode(chunks, show_progress_bar=False) |
| embeddings = np.array(embeddings, dtype="float32") |
|
|
| dim = embeddings.shape[1] |
| index = faiss.IndexFlatL2(dim) |
| index.add(embeddings) |
| faiss_index = index |
|
|
| status = ( |
| f"✅ Video processed successfully!\n" |
| f" • Chunks created : {len(chunks)}\n" |
| f" • Embedding dim : {dim}\n" |
| f" • FAISS vectors : {index.ntotal}\n\n" |
| f"Switch to the 💬 Chat with Video tab to ask questions." |
| ) |
| return status, transcript |
|
|
| |
| |
| |
| def retrieve_context(query: str, top_k: int = 3) -> str: |
| if faiss_index is None or not chunk_store: |
| return "" |
|
|
| query_vec = embedding_model.encode([query], show_progress_bar=False) |
| query_vec = np.array(query_vec, dtype="float32") |
|
|
| k = min(top_k, len(chunk_store)) |
| _, indices = faiss_index.search(query_vec, k) |
|
|
| retrieved = [chunk_store[i] for i in indices[0] if 0 <= i < len(chunk_store)] |
| return "\n\n".join(retrieved) |
|
|
| |
| |
| |
| def generate_answer(query: str) -> str: |
| if faiss_index is None: |
| return ( |
| "⚠️ No video processed yet. " |
| "Go to 📥 Process Video tab first." |
| ) |
|
|
| context = retrieve_context(query, top_k=3) |
| if not context: |
| return "⚠️ Could not retrieve relevant context for your question." |
|
|
| system_prompt = ( |
| "You are a helpful assistant that answers questions strictly " |
| "based on the provided video transcript context. " |
| "If the answer is not in the context, say: " |
| "'I could not find this information in the video transcript.' " |
| "Do NOT hallucinate or make up information." |
| ) |
|
|
| user_prompt = ( |
| f"Context from the video transcript:\n" |
| f"---\n{context}\n---\n\n" |
| f"Question: {query}\n\n" |
| f"Answer:" |
| ) |
|
|
| try: |
| response = inference_client.chat_completion( |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| max_tokens=512, |
| temperature=0.2, |
| top_p=0.9, |
| ) |
| return response.choices[0].message.content.strip() |
| except Exception as exc: |
| return ( |
| f"❌ Inference failed: {exc}\n" |
| "Check that HF_TOKEN is set correctly as a Space secret." |
| ) |
|
|
| |
| |
| |
| |
| def chat(user_message: str, history: list): |
| if not user_message.strip(): |
| history = history + [["", "⚠️ Please enter a question."]] |
| return history, "" |
| answer = generate_answer(user_message) |
| history = history + [[user_message, answer]] |
| return history, "" |
|
|
| |
| |
| |
| with gr.Blocks(title="YouTube RAG Chatbot") as app: |
|
|
| gr.Markdown( |
| """ |
| # 🎬 YouTube RAG Chatbot |
| **Fetch any YouTube transcript and chat with it using RAG + Mistral-7B.** |
| > 🔑 Add your `HF_TOKEN` in Space **Settings → Secrets** for the LLM to work. |
| """ |
| ) |
|
|
| with gr.Tabs(): |
|
|
| |
| with gr.TabItem("📥 Process Video"): |
| gr.Markdown("Enter a YouTube URL and click **Process** to index the transcript.") |
|
|
| with gr.Row(): |
| url_input = gr.Textbox( |
| label="YouTube URL", |
| placeholder="https://www.youtube.com/watch?v=...", |
| scale=5, |
| ) |
| process_btn = gr.Button("⚙️ Process", variant="primary", scale=1) |
|
|
| status_output = gr.Textbox( |
| label="Status", |
| lines=6, |
| interactive=False, |
| ) |
| transcript_output = gr.Textbox( |
| label="Transcript", |
| lines=15, |
| interactive=False, |
| ) |
|
|
| process_btn.click( |
| fn=process_video, |
| inputs=[url_input], |
| outputs=[status_output, transcript_output], |
| ) |
|
|
| |
| with gr.TabItem("💬 Chat with Video"): |
| gr.Markdown("Ask questions about the video. Answers are grounded in the transcript.") |
|
|
| |
| chatbot = gr.Chatbot(label="Conversation", height=450) |
|
|
| with gr.Row(): |
| query_input = gr.Textbox( |
| label="Your question", |
| placeholder="What is the main topic of this video?", |
| scale=5, |
| ) |
| send_btn = gr.Button("Send 🚀", variant="primary", scale=1) |
|
|
| clear_btn = gr.Button("🗑️ Clear", variant="secondary") |
|
|
| |
| chat_history = gr.State([]) |
|
|
| send_btn.click( |
| fn=chat, |
| inputs=[query_input, chat_history], |
| outputs=[chatbot, query_input], |
| ).then( |
| fn=lambda h: h, |
| inputs=[chatbot], |
| outputs=[chat_history], |
| ) |
|
|
| query_input.submit( |
| fn=chat, |
| inputs=[query_input, chat_history], |
| outputs=[chatbot, query_input], |
| ).then( |
| fn=lambda h: h, |
| inputs=[chatbot], |
| outputs=[chat_history], |
| ) |
|
|
| clear_btn.click( |
| fn=lambda: ([], []), |
| outputs=[chatbot, chat_history], |
| ) |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| app.launch() |