|
|
import os |
|
|
import re |
|
|
import ast |
|
|
import threading |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Tuple, Optional, Dict, Any |
|
|
from itertools import islice |
|
|
|
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from rank_bm25 import BM25Okapi |
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
from litellm import completion |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_DATASET_NAME = "CodeKapital/CookingRecipes" |
|
|
|
|
|
DENSE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" |
|
|
|
|
|
CHUNK_SIZE_WORDS = 350 |
|
|
CHUNK_OVERLAP_WORDS = 60 |
|
|
|
|
|
TOPK_BM25 = 25 |
|
|
TOPK_DENSE = 25 |
|
|
TOPK_AFTER_RERANK = 6 |
|
|
|
|
|
OLLAMA_BASE_URL = "http://localhost:11434" |
|
|
|
|
|
DEFAULT_N_RECORDS = 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Chunk: |
|
|
chunk_id: str |
|
|
source: str |
|
|
text: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_whitespace_re = re.compile(r"\s+") |
|
|
_token_re = re.compile(r"[A-Za-zА-Яа-яІіЇїЄє0-9]+") |
|
|
|
|
|
|
|
|
def normalize_text(text: str) -> str: |
|
|
text = (text or "").replace("\u00a0", " ") |
|
|
text = _whitespace_re.sub(" ", text).strip() |
|
|
return text |
|
|
|
|
|
|
|
|
def tokenize_for_bm25(text: str) -> List[str]: |
|
|
return [t.lower() for t in _token_re.findall(text or "")] |
|
|
|
|
|
|
|
|
def chunk_text( |
|
|
source: str, |
|
|
text: str, |
|
|
chunk_size_words: int = CHUNK_SIZE_WORDS, |
|
|
overlap_words: int = CHUNK_OVERLAP_WORDS |
|
|
) -> List[Chunk]: |
|
|
"""Чанкання по словам з overlap.""" |
|
|
words = (text or "").split() |
|
|
if not words: |
|
|
return [] |
|
|
|
|
|
chunks: List[Chunk] = [] |
|
|
start = 0 |
|
|
idx = 0 |
|
|
|
|
|
while start < len(words): |
|
|
end = min(start + chunk_size_words, len(words)) |
|
|
chunk_str = " ".join(words[start:end]).strip() |
|
|
|
|
|
if chunk_str: |
|
|
chunks.append(Chunk( |
|
|
chunk_id=f"{source}::chunk{idx}", |
|
|
source=source, |
|
|
text=chunk_str |
|
|
)) |
|
|
idx += 1 |
|
|
|
|
|
if end == len(words): |
|
|
break |
|
|
start = max(0, end - overlap_words) |
|
|
|
|
|
return chunks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _to_list(x: Any) -> List[str]: |
|
|
"""ingredients/directions можуть бути list або строкою зі списком.""" |
|
|
if x is None: |
|
|
return [] |
|
|
if isinstance(x, list): |
|
|
return [str(i).strip() for i in x if str(i).strip()] |
|
|
if isinstance(x, str): |
|
|
s = x.strip() |
|
|
if not s: |
|
|
return [] |
|
|
try: |
|
|
v = ast.literal_eval(s) |
|
|
if isinstance(v, list): |
|
|
return [str(i).strip() for i in v if str(i).strip()] |
|
|
except Exception: |
|
|
pass |
|
|
if "\n" in s: |
|
|
parts = [p.strip(" -•\t") for p in s.splitlines()] |
|
|
else: |
|
|
parts = [p.strip() for p in s.split(",")] |
|
|
return [p for p in parts if p] |
|
|
return [str(x).strip()] if str(x).strip() else [] |
|
|
|
|
|
|
|
|
def recipe_row_to_doc(row: Dict[str, Any], idx: int) -> Tuple[str, str]: |
|
|
"""Повертає (source_name, full_text) для одного рецепта.""" |
|
|
title = (row.get("title") or "").strip() |
|
|
link = (row.get("link") or "").strip() |
|
|
src = (row.get("source") or "").strip() |
|
|
|
|
|
ingredients = _to_list(row.get("ingredients")) |
|
|
directions = _to_list(row.get("directions")) |
|
|
|
|
|
safe_title = title[:80].replace("\n", " ").strip() |
|
|
source_name = f"CookingRecipes#{idx}" |
|
|
if safe_title: |
|
|
source_name += f" | {safe_title}" |
|
|
if link: |
|
|
source_name += f" | {link}" |
|
|
|
|
|
parts = [] |
|
|
parts.append(f"Title: {title or '(unknown)'}") |
|
|
if src: |
|
|
parts.append(f"Source: {src}") |
|
|
if link: |
|
|
parts.append(f"Link: {link}") |
|
|
|
|
|
if ingredients: |
|
|
parts.append("Ingredients:\n" + "\n".join(f"- {i}" for i in ingredients)) |
|
|
if directions: |
|
|
parts.append("Directions:\n" + "\n".join(f"{i+1}. {d}" for i, d in enumerate(directions))) |
|
|
|
|
|
full_text = normalize_text("\n\n".join(parts)) |
|
|
return source_name, full_text |
|
|
|
|
|
|
|
|
def load_first_n_recipes(n: int, streaming: bool = True) -> List[Tuple[str, str]]: |
|
|
n = int(max(0, n)) |
|
|
if n == 0: |
|
|
return [] |
|
|
|
|
|
if streaming: |
|
|
ds = load_dataset(HF_DATASET_NAME, split="train", streaming=True) |
|
|
iterator = islice(ds, n) |
|
|
else: |
|
|
ds = load_dataset(HF_DATASET_NAME, split=f"train[:{n}]") |
|
|
iterator = ds |
|
|
|
|
|
docs: List[Tuple[str, str]] = [] |
|
|
for idx, row in enumerate(iterator): |
|
|
source_name, text = recipe_row_to_doc(row, idx) |
|
|
if text.strip(): |
|
|
docs.append((source_name, text)) |
|
|
return docs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RAGEngine: |
|
|
def __init__(self): |
|
|
self.chunks: List[Chunk] = [] |
|
|
self.bm25: Optional[BM25Okapi] = None |
|
|
self.bm25_corpus_tokens: List[List[str]] = [] |
|
|
|
|
|
self.dense_model: Optional[SentenceTransformer] = None |
|
|
self.rerank_model: Optional[CrossEncoder] = None |
|
|
self.chunk_embeddings: Optional[np.ndarray] = None |
|
|
|
|
|
self.last_build_info: str = "Index not built yet." |
|
|
|
|
|
def ensure_models(self) -> None: |
|
|
if self.dense_model is None: |
|
|
self.dense_model = SentenceTransformer(DENSE_MODEL_NAME) |
|
|
if self.rerank_model is None: |
|
|
self.rerank_model = CrossEncoder(RERANK_MODEL_NAME) |
|
|
|
|
|
def build_from_dataset(self, n_records: int, streaming: bool) -> None: |
|
|
docs = load_first_n_recipes(n_records, streaming=streaming) |
|
|
|
|
|
all_chunks: List[Chunk] = [] |
|
|
for source, text in docs: |
|
|
all_chunks.extend(chunk_text(source, text)) |
|
|
|
|
|
self.chunks = all_chunks |
|
|
|
|
|
if not self.chunks: |
|
|
self.bm25 = None |
|
|
self.chunk_embeddings = None |
|
|
self.last_build_info = "No chunks built (N too small or empty rows)." |
|
|
return |
|
|
|
|
|
|
|
|
self.ensure_models() |
|
|
|
|
|
|
|
|
self.bm25_corpus_tokens = [tokenize_for_bm25(c.text) for c in self.chunks] |
|
|
self.bm25 = BM25Okapi(self.bm25_corpus_tokens) |
|
|
|
|
|
|
|
|
embs = self.dense_model.encode( |
|
|
[c.text for c in self.chunks], |
|
|
batch_size=64, |
|
|
show_progress_bar=True, |
|
|
normalize_embeddings=True |
|
|
) |
|
|
self.chunk_embeddings = np.asarray(embs, dtype=np.float32) |
|
|
|
|
|
self.last_build_info = ( |
|
|
f"Built index from {len(docs)} recipes → {len(self.chunks)} chunks. " |
|
|
f"Streaming={streaming}." |
|
|
) |
|
|
|
|
|
def retrieve_candidates( |
|
|
self, |
|
|
query: str, |
|
|
use_bm25: bool, |
|
|
use_dense: bool, |
|
|
topk_bm25: int = TOPK_BM25, |
|
|
topk_dense: int = TOPK_DENSE |
|
|
) -> List[int]: |
|
|
if not self.chunks: |
|
|
return [] |
|
|
|
|
|
candidate_ids = set() |
|
|
|
|
|
if use_bm25 and self.bm25 is not None: |
|
|
q_tokens = tokenize_for_bm25(query) |
|
|
scores = self.bm25.get_scores(q_tokens) |
|
|
top_idx = np.argsort(scores)[::-1][:int(topk_bm25)] |
|
|
candidate_ids.update(top_idx.tolist()) |
|
|
|
|
|
if use_dense and self.dense_model is not None and self.chunk_embeddings is not None: |
|
|
q_emb = self.dense_model.encode([query], normalize_embeddings=True) |
|
|
q_emb = np.asarray(q_emb, dtype=np.float32)[0] |
|
|
sims = self.chunk_embeddings @ q_emb |
|
|
top_idx = np.argsort(sims)[::-1][:int(topk_dense)] |
|
|
candidate_ids.update(top_idx.tolist()) |
|
|
|
|
|
return list(candidate_ids) |
|
|
|
|
|
def rerank(self, query: str, candidate_idx: List[int], top_n: int = TOPK_AFTER_RERANK) -> List[int]: |
|
|
if not candidate_idx: |
|
|
return [] |
|
|
if self.rerank_model is None: |
|
|
return candidate_idx[:int(top_n)] |
|
|
|
|
|
pairs = [(query, self.chunks[i].text) for i in candidate_idx] |
|
|
scores = self.rerank_model.predict(pairs) |
|
|
order = np.argsort(scores)[::-1] |
|
|
return [candidate_idx[i] for i in order[:int(top_n)]] |
|
|
|
|
|
def build_context(self, selected_idx: List[int]) -> str: |
|
|
blocks = [] |
|
|
for j, i in enumerate(selected_idx, start=1): |
|
|
c = self.chunks[i] |
|
|
blocks.append( |
|
|
f"[{j}] Source: {c.source} | {c.chunk_id}\n{c.text}" |
|
|
) |
|
|
return "\n\n---\n\n".join(blocks) |
|
|
|
|
|
def answer_with_llm(self, query: str, context: str, model: str, api_key: str, temperature: float = 0.2) -> str: |
|
|
model = (model or "").strip() |
|
|
api_key = (api_key or "").strip() |
|
|
if not model: |
|
|
return "Model is empty." |
|
|
|
|
|
if model.startswith("openai/") or model.startswith("gpt-"): |
|
|
if api_key: |
|
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
elif model.startswith("openrouter/"): |
|
|
if api_key: |
|
|
os.environ["OPENROUTER_API_KEY"] = api_key |
|
|
elif model.startswith("groq/"): |
|
|
if api_key: |
|
|
os.environ["GROQ_API_KEY"] = api_key |
|
|
|
|
|
system = ( |
|
|
"You are a helpful QA assistant.\n" |
|
|
"Answer the user's question using ONLY the provided context.\n" |
|
|
"If the answer is not in the context, say you don't know.\n" |
|
|
"When you use facts from the context, add citations like [1] referring to the chunk numbers." |
|
|
) |
|
|
user = f"Question: {query}\n\nContext:\n{context}" |
|
|
|
|
|
extra = {} |
|
|
if model.startswith("ollama/"): |
|
|
extra["api_base"] = OLLAMA_BASE_URL |
|
|
|
|
|
resp = completion( |
|
|
model=model, |
|
|
messages=[ |
|
|
{"role": "system", "content": system}, |
|
|
{"role": "user", "content": user}, |
|
|
], |
|
|
temperature=temperature, |
|
|
api_key=api_key if api_key else None, |
|
|
**extra |
|
|
) |
|
|
return resp["choices"][0]["message"]["content"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ENGINE = RAGEngine() |
|
|
ENGINE_LOCK = threading.Lock() |
|
|
|
|
|
|
|
|
with ENGINE_LOCK: |
|
|
ENGINE.build_from_dataset(DEFAULT_N_RECORDS, streaming=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rebuild_index(n_records: int, streaming: bool) -> str: |
|
|
with ENGINE_LOCK: |
|
|
ENGINE.build_from_dataset(int(n_records), bool(streaming)) |
|
|
return ENGINE.last_build_info |
|
|
|
|
|
|
|
|
def qa( |
|
|
question: str, |
|
|
use_bm25: bool, |
|
|
use_dense: bool, |
|
|
use_rerank: bool, |
|
|
model: str, |
|
|
api_key: str, |
|
|
topk_bm25: int, |
|
|
topk_dense: int, |
|
|
topk_final: int |
|
|
): |
|
|
question = (question or "").strip() |
|
|
if not question: |
|
|
return "Type a question.", "" |
|
|
|
|
|
if not use_bm25 and not use_dense: |
|
|
return "Enable BM25 and/or Dense retrieval (otherwise there is no context).", "" |
|
|
|
|
|
with ENGINE_LOCK: |
|
|
if not ENGINE.chunks: |
|
|
return "Index is empty. Click 'Rebuild index' with N>0.", "" |
|
|
|
|
|
cands = ENGINE.retrieve_candidates( |
|
|
question, |
|
|
use_bm25=use_bm25, |
|
|
use_dense=use_dense, |
|
|
topk_bm25=int(topk_bm25), |
|
|
topk_dense=int(topk_dense) |
|
|
) |
|
|
if not cands: |
|
|
return "No candidates retrieved.", "" |
|
|
|
|
|
if use_rerank: |
|
|
selected = ENGINE.rerank(question, cands, top_n=int(topk_final)) |
|
|
else: |
|
|
selected = cands[:int(topk_final)] |
|
|
|
|
|
context = ENGINE.build_context(selected) |
|
|
|
|
|
try: |
|
|
answer = ENGINE.answer_with_llm(question, context, model=model, api_key=api_key) |
|
|
except Exception as e: |
|
|
answer = f"LLM call failed: {type(e).__name__}: {e}" |
|
|
|
|
|
return answer, context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_demo() -> gr.Blocks: |
|
|
with gr.Blocks(title="RAG QA on CookingRecipes (BM25 + Dense + Rerank)") as demo: |
|
|
gr.Markdown( |
|
|
"# RAG QA (CookingRecipes)\n" |
|
|
f"Dataset: `{HF_DATASET_NAME}`. Індексуємо **перші N рецептів**.\n\n" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
n_records = gr.Slider(50, 5000, value=DEFAULT_N_RECORDS, step=50, label="N recipes to index (first N)") |
|
|
streaming = gr.Checkbox(value=True, label="Use streaming (recommended)") |
|
|
|
|
|
build_btn = gr.Button("Rebuild index") |
|
|
build_status = gr.Markdown(value=f"**Status:** {ENGINE.last_build_info}") |
|
|
|
|
|
build_btn.click(fn=rebuild_index, inputs=[n_records, streaming], outputs=[build_status]) |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Row(): |
|
|
question = gr.Textbox(label="Question", placeholder="Ask about recipes...", lines=2) |
|
|
|
|
|
with gr.Row(): |
|
|
use_bm25 = gr.Checkbox(value=True, label="Use BM25 (keyword)") |
|
|
use_dense = gr.Checkbox(value=True, label="Use Dense (embeddings)") |
|
|
use_rerank = gr.Checkbox(value=True, label="Use Cross-Encoder Reranker") |
|
|
|
|
|
with gr.Row(): |
|
|
model = gr.Textbox( |
|
|
label="LLM model (LiteLLM)", |
|
|
value="openai/gpt-4o-mini", |
|
|
placeholder="e.g. openai/gpt-4o-mini OR groq/... OR openrouter/..." |
|
|
) |
|
|
api_key = gr.Textbox( |
|
|
label="API key (leave empty for Ollama)", |
|
|
placeholder="Empty for local ollama", |
|
|
type="password" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
topk_bm25 = gr.Slider(5, 80, value=TOPK_BM25, step=1, label="Top-K BM25 candidates") |
|
|
topk_dense = gr.Slider(5, 80, value=TOPK_DENSE, step=1, label="Top-K Dense candidates") |
|
|
topk_final = gr.Slider(1, 12, value=TOPK_AFTER_RERANK, step=1, label="Chunks to LLM (final)") |
|
|
|
|
|
run_btn = gr.Button("Answer") |
|
|
|
|
|
answer = gr.Markdown(label="Answer") |
|
|
context = gr.Textbox(label="Retrieved context (debug)", lines=16) |
|
|
|
|
|
run_btn.click( |
|
|
fn=qa, |
|
|
inputs=[question, use_bm25, use_dense, use_rerank, model, api_key, topk_bm25, topk_dense, topk_final], |
|
|
outputs=[answer, context] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = build_demo() |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|