Spaces:
Sleeping
Sleeping
Oleksii Obolonskyi commited on
Commit ·
2c5e1f2
1
Parent(s): 44720da
Add token-aware context limits
Browse files
README.md
CHANGED
|
@@ -107,6 +107,9 @@ export RAG_HF_PROVIDER=hf-inference
|
|
| 107 |
export RAG_HF_MODEL=Qwen/Qwen2.5-7B-Instruct-1M
|
| 108 |
export RAG_LLM_BACKEND=hf
|
| 109 |
export RAG_HF_API_URL=https://router.huggingface.co/hf-inference/models/Qwen/Qwen2.5-7B-Instruct-1M
|
|
|
|
|
|
|
|
|
|
| 110 |
export RAG_OUT_DIR=data/normalized
|
| 111 |
export RAG_ARTICLE_SOURCES=sources_articles.json
|
| 112 |
```
|
|
|
|
| 107 |
export RAG_HF_MODEL=Qwen/Qwen2.5-7B-Instruct-1M
|
| 108 |
export RAG_LLM_BACKEND=hf
|
| 109 |
export RAG_HF_API_URL=https://router.huggingface.co/hf-inference/models/Qwen/Qwen2.5-7B-Instruct-1M
|
| 110 |
+
export RAG_MAX_CONTEXT_TOKENS=6000
|
| 111 |
+
export RAG_MAX_CHUNKS=6
|
| 112 |
+
export RAG_MAX_GENERATION_TOKENS=512
|
| 113 |
export RAG_OUT_DIR=data/normalized
|
| 114 |
export RAG_ARTICLE_SOURCES=sources_articles.json
|
| 115 |
```
|
app.py
CHANGED
|
@@ -42,6 +42,10 @@ if not HF_API_URL:
|
|
| 42 |
OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
|
| 43 |
OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
REPO_OWNER = "16bitSega"
|
| 46 |
REPO_NAME = "RAG_project"
|
| 47 |
|
|
@@ -134,6 +138,11 @@ def normalize_display_text(s: str) -> str:
|
|
| 134 |
s = re.sub(r"\s+", " ", s).strip()
|
| 135 |
return s
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def is_company_question(q: str) -> bool:
|
| 138 |
q = (q or "").lower()
|
| 139 |
patterns = [
|
|
@@ -395,6 +404,49 @@ def build_context(
|
|
| 395 |
parts.append("ARTICLE EXCERPTS:\n" + "\n\n".join(article_parts))
|
| 396 |
return "\n\n".join(parts)
|
| 397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
def chunk_keyword_overlap(chunk: Chunk, terms: List[str]) -> int:
|
| 399 |
if not terms:
|
| 400 |
return 0
|
|
@@ -468,7 +520,7 @@ def answer_question(
|
|
| 468 |
if not all_hits or not_found_by_terms(question, all_hits):
|
| 469 |
return "Not found in dataset.", citations, False
|
| 470 |
|
| 471 |
-
context =
|
| 472 |
avoid_text = "; ".join(AVOID_PHRASES)
|
| 473 |
base_rules = (
|
| 474 |
"You must answer using only the provided context.\n"
|
|
@@ -498,6 +550,17 @@ def answer_question(
|
|
| 498 |
+ format_rules
|
| 499 |
+ f"\nQuestion:\n{question}\n\nContext:\n{context}\n\nAnswer:"
|
| 500 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
answer, err = llm_chat(prompt)
|
| 502 |
if err:
|
| 503 |
st.error(err)
|
|
@@ -540,7 +603,7 @@ def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Opt
|
|
| 540 |
resp = client.chat.completions.create(
|
| 541 |
model=HF_MODEL,
|
| 542 |
messages=messages,
|
| 543 |
-
max_tokens=
|
| 544 |
temperature=0.2,
|
| 545 |
)
|
| 546 |
text = (resp.choices[0].message.content or "").strip()
|
|
@@ -549,7 +612,7 @@ def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Opt
|
|
| 549 |
try:
|
| 550 |
out = client.text_generation(
|
| 551 |
prompt,
|
| 552 |
-
max_new_tokens=
|
| 553 |
temperature=0.2,
|
| 554 |
do_sample=True,
|
| 555 |
return_full_text=False,
|
|
@@ -571,7 +634,7 @@ def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Opt
|
|
| 571 |
)
|
| 572 |
out = retry_client.text_generation(
|
| 573 |
prompt,
|
| 574 |
-
max_new_tokens=
|
| 575 |
temperature=0.2,
|
| 576 |
do_sample=True,
|
| 577 |
return_full_text=False,
|
|
@@ -595,7 +658,7 @@ def ollama_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str,
|
|
| 595 |
{"role": "user", "content": prompt},
|
| 596 |
],
|
| 597 |
"stream": False,
|
| 598 |
-
"options": {"temperature": 0.2},
|
| 599 |
}
|
| 600 |
try:
|
| 601 |
r = requests.post(url, json=payload, timeout=timeout)
|
|
@@ -706,8 +769,19 @@ with st.sidebar:
|
|
| 706 |
st.write("")
|
| 707 |
st.subheader("Retrieval settings")
|
| 708 |
st.caption(f"book_k={BOOK_K}, article_k={ARTICLE_K}, per_doc_cap={PER_DOC_CAP}, overlap_filter={OVERLAP_FILTER}")
|
| 709 |
-
st.
|
| 710 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
@st.cache_data(show_spinner=False)
|
| 712 |
def load_dataset(path: str) -> List[Chunk]:
|
| 713 |
return read_chunks_jsonl(path)
|
|
@@ -844,6 +918,16 @@ def run_regen():
|
|
| 844 |
"Generate exactly 3 concise user questions about MCP and AI agents orchestration. "
|
| 845 |
"Return each question on its own line without extra text."
|
| 846 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
text, err = llm_chat(gen_prompt)
|
| 848 |
if err:
|
| 849 |
st.error(err)
|
|
|
|
| 42 |
OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
|
| 43 |
OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
|
| 44 |
|
| 45 |
+
MAX_CONTEXT_TOKENS = int(os.getenv("RAG_MAX_CONTEXT_TOKENS", "6000"))
|
| 46 |
+
MAX_CHUNKS = int(os.getenv("RAG_MAX_CHUNKS", "6"))
|
| 47 |
+
MAX_GENERATION_TOKENS = int(os.getenv("RAG_MAX_GENERATION_TOKENS", "512"))
|
| 48 |
+
|
| 49 |
REPO_OWNER = "16bitSega"
|
| 50 |
REPO_NAME = "RAG_project"
|
| 51 |
|
|
|
|
| 138 |
s = re.sub(r"\s+", " ", s).strip()
|
| 139 |
return s
|
| 140 |
|
| 141 |
+
def estimate_tokens(text: str) -> int:
|
| 142 |
+
if not text:
|
| 143 |
+
return 0
|
| 144 |
+
return max(1, len(text) // 4)
|
| 145 |
+
|
| 146 |
def is_company_question(q: str) -> bool:
|
| 147 |
q = (q or "").lower()
|
| 148 |
patterns = [
|
|
|
|
| 404 |
parts.append("ARTICLE EXCERPTS:\n" + "\n\n".join(article_parts))
|
| 405 |
return "\n\n".join(parts)
|
| 406 |
|
| 407 |
+
def build_limited_context(
|
| 408 |
+
hits: List[Tuple[float, Chunk]],
|
| 409 |
+
doc_index: Dict[str, Dict],
|
| 410 |
+
tags: Dict[str, str],
|
| 411 |
+
max_chars_per_chunk: int = 1400,
|
| 412 |
+
) -> Tuple[str, Dict[str, int]]:
|
| 413 |
+
parts: List[str] = []
|
| 414 |
+
tok = 0
|
| 415 |
+
used = 0
|
| 416 |
+
seen_sections = set()
|
| 417 |
+
for _, c in hits:
|
| 418 |
+
if used >= MAX_CHUNKS:
|
| 419 |
+
break
|
| 420 |
+
t = normalize_display_text(c.text)
|
| 421 |
+
if len(t) > max_chars_per_chunk:
|
| 422 |
+
t = t[:max_chars_per_chunk] + "..."
|
| 423 |
+
heading = chunk_heading(c, doc_index, tags)
|
| 424 |
+
block = f"{heading}\n{t}"
|
| 425 |
+
source_type = infer_source_type(c.doc_id, doc_index.get(c.doc_id))
|
| 426 |
+
section = "ARTICLE EXCERPTS:" if source_type == "article" else "BOOK EXCERPTS:"
|
| 427 |
+
section_add = ""
|
| 428 |
+
if section not in seen_sections:
|
| 429 |
+
section_add = section
|
| 430 |
+
addition = (section_add + "\n" if section_add else "") + block
|
| 431 |
+
add_tokens = estimate_tokens(addition)
|
| 432 |
+
if tok + add_tokens > MAX_CONTEXT_TOKENS:
|
| 433 |
+
break
|
| 434 |
+
if section_add:
|
| 435 |
+
parts.append(section_add)
|
| 436 |
+
seen_sections.add(section)
|
| 437 |
+
parts.append(block)
|
| 438 |
+
tok += add_tokens
|
| 439 |
+
used += 1
|
| 440 |
+
return (
|
| 441 |
+
"\n\n".join(parts),
|
| 442 |
+
{
|
| 443 |
+
"context_tokens": tok,
|
| 444 |
+
"used_chunks": used,
|
| 445 |
+
"max_chunks": MAX_CHUNKS,
|
| 446 |
+
"max_context_tokens": MAX_CONTEXT_TOKENS,
|
| 447 |
+
},
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
def chunk_keyword_overlap(chunk: Chunk, terms: List[str]) -> int:
|
| 451 |
if not terms:
|
| 452 |
return 0
|
|
|
|
| 520 |
if not all_hits or not_found_by_terms(question, all_hits):
|
| 521 |
return "Not found in dataset.", citations, False
|
| 522 |
|
| 523 |
+
context, ctx_stats = build_limited_context(all_hits, doc_index, citation_tags)
|
| 524 |
avoid_text = "; ".join(AVOID_PHRASES)
|
| 525 |
base_rules = (
|
| 526 |
"You must answer using only the provided context.\n"
|
|
|
|
| 550 |
+ format_rules
|
| 551 |
+ f"\nQuestion:\n{question}\n\nContext:\n{context}\n\nAnswer:"
|
| 552 |
)
|
| 553 |
+
prompt_tokens = estimate_tokens(prompt)
|
| 554 |
+
total_est = ctx_stats["context_tokens"] + prompt_tokens + MAX_GENERATION_TOKENS
|
| 555 |
+
st.session_state["token_stats"] = {
|
| 556 |
+
"context_tokens": ctx_stats["context_tokens"],
|
| 557 |
+
"prompt_tokens": prompt_tokens,
|
| 558 |
+
"generation_tokens": MAX_GENERATION_TOKENS,
|
| 559 |
+
"total_tokens": total_est,
|
| 560 |
+
"chunks_used": ctx_stats["used_chunks"],
|
| 561 |
+
"chunks_cap": MAX_CHUNKS,
|
| 562 |
+
"context_cap": MAX_CONTEXT_TOKENS,
|
| 563 |
+
}
|
| 564 |
answer, err = llm_chat(prompt)
|
| 565 |
if err:
|
| 566 |
st.error(err)
|
|
|
|
| 603 |
resp = client.chat.completions.create(
|
| 604 |
model=HF_MODEL,
|
| 605 |
messages=messages,
|
| 606 |
+
max_tokens=MAX_GENERATION_TOKENS,
|
| 607 |
temperature=0.2,
|
| 608 |
)
|
| 609 |
text = (resp.choices[0].message.content or "").strip()
|
|
|
|
| 612 |
try:
|
| 613 |
out = client.text_generation(
|
| 614 |
prompt,
|
| 615 |
+
max_new_tokens=MAX_GENERATION_TOKENS,
|
| 616 |
temperature=0.2,
|
| 617 |
do_sample=True,
|
| 618 |
return_full_text=False,
|
|
|
|
| 634 |
)
|
| 635 |
out = retry_client.text_generation(
|
| 636 |
prompt,
|
| 637 |
+
max_new_tokens=MAX_GENERATION_TOKENS,
|
| 638 |
temperature=0.2,
|
| 639 |
do_sample=True,
|
| 640 |
return_full_text=False,
|
|
|
|
| 658 |
{"role": "user", "content": prompt},
|
| 659 |
],
|
| 660 |
"stream": False,
|
| 661 |
+
"options": {"temperature": 0.2, "num_predict": MAX_GENERATION_TOKENS},
|
| 662 |
}
|
| 663 |
try:
|
| 664 |
r = requests.post(url, json=payload, timeout=timeout)
|
|
|
|
| 769 |
st.write("")
|
| 770 |
st.subheader("Retrieval settings")
|
| 771 |
st.caption(f"book_k={BOOK_K}, article_k={ARTICLE_K}, per_doc_cap={PER_DOC_CAP}, overlap_filter={OVERLAP_FILTER}")
|
| 772 |
+
st.markdown("### Dataset Stats")
|
| 773 |
+
ts = st.session_state.get("token_stats")
|
| 774 |
+
if ts:
|
| 775 |
+
st.markdown("**Token Consumption (est.)**")
|
| 776 |
+
st.markdown(f"- Context tokens: `{ts['context_tokens']}` / `{ts['context_cap']}`")
|
| 777 |
+
st.markdown(f"- Chunks used: `{ts['chunks_used']}` / `{ts['chunks_cap']}`")
|
| 778 |
+
st.markdown(f"- Prompt tokens: `{ts['prompt_tokens']}`")
|
| 779 |
+
st.markdown(f"- Generation tokens (max): `{ts['generation_tokens']}`")
|
| 780 |
+
st.markdown(f"- **Total per request (est.):** `{ts['total_tokens']}`")
|
| 781 |
+
if ts["context_tokens"] >= int(0.9 * ts["context_cap"]):
|
| 782 |
+
st.warning("Context near token limit; answers may truncate.")
|
| 783 |
+
else:
|
| 784 |
+
st.markdown("_Ask a question to see token usage._")
|
| 785 |
@st.cache_data(show_spinner=False)
|
| 786 |
def load_dataset(path: str) -> List[Chunk]:
|
| 787 |
return read_chunks_jsonl(path)
|
|
|
|
| 918 |
"Generate exactly 3 concise user questions about MCP and AI agents orchestration. "
|
| 919 |
"Return each question on its own line without extra text."
|
| 920 |
)
|
| 921 |
+
prompt_tokens = estimate_tokens(gen_prompt)
|
| 922 |
+
st.session_state["token_stats"] = {
|
| 923 |
+
"context_tokens": 0,
|
| 924 |
+
"prompt_tokens": prompt_tokens,
|
| 925 |
+
"generation_tokens": MAX_GENERATION_TOKENS,
|
| 926 |
+
"total_tokens": prompt_tokens + MAX_GENERATION_TOKENS,
|
| 927 |
+
"chunks_used": 0,
|
| 928 |
+
"chunks_cap": MAX_CHUNKS,
|
| 929 |
+
"context_cap": MAX_CONTEXT_TOKENS,
|
| 930 |
+
}
|
| 931 |
text, err = llm_chat(gen_prompt)
|
| 932 |
if err:
|
| 933 |
st.error(err)
|