hungnha commited on
Commit
d43db89
·
1 Parent(s): 39f858f

sửa truy vấn bảng

Browse files
core/embeddings/chunk.py CHANGED
@@ -1,8 +1,11 @@
1
  from __future__ import annotations
 
2
  import re
 
3
  from pathlib import Path
4
- from typing import List, Tuple, Dict, Any
5
  import yaml
 
6
  from llama_index.core import Document
7
  from llama_index.core.node_parser import MarkdownNodeParser, SentenceSplitter
8
  from llama_index.core.schema import BaseNode, TextNode
@@ -13,6 +16,12 @@ CHUNK_OVERLAP = 150
13
  MIN_CHUNK_SIZE = 200
14
  TABLE_ROWS_PER_CHUNK = 15
15
 
 
 
 
 
 
 
16
  # Regex
17
  COURSE_PATTERN = re.compile(r"Học\s*phần\s+(.+?)\s*\(\s*m[ãa]\s+([^\)]+)\)", re.I | re.DOTALL)
18
  TABLE_PLACEHOLDER = re.compile(r"__TBL_(\d+)__")
@@ -94,6 +103,101 @@ def _split_table(header: str, rows: List[str], max_rows: int = TABLE_ROWS_PER_CH
94
 
95
  return [header + "\n".join(r) for r in chunks]
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def _enrich_metadata(node: BaseNode, source_path: Path | None) -> None:
99
  if source_path:
@@ -158,13 +262,23 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
158
  if (before := content[last_end:match.start()].strip()) and len(before) >= MIN_CHUNK_SIZE:
159
  nodes.extend(_chunk_text(before, meta) if len(before) > CHUNK_SIZE else [TextNode(text=before, metadata=meta.copy())])
160
 
161
- # Table chunks
162
  if (idx := int(match.group(1))) < len(tables):
163
  header, rows = tables[idx]
164
- chunks = _split_table(header, rows)
165
- for i, chunk in enumerate(chunks):
166
- tbl_meta = {**meta, "is_table": True, **({"table_part": f"{i+1}/{len(chunks)}"} if len(chunks) > 1 else {})}
167
- nodes.append(TextNode(text=chunk, metadata=tbl_meta))
 
 
 
 
 
 
 
 
 
 
168
  last_end = match.end()
169
 
170
  # Text after table
 
1
  from __future__ import annotations
2
+ import os
3
  import re
4
+ import uuid
5
  from pathlib import Path
6
+ from typing import List, Tuple, Dict, Any, Optional
7
  import yaml
8
+ from openai import OpenAI
9
  from llama_index.core import Document
10
  from llama_index.core.node_parser import MarkdownNodeParser, SentenceSplitter
11
  from llama_index.core.schema import BaseNode, TextNode
 
16
  MIN_CHUNK_SIZE = 200
17
  TABLE_ROWS_PER_CHUNK = 15
18
 
19
+ # Small-to-Big Config
20
+ ENABLE_TABLE_SUMMARY = True
21
+ MIN_TABLE_ROWS_FOR_SUMMARY = 5 # Only summarize tables with >= 5 rows
22
+ SUMMARY_MODEL = "nex-agi/DeepSeek-V3.1-Nex-N1"
23
+ SILICONFLOW_BASE_URL = "https://api.siliconflow.com/v1"
24
+
25
  # Regex
26
  COURSE_PATTERN = re.compile(r"Học\s*phần\s+(.+?)\s*\(\s*m[ãa]\s+([^\)]+)\)", re.I | re.DOTALL)
27
  TABLE_PLACEHOLDER = re.compile(r"__TBL_(\d+)__")
 
103
 
104
  return [header + "\n".join(r) for r in chunks]
105
 
106
+ _summary_client: Optional[OpenAI] = None
107
+
108
+ def _get_summary_client() -> Optional[OpenAI]:
109
+ global _summary_client
110
+ if _summary_client is not None:
111
+ return _summary_client
112
+
113
+ api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
114
+ if not api_key:
115
+ print("SILICONFLOW_API_KEY not set. Table summarization disabled.")
116
+ return None
117
+
118
+ _summary_client = OpenAI(api_key=api_key, base_url=SILICONFLOW_BASE_URL)
119
+ return _summary_client
120
+
121
+
122
+ def _summarize_table(table_text: str, context_hint: str = "") -> Optional[str]:
123
+ if not ENABLE_TABLE_SUMMARY:
124
+ return None
125
+
126
+ client = _get_summary_client()
127
+ if client is None:
128
+ return None
129
+
130
+ prompt = f"""Tóm tắt ngắn gọn nội dung bảng sau trong 2-4 câu bằng tiếng Việt.
131
+ Ghi rõ:
132
+ - Bảng này liệt kê/quy định về cái gì
133
+ - Các cột chính trong bảng
134
+ - Thông tin quan trọng (nếu có số liệu cụ thể thì nêu ví dụ)
135
+
136
+ {f"Ngữ cảnh: {context_hint}" if context_hint else ""}
137
+
138
+ Bảng:
139
+ {table_text[:3000]}
140
+ """
141
+
142
+ try:
143
+ response = client.chat.completions.create(
144
+ model=SUMMARY_MODEL,
145
+ messages=[{"role": "user", "content": prompt}],
146
+ temperature=0.3,
147
+ max_tokens=1000,
148
+ )
149
+ summary = response.choices[0].message.content or ""
150
+ return summary.strip() if summary.strip() else None
151
+ except Exception as e:
152
+ print(f" Table summarization failed: {e}")
153
+ return None
154
+
155
+
156
+ def _create_table_nodes(
157
+ table_text: str,
158
+ metadata: dict,
159
+ context_hint: str = ""
160
+ ) -> List[TextNode]:
161
+ # Count rows to decide if we should summarize
162
+ row_count = table_text.count("\n")
163
+
164
+ if row_count < MIN_TABLE_ROWS_FOR_SUMMARY:
165
+ # Table too small, just return as-is
166
+ return [TextNode(text=table_text, metadata={**metadata, "is_table": True})]
167
+
168
+ # Try to generate summary
169
+ summary = _summarize_table(table_text, context_hint)
170
+
171
+ if summary is None:
172
+ # Summarization failed, return raw table
173
+ return [TextNode(text=table_text, metadata={**metadata, "is_table": True})]
174
+
175
+ # Create parent node (raw table - will NOT be embedded)
176
+ parent_id = str(uuid.uuid4())
177
+ parent_node = TextNode(
178
+ text=table_text,
179
+ metadata={
180
+ **metadata,
181
+ "is_table": True,
182
+ "is_parent": True, # Flag to skip embedding
183
+ "node_id": parent_id,
184
+ }
185
+ )
186
+ parent_node.id_ = parent_id
187
+
188
+ # Create summary node (will be embedded for search)
189
+ summary_node = TextNode(
190
+ text=summary,
191
+ metadata={
192
+ **metadata,
193
+ "is_table_summary": True,
194
+ "parent_id": parent_id, # Link to parent
195
+ }
196
+ )
197
+
198
+ print(f"Created summary for table ({row_count} rows)")
199
+ return [parent_node, summary_node]
200
+
201
 
202
  def _enrich_metadata(node: BaseNode, source_path: Path | None) -> None:
203
  if source_path:
 
262
  if (before := content[last_end:match.start()].strip()) and len(before) >= MIN_CHUNK_SIZE:
263
  nodes.extend(_chunk_text(before, meta) if len(before) > CHUNK_SIZE else [TextNode(text=before, metadata=meta.copy())])
264
 
265
+ # Table chunks - using Small-to-Big pattern
266
  if (idx := int(match.group(1))) < len(tables):
267
  header, rows = tables[idx]
268
+ table_chunks = _split_table(header, rows)
269
+
270
+ # Get context hint from header path
271
+ context_hint = meta.get("Header 1", "") or meta.get("section", "")
272
+
273
+ for i, chunk in enumerate(table_chunks):
274
+ chunk_meta = {**meta}
275
+ if len(table_chunks) > 1:
276
+ chunk_meta["table_part"] = f"{i+1}/{len(table_chunks)}"
277
+
278
+ # Create parent + summary nodes if applicable
279
+ table_nodes = _create_table_nodes(chunk, chunk_meta, context_hint)
280
+ nodes.extend(table_nodes)
281
+
282
  last_end = match.end()
283
 
284
  # Text after table
core/embeddings/retrival.py CHANGED
@@ -174,10 +174,25 @@ class Retriever:
174
  return self._reranker is not None
175
 
176
  def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  return {
178
- "id": (doc.metadata or {}).get("id"),
179
- "content": doc.page_content,
180
- "metadata": doc.metadata,
181
  "final_rank": rank,
182
  **extra,
183
  }
 
174
  return self._reranker is not None
175
 
176
  def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]:
177
+ metadata = doc.metadata or {}
178
+ content = doc.page_content
179
+
180
+ # Small-to-Big: If this is a summary node, swap with parent (raw table)
181
+ if metadata.get("is_table_summary") and metadata.get("parent_id"):
182
+ parent = self._vector_db.get_parent_node(metadata["parent_id"])
183
+ if parent:
184
+ content = parent.get("content", content)
185
+ # Merge metadata, keeping summary info for debugging
186
+ metadata = {
187
+ **parent.get("metadata", {}),
188
+ "original_summary": doc.page_content[:200],
189
+ "swapped_from_summary": True,
190
+ }
191
+
192
  return {
193
+ "id": metadata.get("id"),
194
+ "content": content,
195
+ "metadata": metadata,
196
  "final_rank": rank,
197
  **extra,
198
  }
core/embeddings/vector_store.py CHANGED
@@ -30,6 +30,11 @@ class ChromaVectorDB:
30
  self.embedder = embedder
31
  self.config = config or ChromaConfig()
32
  self._hasher = HashProcessor(verbose=False)
 
 
 
 
 
33
 
34
  self._vs = Chroma(
35
  collection_name=self.config.collection_name,
@@ -37,6 +42,28 @@ class ChromaVectorDB:
37
  persist_directory=self.config.persist_dir,
38
  )
39
  logger.info(f"ChromaVectorDB initialized: {self.config.collection_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  @property
42
  def collection(self):
@@ -113,13 +140,42 @@ class ChromaVectorDB:
113
  if ids is not None and len(ids) != len(docs):
114
  raise ValueError("ids length must match docs length")
115
 
116
- all_ids = list(ids) if ids is not None else [self._doc_id(d) for d in docs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  bs = max(1, batch_size)
118
  total = 0
119
 
120
- for start in range(0, len(docs), bs):
121
- batch = docs[start : start + bs]
122
- batch_ids = all_ids[start : start + bs]
123
  lc_docs = self._to_documents(batch, batch_ids)
124
 
125
  try:
@@ -131,7 +187,7 @@ class ChromaVectorDB:
131
  total += len(batch)
132
 
133
  logger.info(f"Added {total} documents to vector store")
134
- return total
135
 
136
  def upsert_documents(
137
  self,
@@ -146,17 +202,46 @@ class ChromaVectorDB:
146
  if ids is not None and len(ids) != len(docs):
147
  raise ValueError("ids length must match docs length")
148
 
149
- all_ids = list(ids) if ids is not None else [self._doc_id(d) for d in docs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  bs = max(1, batch_size)
151
  col = self.collection
152
 
153
  if col is None:
154
- return self.add_documents(docs, ids=all_ids, batch_size=bs)
155
 
156
  total = 0
157
- for start in range(0, len(docs), bs):
158
- batch = docs[start : start + bs]
159
- batch_ids = all_ids[start : start + bs]
160
  lc_docs = self._to_documents(batch, batch_ids)
161
  texts = [d.page_content for d in lc_docs]
162
  metas = [d.metadata for d in lc_docs]
@@ -165,7 +250,7 @@ class ChromaVectorDB:
165
  total += len(batch)
166
 
167
  logger.info(f"Upserted {total} documents to vector store")
168
- return total
169
 
170
  def count(self) -> int:
171
  col = self.collection
@@ -198,3 +283,10 @@ class ChromaVectorDB:
198
  col.delete(ids=list(ids))
199
  logger.info(f"Deleted {len(ids)} documents from vector store")
200
  return len(ids)
 
 
 
 
 
 
 
 
30
  self.embedder = embedder
31
  self.config = config or ChromaConfig()
32
  self._hasher = HashProcessor(verbose=False)
33
+
34
+ # Storage for parent nodes (not embedded, used for Small-to-Big retrieval)
35
+ # Persist to JSON file in same directory as ChromaDB
36
+ self._parent_nodes_path = Path(self.config.persist_dir) / "parent_nodes.json"
37
+ self._parent_nodes: Dict[str, Dict[str, Any]] = self._load_parent_nodes()
38
 
39
  self._vs = Chroma(
40
  collection_name=self.config.collection_name,
 
42
  persist_directory=self.config.persist_dir,
43
  )
44
  logger.info(f"ChromaVectorDB initialized: {self.config.collection_name}")
45
+
46
+ def _load_parent_nodes(self) -> Dict[str, Dict[str, Any]]:
47
+ """Load parent nodes from JSON file if exists."""
48
+ if self._parent_nodes_path.exists():
49
+ try:
50
+ with open(self._parent_nodes_path, 'r', encoding='utf-8') as f:
51
+ data = json.load(f)
52
+ logger.info(f"Loaded {len(data)} parent nodes from {self._parent_nodes_path}")
53
+ return data
54
+ except Exception as e:
55
+ logger.warning(f"Failed to load parent nodes: {e}")
56
+ return {}
57
+
58
+ def _save_parent_nodes(self) -> None:
59
+ """Save parent nodes to JSON file."""
60
+ try:
61
+ self._parent_nodes_path.parent.mkdir(parents=True, exist_ok=True)
62
+ with open(self._parent_nodes_path, 'w', encoding='utf-8') as f:
63
+ json.dump(self._parent_nodes, f, ensure_ascii=False, indent=2)
64
+ logger.info(f"Saved {len(self._parent_nodes)} parent nodes to {self._parent_nodes_path}")
65
+ except Exception as e:
66
+ logger.warning(f"Failed to save parent nodes: {e}")
67
 
68
  @property
69
  def collection(self):
 
140
  if ids is not None and len(ids) != len(docs):
141
  raise ValueError("ids length must match docs length")
142
 
143
+ # Separate parent nodes (not embedded) from regular nodes
144
+ regular_docs = []
145
+ regular_ids = []
146
+ parent_count = 0
147
+
148
+ for i, d in enumerate(docs):
149
+ normalized = self._normalize_doc(d)
150
+ md = normalized.get("metadata", {}) or {}
151
+ doc_id = ids[i] if ids else self._doc_id(d)
152
+
153
+ if md.get("is_parent"):
154
+ # Store parent node separately (for Small-to-Big retrieval)
155
+ parent_id = md.get("node_id", doc_id)
156
+ self._parent_nodes[parent_id] = {
157
+ "id": parent_id,
158
+ "content": normalized.get("content", ""),
159
+ "metadata": md,
160
+ }
161
+ parent_count += 1
162
+ else:
163
+ regular_docs.append(d)
164
+ regular_ids.append(doc_id)
165
+
166
+ if parent_count > 0:
167
+ logger.info(f"Stored {parent_count} parent nodes (not embedded)")
168
+ self._save_parent_nodes() # Persist to disk
169
+
170
+ if not regular_docs:
171
+ return parent_count
172
+
173
  bs = max(1, batch_size)
174
  total = 0
175
 
176
+ for start in range(0, len(regular_docs), bs):
177
+ batch = regular_docs[start : start + bs]
178
+ batch_ids = regular_ids[start : start + bs]
179
  lc_docs = self._to_documents(batch, batch_ids)
180
 
181
  try:
 
187
  total += len(batch)
188
 
189
  logger.info(f"Added {total} documents to vector store")
190
+ return total + parent_count
191
 
192
  def upsert_documents(
193
  self,
 
202
  if ids is not None and len(ids) != len(docs):
203
  raise ValueError("ids length must match docs length")
204
 
205
+ # Separate parent nodes (not embedded) from regular nodes
206
+ regular_docs = []
207
+ regular_ids = []
208
+ parent_count = 0
209
+
210
+ for i, d in enumerate(docs):
211
+ normalized = self._normalize_doc(d)
212
+ md = normalized.get("metadata", {}) or {}
213
+ doc_id = ids[i] if ids else self._doc_id(d)
214
+
215
+ if md.get("is_parent"):
216
+ # Store parent node separately (for Small-to-Big retrieval)
217
+ parent_id = md.get("node_id", doc_id)
218
+ self._parent_nodes[parent_id] = {
219
+ "id": parent_id,
220
+ "content": normalized.get("content", ""),
221
+ "metadata": md,
222
+ }
223
+ parent_count += 1
224
+ else:
225
+ regular_docs.append(d)
226
+ regular_ids.append(doc_id)
227
+
228
+ if parent_count > 0:
229
+ logger.info(f"Stored {parent_count} parent nodes (not embedded)")
230
+ self._save_parent_nodes() # Persist to disk
231
+
232
+ if not regular_docs:
233
+ return parent_count
234
+
235
  bs = max(1, batch_size)
236
  col = self.collection
237
 
238
  if col is None:
239
+ return self.add_documents(regular_docs, ids=regular_ids, batch_size=bs) + parent_count
240
 
241
  total = 0
242
+ for start in range(0, len(regular_docs), bs):
243
+ batch = regular_docs[start : start + bs]
244
+ batch_ids = regular_ids[start : start + bs]
245
  lc_docs = self._to_documents(batch, batch_ids)
246
  texts = [d.page_content for d in lc_docs]
247
  metas = [d.metadata for d in lc_docs]
 
250
  total += len(batch)
251
 
252
  logger.info(f"Upserted {total} documents to vector store")
253
+ return total + parent_count
254
 
255
  def count(self) -> int:
256
  col = self.collection
 
283
  col.delete(ids=list(ids))
284
  logger.info(f"Deleted {len(ids)} documents from vector store")
285
  return len(ids)
286
+
287
+ def get_parent_node(self, parent_id: str) -> Optional[Dict[str, Any]]:
288
+ return self._parent_nodes.get(parent_id)
289
+
290
+ @property
291
+ def parent_nodes(self) -> Dict[str, Dict[str, Any]]:
292
+ return self._parent_nodes
scripts/test_single_file.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Test Small-to-Big với 1 file duy nhất."""
3
+ import sys
4
+ from pathlib import Path
5
+ from dotenv import find_dotenv, load_dotenv
6
+
7
+ load_dotenv(find_dotenv(usecwd=True))
8
+
9
+ REPO_ROOT = Path(__file__).resolve().parents[1]
10
+ if str(REPO_ROOT) not in sys.path:
11
+ sys.path.insert(0, str(REPO_ROOT))
12
+
13
+ from core.embeddings.chunk import chunk_markdown_file
14
+ from core.embeddings.embedding_model import EmbeddingConfig, QwenEmbeddings
15
+ from core.embeddings.vector_store import ChromaConfig, ChromaVectorDB
16
+
17
+ # Test với 1 file chứa nhiều bảng
18
+ TEST_FILE = REPO_ROOT / "data/data_process/quyet_dinh/tieng_anh/06_ Quy định ngoại ngữ từ K70_chính quy_final.md"
19
+
20
+
21
+ def main():
22
+ print("=" * 60)
23
+ print("TEST SMALL-TO-BIG (1 file)")
24
+ print("=" * 60)
25
+
26
+ # 1. Chunk file
27
+ print(f"\n[1/4] Chunking: {TEST_FILE.name}")
28
+ nodes = chunk_markdown_file(TEST_FILE)
29
+
30
+ parent_nodes = [n for n in nodes if n.metadata.get("is_parent")]
31
+ summary_nodes = [n for n in nodes if n.metadata.get("is_table_summary")]
32
+ other_nodes = [n for n in nodes if not n.metadata.get("is_parent")]
33
+
34
+ print(f" Total nodes: {len(nodes)}")
35
+ print(f" - Parent nodes (NOT embedded): {len(parent_nodes)}")
36
+ print(f" - Summary nodes: {len(summary_nodes)}")
37
+ print(f" - Other nodes (text + small tables): {len(other_nodes) - len(summary_nodes)}")
38
+
39
+ # 2. Init DB (với persist_dir tạm)
40
+ print("\n[2/4] Initializing test DB...")
41
+ emb_cfg = EmbeddingConfig()
42
+ emb = QwenEmbeddings(emb_cfg)
43
+
44
+ # Dùng folder tạm để không ảnh hưởng DB chính
45
+ test_persist = str(REPO_ROOT / "data" / "chroma_test")
46
+ db_cfg = ChromaConfig(persist_dir=test_persist, collection_name="test_s2b")
47
+ db = ChromaVectorDB(embedder=emb, config=db_cfg)
48
+ print(f" Persist dir: {test_persist}")
49
+
50
+ # 3. Upsert
51
+ print("\n[3/4] Upserting documents...")
52
+ count = db.upsert_documents(nodes)
53
+ print(f" Upserted: {count}")
54
+ print(f" ChromaDB count: {db.count()}")
55
+ print(f" Parent nodes stored: {len(db.parent_nodes)}")
56
+
57
+ # Check file JSON
58
+ json_path = Path(test_persist) / "parent_nodes.json"
59
+ if json_path.exists():
60
+ print(f" ✅ parent_nodes.json exists ({json_path.stat().st_size} bytes)")
61
+ else:
62
+ print(f" ❌ parent_nodes.json NOT found!")
63
+
64
+ # 4. Test retrieval
65
+ print("\n[4/4] Testing retrieval...")
66
+ from core.embeddings.retrival import Retriever, RetrievalMode
67
+
68
+ retriever = Retriever(vector_db=db, use_reranker=False)
69
+
70
+ test_query = "TOEIC Nghe 350 điểm tương đương bậc mấy?"
71
+ print(f" Query: {test_query}")
72
+
73
+ results = retriever.vector_search(test_query, k=3)
74
+
75
+ for i, r in enumerate(results, 1):
76
+ meta = r.get("metadata", {})
77
+ content = r.get("content", "")[:200]
78
+
79
+ print(f"\n [{i}]")
80
+ print(f" is_table_summary: {meta.get('is_table_summary', False)}")
81
+ print(f" swapped_from_summary: {meta.get('swapped_from_summary', False)}")
82
+ print(f" source: {meta.get('source_file', 'N/A')}")
83
+ print(f" content: {content}...")
84
+
85
+ print("\n" + "=" * 60)
86
+ print("TEST COMPLETE")
87
+ print("=" * 60)
88
+
89
+ # Cleanup prompt
90
+ print(f"\nTo cleanup test data: rm -rf {test_persist}")
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()
test/chunk_results.md ADDED
The diff for this file is too large to render. See raw diff
 
test/test_small_to_big.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Test Small-to-Big table summarization."""
3
+ from __future__ import annotations
4
+ import sys
5
+ from pathlib import Path
6
+ from datetime import datetime
7
+
8
+ REPO_ROOT = Path(__file__).resolve().parents[1]
9
+ sys.path.insert(0, str(REPO_ROOT))
10
+
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
+
14
+ from core.embeddings.chunk import chunk_markdown_file
15
+
16
+ def test_chunk_with_summary():
17
+ """Test chunking a file with tables to verify summary generation."""
18
+
19
+ # Use the K70 English requirements file which has many tables
20
+ test_file = REPO_ROOT / "data/data_process/quyet_dinh/tieng_anh/06_ Quy định ngoại ngữ từ K70_chính quy_final.md"
21
+
22
+ if not test_file.exists():
23
+ print(f"❌ Test file not found: {test_file}")
24
+ return
25
+
26
+ print(f"📄 Processing: {test_file.name}")
27
+ print("=" * 60)
28
+
29
+ nodes = chunk_markdown_file(test_file)
30
+
31
+ print(f"\n📊 Total nodes: {len(nodes)}")
32
+
33
+ # Count different types
34
+ parent_nodes = [n for n in nodes if n.metadata.get("is_parent")]
35
+ summary_nodes = [n for n in nodes if n.metadata.get("is_table_summary")]
36
+ table_nodes = [n for n in nodes if n.metadata.get("is_table") and not n.metadata.get("is_parent")]
37
+ text_nodes = [n for n in nodes if not n.metadata.get("is_table") and not n.metadata.get("is_table_summary")]
38
+
39
+ print(f" - Parent nodes (raw tables, NOT embedded): {len(parent_nodes)}")
40
+ print(f" - Summary nodes (embedded for search): {len(summary_nodes)}")
41
+ print(f" - Small table nodes (embedded directly): {len(table_nodes)}")
42
+ print(f" - Text nodes: {len(text_nodes)}")
43
+
44
+ # Debug: Show sample metadata
45
+ if nodes:
46
+ print("\n🔍 Sample metadata from first node:")
47
+ sample = nodes[0].metadata
48
+ for k, v in sample.items():
49
+ print(f" - {k}: {v}")
50
+
51
+ # Export to markdown
52
+ output_file = REPO_ROOT / "test" / "chunk_results.md"
53
+ export_to_markdown(nodes, output_file, test_file.name)
54
+ print(f"\n📝 Exported detailed results to: {output_file}")
55
+
56
+
57
+ def export_to_markdown(nodes, output_path: Path, source_name: str):
58
+ """Export all chunks to a markdown file for review."""
59
+
60
+ lines = [
61
+ f"# Chunk Results: {source_name}",
62
+ f"",
63
+ f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
64
+ f"",
65
+ f"## Summary",
66
+ f"",
67
+ f"| Type | Count |",
68
+ f"|------|-------|",
69
+ ]
70
+
71
+ # Count types
72
+ parent_nodes = [n for n in nodes if n.metadata.get("is_parent")]
73
+ summary_nodes = [n for n in nodes if n.metadata.get("is_table_summary")]
74
+ table_nodes = [n for n in nodes if n.metadata.get("is_table") and not n.metadata.get("is_parent")]
75
+ text_nodes = [n for n in nodes if not n.metadata.get("is_table") and not n.metadata.get("is_table_summary")]
76
+
77
+ lines.extend([
78
+ f"| Parent nodes (raw tables, NOT embedded) | {len(parent_nodes)} |",
79
+ f"| Summary nodes (embedded for search) | {len(summary_nodes)} |",
80
+ f"| Small table nodes (embedded directly) | {len(table_nodes)} |",
81
+ f"| Text nodes | {len(text_nodes)} |",
82
+ f"| **Total** | **{len(nodes)}** |",
83
+ f"",
84
+ f"---",
85
+ f"",
86
+ ])
87
+
88
+ # Group: Summary nodes with their parents
89
+ lines.append("## 📝 Summary Nodes (with Parent Tables)")
90
+ lines.append("")
91
+
92
+ parent_map = {n.metadata.get("node_id"): n for n in parent_nodes}
93
+
94
+ for i, node in enumerate(summary_nodes, 1):
95
+ parent_id = node.metadata.get("parent_id", "")
96
+ parent = parent_map.get(parent_id)
97
+ meta = node.metadata
98
+
99
+ lines.append(f"### Summary #{i}")
100
+ lines.append(f"")
101
+ lines.append(f"**Metadata:**")
102
+ lines.append(f"- is_table_summary: True")
103
+ lines.append(f"- parent_id: `{parent_id}`")
104
+ if meta.get("source_file"):
105
+ lines.append(f"- source_file: {meta.get('source_file')}")
106
+ if meta.get("applicable_cohorts"):
107
+ lines.append(f"- applicable_cohorts: {meta.get('applicable_cohorts')}")
108
+ lines.append(f"")
109
+ lines.append(f"**Summary Text (embedded for search):**")
110
+ lines.append(f"")
111
+ lines.append(f"> {node.get_content()}")
112
+ lines.append(f"")
113
+
114
+ if parent:
115
+ lines.append(f"**Parent Table (raw, NOT embedded):**")
116
+ lines.append(f"")
117
+ lines.append(f"```markdown")
118
+ lines.append(parent.get_content())
119
+ lines.append(f"```")
120
+ lines.append(f"")
121
+
122
+ lines.append(f"---")
123
+ lines.append(f"")
124
+
125
+ # Small tables (embedded directly)
126
+ if table_nodes:
127
+ lines.append("## 📋 Small Tables (embedded directly)")
128
+ lines.append("")
129
+
130
+ for i, node in enumerate(table_nodes, 1):
131
+ meta = node.metadata
132
+ lines.append(f"### Small Table #{i}")
133
+ lines.append(f"")
134
+ lines.append(f"**Metadata:**")
135
+ lines.append(f"- is_table: True")
136
+ if meta.get("table_part"):
137
+ lines.append(f"- table_part: {meta.get('table_part')}")
138
+ if meta.get("source_file"):
139
+ lines.append(f"- source_file: {meta.get('source_file')}")
140
+ if meta.get("applicable_cohorts"):
141
+ lines.append(f"- applicable_cohorts: {meta.get('applicable_cohorts')}")
142
+ if meta.get("chunk_index") is not None:
143
+ lines.append(f"- chunk_index: {meta.get('chunk_index')}")
144
+ lines.append(f"")
145
+ lines.append(f"```markdown")
146
+ lines.append(node.get_content())
147
+ lines.append(f"```")
148
+ lines.append(f"")
149
+ lines.append(f"---")
150
+ lines.append(f"")
151
+
152
+ # Text nodes
153
+ lines.append("## 📄 Text Nodes")
154
+ lines.append("")
155
+
156
+ for i, node in enumerate(text_nodes, 1):
157
+ content = node.get_content()
158
+ meta = node.metadata
159
+
160
+ lines.append(f"### Text #{i}")
161
+ lines.append(f"")
162
+ lines.append(f"**Metadata:**")
163
+ if meta.get("document_type"):
164
+ lines.append(f"- document_type: {meta.get('document_type')}")
165
+ if meta.get("title"):
166
+ lines.append(f"- title: {meta.get('title')}")
167
+ if meta.get("applicable_cohorts"):
168
+ lines.append(f"- applicable_cohorts: {meta.get('applicable_cohorts')}")
169
+ if meta.get("source_file"):
170
+ lines.append(f"- source_file: {meta.get('source_file')}")
171
+ if meta.get("header_path"):
172
+ lines.append(f"- header_path: {meta.get('header_path')}")
173
+ if meta.get("Header 1"):
174
+ lines.append(f"- Header 1: {meta.get('Header 1')}")
175
+ if meta.get("chunk_index") is not None:
176
+ lines.append(f"- chunk_index: {meta.get('chunk_index')}")
177
+ lines.append(f"")
178
+
179
+ lines.append(f"**Content:**")
180
+ lines.append(f"")
181
+ lines.append(content)
182
+ lines.append(f"")
183
+ lines.append(f"---")
184
+ lines.append(f"")
185
+
186
+ # Write to file
187
+ output_path.write_text("\n".join(lines), encoding="utf-8")
188
+
189
+
190
+ if __name__ == "__main__":
191
+ test_chunk_with_summary()