renatavl commited on
Commit
860ef8a
·
0 Parent(s):
Files changed (2) hide show
  1. app.py +454 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import ast
4
+ import threading
5
+ from dataclasses import dataclass
6
+ from typing import List, Tuple, Optional, Dict, Any
7
+ from itertools import islice
8
+
9
+ import numpy as np
10
+ import gradio as gr
11
+ from rank_bm25 import BM25Okapi
12
+ from sentence_transformers import SentenceTransformer, CrossEncoder
13
+ from litellm import completion
14
+ from datasets import load_dataset
15
+
16
+
17
+ # -----------------------------
18
+ # Config
19
+ # -----------------------------
20
+ HF_DATASET_NAME = "CodeKapital/CookingRecipes"
21
+
22
+ DENSE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
23
+ RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
24
+
25
+ CHUNK_SIZE_WORDS = 350
26
+ CHUNK_OVERLAP_WORDS = 60
27
+
28
+ TOPK_BM25 = 25
29
+ TOPK_DENSE = 25
30
+ TOPK_AFTER_RERANK = 6
31
+
32
+ OLLAMA_BASE_URL = "http://localhost:11434" # локальний Ollama
33
+
34
+ DEFAULT_N_RECORDS = 500
35
+
36
+
37
+ # -----------------------------
38
+ # Data structures
39
+ # -----------------------------
40
+ @dataclass
41
+ class Chunk:
42
+ chunk_id: str
43
+ source: str
44
+ text: str
45
+
46
+
47
+ # -----------------------------
48
+ # Preprocessing + chunking
49
+ # -----------------------------
50
+ _whitespace_re = re.compile(r"\s+")
51
+ _token_re = re.compile(r"[A-Za-zА-Яа-яІіЇїЄє0-9]+")
52
+
53
+
54
+ def normalize_text(text: str) -> str:
55
+ text = (text or "").replace("\u00a0", " ")
56
+ text = _whitespace_re.sub(" ", text).strip()
57
+ return text
58
+
59
+
60
+ def tokenize_for_bm25(text: str) -> List[str]:
61
+ return [t.lower() for t in _token_re.findall(text or "")]
62
+
63
+
64
+ def chunk_text(
65
+ source: str,
66
+ text: str,
67
+ chunk_size_words: int = CHUNK_SIZE_WORDS,
68
+ overlap_words: int = CHUNK_OVERLAP_WORDS
69
+ ) -> List[Chunk]:
70
+ """Чанкання по словам з overlap."""
71
+ words = (text or "").split()
72
+ if not words:
73
+ return []
74
+
75
+ chunks: List[Chunk] = []
76
+ start = 0
77
+ idx = 0
78
+
79
+ while start < len(words):
80
+ end = min(start + chunk_size_words, len(words))
81
+ chunk_str = " ".join(words[start:end]).strip()
82
+
83
+ if chunk_str:
84
+ chunks.append(Chunk(
85
+ chunk_id=f"{source}::chunk{idx}",
86
+ source=source,
87
+ text=chunk_str
88
+ ))
89
+ idx += 1
90
+
91
+ if end == len(words):
92
+ break
93
+ start = max(0, end - overlap_words)
94
+
95
+ return chunks
96
+
97
+
98
+ # -----------------------------
99
+ # HF dataset helpers
100
+ # -----------------------------
101
+ def _to_list(x: Any) -> List[str]:
102
+ """ingredients/directions можуть бути list або строкою зі списком."""
103
+ if x is None:
104
+ return []
105
+ if isinstance(x, list):
106
+ return [str(i).strip() for i in x if str(i).strip()]
107
+ if isinstance(x, str):
108
+ s = x.strip()
109
+ if not s:
110
+ return []
111
+ try:
112
+ v = ast.literal_eval(s)
113
+ if isinstance(v, list):
114
+ return [str(i).strip() for i in v if str(i).strip()]
115
+ except Exception:
116
+ pass
117
+ if "\n" in s:
118
+ parts = [p.strip(" -•\t") for p in s.splitlines()]
119
+ else:
120
+ parts = [p.strip() for p in s.split(",")]
121
+ return [p for p in parts if p]
122
+ return [str(x).strip()] if str(x).strip() else []
123
+
124
+
125
+ def recipe_row_to_doc(row: Dict[str, Any], idx: int) -> Tuple[str, str]:
126
+ """Повертає (source_name, full_text) для одного рецепта."""
127
+ title = (row.get("title") or "").strip()
128
+ link = (row.get("link") or "").strip()
129
+ src = (row.get("source") or "").strip()
130
+
131
+ ingredients = _to_list(row.get("ingredients"))
132
+ directions = _to_list(row.get("directions"))
133
+
134
+ safe_title = title[:80].replace("\n", " ").strip()
135
+ source_name = f"CookingRecipes#{idx}"
136
+ if safe_title:
137
+ source_name += f" | {safe_title}"
138
+ if link:
139
+ source_name += f" | {link}"
140
+
141
+ parts = []
142
+ parts.append(f"Title: {title or '(unknown)'}")
143
+ if src:
144
+ parts.append(f"Source: {src}")
145
+ if link:
146
+ parts.append(f"Link: {link}")
147
+
148
+ if ingredients:
149
+ parts.append("Ingredients:\n" + "\n".join(f"- {i}" for i in ingredients))
150
+ if directions:
151
+ parts.append("Directions:\n" + "\n".join(f"{i+1}. {d}" for i, d in enumerate(directions)))
152
+
153
+ full_text = normalize_text("\n\n".join(parts))
154
+ return source_name, full_text
155
+
156
+
157
+ def load_first_n_recipes(n: int, streaming: bool = True) -> List[Tuple[str, str]]:
158
+ n = int(max(0, n))
159
+ if n == 0:
160
+ return []
161
+
162
+ if streaming:
163
+ ds = load_dataset(HF_DATASET_NAME, split="train", streaming=True)
164
+ iterator = islice(ds, n)
165
+ else:
166
+ ds = load_dataset(HF_DATASET_NAME, split=f"train[:{n}]")
167
+ iterator = ds
168
+
169
+ docs: List[Tuple[str, str]] = []
170
+ for idx, row in enumerate(iterator):
171
+ source_name, text = recipe_row_to_doc(row, idx)
172
+ if text.strip():
173
+ docs.append((source_name, text))
174
+ return docs
175
+
176
+
177
+ # -----------------------------
178
+ # RAG Engine
179
+ # -----------------------------
180
+ class RAGEngine:
181
+ def __init__(self):
182
+ self.chunks: List[Chunk] = []
183
+ self.bm25: Optional[BM25Okapi] = None
184
+ self.bm25_corpus_tokens: List[List[str]] = []
185
+
186
+ self.dense_model: Optional[SentenceTransformer] = None
187
+ self.rerank_model: Optional[CrossEncoder] = None
188
+ self.chunk_embeddings: Optional[np.ndarray] = None
189
+
190
+ self.last_build_info: str = "Index not built yet."
191
+
192
+ def ensure_models(self) -> None:
193
+ if self.dense_model is None:
194
+ self.dense_model = SentenceTransformer(DENSE_MODEL_NAME)
195
+ if self.rerank_model is None:
196
+ self.rerank_model = CrossEncoder(RERANK_MODEL_NAME)
197
+
198
+ def build_from_dataset(self, n_records: int, streaming: bool) -> None:
199
+ docs = load_first_n_recipes(n_records, streaming=streaming)
200
+
201
+ all_chunks: List[Chunk] = []
202
+ for source, text in docs:
203
+ all_chunks.extend(chunk_text(source, text))
204
+
205
+ self.chunks = all_chunks
206
+
207
+ if not self.chunks:
208
+ self.bm25 = None
209
+ self.chunk_embeddings = None
210
+ self.last_build_info = "No chunks built (N too small or empty rows)."
211
+ return
212
+
213
+ # Models
214
+ self.ensure_models()
215
+
216
+ # BM25
217
+ self.bm25_corpus_tokens = [tokenize_for_bm25(c.text) for c in self.chunks]
218
+ self.bm25 = BM25Okapi(self.bm25_corpus_tokens)
219
+
220
+ # Dense embeddings
221
+ embs = self.dense_model.encode(
222
+ [c.text for c in self.chunks],
223
+ batch_size=64,
224
+ show_progress_bar=True,
225
+ normalize_embeddings=True
226
+ )
227
+ self.chunk_embeddings = np.asarray(embs, dtype=np.float32)
228
+
229
+ self.last_build_info = (
230
+ f"Built index from {len(docs)} recipes → {len(self.chunks)} chunks. "
231
+ f"Streaming={streaming}."
232
+ )
233
+
234
+ def retrieve_candidates(
235
+ self,
236
+ query: str,
237
+ use_bm25: bool,
238
+ use_dense: bool,
239
+ topk_bm25: int = TOPK_BM25,
240
+ topk_dense: int = TOPK_DENSE
241
+ ) -> List[int]:
242
+ if not self.chunks:
243
+ return []
244
+
245
+ candidate_ids = set()
246
+
247
+ if use_bm25 and self.bm25 is not None:
248
+ q_tokens = tokenize_for_bm25(query)
249
+ scores = self.bm25.get_scores(q_tokens)
250
+ top_idx = np.argsort(scores)[::-1][:int(topk_bm25)]
251
+ candidate_ids.update(top_idx.tolist())
252
+
253
+ if use_dense and self.dense_model is not None and self.chunk_embeddings is not None:
254
+ q_emb = self.dense_model.encode([query], normalize_embeddings=True)
255
+ q_emb = np.asarray(q_emb, dtype=np.float32)[0]
256
+ sims = self.chunk_embeddings @ q_emb
257
+ top_idx = np.argsort(sims)[::-1][:int(topk_dense)]
258
+ candidate_ids.update(top_idx.tolist())
259
+
260
+ return list(candidate_ids)
261
+
262
+ def rerank(self, query: str, candidate_idx: List[int], top_n: int = TOPK_AFTER_RERANK) -> List[int]:
263
+ if not candidate_idx:
264
+ return []
265
+ if self.rerank_model is None:
266
+ return candidate_idx[:int(top_n)]
267
+
268
+ pairs = [(query, self.chunks[i].text) for i in candidate_idx]
269
+ scores = self.rerank_model.predict(pairs)
270
+ order = np.argsort(scores)[::-1]
271
+ return [candidate_idx[i] for i in order[:int(top_n)]]
272
+
273
+ def build_context(self, selected_idx: List[int]) -> str:
274
+ blocks = []
275
+ for j, i in enumerate(selected_idx, start=1):
276
+ c = self.chunks[i]
277
+ blocks.append(
278
+ f"[{j}] Source: {c.source} | {c.chunk_id}\n{c.text}"
279
+ )
280
+ return "\n\n---\n\n".join(blocks)
281
+
282
+ def answer_with_llm(self, query: str, context: str, model: str, api_key: str, temperature: float = 0.2) -> str:
283
+ model = (model or "").strip()
284
+ api_key = (api_key or "").strip()
285
+ if not model:
286
+ return "Model is empty."
287
+
288
+ if model.startswith("openai/") or model.startswith("gpt-"):
289
+ if api_key:
290
+ os.environ["OPENAI_API_KEY"] = api_key
291
+ elif model.startswith("openrouter/"):
292
+ if api_key:
293
+ os.environ["OPENROUTER_API_KEY"] = api_key
294
+ elif model.startswith("groq/"):
295
+ if api_key:
296
+ os.environ["GROQ_API_KEY"] = api_key
297
+
298
+ system = (
299
+ "You are a helpful QA assistant.\n"
300
+ "Answer the user's question using ONLY the provided context.\n"
301
+ "If the answer is not in the context, say you don't know.\n"
302
+ "When you use facts from the context, add citations like [1] referring to the chunk numbers."
303
+ )
304
+ user = f"Question: {query}\n\nContext:\n{context}"
305
+
306
+ extra = {}
307
+ if model.startswith("ollama/"):
308
+ extra["api_base"] = OLLAMA_BASE_URL
309
+
310
+ resp = completion(
311
+ model=model,
312
+ messages=[
313
+ {"role": "system", "content": system},
314
+ {"role": "user", "content": user},
315
+ ],
316
+ temperature=temperature,
317
+ api_key=api_key if api_key else None,
318
+ **extra
319
+ )
320
+ return resp["choices"][0]["message"]["content"]
321
+
322
+
323
+ # -----------------------------
324
+ # Global engine + lock
325
+ # -----------------------------
326
+ ENGINE = RAGEngine()
327
+ ENGINE_LOCK = threading.Lock()
328
+
329
+ # build once on startup
330
+ with ENGINE_LOCK:
331
+ ENGINE.build_from_dataset(DEFAULT_N_RECORDS, streaming=True)
332
+
333
+
334
+ # -----------------------------
335
+ # Gradio UI callbacks
336
+ # -----------------------------
337
+ def rebuild_index(n_records: int, streaming: bool) -> str:
338
+ with ENGINE_LOCK:
339
+ ENGINE.build_from_dataset(int(n_records), bool(streaming))
340
+ return ENGINE.last_build_info
341
+
342
+
343
+ def qa(
344
+ question: str,
345
+ use_bm25: bool,
346
+ use_dense: bool,
347
+ use_rerank: bool,
348
+ model: str,
349
+ api_key: str,
350
+ topk_bm25: int,
351
+ topk_dense: int,
352
+ topk_final: int
353
+ ):
354
+ question = (question or "").strip()
355
+ if not question:
356
+ return "Type a question.", ""
357
+
358
+ if not use_bm25 and not use_dense:
359
+ return "Enable BM25 and/or Dense retrieval (otherwise there is no context).", ""
360
+
361
+ with ENGINE_LOCK:
362
+ if not ENGINE.chunks:
363
+ return "Index is empty. Click 'Rebuild index' with N>0.", ""
364
+
365
+ cands = ENGINE.retrieve_candidates(
366
+ question,
367
+ use_bm25=use_bm25,
368
+ use_dense=use_dense,
369
+ topk_bm25=int(topk_bm25),
370
+ topk_dense=int(topk_dense)
371
+ )
372
+ if not cands:
373
+ return "No candidates retrieved.", ""
374
+
375
+ if use_rerank:
376
+ selected = ENGINE.rerank(question, cands, top_n=int(topk_final))
377
+ else:
378
+ selected = cands[:int(topk_final)]
379
+
380
+ context = ENGINE.build_context(selected)
381
+
382
+ try:
383
+ answer = ENGINE.answer_with_llm(question, context, model=model, api_key=api_key)
384
+ except Exception as e:
385
+ answer = f"LLM call failed: {type(e).__name__}: {e}"
386
+
387
+ return answer, context
388
+
389
+
390
+ # -----------------------------
391
+ # Launch UI
392
+ # -----------------------------
393
+ def build_demo() -> gr.Blocks:
394
+ with gr.Blocks(title="RAG QA on CookingRecipes (BM25 + Dense + Rerank)") as demo:
395
+ gr.Markdown(
396
+ "# RAG QA (CookingRecipes)\n"
397
+ f"Dataset: `{HF_DATASET_NAME}`. Індексуємо **перші N рецептів**.\n\n"
398
+ )
399
+
400
+ with gr.Row():
401
+ n_records = gr.Slider(50, 5000, value=DEFAULT_N_RECORDS, step=50, label="N recipes to index (first N)")
402
+ streaming = gr.Checkbox(value=True, label="Use streaming (recommended)")
403
+
404
+ build_btn = gr.Button("Rebuild index")
405
+ build_status = gr.Markdown(value=f"**Status:** {ENGINE.last_build_info}")
406
+
407
+ build_btn.click(fn=rebuild_index, inputs=[n_records, streaming], outputs=[build_status])
408
+
409
+ gr.Markdown("---")
410
+
411
+ with gr.Row():
412
+ question = gr.Textbox(label="Question", placeholder="Ask about recipes...", lines=2)
413
+
414
+ with gr.Row():
415
+ use_bm25 = gr.Checkbox(value=True, label="Use BM25 (keyword)")
416
+ use_dense = gr.Checkbox(value=True, label="Use Dense (embeddings)")
417
+ use_rerank = gr.Checkbox(value=True, label="Use Cross-Encoder Reranker")
418
+
419
+ with gr.Row():
420
+ model = gr.Textbox(
421
+ label="LLM model (LiteLLM)",
422
+ value="openai/gpt-4o-mini",
423
+ placeholder="e.g. openai/gpt-4o-mini OR groq/... OR openrouter/..."
424
+ )
425
+ api_key = gr.Textbox(
426
+ label="API key (leave empty for Ollama)",
427
+ placeholder="Empty for local ollama",
428
+ type="password"
429
+ )
430
+
431
+ with gr.Row():
432
+ topk_bm25 = gr.Slider(5, 80, value=TOPK_BM25, step=1, label="Top-K BM25 candidates")
433
+ topk_dense = gr.Slider(5, 80, value=TOPK_DENSE, step=1, label="Top-K Dense candidates")
434
+ topk_final = gr.Slider(1, 12, value=TOPK_AFTER_RERANK, step=1, label="Chunks to LLM (final)")
435
+
436
+ run_btn = gr.Button("Answer")
437
+
438
+ answer = gr.Markdown(label="Answer")
439
+ context = gr.Textbox(label="Retrieved context (debug)", lines=16)
440
+
441
+ run_btn.click(
442
+ fn=qa,
443
+ inputs=[question, use_bm25, use_dense, use_rerank, model, api_key, topk_bm25, topk_dense, topk_final],
444
+ outputs=[answer, context]
445
+ )
446
+
447
+ return demo
448
+
449
+
450
+ if __name__ == "__main__":
451
+ demo = build_demo()
452
+ demo.launch()
453
+ # for local run with fixed port:
454
+ # demo.launch(server_name="127.0.0.1", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ numpy>=1.24.0
3
+ rank-bm25>=0.2.2
4
+ sentence-transformers>=2.6.0
5
+ litellm>=1.40.0
6
+ pypdf>=4.0.0
7
+ datasets>=2.18.0