Spaces:
Runtime error
Runtime error
| import nltk | |
| import os, json | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| nltk.download("punkt_tab") | |
| RETRIEVER = None | |
| import gradio as gr | |
| import nltk | |
| from typing import List | |
| from nltk.tokenize import sent_tokenize | |
| from dataclasses import dataclass | |
| import re | |
| from sentence_transformers import CrossEncoder | |
| from llama_index.retrievers.bm25 import BM25Retriever | |
| from llama_index.core.retrievers import QueryFusionRetriever | |
| from llama_index.core import Settings, VectorStoreIndex | |
| from llama_index.core.schema import TextNode | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.llms.openai import OpenAI | |
| Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| Settings.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=os.environ.get("OPENAI_API_BASE")) | |
| class Utterance: | |
| start: float | |
| end: float | |
| speaker: str | |
| text: str | |
| def ts_to_sec(ts: str) -> float: | |
| h, m, s = ts.split(":") | |
| return int(h) * 3600 + int(m) * 60 + float(s) | |
| def parse_webvtt(path: str) -> list[Utterance]: | |
| utterances = [] | |
| lines = open(path, encoding="utf-8").readlines() | |
| i = 0 | |
| while i < len(lines): | |
| line = lines[i].strip() | |
| if "-->" in line: | |
| start, end = map(str.strip, line.split("-->")) | |
| start, end = ts_to_sec(start), ts_to_sec(end) | |
| i += 1 | |
| speaker, text = "UNKNOWN", "" | |
| if ":" in lines[i]: | |
| speaker, text = lines[i].split(":", 1) | |
| speaker, text = speaker.strip(), text.strip() | |
| else: | |
| text = lines[i].strip() | |
| utterances.append(Utterance(start, end, speaker, text)) | |
| i += 1 | |
| return utterances | |
| def build_subchunks( | |
| utterances, | |
| max_gap_sec=25, | |
| max_words=120, | |
| sentences_per_chunk=3 | |
| ): | |
| chunks, current = [], [] | |
| last_end = None | |
| for u in utterances: | |
| gap = None if last_end is None else u.start - last_end | |
| wc = sum(len(x.text.split()) for x in current) | |
| if (gap and gap > max_gap_sec) or wc > max_words: | |
| chunks.append(current) | |
| current = [] | |
| current.append(u) | |
| last_end = u.end | |
| if current: | |
| chunks.append(current) | |
| subchunks = [] | |
| for c in chunks: | |
| text = " ".join(u.text for u in c) | |
| sentences = sent_tokenize(text) | |
| for i in range(0, len(sentences), sentences_per_chunk): | |
| subchunks.append({ | |
| "text": " ".join(sentences[i:i+sentences_per_chunk]), | |
| "start": c[0].start, | |
| "end": c[-1].end, | |
| "speakers": list(set(u.speaker for u in c)) | |
| }) | |
| return subchunks | |
| TOPIC_RULES = { | |
| "gpu": ["gpu", "graphics card", "cuda", "vram", "nvidia"], | |
| "technical_challenge": [ | |
| "issue", "problem", "challenge", "difficulty", | |
| "error", "not working", "failed", "crash" | |
| ], | |
| "real_world_use_case": [ | |
| "use case", "real world", "industry", | |
| "production", "business case", "example" | |
| ], | |
| "qa": [ | |
| "question", "follow up", "does that help", | |
| "good question", "let me clarify" | |
| ] | |
| } | |
| def tag_topics(text: str) -> list[str]: | |
| text = text.lower() | |
| tags = set() | |
| for topic, kws in TOPIC_RULES.items(): | |
| if any(re.search(rf"\b{re.escape(k)}\b", text) for k in kws): | |
| tags.add(topic) | |
| return list(tags) | |
| def build_nodes(subchunks): | |
| nodes = [] | |
| for c in subchunks: | |
| nodes.append( | |
| TextNode( | |
| text=c["text"], | |
| metadata={ | |
| "start": c["start"], | |
| "end": c["end"], | |
| "speakers": c["speakers"], | |
| "topics": tag_topics(c["text"]) | |
| } | |
| ) | |
| ) | |
| return nodes | |
| def build_hybrid_retriever(nodes): | |
| index = VectorStoreIndex(nodes) | |
| # Use nodes= keyword argument explicitly | |
| bm25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=20) | |
| vector = index.as_retriever(similarity_top_k=20) | |
| return QueryFusionRetriever( | |
| retrievers=[bm25, vector], | |
| similarity_top_k=10, | |
| mode="reciprocal_rerank" | |
| ) | |
| def expand_query(q: str) -> str: | |
| expansions = { | |
| "gpu": ["graphics card", "cuda", "vram"], | |
| "challenge": ["issue", "problem", "difficulty", "error"] | |
| } | |
| ql = q.lower() | |
| for k, v in expansions.items(): | |
| if k in ql: | |
| q += " " + " ".join(v) | |
| return q | |
| def infer_required_topics(q: str) -> set[str]: | |
| ql = q.lower() | |
| req = set() | |
| if any(w in ql for w in ["gpu", "cuda", "vram"]): | |
| req.add("gpu") | |
| if any(w in ql for w in ["challenge", "issue", "problem", "difficulty"]): | |
| req.add("technical_challenge") | |
| return req | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| def rerank(query, nodes): | |
| scores = reranker.predict([[query, n.text] for n in nodes]) | |
| return [n for _, n in sorted(zip(scores, nodes), reverse=True)] | |
| def retrieve(query, retriever, top_k=5): | |
| expanded = expand_query(query) | |
| required_topics = infer_required_topics(query) | |
| candidates = retriever.retrieve(expanded) | |
| if required_topics: | |
| candidates = [ | |
| n for n in candidates | |
| if required_topics.issubset(set(n.metadata["topics"])) | |
| ] | |
| reranked = rerank(expanded, candidates) | |
| return [{ | |
| "text": n.text, | |
| "topics": n.metadata["topics"], | |
| "start": n.metadata["start"], | |
| "end": n.metadata["end"], | |
| "speakers": n.metadata["speakers"] | |
| } for n in reranked[:top_k]] | |
| # ----------------------------- | |
| # Gradio App | |
| # ----------------------------- | |
| def index_file(file): | |
| global RETRIEVER | |
| utterances = parse_webvtt(file.name) | |
| subchunks = build_subchunks(utterances) | |
| nodes = build_nodes(subchunks) | |
| RETRIEVER = build_hybrid_retriever(nodes) | |
| return "✅ Index built successfully" | |
| def run_query(query): | |
| global RETRIEVER | |
| if RETRIEVER is None: | |
| return "❌ Please upload and index a transcript first." | |
| return retrieve(query, RETRIEVER) | |
| with gr.Blocks(title="Transcript Hybrid RAG") as demo: | |
| gr.Markdown("## 🎙️ Transcript Hybrid Search (BM25 + Vectors)") | |
| gr.Markdown( | |
| "Upload a transcript and ask questions. " | |
| "**Retrieval only** (no hallucinations)." | |
| ) | |
| upload = gr.File( | |
| label="Upload transcript", | |
| file_types=[".vtt", ".txt", ".transcript"] | |
| ) | |
| index_btn = gr.Button("Build Index") | |
| status = gr.Textbox(label="Status") | |
| index_btn.click( | |
| fn=index_file, | |
| inputs=upload, | |
| outputs=status | |
| ) | |
| query = gr.Textbox( | |
| label="Ask a question", | |
| placeholder="Did the instructor face GPU challenges?" | |
| ) | |
| output = gr.Textbox( | |
| label="Retrieved Evidence", | |
| lines=15 | |
| ) | |
| query.submit( | |
| fn=run_query, | |
| inputs=query, | |
| outputs=output | |
| ) | |
| demo.launch() | |