hungnha commited on
Commit
39f858f
·
1 Parent(s): 11133c9

đổi promt

Browse files
core/embeddings/generator.py CHANGED
@@ -13,10 +13,17 @@ SYSTEM_PROMPT = """Bạn là Trợ lý học vụ Đại học Bách khoa Hà N
13
 
14
  ## NGUYÊN TẮC:
15
  1. Chỉ trả lời dựa trên CONTEXT được cung cấp. Không suy đoán, không bổ sung thông tin ngoài CONTEXT.
16
- 2. Nếu trong CONTEXT có nội dung về "Hiệu lực thi hành" hoặc "Điều khoản chuyển tiếp", hãy nêu phạm vi áp dụng (theo khóa hoặc thời gian) đúng như nội dung đã nêu. Nếu CONTEXT không đề cập, không tự suy luận.
17
- 3. Nếu CONTEXT chứa nhiều văn bản khác nhau, ưu tiên nội dung được nêu đang áp dụng, hoặc ghi thời điểm hiệu lực. Không tự xác định văn bản mới/cũ nếu CONTEXT không nói rõ.
18
  4. Cuối câu trả lời, trích dẫn nguồn đúng theo tài liệu xuất hiện trong CONTEXT. Không tự tạo nguồn.
19
- 5. Nếu không tìm thấy thông tin trong CONTEXT, trả lời: "Không tìm thấy thông tin trong dữ liệu hiện có."
 
 
 
 
 
 
 
20
  """
21
 
22
 
@@ -117,3 +124,29 @@ class RAGGenerator:
117
  if delta:
118
  acc += delta
119
  yield acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  ## NGUYÊN TẮC:
15
  1. Chỉ trả lời dựa trên CONTEXT được cung cấp. Không suy đoán, không bổ sung thông tin ngoài CONTEXT.
16
+ 2. Nếu trong CONTEXT có nội dung về "Hiệu lực thi hành" hoặc "Điều khoản chuyển tiếp", KIỂM TRA xem ngoại lệ theo khóa/thời gian không GHI RÕ.
17
+ 3. Nếu CONTEXT chứa nhiều văn bản khác nhau, ưu tiên nội dung mới nhất, TRỪ KHIđiều khoản chuyển tiếp nói khác.
18
  4. Cuối câu trả lời, trích dẫn nguồn đúng theo tài liệu xuất hiện trong CONTEXT. Không tự tạo nguồn.
19
+ 5. PHÂN BIỆT các loại CTĐT:
20
+ - CTĐT CHUẨN: Phụ lục III (Bảng 3.x) - áp dụng cho đa số sinh viên
21
+ - CTĐT TÀI NĂNG: Phụ lục IV (Bảng 4.x)
22
+ - CTĐT ELITECH/Tiên tiến: Phụ lục V (Bảng 5.x)
23
+ - CTĐT HỢP TÁC QUỐC TẾ: Phụ lục VI (Bảng 6.x)
24
+ - CTĐT NGÔN NGỮ (FL1, FL2, FL3): Phụ lục VIII - KHÔNG ÁP DỤNG cho sinh viên thường
25
+ Khi người dùng nói "chương trình chuẩn", CHỈ trả lời theo Phụ lục III, KHÔNG lẫn với ngành ngôn ngữ.
26
+ 6. Nếu không tìm thấy thông tin trong CONTEXT, trả lời: "Không tìm thấy thông tin trong dữ liệu hiện có."
27
  """
28
 
29
 
 
124
  if delta:
125
  acc += delta
126
  yield acc
127
+
128
+ def generate_stream_from_results(
129
+ self, question: str, results: List[Dict[str, Any]]
130
+ ) -> Generator[str, None, None]:
131
+ """Stream generation from pre-fetched results (no retrieval)."""
132
+ if not results:
133
+ yield "Không tìm thấy thông tin trong dữ liệu hiện có."
134
+ return
135
+
136
+ context = build_context(results, self._max_context_chars)
137
+ prompt = self._build_prompt(question, context)
138
+
139
+ completion = self._groq.chat.completions.create(
140
+ model=self._llm_model,
141
+ messages=[{"role": "user", "content": prompt}],
142
+ temperature=self._temperature,
143
+ max_completion_tokens=self._max_tokens,
144
+ stream=True,
145
+ )
146
+
147
+ acc = ""
148
+ for chunk in completion:
149
+ delta = getattr(chunk.choices[0].delta, "content", "") or ""
150
+ if delta:
151
+ acc += delta
152
+ yield acc
core/embeddings/retrival.py CHANGED
@@ -1,8 +1,11 @@
1
  from __future__ import annotations
2
  import os
 
3
  import logging
4
  from dataclasses import dataclass
 
5
  from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING
 
6
  import requests
7
  from pydantic import Field
8
  from langchain_core.documents import Document
@@ -18,6 +21,14 @@ if TYPE_CHECKING:
18
  logger = logging.getLogger(__name__)
19
 
20
 
 
 
 
 
 
 
 
 
21
  @dataclass
22
  class RetrievalConfig:
23
  rerank_api_base_url: str = "https://api.siliconflow.com/v1"
@@ -31,6 +42,7 @@ class RetrievalConfig:
31
 
32
  _retrieval_config: RetrievalConfig | None = None
33
 
 
34
  def get_retrieval_config() -> RetrievalConfig:
35
  global _retrieval_config
36
  if _retrieval_config is None:
@@ -53,144 +65,103 @@ class SiliconFlowReranker(BaseDocumentCompressor):
53
  query: str,
54
  callbacks: Optional[Callbacks] = None,
55
  ) -> Sequence[Document]:
56
- if not documents:
57
- return []
58
-
59
- if not self.api_key:
60
- logger.warning("No API key, returning documents as-is")
61
  return list(documents)
62
 
63
- import time
64
- max_retries = 3
65
-
66
- for attempt in range(max_retries):
67
  try:
68
- url = f"{self.api_base_url}/rerank"
69
- headers = {
70
- "Authorization": f"Bearer {self.api_key}",
71
- "Content-Type": "application/json",
72
- }
73
- payload = {
74
- "model": self.model,
75
- "query": query,
76
- "documents": [doc.page_content for doc in documents],
77
- "top_n": self.top_n or len(documents),
78
- }
79
-
80
- response = requests.post(url, headers=headers, json=payload, timeout=30)
 
81
  response.raise_for_status()
82
  data = response.json()
83
 
84
  if "results" not in data:
85
- logger.warning("Unexpected rerank response format")
86
  return list(documents)
87
 
88
  reranked: List[Document] = []
89
  for result in data["results"]:
90
- idx = result["index"]
91
- score = result["relevance_score"]
92
-
93
- doc = documents[idx]
94
- new_metadata = dict(doc.metadata or {})
95
- new_metadata["rerank_score"] = score
96
-
97
- reranked.append(Document(
98
- page_content=doc.page_content,
99
- metadata=new_metadata
100
- ))
101
 
102
- logger.debug(f"Reranked {len(reranked)} documents")
103
  return reranked
104
 
105
  except Exception as e:
106
- if "rate" in str(e).lower() and attempt < max_retries - 1:
107
- wait_time = 2 ** attempt
108
- logger.warning(f"Rate limit hit, waiting {wait_time}s...")
109
- time.sleep(wait_time)
110
  else:
111
  logger.error(f"Rerank error: {e}")
112
  return list(documents)
113
-
114
  return list(documents)
115
 
116
 
117
  class Retriever:
118
- def __init__(
119
- self,
120
- vector_db: "ChromaVectorDB",
121
- use_reranker: bool = True,
122
- ):
123
  self._vector_db = vector_db
124
  self._config = get_retrieval_config()
125
  self._reranker: Optional[SiliconFlowReranker] = None
126
 
127
- self._vector_retriever = self._init_vector_retriever()
128
- self._bm25_retriever = self._init_bm25_retriever()
129
- self._ensemble_retriever = self._init_ensemble_retriever()
 
 
130
 
131
  if use_reranker:
132
  self._reranker = self._init_reranker()
133
 
134
- self._final_retriever = self._build_final_retriever()
135
-
136
- def _init_vector_retriever(self):
137
- return self._vector_db.vectorstore.as_retriever(
138
- search_kwargs={"k": self._config.initial_k}
139
- )
140
 
141
- def _init_bm25_retriever(self) -> Optional[BM25Retriever]:
142
  try:
143
  docs = self._vector_db.get_all_documents()
144
  if not docs:
145
- logger.warning("No documents for BM25 index")
146
  return None
147
 
148
  lc_docs = [
149
- Document(
150
- page_content=d["content"],
151
- metadata=d.get("metadata", {})
152
- )
153
  for d in docs
154
  ]
155
-
156
  bm25 = BM25Retriever.from_documents(lc_docs)
157
  bm25.k = self._config.initial_k
158
- logger.info(f"BM25 index built with {len(lc_docs)} documents")
159
  return bm25
160
-
161
- except Exception as e:
162
- logger.error(f"Failed to build BM25 index: {e}")
163
  return None
164
 
165
- def _init_ensemble_retriever(self) -> EnsembleRetriever:
166
- retrievers: List[Any] = [self._vector_retriever]
167
- weights: List[float] = [1.0]
168
-
169
  if self._bm25_retriever:
170
- retrievers.append(self._bm25_retriever)
171
- weights = [self._config.vector_weight, self._config.bm25_weight]
172
-
173
- return EnsembleRetriever(
174
- retrievers=retrievers,
175
- weights=weights
176
- )
177
 
178
  def _init_reranker(self) -> Optional[SiliconFlowReranker]:
179
  api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
180
  if not api_key:
181
- logger.warning("SILICONFLOW_API_KEY not found. Reranking disabled.")
182
  return None
183
-
184
- reranker = SiliconFlowReranker(
185
  api_key=api_key,
186
  api_base_url=self._config.rerank_api_base_url,
187
  model=self._config.rerank_model,
188
  top_n=self._config.rerank_top_n,
189
  )
190
- logger.info(f"Reranker initialized: {self._config.rerank_model}")
191
- return reranker
192
 
193
- def _build_final_retriever(self):
194
  if self._reranker:
195
  return ContextualCompressionRetriever(
196
  base_compressor=self._reranker,
@@ -202,94 +173,165 @@ class Retriever:
202
  def has_reranker(self) -> bool:
203
  return self._reranker is not None
204
 
205
- def query(
206
- self,
207
- text: str,
208
- *,
209
- k: int | None = None,
210
- where: Optional[Dict[str, Any]] = None,
 
 
 
 
 
 
211
  ) -> List[Dict[str, Any]]:
212
  if not text.strip():
213
  return []
214
-
215
  k = k or self._config.top_k
216
- vectorstore = self._vector_db.vectorstore
217
- results = vectorstore.similarity_search_with_score(text, k=k, filter=where)
 
 
 
 
 
 
218
 
219
- return [
220
- {
221
- "id": (doc.metadata or {}).get("id"),
222
- "content": doc.page_content,
223
- "metadata": doc.metadata,
224
- "distance": score,
225
- }
226
- for doc, score in results
227
- ]
228
 
229
  def hybrid_search(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  self,
231
  text: str,
232
  *,
233
  k: int | None = None,
 
234
  initial_k: int | None = None,
235
  ) -> List[Dict[str, Any]]:
 
236
  if not text.strip():
237
  return []
238
 
239
  k = k or self._config.top_k
 
 
 
 
 
 
 
 
 
 
 
240
 
 
241
  if initial_k:
242
  self._vector_retriever.search_kwargs["k"] = initial_k
243
  if self._bm25_retriever:
244
  self._bm25_retriever.k = initial_k
245
 
246
  results = self._final_retriever.invoke(text)
247
-
248
- out: List[Dict[str, Any]] = []
249
- for i, doc in enumerate(results[:k]):
250
- out.append({
251
- "id": (doc.metadata or {}).get("id"),
252
- "content": doc.page_content,
253
- "metadata": doc.metadata,
254
- "rerank_score": doc.metadata.get("rerank_score"),
255
- "final_rank": i + 1,
256
- })
257
-
258
- return out
259
 
260
- def search_with_rerank(
261
  self,
262
  text: str,
263
  *,
 
264
  k: int | None = None,
265
- where: Optional[Dict[str, Any]] = None,
266
  initial_k: int | None = None,
 
 
267
  ) -> List[Dict[str, Any]]:
268
  if not text.strip():
269
  return []
270
 
 
 
 
 
 
 
271
  k = k or self._config.top_k
272
  initial_k = initial_k or self._config.initial_k
273
 
274
- # If filter is provided, use vector-only search (BM25 doesn't support filters)
275
- if where:
276
- vectorstore = self._vector_db.vectorstore
277
- results = vectorstore.similarity_search(text, k=initial_k, filter=where)
278
-
279
- # Apply reranker if available
280
- if self._reranker:
281
- results = self._reranker.compress_documents(results, text)
282
-
283
- out: List[Dict[str, Any]] = []
284
- for i, doc in enumerate(results[:k]):
285
- out.append({
286
- "id": (doc.metadata or {}).get("id"),
287
- "content": doc.page_content,
288
- "metadata": doc.metadata,
289
- "rerank_score": doc.metadata.get("rerank_score"),
290
- "final_rank": i + 1,
291
- })
292
- return out
293
 
294
- # No filter - use hybrid search
295
- return self.hybrid_search(text, k=k, initial_k=initial_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
  import os
3
+ import time
4
  import logging
5
  from dataclasses import dataclass
6
+ from enum import Enum
7
  from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING
8
+ import re
9
  import requests
10
  from pydantic import Field
11
  from langchain_core.documents import Document
 
21
  logger = logging.getLogger(__name__)
22
 
23
 
24
+ class RetrievalMode(str, Enum):
25
+ """Retrieval modes."""
26
+ VECTOR_ONLY = "vector_only"
27
+ BM25_ONLY = "bm25_only"
28
+ HYBRID = "hybrid"
29
+ HYBRID_RERANK = "hybrid_rerank"
30
+
31
+
32
  @dataclass
33
  class RetrievalConfig:
34
  rerank_api_base_url: str = "https://api.siliconflow.com/v1"
 
42
 
43
  _retrieval_config: RetrievalConfig | None = None
44
 
45
+
46
  def get_retrieval_config() -> RetrievalConfig:
47
  global _retrieval_config
48
  if _retrieval_config is None:
 
65
  query: str,
66
  callbacks: Optional[Callbacks] = None,
67
  ) -> Sequence[Document]:
68
+ if not documents or not self.api_key:
 
 
 
 
69
  return list(documents)
70
 
71
+ for attempt in range(3):
 
 
 
72
  try:
73
+ response = requests.post(
74
+ f"{self.api_base_url}/rerank",
75
+ headers={
76
+ "Authorization": f"Bearer {self.api_key}",
77
+ "Content-Type": "application/json",
78
+ },
79
+ json={
80
+ "model": self.model,
81
+ "query": query,
82
+ "documents": [doc.page_content for doc in documents],
83
+ "top_n": self.top_n or len(documents),
84
+ },
85
+ timeout=30,
86
+ )
87
  response.raise_for_status()
88
  data = response.json()
89
 
90
  if "results" not in data:
 
91
  return list(documents)
92
 
93
  reranked: List[Document] = []
94
  for result in data["results"]:
95
+ doc = documents[result["index"]]
96
+ meta = dict(doc.metadata or {})
97
+ meta["rerank_score"] = result["relevance_score"]
98
+ reranked.append(Document(page_content=doc.page_content, metadata=meta))
 
 
 
 
 
 
 
99
 
 
100
  return reranked
101
 
102
  except Exception as e:
103
+ if "rate" in str(e).lower() and attempt < 2:
104
+ time.sleep(2 ** attempt)
 
 
105
  else:
106
  logger.error(f"Rerank error: {e}")
107
  return list(documents)
108
+
109
  return list(documents)
110
 
111
 
112
  class Retriever:
113
+ def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
 
 
 
 
114
  self._vector_db = vector_db
115
  self._config = get_retrieval_config()
116
  self._reranker: Optional[SiliconFlowReranker] = None
117
 
118
+ self._vector_retriever = self._vector_db.vectorstore.as_retriever(
119
+ search_kwargs={"k": self._config.initial_k}
120
+ )
121
+ self._bm25_retriever = self._init_bm25()
122
+ self._ensemble_retriever = self._init_ensemble()
123
 
124
  if use_reranker:
125
  self._reranker = self._init_reranker()
126
 
127
+ self._final_retriever = self._build_final()
 
 
 
 
 
128
 
129
+ def _init_bm25(self) -> Optional[BM25Retriever]:
130
  try:
131
  docs = self._vector_db.get_all_documents()
132
  if not docs:
 
133
  return None
134
 
135
  lc_docs = [
136
+ Document(page_content=d["content"], metadata=d.get("metadata", {}))
 
 
 
137
  for d in docs
138
  ]
 
139
  bm25 = BM25Retriever.from_documents(lc_docs)
140
  bm25.k = self._config.initial_k
 
141
  return bm25
142
+ except Exception:
 
 
143
  return None
144
 
145
+ def _init_ensemble(self) -> EnsembleRetriever:
 
 
 
146
  if self._bm25_retriever:
147
+ return EnsembleRetriever(
148
+ retrievers=[self._vector_retriever, self._bm25_retriever],
149
+ weights=[self._config.vector_weight, self._config.bm25_weight]
150
+ )
151
+ return EnsembleRetriever(retrievers=[self._vector_retriever], weights=[1.0])
 
 
152
 
153
  def _init_reranker(self) -> Optional[SiliconFlowReranker]:
154
  api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
155
  if not api_key:
 
156
  return None
157
+ return SiliconFlowReranker(
 
158
  api_key=api_key,
159
  api_base_url=self._config.rerank_api_base_url,
160
  model=self._config.rerank_model,
161
  top_n=self._config.rerank_top_n,
162
  )
 
 
163
 
164
+ def _build_final(self):
165
  if self._reranker:
166
  return ContextualCompressionRetriever(
167
  base_compressor=self._reranker,
 
173
  def has_reranker(self) -> bool:
174
  return self._reranker is not None
175
 
176
+ def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]:
177
+ return {
178
+ "id": (doc.metadata or {}).get("id"),
179
+ "content": doc.page_content,
180
+ "metadata": doc.metadata,
181
+ "final_rank": rank,
182
+ **extra,
183
+ }
184
+
185
+
186
+ def vector_search(
187
+ self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None
188
  ) -> List[Dict[str, Any]]:
189
  if not text.strip():
190
  return []
 
191
  k = k or self._config.top_k
192
+ results = self._vector_db.vectorstore.similarity_search_with_score(text, k=k, filter=where)
193
+ return [self._to_result(doc, i + 1, distance=score) for i, (doc, score) in enumerate(results)]
194
+
195
+ def bm25_search(self, text: str, *, k: int | None = None) -> List[Dict[str, Any]]:
196
+ if not text.strip():
197
+ return []
198
+ if not self._bm25_retriever:
199
+ return self.vector_search(text, k=k)
200
 
201
+ k = k or self._config.top_k
202
+ self._bm25_retriever.k = k
203
+ results = self._bm25_retriever.invoke(text)
204
+ return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])]
 
 
 
 
 
205
 
206
  def hybrid_search(
207
+ self, text: str, *, k: int | None = None, initial_k: int | None = None
208
+ ) -> List[Dict[str, Any]]:
209
+ """Hybrid search (Vector + BM25) WITHOUT reranking."""
210
+ if not text.strip():
211
+ return []
212
+
213
+ k = k or self._config.top_k
214
+ if initial_k:
215
+ self._vector_retriever.search_kwargs["k"] = initial_k
216
+ if self._bm25_retriever:
217
+ self._bm25_retriever.k = initial_k
218
+
219
+ # Dùng ensemble_retriever (KHÔNG có reranker) thay vì final_retriever
220
+ results = self._ensemble_retriever.invoke(text)
221
+ return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])]
222
+
223
+ def search_with_rerank(
224
  self,
225
  text: str,
226
  *,
227
  k: int | None = None,
228
+ where: Optional[Dict[str, Any]] = None,
229
  initial_k: int | None = None,
230
  ) -> List[Dict[str, Any]]:
231
+ """Hybrid search (Vector + BM25) WITH reranking."""
232
  if not text.strip():
233
  return []
234
 
235
  k = k or self._config.top_k
236
+ initial_k = initial_k or self._config.initial_k
237
+
238
+ # Có filter -> dùng vector search + manual rerank
239
+ if where:
240
+ results = self._vector_db.vectorstore.similarity_search(text, k=initial_k, filter=where)
241
+ if self._reranker:
242
+ results = self._reranker.compress_documents(results, text)
243
+ return [
244
+ self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score"))
245
+ for i, doc in enumerate(results[:k])
246
+ ]
247
 
248
+ # _final_retriever (ensemble + reranker)
249
  if initial_k:
250
  self._vector_retriever.search_kwargs["k"] = initial_k
251
  if self._bm25_retriever:
252
  self._bm25_retriever.k = initial_k
253
 
254
  results = self._final_retriever.invoke(text)
255
+ return [
256
+ self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score"))
257
+ for i, doc in enumerate(results[:k])
258
+ ]
 
 
 
 
 
 
 
 
259
 
260
+ def flexible_search(
261
  self,
262
  text: str,
263
  *,
264
+ mode: RetrievalMode | str = RetrievalMode.HYBRID_RERANK,
265
  k: int | None = None,
 
266
  initial_k: int | None = None,
267
+ where: Optional[Dict[str, Any]] = None,
268
+ auto_detect_cohort: bool = False,
269
  ) -> List[Dict[str, Any]]:
270
  if not text.strip():
271
  return []
272
 
273
+ if isinstance(mode, str):
274
+ try:
275
+ mode = RetrievalMode(mode.lower())
276
+ except ValueError:
277
+ mode = RetrievalMode.HYBRID_RERANK
278
+
279
  k = k or self._config.top_k
280
  initial_k = initial_k or self._config.initial_k
281
 
282
+ # Auto-detect cohort tạo filter nếu enabled
283
+ if auto_detect_cohort and where is None:
284
+ where = auto_filter_by_cohort(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ if mode == RetrievalMode.VECTOR_ONLY:
287
+ return self.vector_search(text, k=k, where=where)
288
+ elif mode == RetrievalMode.BM25_ONLY:
289
+ return self.bm25_search(text, k=k)
290
+ elif mode == RetrievalMode.HYBRID:
291
+ if where:
292
+ return self.vector_search(text, k=k, where=where)
293
+ return self.hybrid_search(text, k=k, initial_k=initial_k)
294
+ else: # HYBRID_RERANK
295
+ return self.search_with_rerank(text, k=k, where=where, initial_k=initial_k)
296
+
297
+ # Legacy alias
298
+ query = vector_search
299
+
300
+ NGOAI_NGU_KEYWORDS = ["tiếng anh", "toeic", "ielts", "ngoại ngữ", "english", "chuẩn đầu ra"]
301
+
302
+
303
+ def detect_cohort(text: str) -> Optional[str]:
304
+ patterns = [
305
+ r'\bK(\d{2})\b',
306
+ r'\bkhóa\s*(\d{2})\b',
307
+ r'\bkhoá\s*(\d{2})\b',
308
+ ]
309
+ for pattern in patterns:
310
+ match = re.search(pattern, text, re.IGNORECASE)
311
+ if match:
312
+ return f"K{match.group(1)}"
313
+ return None
314
+
315
+
316
+ def cohort_to_filter(cohort: str) -> Optional[Dict[str, Any]]:
317
+ if not cohort:
318
+ return None
319
+ try:
320
+ num = int(cohort.replace("K", "").replace("k", ""))
321
+ except ValueError:
322
+ return None
323
+
324
+ if num >= 70:
325
+ return {"applicable_cohorts": ">=K70"}
326
+ elif num >= 68:
327
+ return {"applicable_cohorts": ">=K68"}
328
+ elif num >= 65:
329
+ return {"applicable_cohorts": ">=K65"}
330
+ return None
331
+
332
+
333
+ def auto_filter_by_cohort(text: str) -> Optional[Dict[str, Any]]:
334
+ cohort = detect_cohort(text)
335
+ if cohort and any(kw in text.lower() for kw in NGOAI_NGU_KEYWORDS):
336
+ return cohort_to_filter(cohort)
337
+ return None
core/embeddings/vector_store.py CHANGED
@@ -58,22 +58,45 @@ class ChromaVectorDB:
58
  else:
59
  out[str(k)] = str(v)
60
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def _to_documents(self, docs: Sequence[Dict[str, Any]], ids: Sequence[str]) -> List[Document]:
63
  out: List[Document] = []
64
  for d, doc_id in zip(docs, ids):
65
- md = self._flatten_metadata(d.get("metadata", {}) or {})
 
66
  md.setdefault("id", doc_id)
67
- out.append(Document(page_content=d.get("content", ""), metadata=md))
68
  return out
69
 
70
- def _doc_id(self, doc: Dict[str, Any]) -> str:
71
- md = doc.get("metadata") or {}
 
72
  key = {
73
  "source_file": md.get("source_file"),
74
  "header_path": md.get("header_path"),
75
  "chunk_index": md.get("chunk_index"),
76
- "content": doc.get("content"),
77
  }
78
  return self._hasher.get_string_hash(str(key))
79
 
 
58
  else:
59
  out[str(k)] = str(v)
60
  return out
61
+
62
+ def _normalize_doc(self, doc: Any) -> Dict[str, Any]:
63
+ # Nếu đã là dict
64
+ if isinstance(doc, dict):
65
+ return doc
66
+
67
+ # Nếu là TextNode/BaseNode từ llama_index
68
+ if hasattr(doc, "get_content") and hasattr(doc, "metadata"):
69
+ return {
70
+ "content": doc.get_content(),
71
+ "metadata": dict(doc.metadata) if doc.metadata else {},
72
+ }
73
+
74
+ # Nếu là Document từ langchain
75
+ if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
76
+ return {
77
+ "content": doc.page_content,
78
+ "metadata": dict(doc.metadata) if doc.metadata else {},
79
+ }
80
+
81
+ raise TypeError(f"Unsupported document type: {type(doc)}")
82
 
83
+ def _to_documents(self, docs: Sequence[Any], ids: Sequence[str]) -> List[Document]:
84
  out: List[Document] = []
85
  for d, doc_id in zip(docs, ids):
86
+ normalized = self._normalize_doc(d)
87
+ md = self._flatten_metadata(normalized.get("metadata", {}) or {})
88
  md.setdefault("id", doc_id)
89
+ out.append(Document(page_content=normalized.get("content", ""), metadata=md))
90
  return out
91
 
92
+ def _doc_id(self, doc: Any) -> str:
93
+ normalized = self._normalize_doc(doc)
94
+ md = normalized.get("metadata") or {}
95
  key = {
96
  "source_file": md.get("source_file"),
97
  "header_path": md.get("header_path"),
98
  "chunk_index": md.get("chunk_index"),
99
+ "content": normalized.get("content"),
100
  }
101
  return self._hasher.get_string_hash(str(key))
102
 
core/gradio/gradio_rag_qwen.py CHANGED
@@ -29,10 +29,14 @@ def _load_env() -> None:
29
 
30
  from core.embeddings.embedding_model import EmbeddingConfig, QwenEmbeddings
31
  from core.embeddings.vector_store import ChromaConfig, ChromaVectorDB
32
- from core.embeddings.retrival import Retriever, get_retrieval_config
 
33
 
34
  _load_env()
35
 
 
 
 
36
  # Load all configs
37
  GRADIO_CFG = GradioConfig(
38
  llm_model="qwen/qwen3-32b",
@@ -48,6 +52,7 @@ class AppState:
48
  def __init__(self) -> None:
49
  self.db: Optional[ChromaVectorDB] = None
50
  self.retriever: Optional[Retriever] = None
 
51
  self.groq: Optional[Groq] = None
52
 
53
 
@@ -58,7 +63,8 @@ def _init_resources() -> None:
58
  if STATE.db is not None:
59
  return
60
 
61
- print(" Đang khởi tạo Database & Re-ranker...")
 
62
 
63
  emb = QwenEmbeddings(EmbeddingConfig())
64
 
@@ -75,6 +81,15 @@ def _init_resources() -> None:
75
  raise RuntimeError("Missing GROQ_API_KEY")
76
  STATE.groq = Groq(api_key=api_key)
77
 
 
 
 
 
 
 
 
 
 
78
  print(" Đã sẵn sàng!")
79
 
80
 
@@ -83,92 +98,48 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
83
 
84
  assert STATE.db is not None
85
  assert STATE.groq is not None
86
-
87
- # Vector Search + Re-ranking (use config values)
88
  assert STATE.retriever is not None
89
- results = STATE.retriever.search_with_rerank(
90
- message,
 
 
 
 
91
  k=RETRIEVAL_CFG.top_k,
92
- initial_k=RETRIEVAL_CFG.initial_k
 
93
  )
94
 
95
  if not results:
96
  yield "Xin lỗi, tôi không tìm thấy thông tin phù hợp trong dữ liệu."
97
  return
98
 
99
- # Build context from results
100
- structured_context = ""
101
- for i, r in enumerate(results, 1):
102
- md = r.get("metadata", {})
103
- program = md.get("program_name", "")
104
- doc_type = md.get("type", md.get("document_type", ""))
105
- section = md.get("section", "")
106
- source = md.get("source_file", "")
107
- content = r.get("content", "").strip()
108
- is_injected = r.get("_injected", False)
109
-
110
- if is_injected or "hiệu lực" in section.lower() or "chuyển tiếp" in section.lower():
111
- display_content = content
112
- else:
113
- display_content = content[:600]
114
-
115
- structured_context += f"""
116
- ---
117
- [TÀI LIỆU {i}]{' [ĐIỀU KHOẢN HIỆU LỰC]' if is_injected else ''}
118
- - Phần/Điều: {section if section else 'N/A'}
119
- - Nguồn: {source if source else 'N/A'}
120
- {display_content}
121
- """
122
-
123
- max_context_chars = 5000
124
- if len(structured_context) > max_context_chars:
125
- structured_context = structured_context[:max_context_chars] + "\n\n[...truncated...]"
126
-
127
- prompt = f"""Bạn là Trợ lý học vụ ĐHBK Hà Nội.
128
-
129
- ## NGUYÊN TẮC:
130
- 1. Chỉ trả lời dựa trên CONTEXT. Không bịa thông tin.
131
- 2. Nếu thấy "Hiệu lực thi hành" hoặc "Điều khoản chuyển tiếp", KIỂM TRA xem có ngoại lệ theo khóa/thời gian không và GHI RÕ.
132
- 3. Ưu tiên văn bản mới nhất, TRỪ KHI có điều khoản chuyển tiếp nói khác.
133
- 4. Trích nguồn cuối câu trả lời.
134
-
135
- ## CONTEXT:
136
- {structured_context}
137
-
138
- ## CÂU HỎI: {message}
139
-
140
- ## TRẢ LỜI:"""
141
-
142
- completion = STATE.groq.chat.completions.create(
143
- model=GRADIO_CFG.llm_model,
144
- messages=[{"role": "user", "content": prompt}],
145
- temperature=GRADIO_CFG.llm_temperature,
146
- max_completion_tokens=GRADIO_CFG.llm_max_tokens,
147
- stream=True,
148
- )
149
-
150
  acc = ""
151
- for chunk in completion:
152
- try:
153
- delta = chunk.choices[0].delta.content or ""
154
- except Exception:
155
- delta = ""
156
- if not delta:
157
- continue
158
- acc += delta
159
  yield acc
160
 
161
- # Debug info
162
- debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Vector + Re-rank)**\n\n"
163
  for i, r in enumerate(results, 1):
164
  md = r.get("metadata", {})
165
  content = r.get("content", "").strip()
166
- rerank_score = r.get("rerank_score", 0)
167
- vector_dist = r.get("vector_distance", 999.0)
168
  section = md.get("section", "N/A")
169
  doc_type = md.get("type", md.get("document_type", "N/A"))
170
 
171
- debug_info += f"**#{i}** | Rerank: `{rerank_score:.4f}` | VecDist: `{vector_dist:.3f}`\n"
 
 
 
 
 
 
 
 
 
172
  debug_info += f" - **Type:** {doc_type} | **Section:** {section[:60]}{'...' if len(section) > 60 else ''}\n"
173
  debug_info += f" - **Content:** {content[:200]}{'...' if len(content) > 200 else ''}\n\n"
174
 
@@ -178,8 +149,8 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
178
  # Create Gradio interface
179
  demo = gr.ChatInterface(
180
  fn=rag_chat,
181
- title="HUST RAG Assistant",
182
- description="Trợ lý học vụ Đại học Bách khoa Hà Nội",
183
  examples=[
184
  "Điều kiện tốt nghiệp đại học là gì?",
185
  "Yêu cầu TOEIC của ngành Toán tin là bao nhiêu?",
@@ -188,6 +159,9 @@ demo = gr.ChatInterface(
188
  )
189
 
190
  if __name__ == "__main__":
 
 
 
191
  demo.launch(
192
  server_name=GRADIO_CFG.server_host,
193
  server_port=GRADIO_CFG.server_port
 
29
 
30
  from core.embeddings.embedding_model import EmbeddingConfig, QwenEmbeddings
31
  from core.embeddings.vector_store import ChromaConfig, ChromaVectorDB
32
+ from core.embeddings.retrival import Retriever, RetrievalMode, get_retrieval_config
33
+ from core.embeddings.generator import RAGGenerator
34
 
35
  _load_env()
36
 
37
+ RETRIEVAL_MODE = RetrievalMode.HYBRID_RERANK
38
+
39
+
40
  # Load all configs
41
  GRADIO_CFG = GradioConfig(
42
  llm_model="qwen/qwen3-32b",
 
52
  def __init__(self) -> None:
53
  self.db: Optional[ChromaVectorDB] = None
54
  self.retriever: Optional[Retriever] = None
55
+ self.generator: Optional[RAGGenerator] = None
56
  self.groq: Optional[Groq] = None
57
 
58
 
 
63
  if STATE.db is not None:
64
  return
65
 
66
+ print(f" Đang khởi tạo Database & Re-ranker...")
67
+ print(f" Retrieval Mode: {RETRIEVAL_MODE.value}")
68
 
69
  emb = QwenEmbeddings(EmbeddingConfig())
70
 
 
81
  raise RuntimeError("Missing GROQ_API_KEY")
82
  STATE.groq = Groq(api_key=api_key)
83
 
84
+ # Initialize RAGGenerator with shared retriever and groq client
85
+ STATE.generator = RAGGenerator(
86
+ retriever=STATE.retriever,
87
+ llm_model=GRADIO_CFG.llm_model,
88
+ temperature=GRADIO_CFG.llm_temperature,
89
+ max_tokens=GRADIO_CFG.llm_max_tokens,
90
+ groq_client=STATE.groq,
91
+ )
92
+
93
  print(" Đã sẵn sàng!")
94
 
95
 
 
98
 
99
  assert STATE.db is not None
100
  assert STATE.groq is not None
 
 
101
  assert STATE.retriever is not None
102
+ assert STATE.generator is not None
103
+
104
+ # Flexible search với auto_detect_cohort để tự động filter theo khóa
105
+ results = STATE.retriever.flexible_search(
106
+ message,
107
+ mode=RETRIEVAL_MODE,
108
  k=RETRIEVAL_CFG.top_k,
109
+ initial_k=RETRIEVAL_CFG.initial_k,
110
+ auto_detect_cohort=True,
111
  )
112
 
113
  if not results:
114
  yield "Xin lỗi, tôi không tìm thấy thông tin phù hợp trong dữ liệu."
115
  return
116
 
117
+ # Use RAGGenerator for streaming response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  acc = ""
119
+ for partial in STATE.generator.generate_stream_from_results(message, results):
120
+ acc = partial
 
 
 
 
 
 
121
  yield acc
122
 
123
+ # Debug info with mode indicator
124
+ debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
125
  for i, r in enumerate(results, 1):
126
  md = r.get("metadata", {})
127
  content = r.get("content", "").strip()
128
+ rerank_score = r.get("rerank_score")
129
+ distance = r.get("distance")
130
  section = md.get("section", "N/A")
131
  doc_type = md.get("type", md.get("document_type", "N/A"))
132
 
133
+ # Show relevant scores based on mode
134
+ score_info = ""
135
+ if rerank_score is not None:
136
+ score_info += f"Rerank: `{rerank_score:.4f}` "
137
+ if distance is not None:
138
+ score_info += f"Distance: `{distance:.4f}`"
139
+ if not score_info:
140
+ score_info = f"Rank: `{r.get('final_rank', i)}`"
141
+
142
+ debug_info += f"**#{i}** | {score_info}\n"
143
  debug_info += f" - **Type:** {doc_type} | **Section:** {section[:60]}{'...' if len(section) > 60 else ''}\n"
144
  debug_info += f" - **Content:** {content[:200]}{'...' if len(content) > 200 else ''}\n\n"
145
 
 
149
  # Create Gradio interface
150
  demo = gr.ChatInterface(
151
  fn=rag_chat,
152
+ title=f"HUST RAG Assistant",
153
+ description=f"Trợ lý học vụ Đại học Bách khoa Hà Nội",
154
  examples=[
155
  "Điều kiện tốt nghiệp đại học là gì?",
156
  "Yêu cầu TOEIC của ngành Toán tin là bao nhiêu?",
 
159
  )
160
 
161
  if __name__ == "__main__":
162
+ print(f"\n{'='*60}")
163
+ print(f"Starting HUST RAG Assistant")
164
+ print(f"{'='*60}\n")
165
  demo.launch(
166
  server_name=GRADIO_CFG.server_host,
167
  server_port=GRADIO_CFG.server_port
scripts/rag.py CHANGED
@@ -33,7 +33,7 @@ def main():
33
  args = parser.parse_args()
34
 
35
  print("=" * 60)
36
- print("REBUILD HUST RAG DATABASE")
37
  print("=" * 60)
38
 
39
  print("\n[1/4] Initializing embedder...")
@@ -93,18 +93,24 @@ def main():
93
  print("TESTING QUERY")
94
  print("=" * 60)
95
 
96
- from core.embeddings.retrival import Retriever
 
 
 
97
  retriever = Retriever(vector_db=db, use_reranker=False)
98
 
99
  test_query = "Yêu cầu TOEIC của ngành Toán tin là bao nhiêu?"
100
  print(f"Query: {test_query}")
101
- results = retriever.query(test_query, k=3)
 
 
102
 
103
  if results:
104
  print(f"\nTop {len(results)} results:")
105
  for i, r in enumerate(results, 1):
106
- print(f"\n[{i}] Distance: {r['distance']:.4f}")
107
- print(f" Program: {r['metadata'].get('program_name', 'N/A')}")
 
108
  print(f" Section: {r['metadata'].get('section', 'N/A')}")
109
  print(f" Content: {r['content'][:150]}...")
110
  else:
 
33
  args = parser.parse_args()
34
 
35
  print("=" * 60)
36
+ print("BUILD HUST RAG DATABASE")
37
  print("=" * 60)
38
 
39
  print("\n[1/4] Initializing embedder...")
 
93
  print("TESTING QUERY")
94
  print("=" * 60)
95
 
96
+ from core.embeddings.retrival import Retriever, RetrievalMode
97
+
98
+ # Test với mode VECTOR_ONLY
99
+ test_mode = RetrievalMode.VECTOR_ONLY
100
  retriever = Retriever(vector_db=db, use_reranker=False)
101
 
102
  test_query = "Yêu cầu TOEIC của ngành Toán tin là bao nhiêu?"
103
  print(f"Query: {test_query}")
104
+ print(f"Mode: {test_mode.value}")
105
+
106
+ results = retriever.flexible_search(test_query, mode=test_mode, k=3)
107
 
108
  if results:
109
  print(f"\nTop {len(results)} results:")
110
  for i, r in enumerate(results, 1):
111
+ score = r.get('distance') or r.get('rerank_score') or r.get('final_rank')
112
+ print(f"\n[{i}] Score: {score}")
113
+ print(f" Source: {r['metadata'].get('source_file', 'N/A')}")
114
  print(f" Section: {r['metadata'].get('section', 'N/A')}")
115
  print(f" Content: {r['content'][:150]}...")
116
  else:
test/test_chunk.py CHANGED
@@ -3,7 +3,7 @@ sys.path.insert(0, "/home/bahung/DoAn")
3
 
4
  from core.embeddings.chunk import chunk_markdown_file
5
 
6
- test_file = "data/data_process/chuong_trinh_dao_tao/1.1. Kỹ thuật điện tử.md"
7
 
8
  print("=" * 70)
9
  print(f" File: {test_file}")
 
3
 
4
  from core.embeddings.chunk import chunk_markdown_file
5
 
6
+ test_file = "data/data_process/quyet_dinh/tieng_anh/06_ Quy định ngoại ngữ từ K70_chính quy_final.md"
7
 
8
  print("=" * 70)
9
  print(f" File: {test_file}")
test_chunk.md CHANGED
The diff for this file is too large to render. See raw diff