import os import re import ast import threading from dataclasses import dataclass from typing import List, Tuple, Optional, Dict, Any from itertools import islice import numpy as np import gradio as gr from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer, CrossEncoder from litellm import completion from datasets import load_dataset # ----------------------------- # Config # ----------------------------- HF_DATASET_NAME = "CodeKapital/CookingRecipes" DENSE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" CHUNK_SIZE_WORDS = 350 CHUNK_OVERLAP_WORDS = 60 TOPK_BM25 = 25 TOPK_DENSE = 25 TOPK_AFTER_RERANK = 6 OLLAMA_BASE_URL = "http://localhost:11434" # локальний Ollama DEFAULT_N_RECORDS = 500 # ----------------------------- # Data structures # ----------------------------- @dataclass class Chunk: chunk_id: str source: str text: str # ----------------------------- # Preprocessing + chunking # ----------------------------- _whitespace_re = re.compile(r"\s+") _token_re = re.compile(r"[A-Za-zА-Яа-яІіЇїЄє0-9]+") def normalize_text(text: str) -> str: text = (text or "").replace("\u00a0", " ") text = _whitespace_re.sub(" ", text).strip() return text def tokenize_for_bm25(text: str) -> List[str]: return [t.lower() for t in _token_re.findall(text or "")] def chunk_text( source: str, text: str, chunk_size_words: int = CHUNK_SIZE_WORDS, overlap_words: int = CHUNK_OVERLAP_WORDS ) -> List[Chunk]: """Чанкання по словам з overlap.""" words = (text or "").split() if not words: return [] chunks: List[Chunk] = [] start = 0 idx = 0 while start < len(words): end = min(start + chunk_size_words, len(words)) chunk_str = " ".join(words[start:end]).strip() if chunk_str: chunks.append(Chunk( chunk_id=f"{source}::chunk{idx}", source=source, text=chunk_str )) idx += 1 if end == len(words): break start = max(0, end - overlap_words) return chunks # ----------------------------- # HF dataset helpers # ----------------------------- def _to_list(x: Any) -> List[str]: """ingredients/directions можуть бути list або строкою зі списком.""" if x is None: return [] if isinstance(x, list): return [str(i).strip() for i in x if str(i).strip()] if isinstance(x, str): s = x.strip() if not s: return [] try: v = ast.literal_eval(s) if isinstance(v, list): return [str(i).strip() for i in v if str(i).strip()] except Exception: pass if "\n" in s: parts = [p.strip(" -•\t") for p in s.splitlines()] else: parts = [p.strip() for p in s.split(",")] return [p for p in parts if p] return [str(x).strip()] if str(x).strip() else [] def recipe_row_to_doc(row: Dict[str, Any], idx: int) -> Tuple[str, str]: """Повертає (source_name, full_text) для одного рецепта.""" title = (row.get("title") or "").strip() link = (row.get("link") or "").strip() src = (row.get("source") or "").strip() ingredients = _to_list(row.get("ingredients")) directions = _to_list(row.get("directions")) safe_title = title[:80].replace("\n", " ").strip() source_name = f"CookingRecipes#{idx}" if safe_title: source_name += f" | {safe_title}" if link: source_name += f" | {link}" parts = [] parts.append(f"Title: {title or '(unknown)'}") if src: parts.append(f"Source: {src}") if link: parts.append(f"Link: {link}") if ingredients: parts.append("Ingredients:\n" + "\n".join(f"- {i}" for i in ingredients)) if directions: parts.append("Directions:\n" + "\n".join(f"{i+1}. {d}" for i, d in enumerate(directions))) full_text = normalize_text("\n\n".join(parts)) return source_name, full_text def load_first_n_recipes(n: int, streaming: bool = True) -> List[Tuple[str, str]]: n = int(max(0, n)) if n == 0: return [] if streaming: ds = load_dataset(HF_DATASET_NAME, split="train", streaming=True) iterator = islice(ds, n) else: ds = load_dataset(HF_DATASET_NAME, split=f"train[:{n}]") iterator = ds docs: List[Tuple[str, str]] = [] for idx, row in enumerate(iterator): source_name, text = recipe_row_to_doc(row, idx) if text.strip(): docs.append((source_name, text)) return docs # ----------------------------- # RAG Engine # ----------------------------- class RAGEngine: def __init__(self): self.chunks: List[Chunk] = [] self.bm25: Optional[BM25Okapi] = None self.bm25_corpus_tokens: List[List[str]] = [] self.dense_model: Optional[SentenceTransformer] = None self.rerank_model: Optional[CrossEncoder] = None self.chunk_embeddings: Optional[np.ndarray] = None self.last_build_info: str = "Index not built yet." def ensure_models(self) -> None: if self.dense_model is None: self.dense_model = SentenceTransformer(DENSE_MODEL_NAME) if self.rerank_model is None: self.rerank_model = CrossEncoder(RERANK_MODEL_NAME) def build_from_dataset(self, n_records: int, streaming: bool) -> None: docs = load_first_n_recipes(n_records, streaming=streaming) all_chunks: List[Chunk] = [] for source, text in docs: all_chunks.extend(chunk_text(source, text)) self.chunks = all_chunks if not self.chunks: self.bm25 = None self.chunk_embeddings = None self.last_build_info = "No chunks built (N too small or empty rows)." return # Models self.ensure_models() # BM25 self.bm25_corpus_tokens = [tokenize_for_bm25(c.text) for c in self.chunks] self.bm25 = BM25Okapi(self.bm25_corpus_tokens) # Dense embeddings embs = self.dense_model.encode( [c.text for c in self.chunks], batch_size=64, show_progress_bar=True, normalize_embeddings=True ) self.chunk_embeddings = np.asarray(embs, dtype=np.float32) self.last_build_info = ( f"Built index from {len(docs)} recipes → {len(self.chunks)} chunks. " f"Streaming={streaming}." ) def retrieve_candidates( self, query: str, use_bm25: bool, use_dense: bool, topk_bm25: int = TOPK_BM25, topk_dense: int = TOPK_DENSE ) -> List[int]: if not self.chunks: return [] candidate_ids = set() if use_bm25 and self.bm25 is not None: q_tokens = tokenize_for_bm25(query) scores = self.bm25.get_scores(q_tokens) top_idx = np.argsort(scores)[::-1][:int(topk_bm25)] candidate_ids.update(top_idx.tolist()) if use_dense and self.dense_model is not None and self.chunk_embeddings is not None: q_emb = self.dense_model.encode([query], normalize_embeddings=True) q_emb = np.asarray(q_emb, dtype=np.float32)[0] sims = self.chunk_embeddings @ q_emb top_idx = np.argsort(sims)[::-1][:int(topk_dense)] candidate_ids.update(top_idx.tolist()) return list(candidate_ids) def rerank(self, query: str, candidate_idx: List[int], top_n: int = TOPK_AFTER_RERANK) -> List[int]: if not candidate_idx: return [] if self.rerank_model is None: return candidate_idx[:int(top_n)] pairs = [(query, self.chunks[i].text) for i in candidate_idx] scores = self.rerank_model.predict(pairs) order = np.argsort(scores)[::-1] return [candidate_idx[i] for i in order[:int(top_n)]] def build_context(self, selected_idx: List[int]) -> str: blocks = [] for j, i in enumerate(selected_idx, start=1): c = self.chunks[i] blocks.append( f"[{j}] Source: {c.source} | {c.chunk_id}\n{c.text}" ) return "\n\n---\n\n".join(blocks) def answer_with_llm(self, query: str, context: str, model: str, api_key: str, temperature: float = 0.2) -> str: model = (model or "").strip() api_key = (api_key or "").strip() if not model: return "Model is empty." if model.startswith("openai/") or model.startswith("gpt-"): if api_key: os.environ["OPENAI_API_KEY"] = api_key elif model.startswith("openrouter/"): if api_key: os.environ["OPENROUTER_API_KEY"] = api_key elif model.startswith("groq/"): if api_key: os.environ["GROQ_API_KEY"] = api_key system = ( "You are a helpful QA assistant.\n" "Answer the user's question using ONLY the provided context.\n" "If the answer is not in the context, say you don't know.\n" "When you use facts from the context, add citations like [1] referring to the chunk numbers." ) user = f"Question: {query}\n\nContext:\n{context}" extra = {} if model.startswith("ollama/"): extra["api_base"] = OLLAMA_BASE_URL resp = completion( model=model, messages=[ {"role": "system", "content": system}, {"role": "user", "content": user}, ], temperature=temperature, api_key=api_key if api_key else None, **extra ) return resp["choices"][0]["message"]["content"] # ----------------------------- # Global engine + lock # ----------------------------- ENGINE = RAGEngine() ENGINE_LOCK = threading.Lock() # build once on startup with ENGINE_LOCK: ENGINE.build_from_dataset(DEFAULT_N_RECORDS, streaming=True) # ----------------------------- # Gradio UI callbacks # ----------------------------- def rebuild_index(n_records: int, streaming: bool) -> str: with ENGINE_LOCK: ENGINE.build_from_dataset(int(n_records), bool(streaming)) return ENGINE.last_build_info def qa( question: str, use_bm25: bool, use_dense: bool, use_rerank: bool, model: str, api_key: str, topk_bm25: int, topk_dense: int, topk_final: int ): question = (question or "").strip() if not question: return "Type a question.", "" if not use_bm25 and not use_dense: return "Enable BM25 and/or Dense retrieval (otherwise there is no context).", "" with ENGINE_LOCK: if not ENGINE.chunks: return "Index is empty. Click 'Rebuild index' with N>0.", "" cands = ENGINE.retrieve_candidates( question, use_bm25=use_bm25, use_dense=use_dense, topk_bm25=int(topk_bm25), topk_dense=int(topk_dense) ) if not cands: return "No candidates retrieved.", "" if use_rerank: selected = ENGINE.rerank(question, cands, top_n=int(topk_final)) else: selected = cands[:int(topk_final)] context = ENGINE.build_context(selected) try: answer = ENGINE.answer_with_llm(question, context, model=model, api_key=api_key) except Exception as e: answer = f"LLM call failed: {type(e).__name__}: {e}" return answer, context # ----------------------------- # Launch UI # ----------------------------- def build_demo() -> gr.Blocks: with gr.Blocks(title="RAG QA on CookingRecipes (BM25 + Dense + Rerank)") as demo: gr.Markdown( "# RAG QA (CookingRecipes)\n" f"Dataset: `{HF_DATASET_NAME}`. Індексуємо **перші N рецептів**.\n\n" ) with gr.Row(): n_records = gr.Slider(50, 5000, value=DEFAULT_N_RECORDS, step=50, label="N recipes to index (first N)") streaming = gr.Checkbox(value=True, label="Use streaming (recommended)") build_btn = gr.Button("Rebuild index") build_status = gr.Markdown(value=f"**Status:** {ENGINE.last_build_info}") build_btn.click(fn=rebuild_index, inputs=[n_records, streaming], outputs=[build_status]) gr.Markdown("---") with gr.Row(): question = gr.Textbox(label="Question", placeholder="Ask about recipes...", lines=2) with gr.Row(): use_bm25 = gr.Checkbox(value=True, label="Use BM25 (keyword)") use_dense = gr.Checkbox(value=True, label="Use Dense (embeddings)") use_rerank = gr.Checkbox(value=True, label="Use Cross-Encoder Reranker") with gr.Row(): model = gr.Textbox( label="LLM model (LiteLLM)", value="openai/gpt-4o-mini", placeholder="e.g. openai/gpt-4o-mini OR groq/... OR openrouter/..." ) api_key = gr.Textbox( label="API key (leave empty for Ollama)", placeholder="Empty for local ollama", type="password" ) with gr.Row(): topk_bm25 = gr.Slider(5, 80, value=TOPK_BM25, step=1, label="Top-K BM25 candidates") topk_dense = gr.Slider(5, 80, value=TOPK_DENSE, step=1, label="Top-K Dense candidates") topk_final = gr.Slider(1, 12, value=TOPK_AFTER_RERANK, step=1, label="Chunks to LLM (final)") run_btn = gr.Button("Answer") answer = gr.Markdown(label="Answer") context = gr.Textbox(label="Retrieved context (debug)", lines=16) run_btn.click( fn=qa, inputs=[question, use_bm25, use_dense, use_rerank, model, api_key, topk_bm25, topk_dense, topk_final], outputs=[answer, context] ) return demo if __name__ == "__main__": demo = build_demo() demo.launch() # for local run with fixed port: # demo.launch(server_name="127.0.0.1", server_port=7860)