sohchattglc11111's picture
Update app.py
45d99ed verified
import nltk
import os, json
from dotenv import load_dotenv
load_dotenv()
nltk.download("punkt_tab")
RETRIEVER = None
import gradio as gr
import nltk
from typing import List
from nltk.tokenize import sent_tokenize
from dataclasses import dataclass
import re
from sentence_transformers import CrossEncoder
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core import Settings, VectorStoreIndex
from llama_index.core.schema import TextNode
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.openai import OpenAI
Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
Settings.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=os.environ.get("OPENAI_API_BASE"))
@dataclass
class Utterance:
start: float
end: float
speaker: str
text: str
def ts_to_sec(ts: str) -> float:
h, m, s = ts.split(":")
return int(h) * 3600 + int(m) * 60 + float(s)
def parse_webvtt(path: str) -> list[Utterance]:
utterances = []
lines = open(path, encoding="utf-8").readlines()
i = 0
while i < len(lines):
line = lines[i].strip()
if "-->" in line:
start, end = map(str.strip, line.split("-->"))
start, end = ts_to_sec(start), ts_to_sec(end)
i += 1
speaker, text = "UNKNOWN", ""
if ":" in lines[i]:
speaker, text = lines[i].split(":", 1)
speaker, text = speaker.strip(), text.strip()
else:
text = lines[i].strip()
utterances.append(Utterance(start, end, speaker, text))
i += 1
return utterances
def build_subchunks(
utterances,
max_gap_sec=25,
max_words=120,
sentences_per_chunk=3
):
chunks, current = [], []
last_end = None
for u in utterances:
gap = None if last_end is None else u.start - last_end
wc = sum(len(x.text.split()) for x in current)
if (gap and gap > max_gap_sec) or wc > max_words:
chunks.append(current)
current = []
current.append(u)
last_end = u.end
if current:
chunks.append(current)
subchunks = []
for c in chunks:
text = " ".join(u.text for u in c)
sentences = sent_tokenize(text)
for i in range(0, len(sentences), sentences_per_chunk):
subchunks.append({
"text": " ".join(sentences[i:i+sentences_per_chunk]),
"start": c[0].start,
"end": c[-1].end,
"speakers": list(set(u.speaker for u in c))
})
return subchunks
TOPIC_RULES = {
"gpu": ["gpu", "graphics card", "cuda", "vram", "nvidia"],
"technical_challenge": [
"issue", "problem", "challenge", "difficulty",
"error", "not working", "failed", "crash"
],
"real_world_use_case": [
"use case", "real world", "industry",
"production", "business case", "example"
],
"qa": [
"question", "follow up", "does that help",
"good question", "let me clarify"
]
}
def tag_topics(text: str) -> list[str]:
text = text.lower()
tags = set()
for topic, kws in TOPIC_RULES.items():
if any(re.search(rf"\b{re.escape(k)}\b", text) for k in kws):
tags.add(topic)
return list(tags)
def build_nodes(subchunks):
nodes = []
for c in subchunks:
nodes.append(
TextNode(
text=c["text"],
metadata={
"start": c["start"],
"end": c["end"],
"speakers": c["speakers"],
"topics": tag_topics(c["text"])
}
)
)
return nodes
def build_hybrid_retriever(nodes):
index = VectorStoreIndex(nodes)
# Use nodes= keyword argument explicitly
bm25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=20)
vector = index.as_retriever(similarity_top_k=20)
return QueryFusionRetriever(
retrievers=[bm25, vector],
similarity_top_k=10,
mode="reciprocal_rerank"
)
def expand_query(q: str) -> str:
expansions = {
"gpu": ["graphics card", "cuda", "vram"],
"challenge": ["issue", "problem", "difficulty", "error"]
}
ql = q.lower()
for k, v in expansions.items():
if k in ql:
q += " " + " ".join(v)
return q
def infer_required_topics(q: str) -> set[str]:
ql = q.lower()
req = set()
if any(w in ql for w in ["gpu", "cuda", "vram"]):
req.add("gpu")
if any(w in ql for w in ["challenge", "issue", "problem", "difficulty"]):
req.add("technical_challenge")
return req
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
def rerank(query, nodes):
scores = reranker.predict([[query, n.text] for n in nodes])
return [n for _, n in sorted(zip(scores, nodes), reverse=True)]
def retrieve(query, retriever, top_k=5):
expanded = expand_query(query)
required_topics = infer_required_topics(query)
candidates = retriever.retrieve(expanded)
if required_topics:
candidates = [
n for n in candidates
if required_topics.issubset(set(n.metadata["topics"]))
]
reranked = rerank(expanded, candidates)
return [{
"text": n.text,
"topics": n.metadata["topics"],
"start": n.metadata["start"],
"end": n.metadata["end"],
"speakers": n.metadata["speakers"]
} for n in reranked[:top_k]]
# -----------------------------
# Gradio App
# -----------------------------
def index_file(file):
global RETRIEVER
utterances = parse_webvtt(file.name)
subchunks = build_subchunks(utterances)
nodes = build_nodes(subchunks)
RETRIEVER = build_hybrid_retriever(nodes)
return "✅ Index built successfully"
def run_query(query):
global RETRIEVER
if RETRIEVER is None:
return "❌ Please upload and index a transcript first."
return retrieve(query, RETRIEVER)
with gr.Blocks(title="Transcript Hybrid RAG") as demo:
gr.Markdown("## 🎙️ Transcript Hybrid Search (BM25 + Vectors)")
gr.Markdown(
"Upload a transcript and ask questions. "
"**Retrieval only** (no hallucinations)."
)
upload = gr.File(
label="Upload transcript",
file_types=[".vtt", ".txt", ".transcript"]
)
index_btn = gr.Button("Build Index")
status = gr.Textbox(label="Status")
index_btn.click(
fn=index_file,
inputs=upload,
outputs=status
)
query = gr.Textbox(
label="Ask a question",
placeholder="Did the instructor face GPU challenges?"
)
output = gr.Textbox(
label="Retrieved Evidence",
lines=15
)
query.submit(
fn=run_query,
inputs=query,
outputs=output
)
demo.launch()