Oleksii Obolonskyi commited on
Commit
2c5e1f2
·
1 Parent(s): 44720da

Add token-aware context limits

Browse files
Files changed (2) hide show
  1. README.md +3 -0
  2. app.py +91 -7
README.md CHANGED
@@ -107,6 +107,9 @@ export RAG_HF_PROVIDER=hf-inference
107
  export RAG_HF_MODEL=Qwen/Qwen2.5-7B-Instruct-1M
108
  export RAG_LLM_BACKEND=hf
109
  export RAG_HF_API_URL=https://router.huggingface.co/hf-inference/models/Qwen/Qwen2.5-7B-Instruct-1M
 
 
 
110
  export RAG_OUT_DIR=data/normalized
111
  export RAG_ARTICLE_SOURCES=sources_articles.json
112
  ```
 
107
  export RAG_HF_MODEL=Qwen/Qwen2.5-7B-Instruct-1M
108
  export RAG_LLM_BACKEND=hf
109
  export RAG_HF_API_URL=https://router.huggingface.co/hf-inference/models/Qwen/Qwen2.5-7B-Instruct-1M
110
+ export RAG_MAX_CONTEXT_TOKENS=6000
111
+ export RAG_MAX_CHUNKS=6
112
+ export RAG_MAX_GENERATION_TOKENS=512
113
  export RAG_OUT_DIR=data/normalized
114
  export RAG_ARTICLE_SOURCES=sources_articles.json
115
  ```
app.py CHANGED
@@ -42,6 +42,10 @@ if not HF_API_URL:
42
  OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
43
  OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
44
 
 
 
 
 
45
  REPO_OWNER = "16bitSega"
46
  REPO_NAME = "RAG_project"
47
 
@@ -134,6 +138,11 @@ def normalize_display_text(s: str) -> str:
134
  s = re.sub(r"\s+", " ", s).strip()
135
  return s
136
 
 
 
 
 
 
137
  def is_company_question(q: str) -> bool:
138
  q = (q or "").lower()
139
  patterns = [
@@ -395,6 +404,49 @@ def build_context(
395
  parts.append("ARTICLE EXCERPTS:\n" + "\n\n".join(article_parts))
396
  return "\n\n".join(parts)
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  def chunk_keyword_overlap(chunk: Chunk, terms: List[str]) -> int:
399
  if not terms:
400
  return 0
@@ -468,7 +520,7 @@ def answer_question(
468
  if not all_hits or not_found_by_terms(question, all_hits):
469
  return "Not found in dataset.", citations, False
470
 
471
- context = build_context(book_hits, article_hits, doc_index, citation_tags)
472
  avoid_text = "; ".join(AVOID_PHRASES)
473
  base_rules = (
474
  "You must answer using only the provided context.\n"
@@ -498,6 +550,17 @@ def answer_question(
498
  + format_rules
499
  + f"\nQuestion:\n{question}\n\nContext:\n{context}\n\nAnswer:"
500
  )
 
 
 
 
 
 
 
 
 
 
 
501
  answer, err = llm_chat(prompt)
502
  if err:
503
  st.error(err)
@@ -540,7 +603,7 @@ def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Opt
540
  resp = client.chat.completions.create(
541
  model=HF_MODEL,
542
  messages=messages,
543
- max_tokens=512,
544
  temperature=0.2,
545
  )
546
  text = (resp.choices[0].message.content or "").strip()
@@ -549,7 +612,7 @@ def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Opt
549
  try:
550
  out = client.text_generation(
551
  prompt,
552
- max_new_tokens=512,
553
  temperature=0.2,
554
  do_sample=True,
555
  return_full_text=False,
@@ -571,7 +634,7 @@ def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Opt
571
  )
572
  out = retry_client.text_generation(
573
  prompt,
574
- max_new_tokens=512,
575
  temperature=0.2,
576
  do_sample=True,
577
  return_full_text=False,
@@ -595,7 +658,7 @@ def ollama_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str,
595
  {"role": "user", "content": prompt},
596
  ],
597
  "stream": False,
598
- "options": {"temperature": 0.2},
599
  }
600
  try:
601
  r = requests.post(url, json=payload, timeout=timeout)
@@ -706,8 +769,19 @@ with st.sidebar:
706
  st.write("")
707
  st.subheader("Retrieval settings")
708
  st.caption(f"book_k={BOOK_K}, article_k={ARTICLE_K}, per_doc_cap={PER_DOC_CAP}, overlap_filter={OVERLAP_FILTER}")
709
- st.subheader("Dataset stats")
710
- st.caption("Local dataset only")
 
 
 
 
 
 
 
 
 
 
 
711
  @st.cache_data(show_spinner=False)
712
  def load_dataset(path: str) -> List[Chunk]:
713
  return read_chunks_jsonl(path)
@@ -844,6 +918,16 @@ def run_regen():
844
  "Generate exactly 3 concise user questions about MCP and AI agents orchestration. "
845
  "Return each question on its own line without extra text."
846
  )
 
 
 
 
 
 
 
 
 
 
847
  text, err = llm_chat(gen_prompt)
848
  if err:
849
  st.error(err)
 
42
  OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
43
  OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
44
 
45
+ MAX_CONTEXT_TOKENS = int(os.getenv("RAG_MAX_CONTEXT_TOKENS", "6000"))
46
+ MAX_CHUNKS = int(os.getenv("RAG_MAX_CHUNKS", "6"))
47
+ MAX_GENERATION_TOKENS = int(os.getenv("RAG_MAX_GENERATION_TOKENS", "512"))
48
+
49
  REPO_OWNER = "16bitSega"
50
  REPO_NAME = "RAG_project"
51
 
 
138
  s = re.sub(r"\s+", " ", s).strip()
139
  return s
140
 
141
+ def estimate_tokens(text: str) -> int:
142
+ if not text:
143
+ return 0
144
+ return max(1, len(text) // 4)
145
+
146
  def is_company_question(q: str) -> bool:
147
  q = (q or "").lower()
148
  patterns = [
 
404
  parts.append("ARTICLE EXCERPTS:\n" + "\n\n".join(article_parts))
405
  return "\n\n".join(parts)
406
 
407
+ def build_limited_context(
408
+ hits: List[Tuple[float, Chunk]],
409
+ doc_index: Dict[str, Dict],
410
+ tags: Dict[str, str],
411
+ max_chars_per_chunk: int = 1400,
412
+ ) -> Tuple[str, Dict[str, int]]:
413
+ parts: List[str] = []
414
+ tok = 0
415
+ used = 0
416
+ seen_sections = set()
417
+ for _, c in hits:
418
+ if used >= MAX_CHUNKS:
419
+ break
420
+ t = normalize_display_text(c.text)
421
+ if len(t) > max_chars_per_chunk:
422
+ t = t[:max_chars_per_chunk] + "..."
423
+ heading = chunk_heading(c, doc_index, tags)
424
+ block = f"{heading}\n{t}"
425
+ source_type = infer_source_type(c.doc_id, doc_index.get(c.doc_id))
426
+ section = "ARTICLE EXCERPTS:" if source_type == "article" else "BOOK EXCERPTS:"
427
+ section_add = ""
428
+ if section not in seen_sections:
429
+ section_add = section
430
+ addition = (section_add + "\n" if section_add else "") + block
431
+ add_tokens = estimate_tokens(addition)
432
+ if tok + add_tokens > MAX_CONTEXT_TOKENS:
433
+ break
434
+ if section_add:
435
+ parts.append(section_add)
436
+ seen_sections.add(section)
437
+ parts.append(block)
438
+ tok += add_tokens
439
+ used += 1
440
+ return (
441
+ "\n\n".join(parts),
442
+ {
443
+ "context_tokens": tok,
444
+ "used_chunks": used,
445
+ "max_chunks": MAX_CHUNKS,
446
+ "max_context_tokens": MAX_CONTEXT_TOKENS,
447
+ },
448
+ )
449
+
450
  def chunk_keyword_overlap(chunk: Chunk, terms: List[str]) -> int:
451
  if not terms:
452
  return 0
 
520
  if not all_hits or not_found_by_terms(question, all_hits):
521
  return "Not found in dataset.", citations, False
522
 
523
+ context, ctx_stats = build_limited_context(all_hits, doc_index, citation_tags)
524
  avoid_text = "; ".join(AVOID_PHRASES)
525
  base_rules = (
526
  "You must answer using only the provided context.\n"
 
550
  + format_rules
551
  + f"\nQuestion:\n{question}\n\nContext:\n{context}\n\nAnswer:"
552
  )
553
+ prompt_tokens = estimate_tokens(prompt)
554
+ total_est = ctx_stats["context_tokens"] + prompt_tokens + MAX_GENERATION_TOKENS
555
+ st.session_state["token_stats"] = {
556
+ "context_tokens": ctx_stats["context_tokens"],
557
+ "prompt_tokens": prompt_tokens,
558
+ "generation_tokens": MAX_GENERATION_TOKENS,
559
+ "total_tokens": total_est,
560
+ "chunks_used": ctx_stats["used_chunks"],
561
+ "chunks_cap": MAX_CHUNKS,
562
+ "context_cap": MAX_CONTEXT_TOKENS,
563
+ }
564
  answer, err = llm_chat(prompt)
565
  if err:
566
  st.error(err)
 
603
  resp = client.chat.completions.create(
604
  model=HF_MODEL,
605
  messages=messages,
606
+ max_tokens=MAX_GENERATION_TOKENS,
607
  temperature=0.2,
608
  )
609
  text = (resp.choices[0].message.content or "").strip()
 
612
  try:
613
  out = client.text_generation(
614
  prompt,
615
+ max_new_tokens=MAX_GENERATION_TOKENS,
616
  temperature=0.2,
617
  do_sample=True,
618
  return_full_text=False,
 
634
  )
635
  out = retry_client.text_generation(
636
  prompt,
637
+ max_new_tokens=MAX_GENERATION_TOKENS,
638
  temperature=0.2,
639
  do_sample=True,
640
  return_full_text=False,
 
658
  {"role": "user", "content": prompt},
659
  ],
660
  "stream": False,
661
+ "options": {"temperature": 0.2, "num_predict": MAX_GENERATION_TOKENS},
662
  }
663
  try:
664
  r = requests.post(url, json=payload, timeout=timeout)
 
769
  st.write("")
770
  st.subheader("Retrieval settings")
771
  st.caption(f"book_k={BOOK_K}, article_k={ARTICLE_K}, per_doc_cap={PER_DOC_CAP}, overlap_filter={OVERLAP_FILTER}")
772
+ st.markdown("### Dataset Stats")
773
+ ts = st.session_state.get("token_stats")
774
+ if ts:
775
+ st.markdown("**Token Consumption (est.)**")
776
+ st.markdown(f"- Context tokens: `{ts['context_tokens']}` / `{ts['context_cap']}`")
777
+ st.markdown(f"- Chunks used: `{ts['chunks_used']}` / `{ts['chunks_cap']}`")
778
+ st.markdown(f"- Prompt tokens: `{ts['prompt_tokens']}`")
779
+ st.markdown(f"- Generation tokens (max): `{ts['generation_tokens']}`")
780
+ st.markdown(f"- **Total per request (est.):** `{ts['total_tokens']}`")
781
+ if ts["context_tokens"] >= int(0.9 * ts["context_cap"]):
782
+ st.warning("Context near token limit; answers may truncate.")
783
+ else:
784
+ st.markdown("_Ask a question to see token usage._")
785
  @st.cache_data(show_spinner=False)
786
  def load_dataset(path: str) -> List[Chunk]:
787
  return read_chunks_jsonl(path)
 
918
  "Generate exactly 3 concise user questions about MCP and AI agents orchestration. "
919
  "Return each question on its own line without extra text."
920
  )
921
+ prompt_tokens = estimate_tokens(gen_prompt)
922
+ st.session_state["token_stats"] = {
923
+ "context_tokens": 0,
924
+ "prompt_tokens": prompt_tokens,
925
+ "generation_tokens": MAX_GENERATION_TOKENS,
926
+ "total_tokens": prompt_tokens + MAX_GENERATION_TOKENS,
927
+ "chunks_used": 0,
928
+ "chunks_cap": MAX_CHUNKS,
929
+ "context_cap": MAX_CONTEXT_TOKENS,
930
+ }
931
  text, err = llm_chat(gen_prompt)
932
  if err:
933
  st.error(err)