RAG / app.py
renatavl's picture
init
860ef8a
import os
import re
import ast
import threading
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any
from itertools import islice
import numpy as np
import gradio as gr
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from litellm import completion
from datasets import load_dataset
# -----------------------------
# Config
# -----------------------------
HF_DATASET_NAME = "CodeKapital/CookingRecipes"
DENSE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
CHUNK_SIZE_WORDS = 350
CHUNK_OVERLAP_WORDS = 60
TOPK_BM25 = 25
TOPK_DENSE = 25
TOPK_AFTER_RERANK = 6
OLLAMA_BASE_URL = "http://localhost:11434" # локальний Ollama
DEFAULT_N_RECORDS = 500
# -----------------------------
# Data structures
# -----------------------------
@dataclass
class Chunk:
chunk_id: str
source: str
text: str
# -----------------------------
# Preprocessing + chunking
# -----------------------------
_whitespace_re = re.compile(r"\s+")
_token_re = re.compile(r"[A-Za-zА-Яа-яІіЇїЄє0-9]+")
def normalize_text(text: str) -> str:
text = (text or "").replace("\u00a0", " ")
text = _whitespace_re.sub(" ", text).strip()
return text
def tokenize_for_bm25(text: str) -> List[str]:
return [t.lower() for t in _token_re.findall(text or "")]
def chunk_text(
source: str,
text: str,
chunk_size_words: int = CHUNK_SIZE_WORDS,
overlap_words: int = CHUNK_OVERLAP_WORDS
) -> List[Chunk]:
"""Чанкання по словам з overlap."""
words = (text or "").split()
if not words:
return []
chunks: List[Chunk] = []
start = 0
idx = 0
while start < len(words):
end = min(start + chunk_size_words, len(words))
chunk_str = " ".join(words[start:end]).strip()
if chunk_str:
chunks.append(Chunk(
chunk_id=f"{source}::chunk{idx}",
source=source,
text=chunk_str
))
idx += 1
if end == len(words):
break
start = max(0, end - overlap_words)
return chunks
# -----------------------------
# HF dataset helpers
# -----------------------------
def _to_list(x: Any) -> List[str]:
"""ingredients/directions можуть бути list або строкою зі списком."""
if x is None:
return []
if isinstance(x, list):
return [str(i).strip() for i in x if str(i).strip()]
if isinstance(x, str):
s = x.strip()
if not s:
return []
try:
v = ast.literal_eval(s)
if isinstance(v, list):
return [str(i).strip() for i in v if str(i).strip()]
except Exception:
pass
if "\n" in s:
parts = [p.strip(" -•\t") for p in s.splitlines()]
else:
parts = [p.strip() for p in s.split(",")]
return [p for p in parts if p]
return [str(x).strip()] if str(x).strip() else []
def recipe_row_to_doc(row: Dict[str, Any], idx: int) -> Tuple[str, str]:
"""Повертає (source_name, full_text) для одного рецепта."""
title = (row.get("title") or "").strip()
link = (row.get("link") or "").strip()
src = (row.get("source") or "").strip()
ingredients = _to_list(row.get("ingredients"))
directions = _to_list(row.get("directions"))
safe_title = title[:80].replace("\n", " ").strip()
source_name = f"CookingRecipes#{idx}"
if safe_title:
source_name += f" | {safe_title}"
if link:
source_name += f" | {link}"
parts = []
parts.append(f"Title: {title or '(unknown)'}")
if src:
parts.append(f"Source: {src}")
if link:
parts.append(f"Link: {link}")
if ingredients:
parts.append("Ingredients:\n" + "\n".join(f"- {i}" for i in ingredients))
if directions:
parts.append("Directions:\n" + "\n".join(f"{i+1}. {d}" for i, d in enumerate(directions)))
full_text = normalize_text("\n\n".join(parts))
return source_name, full_text
def load_first_n_recipes(n: int, streaming: bool = True) -> List[Tuple[str, str]]:
n = int(max(0, n))
if n == 0:
return []
if streaming:
ds = load_dataset(HF_DATASET_NAME, split="train", streaming=True)
iterator = islice(ds, n)
else:
ds = load_dataset(HF_DATASET_NAME, split=f"train[:{n}]")
iterator = ds
docs: List[Tuple[str, str]] = []
for idx, row in enumerate(iterator):
source_name, text = recipe_row_to_doc(row, idx)
if text.strip():
docs.append((source_name, text))
return docs
# -----------------------------
# RAG Engine
# -----------------------------
class RAGEngine:
def __init__(self):
self.chunks: List[Chunk] = []
self.bm25: Optional[BM25Okapi] = None
self.bm25_corpus_tokens: List[List[str]] = []
self.dense_model: Optional[SentenceTransformer] = None
self.rerank_model: Optional[CrossEncoder] = None
self.chunk_embeddings: Optional[np.ndarray] = None
self.last_build_info: str = "Index not built yet."
def ensure_models(self) -> None:
if self.dense_model is None:
self.dense_model = SentenceTransformer(DENSE_MODEL_NAME)
if self.rerank_model is None:
self.rerank_model = CrossEncoder(RERANK_MODEL_NAME)
def build_from_dataset(self, n_records: int, streaming: bool) -> None:
docs = load_first_n_recipes(n_records, streaming=streaming)
all_chunks: List[Chunk] = []
for source, text in docs:
all_chunks.extend(chunk_text(source, text))
self.chunks = all_chunks
if not self.chunks:
self.bm25 = None
self.chunk_embeddings = None
self.last_build_info = "No chunks built (N too small or empty rows)."
return
# Models
self.ensure_models()
# BM25
self.bm25_corpus_tokens = [tokenize_for_bm25(c.text) for c in self.chunks]
self.bm25 = BM25Okapi(self.bm25_corpus_tokens)
# Dense embeddings
embs = self.dense_model.encode(
[c.text for c in self.chunks],
batch_size=64,
show_progress_bar=True,
normalize_embeddings=True
)
self.chunk_embeddings = np.asarray(embs, dtype=np.float32)
self.last_build_info = (
f"Built index from {len(docs)} recipes → {len(self.chunks)} chunks. "
f"Streaming={streaming}."
)
def retrieve_candidates(
self,
query: str,
use_bm25: bool,
use_dense: bool,
topk_bm25: int = TOPK_BM25,
topk_dense: int = TOPK_DENSE
) -> List[int]:
if not self.chunks:
return []
candidate_ids = set()
if use_bm25 and self.bm25 is not None:
q_tokens = tokenize_for_bm25(query)
scores = self.bm25.get_scores(q_tokens)
top_idx = np.argsort(scores)[::-1][:int(topk_bm25)]
candidate_ids.update(top_idx.tolist())
if use_dense and self.dense_model is not None and self.chunk_embeddings is not None:
q_emb = self.dense_model.encode([query], normalize_embeddings=True)
q_emb = np.asarray(q_emb, dtype=np.float32)[0]
sims = self.chunk_embeddings @ q_emb
top_idx = np.argsort(sims)[::-1][:int(topk_dense)]
candidate_ids.update(top_idx.tolist())
return list(candidate_ids)
def rerank(self, query: str, candidate_idx: List[int], top_n: int = TOPK_AFTER_RERANK) -> List[int]:
if not candidate_idx:
return []
if self.rerank_model is None:
return candidate_idx[:int(top_n)]
pairs = [(query, self.chunks[i].text) for i in candidate_idx]
scores = self.rerank_model.predict(pairs)
order = np.argsort(scores)[::-1]
return [candidate_idx[i] for i in order[:int(top_n)]]
def build_context(self, selected_idx: List[int]) -> str:
blocks = []
for j, i in enumerate(selected_idx, start=1):
c = self.chunks[i]
blocks.append(
f"[{j}] Source: {c.source} | {c.chunk_id}\n{c.text}"
)
return "\n\n---\n\n".join(blocks)
def answer_with_llm(self, query: str, context: str, model: str, api_key: str, temperature: float = 0.2) -> str:
model = (model or "").strip()
api_key = (api_key or "").strip()
if not model:
return "Model is empty."
if model.startswith("openai/") or model.startswith("gpt-"):
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
elif model.startswith("openrouter/"):
if api_key:
os.environ["OPENROUTER_API_KEY"] = api_key
elif model.startswith("groq/"):
if api_key:
os.environ["GROQ_API_KEY"] = api_key
system = (
"You are a helpful QA assistant.\n"
"Answer the user's question using ONLY the provided context.\n"
"If the answer is not in the context, say you don't know.\n"
"When you use facts from the context, add citations like [1] referring to the chunk numbers."
)
user = f"Question: {query}\n\nContext:\n{context}"
extra = {}
if model.startswith("ollama/"):
extra["api_base"] = OLLAMA_BASE_URL
resp = completion(
model=model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
temperature=temperature,
api_key=api_key if api_key else None,
**extra
)
return resp["choices"][0]["message"]["content"]
# -----------------------------
# Global engine + lock
# -----------------------------
ENGINE = RAGEngine()
ENGINE_LOCK = threading.Lock()
# build once on startup
with ENGINE_LOCK:
ENGINE.build_from_dataset(DEFAULT_N_RECORDS, streaming=True)
# -----------------------------
# Gradio UI callbacks
# -----------------------------
def rebuild_index(n_records: int, streaming: bool) -> str:
with ENGINE_LOCK:
ENGINE.build_from_dataset(int(n_records), bool(streaming))
return ENGINE.last_build_info
def qa(
question: str,
use_bm25: bool,
use_dense: bool,
use_rerank: bool,
model: str,
api_key: str,
topk_bm25: int,
topk_dense: int,
topk_final: int
):
question = (question or "").strip()
if not question:
return "Type a question.", ""
if not use_bm25 and not use_dense:
return "Enable BM25 and/or Dense retrieval (otherwise there is no context).", ""
with ENGINE_LOCK:
if not ENGINE.chunks:
return "Index is empty. Click 'Rebuild index' with N>0.", ""
cands = ENGINE.retrieve_candidates(
question,
use_bm25=use_bm25,
use_dense=use_dense,
topk_bm25=int(topk_bm25),
topk_dense=int(topk_dense)
)
if not cands:
return "No candidates retrieved.", ""
if use_rerank:
selected = ENGINE.rerank(question, cands, top_n=int(topk_final))
else:
selected = cands[:int(topk_final)]
context = ENGINE.build_context(selected)
try:
answer = ENGINE.answer_with_llm(question, context, model=model, api_key=api_key)
except Exception as e:
answer = f"LLM call failed: {type(e).__name__}: {e}"
return answer, context
# -----------------------------
# Launch UI
# -----------------------------
def build_demo() -> gr.Blocks:
with gr.Blocks(title="RAG QA on CookingRecipes (BM25 + Dense + Rerank)") as demo:
gr.Markdown(
"# RAG QA (CookingRecipes)\n"
f"Dataset: `{HF_DATASET_NAME}`. Індексуємо **перші N рецептів**.\n\n"
)
with gr.Row():
n_records = gr.Slider(50, 5000, value=DEFAULT_N_RECORDS, step=50, label="N recipes to index (first N)")
streaming = gr.Checkbox(value=True, label="Use streaming (recommended)")
build_btn = gr.Button("Rebuild index")
build_status = gr.Markdown(value=f"**Status:** {ENGINE.last_build_info}")
build_btn.click(fn=rebuild_index, inputs=[n_records, streaming], outputs=[build_status])
gr.Markdown("---")
with gr.Row():
question = gr.Textbox(label="Question", placeholder="Ask about recipes...", lines=2)
with gr.Row():
use_bm25 = gr.Checkbox(value=True, label="Use BM25 (keyword)")
use_dense = gr.Checkbox(value=True, label="Use Dense (embeddings)")
use_rerank = gr.Checkbox(value=True, label="Use Cross-Encoder Reranker")
with gr.Row():
model = gr.Textbox(
label="LLM model (LiteLLM)",
value="openai/gpt-4o-mini",
placeholder="e.g. openai/gpt-4o-mini OR groq/... OR openrouter/..."
)
api_key = gr.Textbox(
label="API key (leave empty for Ollama)",
placeholder="Empty for local ollama",
type="password"
)
with gr.Row():
topk_bm25 = gr.Slider(5, 80, value=TOPK_BM25, step=1, label="Top-K BM25 candidates")
topk_dense = gr.Slider(5, 80, value=TOPK_DENSE, step=1, label="Top-K Dense candidates")
topk_final = gr.Slider(1, 12, value=TOPK_AFTER_RERANK, step=1, label="Chunks to LLM (final)")
run_btn = gr.Button("Answer")
answer = gr.Markdown(label="Answer")
context = gr.Textbox(label="Retrieved context (debug)", lines=16)
run_btn.click(
fn=qa,
inputs=[question, use_bm25, use_dense, use_rerank, model, api_key, topk_bm25, topk_dense, topk_final],
outputs=[answer, context]
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()
# for local run with fixed port:
# demo.launch(server_name="127.0.0.1", server_port=7860)