hungnha commited on
Commit
6c0b009
·
1 Parent(s): f337fcd

tach retrival

Browse files
core/embeddings/retrival.py CHANGED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional, TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from core.embeddings.vector_store import ChromaVectorDB
8
+
9
+ # Reranker - sentence_transformers với model BGE-M3
10
+ try:
11
+ from sentence_transformers import CrossEncoder
12
+ HAS_RERANKER = True
13
+ except ImportError:
14
+ HAS_RERANKER = False
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ DEFAULT_INITIAL_K = 100
19
+ DEFAULT_TOP_K = 5
20
+ RERANKER_MAX_LENGTH = 512
21
+
22
+ class Retriever:
23
+ def __init__(
24
+ self,
25
+ vector_db: "ChromaVectorDB",
26
+ reranker_model: str = "BAAI/bge-reranker-v2-m3",
27
+ use_reranker: bool = True,
28
+ ):
29
+
30
+ self._vector_db = vector_db
31
+ self._reranker: Optional[Any] = None
32
+ self._reranker_model = reranker_model
33
+
34
+ if use_reranker and HAS_RERANKER:
35
+ self._load_reranker(reranker_model)
36
+
37
+ def _load_reranker(self, model_name: str) -> None:
38
+ try:
39
+ logger.info(f"Loading reranker: {model_name}...")
40
+ self._reranker = CrossEncoder(model_name, max_length=RERANKER_MAX_LENGTH)
41
+ logger.info("Reranker loaded successfully!")
42
+ except Exception as e:
43
+ logger.error(f"Reranker failed to load: {e}")
44
+ self._reranker = None
45
+
46
+ @property
47
+ def has_reranker(self) -> bool:
48
+ return self._reranker is not None
49
+
50
+ def query(
51
+ self,
52
+ text: str,
53
+ *,
54
+ k: int = DEFAULT_TOP_K,
55
+ where: Optional[Dict[str, Any]] = None,
56
+ ) -> List[Dict[str, Any]]:
57
+
58
+ if not text.strip():
59
+ return []
60
+
61
+ if k <= 0:
62
+ raise ValueError("k must be positive")
63
+
64
+ vectorstore = self._vector_db.vectorstore
65
+ results = vectorstore.similarity_search_with_score(text, k=k, filter=where)
66
+
67
+ out: List[Dict[str, Any]] = []
68
+ for doc, score in results:
69
+ out.append({
70
+ "id": (doc.metadata or {}).get("id"),
71
+ "content": doc.page_content,
72
+ "metadata": doc.metadata,
73
+ "distance": score,
74
+ })
75
+ return out
76
+
77
+ def search_with_rerank(
78
+ self,
79
+ text: str,
80
+ *,
81
+ k: int = DEFAULT_TOP_K,
82
+ where: Optional[Dict[str, Any]] = None,
83
+ initial_k: int = DEFAULT_INITIAL_K,
84
+ ) -> List[Dict[str, Any]]:
85
+
86
+ if not text.strip():
87
+ return []
88
+
89
+ if k <= 0:
90
+ raise ValueError("k must be positive")
91
+
92
+ if initial_k < k:
93
+ logger.warning(f"initial_k ({initial_k}) < k ({k}), setting initial_k = k")
94
+ initial_k = k
95
+
96
+ # Stage 1: Vector Search
97
+ vectorstore = self._vector_db.vectorstore
98
+ vector_results = vectorstore.similarity_search_with_score(
99
+ text, k=initial_k, filter=where
100
+ )
101
+
102
+ if not vector_results:
103
+ return []
104
+
105
+ # Build candidates list
106
+ candidates = []
107
+ for rank, (doc, score) in enumerate(vector_results):
108
+ doc_id = (doc.metadata or {}).get("id", doc.page_content[:50])
109
+ candidates.append({
110
+ "id": doc_id,
111
+ "content": doc.page_content,
112
+ "metadata": doc.metadata,
113
+ "vector_distance": score,
114
+ "vector_rank": rank + 1,
115
+ })
116
+
117
+ # Stage 2: Re-ranking
118
+ candidates = self._rerank_candidates(text, candidates)
119
+
120
+ # Add final rank
121
+ for i, c in enumerate(candidates[:k]):
122
+ c["final_rank"] = i + 1
123
+
124
+ return candidates[:k]
125
+
126
+ def _rerank_candidates(
127
+ self,
128
+ query: str,
129
+ candidates: List[Dict[str, Any]],
130
+ ) -> List[Dict[str, Any]]:
131
+
132
+ if self._reranker and len(candidates) > 1:
133
+ try:
134
+ pairs = [[query, c["content"]] for c in candidates]
135
+ scores = self._reranker.predict(pairs)
136
+
137
+ for i, score in enumerate(scores):
138
+ candidates[i]["rerank_score"] = float(score)
139
+
140
+ candidates.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
141
+
142
+ except Exception as e:
143
+ logger.error(f"Rerank error: {e}")
144
+ self._fallback_scoring(candidates)
145
+ else:
146
+ # No reranker: use inverse vector distance as fallback
147
+ self._fallback_scoring(candidates)
148
+
149
+ return candidates
150
+
151
+ def _fallback_scoring(self, candidates: List[Dict[str, Any]]) -> None:
152
+ """Apply fallback scoring using inverse vector distance."""
153
+ for c in candidates:
154
+ c["rerank_score"] = 1.0 / (1.0 + c["vector_distance"])
core/embeddings/vector_store.py CHANGED
@@ -1,307 +1,239 @@
1
  from __future__ import annotations
 
2
  import json
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
  from typing import Any, Dict, List, Optional, Sequence
6
- from core.hash_file.hash_file import HashProcessor
7
  from langchain_core.documents import Document
8
  from langchain_chroma import Chroma
 
 
9
  from utils.helpers import read_yaml
10
 
11
- # Reranker - sentence_transformers với model BGE-M3
12
- try:
13
- from sentence_transformers import CrossEncoder
14
- HAS_RERANKER = True
15
- except ImportError:
16
- HAS_RERANKER = False
17
 
18
 
19
  @dataclass
20
  class ChromaConfig:
21
- persist_dir: str
22
- collection_name: str
23
- space: str
24
-
25
- @staticmethod
26
- def default_yaml_path() -> Path:
27
- return Path(__file__).resolve().parents[2] / "config" / "vector_db.yaml"
28
-
29
- @classmethod
30
- def from_yaml(cls, path: str | Path | None = None) -> "ChromaConfig":
31
- cfg_path = Path(path) if path is not None else cls.default_yaml_path()
32
- try:
33
- if not cfg_path.exists():
34
- raise FileNotFoundError(f"Vector DB config not found: {cfg_path}")
35
- data = read_yaml(cfg_path) or {}
36
- if not isinstance(data, dict):
37
- raise ValueError(f"Invalid config format: {cfg_path}")
38
-
39
- required = {"persist_dir", "collection_name", "space"}
40
- missing = sorted([k for k in required if k not in data])
41
- if missing:
42
- raise KeyError(f"Missing keys in {cfg_path}: {', '.join(missing)}")
43
-
44
- cfg = cls(
45
- persist_dir=str(data["persist_dir"]),
46
- collection_name=str(data["collection_name"]),
47
- space=str(data["space"]),
48
- )
49
- p = Path(cfg.persist_dir)
50
- if not p.is_absolute():
51
- cfg.persist_dir = str((cfg_path.parent.parent / p).resolve())
52
- return cfg
53
- except Exception:
54
- raise
 
55
 
56
 
57
  class ChromaVectorDB:
58
- def __init__(
59
- self,
60
- embedder: Any,
61
- config: ChromaConfig | None = None,
62
- reranker_model: str = "BAAI/bge-reranker-v2-m3",
63
- ):
64
- self.embedder = embedder
65
- self.config = config or ChromaConfig.from_yaml()
66
- self._hasher = HashProcessor(verbose=False)
67
-
68
- self._vs = Chroma(
69
- collection_name=self.config.collection_name,
70
- embedding_function=self.embedder,
71
- persist_directory=self.config.persist_dir,
72
- )
73
-
74
- # Reranker (Cross-Encoder)
75
- self._reranker: Optional[Any] = None
76
- self._reranker_model = reranker_model
77
-
78
- if HAS_RERANKER:
79
- try:
80
- print(f"Loading reranker: {reranker_model}...")
81
- self._reranker = CrossEncoder(reranker_model, max_length=512)
82
- print(f"Reranker loaded successfully!")
83
- except Exception as e:
84
- print(f" Reranker failed to load: {e}")
85
- self._reranker = None
86
-
87
- @property
88
- def collection(self):
89
- return getattr(self._vs, "_collection", None)
90
-
91
- @property
92
- def vectorstore(self):
93
- return self._vs
94
-
95
- def _flatten_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
96
- out: Dict[str, Any] = {}
97
- for k, v in (metadata or {}).items():
98
- key = str(k)
99
- if v is None:
100
- continue
101
- if isinstance(v, (str, int, float, bool)):
102
- out[key] = v
103
- continue
104
- if isinstance(v, (list, tuple, set, dict)):
105
- out[key] = json.dumps(v, ensure_ascii=False)
106
- continue
107
- out[key] = str(v)
108
- return out
109
-
110
- def _to_documents(self, docs: Sequence[Dict[str, Any]], ids: Sequence[str]) -> List[Document]:
111
- out: List[Document] = []
112
- for d, doc_id in zip(docs, ids):
113
- md = self._flatten_metadata(d.get("metadata", {}) or {})
114
- md.setdefault("id", doc_id)
115
- out.append(Document(page_content=d.get("content", ""), metadata=md))
116
- return out
117
-
118
- def _doc_id(self, doc: Dict[str, Any]) -> str:
119
- md = doc.get("metadata") or {}
120
- key = {
121
- "source_path": md.get("source_path"),
122
- "source_file": md.get("source_file"),
123
- "source_basename": md.get("source_basename"),
124
- "section": md.get("section"),
125
- "section_path": md.get("section_path"),
126
- "type": md.get("type"),
127
- "course_code": md.get("course_code"),
128
- "stt": md.get("stt"),
129
- "chunk_index": md.get("chunk_index"),
130
- "chunk_in_section": md.get("chunk_in_section"),
131
- "content": doc.get("content"),
132
- }
133
- return self._hasher.get_string_hash(str(key))
134
-
135
- def _ensure_unique_ids(self, ids: Sequence[str]) -> List[str]:
136
- seen: Dict[str, int] = {}
137
- out: List[str] = []
138
- for i in ids:
139
- base = str(i)
140
- n = seen.get(base, 0)
141
- seen[base] = n + 1
142
- out.append(base if n == 0 else f"{base}__dup{n}")
143
- return out
144
-
145
- def add_documents(
146
- self,
147
- docs: Sequence[Dict[str, Any]],
148
- *,
149
- ids: Optional[Sequence[str]] = None,
150
- batch_size: int = 128,
151
- ) -> int:
152
- if not docs:
153
- return 0
154
-
155
- if ids is not None and len(ids) != len(docs):
156
- raise ValueError("ids length must match docs length")
157
-
158
- all_ids = list(ids) if ids is not None else [self._doc_id(d) for d in docs]
159
- all_ids = self._ensure_unique_ids(all_ids)
160
- bs = max(1, batch_size)
161
- total = 0
162
- for start in range(0, len(docs), bs):
163
- batch = docs[start : start + bs]
164
- batch_ids = all_ids[start : start + bs]
165
- lc_docs = self._to_documents(batch, batch_ids)
166
-
167
- try:
168
- self._vs.add_documents(lc_docs, ids=batch_ids)
169
- except TypeError:
170
- texts = [d.page_content for d in lc_docs]
171
- metas = [d.metadata for d in lc_docs]
172
- self._vs.add_texts(texts=texts, metadatas=metas, ids=batch_ids)
173
- total += len(batch)
174
-
175
- return total
176
-
177
- def upsert_documents(
178
- self,
179
- docs: Sequence[Dict[str, Any]],
180
- *,
181
- ids: Optional[Sequence[str]] = None,
182
- batch_size: int = 128,
183
- ) -> int:
184
- if not docs:
185
- return 0
186
-
187
- if ids is not None and len(ids) != len(docs):
188
- raise ValueError("ids length must match docs length")
189
-
190
- all_ids = list(ids) if ids is not None else [self._doc_id(d) for d in docs]
191
- all_ids = self._ensure_unique_ids(all_ids)
192
- bs = max(1, batch_size)
193
- col = getattr(self._vs, "_collection", None)
194
- if col is None:
195
- return self.add_documents(docs, ids=all_ids, batch_size=bs)
196
-
197
- total = 0
198
- for start in range(0, len(docs), bs):
199
- batch = docs[start : start + bs]
200
- batch_ids = all_ids[start : start + bs]
201
- lc_docs = self._to_documents(batch, batch_ids)
202
- texts = [d.page_content for d in lc_docs]
203
- metas = [d.metadata for d in lc_docs]
204
- embs = self.embedder.embed_documents(texts)
205
- col.upsert(ids=batch_ids, documents=texts, metadatas=metas, embeddings=embs)
206
- total += len(batch)
207
-
208
- return total
209
-
210
- def query(
211
- self,
212
- text: str,
213
- *,
214
- k: int = 5,
215
- where: Optional[Dict[str, Any]] = None,
216
- ) -> List[Dict[str, Any]]:
217
- if not text.strip():
218
- return []
219
-
220
- results = self._vs.similarity_search_with_score(text, k=k, filter=where)
221
- out: List[Dict[str, Any]] = []
222
- for doc, score in results:
223
- out.append({
224
- "id": (doc.metadata or {}).get("id"),
225
- "content": doc.page_content,
226
- "metadata": doc.metadata,
227
- "distance": score,
228
- })
229
- return out
230
-
231
- def count(self) -> int:
232
- col = getattr(self._vs, "_collection", None)
233
- if col is None:
234
- return 0
235
- return int(col.count())
236
-
237
- def get_all_documents(self, limit: int = 5000) -> List[Dict[str, Any]]:
238
- col = self.collection
239
- if col is None:
240
- return []
241
-
242
- result = col.get(limit=limit, include=['documents', 'metadatas'])
243
- docs = []
244
- for i, doc_content in enumerate(result.get('documents', [])):
245
- if doc_content:
246
- meta = result['metadatas'][i] if result.get('metadatas') else {}
247
- docs.append({
248
- 'id': result['ids'][i] if result.get('ids') else str(i),
249
- 'content': doc_content,
250
- 'metadata': meta or {},
251
- })
252
- return docs
253
-
254
- def search_with_rerank(
255
- self,
256
- text: str,
257
- *,
258
- k: int = 5,
259
- where: Optional[Dict[str, Any]] = None,
260
- initial_k: int = 100,
261
- ) -> List[Dict[str, Any]]:
262
-
263
- if not text.strip():
264
- return []
265
-
266
- # Stage 1: Vector Search
267
- vector_results = self._vs.similarity_search_with_score(text, k=initial_k, filter=where)
268
-
269
- if not vector_results:
270
- return []
271
-
272
- candidates = []
273
- for rank, (doc, score) in enumerate(vector_results):
274
- doc_id = (doc.metadata or {}).get("id", doc.page_content[:50])
275
- candidates.append({
276
- "id": doc_id,
277
- "content": doc.page_content,
278
- "metadata": doc.metadata,
279
- "vector_distance": score,
280
- "vector_rank": rank + 1,
281
- })
282
-
283
- # Stage 2: Re-ranking
284
- if self._reranker and len(candidates) > 1:
285
- try:
286
- pairs = [[text, c["content"]] for c in candidates]
287
- scores = self._reranker.predict(pairs)
288
-
289
- for i, score in enumerate(scores):
290
- candidates[i]["rerank_score"] = float(score)
291
-
292
- candidates.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
293
-
294
- except Exception as e:
295
- print(f" Rerank error: {e}")
296
- for c in candidates:
297
- c["rerank_score"] = 0.0
298
- else:
299
- # No reranker: use inverse vector distance
300
- for c in candidates:
301
- c["rerank_score"] = 1.0 / (1.0 + c["vector_distance"])
302
-
303
- # Add final rank
304
- for i, c in enumerate(candidates[:k]):
305
- c["final_rank"] = i + 1
306
-
307
- return candidates[:k]
 
1
  from __future__ import annotations
2
+
3
  import json
4
+ import logging
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
  from typing import Any, Dict, List, Optional, Sequence
8
+
9
  from langchain_core.documents import Document
10
  from langchain_chroma import Chroma
11
+
12
+ from core.hash_file.hash_file import HashProcessor
13
  from utils.helpers import read_yaml
14
 
15
+ logger = logging.getLogger(__name__)
 
 
 
 
 
16
 
17
 
18
  @dataclass
19
  class ChromaConfig:
20
+
21
+ persist_dir: str
22
+ collection_name: str
23
+ space: str
24
+
25
+ @staticmethod
26
+ def default_yaml_path() -> Path:
27
+ return Path(__file__).resolve().parents[2] / "config" / "vector_db.yaml"
28
+
29
+ @classmethod
30
+ def from_yaml(cls, path: str | Path | None = None) -> "ChromaConfig":
31
+ cfg_path = Path(path) if path is not None else cls.default_yaml_path()
32
+ try:
33
+ if not cfg_path.exists():
34
+ raise FileNotFoundError(f"Vector DB config not found: {cfg_path}")
35
+ data = read_yaml(cfg_path) or {}
36
+ if not isinstance(data, dict):
37
+ raise ValueError(f"Invalid config format: {cfg_path}")
38
+
39
+ required = {"persist_dir", "collection_name", "space"}
40
+ missing = sorted([k for k in required if k not in data])
41
+ if missing:
42
+ raise KeyError(f"Missing keys in {cfg_path}: {', '.join(missing)}")
43
+
44
+ cfg = cls(
45
+ persist_dir=str(data["persist_dir"]),
46
+ collection_name=str(data["collection_name"]),
47
+ space=str(data["space"]),
48
+ )
49
+ p = Path(cfg.persist_dir)
50
+ if not p.is_absolute():
51
+ cfg.persist_dir = str((cfg_path.parent.parent / p).resolve())
52
+ return cfg
53
+ except Exception:
54
+ raise
55
 
56
 
57
  class ChromaVectorDB:
58
+ def __init__(
59
+ self,
60
+ embedder: Any,
61
+ config: ChromaConfig | None = None,
62
+ ):
63
+
64
+ self.embedder = embedder
65
+ self.config = config or ChromaConfig.from_yaml()
66
+ self._hasher = HashProcessor(verbose=False)
67
+
68
+ self._vs = Chroma(
69
+ collection_name=self.config.collection_name,
70
+ embedding_function=self.embedder,
71
+ persist_directory=self.config.persist_dir,
72
+ )
73
+ logger.info(f"ChromaVectorDB initialized: {self.config.collection_name}")
74
+
75
+ @property
76
+ def collection(self):
77
+ return getattr(self._vs, "_collection", None)
78
+
79
+ @property
80
+ def vectorstore(self):
81
+ return self._vs
82
+
83
+ def _flatten_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
84
+ out: Dict[str, Any] = {}
85
+ for k, v in (metadata or {}).items():
86
+ key = str(k)
87
+ if v is None:
88
+ continue
89
+ if isinstance(v, (str, int, float, bool)):
90
+ out[key] = v
91
+ continue
92
+ if isinstance(v, (list, tuple, set, dict)):
93
+ out[key] = json.dumps(v, ensure_ascii=False)
94
+ continue
95
+ out[key] = str(v)
96
+ return out
97
+
98
+ def _to_documents(self, docs: Sequence[Dict[str, Any]], ids: Sequence[str]) -> List[Document]:
99
+ out: List[Document] = []
100
+ for d, doc_id in zip(docs, ids):
101
+ md = self._flatten_metadata(d.get("metadata", {}) or {})
102
+ md.setdefault("id", doc_id)
103
+ out.append(Document(page_content=d.get("content", ""), metadata=md))
104
+ return out
105
+
106
+ def _doc_id(self, doc: Dict[str, Any]) -> str:
107
+ md = doc.get("metadata") or {}
108
+ key = {
109
+ "source_path": md.get("source_path"),
110
+ "source_file": md.get("source_file"),
111
+ "source_basename": md.get("source_basename"),
112
+ "section": md.get("section"),
113
+ "section_path": md.get("section_path"),
114
+ "type": md.get("type"),
115
+ "course_code": md.get("course_code"),
116
+ "stt": md.get("stt"),
117
+ "chunk_index": md.get("chunk_index"),
118
+ "chunk_in_section": md.get("chunk_in_section"),
119
+ "content": doc.get("content"),
120
+ }
121
+ return self._hasher.get_string_hash(str(key))
122
+
123
+ def _ensure_unique_ids(self, ids: Sequence[str]) -> List[str]:
124
+ seen: Dict[str, int] = {}
125
+ out: List[str] = []
126
+ for i in ids:
127
+ base = str(i)
128
+ n = seen.get(base, 0)
129
+ seen[base] = n + 1
130
+ out.append(base if n == 0 else f"{base}__dup{n}")
131
+ return out
132
+
133
+ def add_documents(
134
+ self,
135
+ docs: Sequence[Dict[str, Any]],
136
+ *,
137
+ ids: Optional[Sequence[str]] = None,
138
+ batch_size: int = 128,
139
+ ) -> int:
140
+
141
+ if not docs:
142
+ return 0
143
+
144
+ if ids is not None and len(ids) != len(docs):
145
+ raise ValueError("ids length must match docs length")
146
+
147
+ all_ids = list(ids) if ids is not None else [self._doc_id(d) for d in docs]
148
+ all_ids = self._ensure_unique_ids(all_ids)
149
+ bs = max(1, batch_size)
150
+ total = 0
151
+
152
+ for start in range(0, len(docs), bs):
153
+ batch = docs[start : start + bs]
154
+ batch_ids = all_ids[start : start + bs]
155
+ lc_docs = self._to_documents(batch, batch_ids)
156
+
157
+ try:
158
+ self._vs.add_documents(lc_docs, ids=batch_ids)
159
+ except TypeError:
160
+ texts = [d.page_content for d in lc_docs]
161
+ metas = [d.metadata for d in lc_docs]
162
+ self._vs.add_texts(texts=texts, metadatas=metas, ids=batch_ids)
163
+ total += len(batch)
164
+
165
+ logger.info(f"Added {total} documents to vector store")
166
+ return total
167
+
168
+ def upsert_documents(
169
+ self,
170
+ docs: Sequence[Dict[str, Any]],
171
+ *,
172
+ ids: Optional[Sequence[str]] = None,
173
+ batch_size: int = 128,
174
+ ) -> int:
175
+
176
+ if not docs:
177
+ return 0
178
+
179
+ if ids is not None and len(ids) != len(docs):
180
+ raise ValueError("ids length must match docs length")
181
+
182
+ all_ids = list(ids) if ids is not None else [self._doc_id(d) for d in docs]
183
+ all_ids = self._ensure_unique_ids(all_ids)
184
+ bs = max(1, batch_size)
185
+ col = getattr(self._vs, "_collection", None)
186
+
187
+ if col is None:
188
+ return self.add_documents(docs, ids=all_ids, batch_size=bs)
189
+
190
+ total = 0
191
+ for start in range(0, len(docs), bs):
192
+ batch = docs[start : start + bs]
193
+ batch_ids = all_ids[start : start + bs]
194
+ lc_docs = self._to_documents(batch, batch_ids)
195
+ texts = [d.page_content for d in lc_docs]
196
+ metas = [d.metadata for d in lc_docs]
197
+ embs = self.embedder.embed_documents(texts)
198
+ col.upsert(ids=batch_ids, documents=texts, metadatas=metas, embeddings=embs)
199
+ total += len(batch)
200
+
201
+ logger.info(f"Upserted {total} documents to vector store")
202
+ return total
203
+
204
+ def count(self) -> int:
205
+ col = getattr(self._vs, "_collection", None)
206
+ if col is None:
207
+ return 0
208
+ return int(col.count())
209
+
210
+ def get_all_documents(self, limit: int = 5000) -> List[Dict[str, Any]]:
211
+
212
+ col = self.collection
213
+ if col is None:
214
+ return []
215
+
216
+ result = col.get(limit=limit, include=['documents', 'metadatas'])
217
+ docs = []
218
+ for i, doc_content in enumerate(result.get('documents', [])):
219
+ if doc_content:
220
+ meta = result['metadatas'][i] if result.get('metadatas') else {}
221
+ docs.append({
222
+ 'id': result['ids'][i] if result.get('ids') else str(i),
223
+ 'content': doc_content,
224
+ 'metadata': meta or {},
225
+ })
226
+ return docs
227
+
228
+ def delete_documents(self, ids: Sequence[str]) -> int:
229
+
230
+ if not ids:
231
+ return 0
232
+
233
+ col = self.collection
234
+ if col is None:
235
+ return 0
236
+
237
+ col.delete(ids=list(ids))
238
+ logger.info(f"Deleted {len(ids)} documents from vector store")
239
+ return len(ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/gradio/gradio_rag_qwen.py CHANGED
@@ -21,6 +21,7 @@ def _load_env() -> None:
21
 
22
  from core.embeddings.embedding_model import VietnameseBiEncoderConfig, VietnameseBiEncoderEmbeddings
23
  from core.embeddings.vector_store import ChromaConfig, ChromaVectorDB
 
24
 
25
  _load_env()
26
 
@@ -32,6 +33,7 @@ GROQ_MODEL = os.getenv("GROQ_MODEL", "qwen/qwen3-32b")
32
  class AppState:
33
  def __init__(self) -> None:
34
  self.db: Optional[ChromaVectorDB] = None
 
35
  self.groq: Optional[Groq] = None
36
 
37
 
@@ -56,6 +58,7 @@ def _init_resources() -> None:
56
  embedder=emb,
57
  config=db_cfg,
58
  )
 
59
 
60
  api_key = (os.getenv("GROQ_API_KEY") or "").strip()
61
  if not api_key:
@@ -72,7 +75,8 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
72
  assert STATE.groq is not None
73
 
74
  # Vector Search + Re-ranking
75
- results = STATE.db.search_with_rerank(message, k=TOP_K, initial_k=50)
 
76
 
77
  if not results:
78
  yield "Xin lỗi, tôi không tìm thấy thông tin phù hợp trong dữ liệu."
 
21
 
22
  from core.embeddings.embedding_model import VietnameseBiEncoderConfig, VietnameseBiEncoderEmbeddings
23
  from core.embeddings.vector_store import ChromaConfig, ChromaVectorDB
24
+ from core.embeddings.retrival import Retriever
25
 
26
  _load_env()
27
 
 
33
  class AppState:
34
  def __init__(self) -> None:
35
  self.db: Optional[ChromaVectorDB] = None
36
+ self.retriever: Optional[Retriever] = None
37
  self.groq: Optional[Groq] = None
38
 
39
 
 
58
  embedder=emb,
59
  config=db_cfg,
60
  )
61
+ STATE.retriever = Retriever(vector_db=STATE.db)
62
 
63
  api_key = (os.getenv("GROQ_API_KEY") or "").strip()
64
  if not api_key:
 
75
  assert STATE.groq is not None
76
 
77
  # Vector Search + Re-ranking
78
+ assert STATE.retriever is not None
79
+ results = STATE.retriever.search_with_rerank(message, k=TOP_K, initial_k=50)
80
 
81
  if not results:
82
  yield "Xin lỗi, tôi không tìm thấy thông tin phù hợp trong dữ liệu."
evaluation/simple_eval.py CHANGED
@@ -22,6 +22,7 @@ load_dotenv(find_dotenv(usecwd=True))
22
  from langchain_groq import ChatGroq
23
  from core.embeddings.embedding_model import VietnameseBiEncoderConfig, VietnameseBiEncoderEmbeddings
24
  from core.embeddings.vector_store import ChromaConfig, ChromaVectorDB
 
25
 
26
  TOP_K = int(os.getenv("TOP_K", "5"))
27
  INITIAL_K = int(os.getenv("INITIAL_K", "50"))
@@ -51,6 +52,7 @@ def extract_keywords(text: str) -> set:
51
  class SimpleRAGEvaluator:
52
  def __init__(self):
53
  self.db: Optional[ChromaVectorDB] = None
 
54
  self.embedder: Optional[VietnameseBiEncoderEmbeddings] = None
55
  self.llm = None
56
  self.llm_fast = None
@@ -71,6 +73,7 @@ class SimpleRAGEvaluator:
71
  print(f"Vector DB: {db_cfg.collection_name}")
72
 
73
  self.db = ChromaVectorDB(embedder=self.embedder, config=db_cfg)
 
74
 
75
  api_key = os.getenv("GROQ_API_KEY")
76
  if not api_key:
@@ -115,20 +118,48 @@ TRẢ LỜI:"""
115
 
116
  def retrieve_contexts(self, question: str) -> List[str]:
117
  try:
118
- results = self.db.search_with_rerank(question, k=TOP_K, initial_k=INITIAL_K)
119
  return [r.get("content", "")[:1000] for r in results if r.get("content")]
120
  except Exception as e:
121
  print(f"Retrieval error: {e}")
122
  return []
123
 
124
- def calculate_semantic_similarity(self, text1: str, text2: str) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  try:
126
- emb1 = np.array(self.embedder.embed_query(text1))
127
- emb2 = np.array(self.embedder.embed_query(text2))
128
- return cosine_similarity(emb1, emb2)
 
 
129
  except Exception as e:
130
- print(f"Embedding error: {e}")
131
- return 0.0
132
 
133
  def calculate_keyword_overlap(self, answer: str, ground_truth: str) -> float:
134
  gt_keywords = extract_keywords(ground_truth)
 
22
  from langchain_groq import ChatGroq
23
  from core.embeddings.embedding_model import VietnameseBiEncoderConfig, VietnameseBiEncoderEmbeddings
24
  from core.embeddings.vector_store import ChromaConfig, ChromaVectorDB
25
+ from core.embeddings.retrival import Retriever
26
 
27
  TOP_K = int(os.getenv("TOP_K", "5"))
28
  INITIAL_K = int(os.getenv("INITIAL_K", "50"))
 
52
  class SimpleRAGEvaluator:
53
  def __init__(self):
54
  self.db: Optional[ChromaVectorDB] = None
55
+ self.retriever: Optional[Retriever] = None
56
  self.embedder: Optional[VietnameseBiEncoderEmbeddings] = None
57
  self.llm = None
58
  self.llm_fast = None
 
73
  print(f"Vector DB: {db_cfg.collection_name}")
74
 
75
  self.db = ChromaVectorDB(embedder=self.embedder, config=db_cfg)
76
+ self.retriever = Retriever(vector_db=self.db)
77
 
78
  api_key = os.getenv("GROQ_API_KEY")
79
  if not api_key:
 
118
 
119
  def retrieve_contexts(self, question: str) -> List[str]:
120
  try:
121
+ results = self.retriever.search_with_rerank(question, k=TOP_K, initial_k=INITIAL_K)
122
  return [r.get("content", "")[:1000] for r in results if r.get("content")]
123
  except Exception as e:
124
  print(f"Retrieval error: {e}")
125
  return []
126
 
127
+ def calculate_semantic_similarity(self, answer: str, ground_truth: str) -> float:
128
+ """
129
+ Đánh giá semantic similarity giữa answer và ground_truth bằng LLM.
130
+ Thay thế cosine similarity bằng LLM-based scoring.
131
+ """
132
+ if not answer.strip() or not ground_truth.strip():
133
+ return 0.0
134
+
135
+ prompt = f"""Bạn là giám khảo chấm thi.
136
+ Nhiệm vụ: So sánh độ tương đồng ngữ nghĩa giữa CÂU TRẢ LỜI và ĐÁP ÁN CHUẨN.
137
+
138
+ ĐÁP ÁN CHUẨN:
139
+ {ground_truth[:800]}
140
+
141
+ CÂU TRẢ LỜI:
142
+ {answer[:800]}
143
+
144
+ Yêu cầu đánh giá độ tương đồng ngữ nghĩa:
145
+ - 1.0: Câu trả lời chứa đầy đủ và chính xác thông tin như đáp án chuẩn
146
+ - 0.8: Câu trả lời đúng ý chính, có thể thiếu một số chi tiết nhỏ
147
+ - 0.6: Câu trả lời đúng một phần, thiếu một số thông tin quan trọng
148
+ - 0.4: Câu trả lời có liên quan nhưng thiếu nhiều thông tin hoặc không chính xác
149
+ - 0.2: Câu trả lời chỉ đúng một phần rất nhỏ
150
+ - 0.0: Câu trả lời hoàn toàn sai hoặc không liên quan
151
+
152
+ CHỈ TRẢ VỀ MỘT CON SỐ (0.0, 0.2, 0.4, 0.6, 0.8 hoặc 1.0), KHÔNG GIẢI THÍCH:"""
153
+
154
  try:
155
+ response = self.llm_fast.invoke(prompt).content.strip()
156
+ match = re.search(r"(1\.0|0\.\d|0|1)", response)
157
+ if match:
158
+ return float(match.group())
159
+ return 0.5
160
  except Exception as e:
161
+ print(f"Semantic similarity LLM error: {e}")
162
+ return 0.5
163
 
164
  def calculate_keyword_overlap(self, answer: str, ground_truth: str) -> float:
165
  gt_keywords = extract_keywords(ground_truth)