minh-4T commited on
Commit
628deed
·
1 Parent(s): 23f0b25

name khoa : supabase -> qdrant

Browse files
core/collection_router_retriever.py CHANGED
@@ -4,7 +4,7 @@ from typing import List
4
 
5
  from langchain_core.documents import Document as LangChainDocument
6
 
7
- from .collection_utils import collection_matches_year
8
  from .document_db import SessionLocal, list_active_collection_names
9
 
10
  logger = logging.getLogger(__name__)
@@ -45,18 +45,18 @@ class CollectionRouterRetriever:
45
  finally:
46
  db.close()
47
 
48
- def _select_target_collections(self, year_scope: str | None) -> List[str]:
49
  fetch_limit = max(self.top_n_collections * 4, 12)
50
  active_collections = self._get_active_collections(limit=fetch_limit)
51
  if not active_collections:
52
  return []
53
 
54
- normalized_year_scope = (year_scope or "").strip()
55
- if normalized_year_scope:
56
  return [
57
  collection_name
58
  for collection_name in active_collections
59
- if collection_matches_year(collection_name, normalized_year_scope)
60
  ]
61
 
62
  return active_collections[: self.top_n_collections]
@@ -111,15 +111,15 @@ class CollectionRouterRetriever:
111
  scored_docs.sort(key=lambda row: row[0], reverse=True)
112
  return [doc for _, doc in scored_docs]
113
 
114
- def search(self, query: str, k: int = 10, alpha: float = 0.6, year_scope: str | None = None) -> List:
115
  if k <= 0:
116
  return []
117
 
118
  candidate_k = max(k * 4, k)
119
- year_scoped = bool((year_scope or "").strip())
120
- target_collections = self._select_target_collections(year_scope)
121
 
122
- if year_scoped and not target_collections:
123
  return []
124
 
125
  routed_docs = self._search_target_collections(
@@ -128,7 +128,7 @@ class CollectionRouterRetriever:
128
  limit=candidate_k,
129
  )
130
 
131
- if year_scoped:
132
  deduplicated = []
133
  seen = set()
134
  for doc in routed_docs:
@@ -148,7 +148,7 @@ class CollectionRouterRetriever:
148
  query,
149
  k=candidate_k,
150
  alpha=alpha,
151
- year_scope=year_scope,
152
  )
153
  except TypeError:
154
  fallback_docs = self.base_retriever.search(
 
4
 
5
  from langchain_core.documents import Document as LangChainDocument
6
 
7
+ from .collection_utils import collection_matches_cohort
8
  from .document_db import SessionLocal, list_active_collection_names
9
 
10
  logger = logging.getLogger(__name__)
 
45
  finally:
46
  db.close()
47
 
48
+ def _select_target_collections(self, cohort_key: str | None) -> List[str]:
49
  fetch_limit = max(self.top_n_collections * 4, 12)
50
  active_collections = self._get_active_collections(limit=fetch_limit)
51
  if not active_collections:
52
  return []
53
 
54
+ normalized_cohort = (cohort_key or "").strip()
55
+ if normalized_cohort:
56
  return [
57
  collection_name
58
  for collection_name in active_collections
59
+ if collection_matches_cohort(collection_name, normalized_cohort)
60
  ]
61
 
62
  return active_collections[: self.top_n_collections]
 
111
  scored_docs.sort(key=lambda row: row[0], reverse=True)
112
  return [doc for _, doc in scored_docs]
113
 
114
+ def search(self, query: str, k: int = 10, alpha: float = 0.6, cohort_key: str | None = None) -> List:
115
  if k <= 0:
116
  return []
117
 
118
  candidate_k = max(k * 4, k)
119
+ cohort_scoped = bool((cohort_key or "").strip())
120
+ target_collections = self._select_target_collections(cohort_key)
121
 
122
+ if cohort_scoped and not target_collections:
123
  return []
124
 
125
  routed_docs = self._search_target_collections(
 
128
  limit=candidate_k,
129
  )
130
 
131
+ if cohort_scoped:
132
  deduplicated = []
133
  seen = set()
134
  for doc in routed_docs:
 
148
  query,
149
  k=candidate_k,
150
  alpha=alpha,
151
+ cohort_key=cohort_key,
152
  )
153
  except TypeError:
154
  fallback_docs = self.base_retriever.search(
core/collection_utils.py CHANGED
@@ -24,6 +24,37 @@ def extract_year_tokens(value: str) -> Set[str]:
24
  return {token for token in _YEAR_PATTERN.findall(value or "")}
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def collection_matches_year(collection_name: str, year_scope: str) -> bool:
28
  if not year_scope:
29
  return False
 
24
  return {token for token in _YEAR_PATTERN.findall(value or "")}
25
 
26
 
27
+ def extract_folder_key_from_collection_name(collection_name: str, prefix: str = "rag") -> str | None:
28
+ """
29
+ Extract folder_key from collection name.
30
+ E.g., 'rag_k63' -> 'k63', 'rag_2023_2024' -> '2023_2024'
31
+ Returns None if collection_name doesn't match the expected pattern.
32
+ """
33
+ if not collection_name:
34
+ return None
35
+
36
+ prefix_with_underscore = f"{prefix}_"
37
+ if collection_name.startswith(prefix_with_underscore):
38
+ return collection_name[len(prefix_with_underscore):]
39
+
40
+ return None
41
+
42
+
43
+ def collection_matches_cohort(collection_name: str, cohort_key: str, prefix: str = "rag") -> bool:
44
+ """
45
+ Check if collection matches the given cohort_key.
46
+ E.g., collection='rag_k63', cohort_key='k63' -> True
47
+ """
48
+ if not cohort_key:
49
+ return False
50
+
51
+ extracted = extract_folder_key_from_collection_name(collection_name, prefix)
52
+ if not extracted:
53
+ return False
54
+
55
+ return extracted.lower() == cohort_key.lower()
56
+
57
+
58
  def collection_matches_year(collection_name: str, year_scope: str) -> bool:
59
  if not year_scope:
60
  return False
core/config.py CHANGED
@@ -67,13 +67,6 @@ SUPABASE_SYNC_ALLOWED_IPS = [ip.strip() for ip in os.getenv('SUPABASE_SYNC_ALLOW
67
  SUPABASE_SYNC_ALLOW_PRIVATE_NETWORK = os.getenv('SUPABASE_SYNC_ALLOW_PRIVATE_NETWORK', 'true').strip().lower() in {'1', 'true', 'yes', 'on'}
68
  COLLECTION_ROUTER_TOP_N = _bounded_int_from_env('COLLECTION_ROUTER_TOP_N', 3, 1, 20)
69
 
70
- # Cohort to academic year mapping
71
- COHORT_TO_YEAR = {
72
- 'k65': '2023-2024',
73
- 'k64': '2022-2023',
74
- 'k63': '2021-2022',
75
- }
76
-
77
  # - Context and output limits
78
  MAX_CONTEXT_CHARS = int(os.getenv('MAX_CONTEXT_CHARS', '12000'))
79
  MAX_OUT_CHARS = int(os.getenv('MAX_OUT_CHARS', '3000'))
 
67
  SUPABASE_SYNC_ALLOW_PRIVATE_NETWORK = os.getenv('SUPABASE_SYNC_ALLOW_PRIVATE_NETWORK', 'true').strip().lower() in {'1', 'true', 'yes', 'on'}
68
  COLLECTION_ROUTER_TOP_N = _bounded_int_from_env('COLLECTION_ROUTER_TOP_N', 3, 1, 20)
69
 
 
 
 
 
 
 
 
70
  # - Context and output limits
71
  MAX_CONTEXT_CHARS = int(os.getenv('MAX_CONTEXT_CHARS', '12000'))
72
  MAX_OUT_CHARS = int(os.getenv('MAX_OUT_CHARS', '3000'))
core/qa_pipeline.py CHANGED
@@ -221,16 +221,16 @@ def generate_standalone_query(message: str, history: List) -> str:
221
 
222
  return message
223
 
224
- def ask_ai_improved(message: str, history: List, hybrid_retriever, year_scope: str | None = None) -> Generator[str, None, None]:
225
  full_response = ""
226
- for delta in ask_ai_stream_delta(message, history, hybrid_retriever, year_scope=year_scope):
227
  full_response += delta
228
  if len(full_response) > MAX_OUT_CHARS:
229
  yield full_response[:MAX_OUT_CHARS] + "\n\n[Đã cắt bớt nội dung dài]"
230
  return
231
  yield full_response
232
 
233
- def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scope: str | None = None) -> Generator[str, None, None]:
234
  if not message.strip():
235
  yield " Bạn chưa nhập câu hỏi."
236
  return
@@ -241,12 +241,6 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
241
 
242
  logger.info(f" CÂU HỎI GỐC: {message}")
243
  question = generate_standalone_query(message, history)
244
- # [YEAR-AWARE CHANGE] Xac dinh pham vi nam ma nguoi dung yeu cau.
245
- requested_year_range, mentioned_years = detect_requested_year(f"{message}\n{question}")
246
- if requested_year_range:
247
- logger.info(f"Lọc theo năm học yêu cầu: {requested_year_range}")
248
- elif mentioned_years:
249
- logger.info(f"Lọc theo năm được nhắc tới: {sorted(mentioned_years)}")
250
 
251
  processed_data = analyze_and_expand_query(question)
252
 
@@ -261,12 +255,9 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
261
 
262
  all_docs: List = []
263
  seen = set()
264
- # Prefer passed year_scope over detected year
265
- if year_scope:
266
- year_scope_hint = year_scope
267
- logger.info(f"Sử dụng year_scope từ cohort: {year_scope_hint}")
268
- else:
269
- year_scope_hint = requested_year_range or (", ".join(sorted(mentioned_years)) if mentioned_years else None)
270
  for query in queries:
271
  #Giữ nguyên logic alpha ngành CNTT của Minh
272
  current_alpha = 0.4 if "CNTT" in query.upper() else 0.5
@@ -274,7 +265,7 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
274
  query,
275
  k=TOP_K_RESULTS,
276
  alpha=current_alpha,
277
- year_scope=year_scope_hint,
278
  )
279
  for doc in docs:
280
  content_hash = hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest()
@@ -287,23 +278,6 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
287
  yield "Không tìm thấy thông tin liên quan trong tài liệu."
288
  return
289
 
290
- # [YEAR-AWARE CHANGE] Lọc theo năm nhưng vẫn fallback nếu không có tài liệu đúng năm.
291
- year_scope = None
292
- year_filter_requested = bool(requested_year_range or mentioned_years)
293
- year_filtered_docs = filter_docs_by_year(all_docs, requested_year_range, mentioned_years)
294
-
295
- if year_filter_requested:
296
- if year_filtered_docs:
297
- if len(year_filtered_docs) != len(all_docs):
298
- logger.info(f"Đã lọc theo năm: còn {len(year_filtered_docs)}/{len(all_docs)} documents")
299
- all_docs = year_filtered_docs
300
- if requested_year_range:
301
- year_scope = requested_year_range
302
- elif mentioned_years:
303
- year_scope = ", ".join(sorted(mentioned_years))
304
- else:
305
- logger.warning("Không tìm thấy tài liệu đúng năm yêu cầu, fallback sang tập tài liệu tổng quát")
306
-
307
  final_docs = advanced_rerank(question, all_docs, top_k=FINAL_TOP_K)
308
 
309
  context_parts = []
@@ -311,10 +285,7 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
311
  for doc in final_docs:
312
  page = doc.metadata.get('page_number', 'N/A')
313
  file_name = doc.metadata.get('source_file') or doc.metadata.get('source')
314
- # [YEAR-AWARE CHANGE] Gan nhan nam trong context de LLM bam dung nguon.
315
- doc_year = infer_doc_academic_year(doc)
316
- year_label = f"Năm {doc_year}" if doc_year != "ALL" else "Áp dụng nhiều năm"
317
- source = f"[{year_label} | {os.path.basename(file_name)} | Trang {page}]" if file_name else f"[{year_label} | Trang {page}]"
318
  block = f"{source}\n{doc.page_content}"
319
  if total_chars + len(block) > MAX_CONTEXT_CHARS:
320
  break
@@ -324,7 +295,7 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
324
  context = "\n\n---\n\n".join(context_parts)
325
  topic_hint = processed_data.get('topic') or processed_data.get('root_question') or question
326
 
327
- prompt = create_advanced_prompt(question, context, question_type, topic_hint, year_scope=year_scope)
328
 
329
  logger.info("Đang tạo câu trả lời cuối cùng ...")
330
 
 
221
 
222
  return message
223
 
224
+ def ask_ai_improved(message: str, history: List, hybrid_retriever, cohort_key: str | None = None) -> Generator[str, None, None]:
225
  full_response = ""
226
+ for delta in ask_ai_stream_delta(message, history, hybrid_retriever, cohort_key=cohort_key):
227
  full_response += delta
228
  if len(full_response) > MAX_OUT_CHARS:
229
  yield full_response[:MAX_OUT_CHARS] + "\n\n[Đã cắt bớt nội dung dài]"
230
  return
231
  yield full_response
232
 
233
+ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, cohort_key: str | None = None) -> Generator[str, None, None]:
234
  if not message.strip():
235
  yield " Bạn chưa nhập câu hỏi."
236
  return
 
241
 
242
  logger.info(f" CÂU HỎI GỐC: {message}")
243
  question = generate_standalone_query(message, history)
 
 
 
 
 
 
244
 
245
  processed_data = analyze_and_expand_query(question)
246
 
 
255
 
256
  all_docs: List = []
257
  seen = set()
258
+ if cohort_key:
259
+ logger.info(f"Sử dụng cohort_key: {cohort_key}")
260
+
 
 
 
261
  for query in queries:
262
  #Giữ nguyên logic alpha ngành CNTT của Minh
263
  current_alpha = 0.4 if "CNTT" in query.upper() else 0.5
 
265
  query,
266
  k=TOP_K_RESULTS,
267
  alpha=current_alpha,
268
+ cohort_key=cohort_key,
269
  )
270
  for doc in docs:
271
  content_hash = hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest()
 
278
  yield "Không tìm thấy thông tin liên quan trong tài liệu."
279
  return
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  final_docs = advanced_rerank(question, all_docs, top_k=FINAL_TOP_K)
282
 
283
  context_parts = []
 
285
  for doc in final_docs:
286
  page = doc.metadata.get('page_number', 'N/A')
287
  file_name = doc.metadata.get('source_file') or doc.metadata.get('source')
288
+ source = f"[{os.path.basename(file_name)} | Trang {page}]" if file_name else f"[Trang {page}]"
 
 
 
289
  block = f"{source}\n{doc.page_content}"
290
  if total_chars + len(block) > MAX_CONTEXT_CHARS:
291
  break
 
295
  context = "\n\n---\n\n".join(context_parts)
296
  topic_hint = processed_data.get('topic') or processed_data.get('root_question') or question
297
 
298
+ prompt = create_advanced_prompt(question, context, question_type, topic_hint)
299
 
300
  logger.info("Đang tạo câu trả lời cuối cùng ...")
301
 
main.py CHANGED
@@ -14,7 +14,6 @@ from qdrant_client import QdrantClient
14
  #Import các model và các hàm cần thiết từ core
15
  from core.config import (
16
  COLLECTION_ROUTER_TOP_N,
17
- COHORT_TO_YEAR,
18
  DATABASE_URL,
19
  QDRANT_API_KEY,
20
  QDRANT_URL,
@@ -344,20 +343,15 @@ async def chat_endpoint(payload: ChatRequest, request: Request):
344
  user_id = payload.user_id # Lấy user_id từ request
345
  cohort_key = payload.cohort_key # Lấy cohort_key từ request
346
 
347
- # Convert cohort_key to year_scope for collection routing
348
- year_scope = None
349
- if cohort_key and cohort_key in COHORT_TO_YEAR:
350
- year_scope = COHORT_TO_YEAR[cohort_key]
351
- logger.info(f"Sử dụng cohort: {cohort_key} -> năm học: {year_scope}")
352
- elif cohort_key:
353
- logger.warning(f"Cohort không hợp lệ: {cohort_key}")
354
 
355
  history = await get_history_async(db_pool, session_id)
356
 
357
  # Tập hợp toàn bộ response từ generator
358
  full_response = ""
359
  try:
360
- async for chunk in iterate_in_threadpool(ask_ai_improved(user_msg, history, retriever, year_scope=year_scope)):
361
  full_response = chunk
362
  except Exception:
363
  logger.exception("Lỗi khi xử lý phản hồi từ AI:", exc_info=True)
@@ -381,13 +375,8 @@ async def chat_stream_endpoint(payload: ChatRequest, request: Request):
381
  user_id = payload.user_id # Lấy user_id từ request
382
  cohort_key = payload.cohort_key # Lấy cohort_key từ request
383
 
384
- # Convert cohort_key to year_scope for collection routing
385
- year_scope = None
386
- if cohort_key and cohort_key in COHORT_TO_YEAR:
387
- year_scope = COHORT_TO_YEAR[cohort_key]
388
- logger.info(f"Sử dụng cohort: {cohort_key} -> năm học: {year_scope}")
389
- elif cohort_key:
390
- logger.warning(f"Cohort không hợp lệ: {cohort_key}")
391
 
392
  history = await get_history_async(db_pool, session_id)
393
 
@@ -396,7 +385,7 @@ async def chat_stream_endpoint(payload: ChatRequest, request: Request):
396
  full_response = ""
397
  try:
398
  # ask_ai_stream_delta yield từng delta chunk (không cumulative)
399
- async for delta_chunk in iterate_in_threadpool(ask_ai_stream_delta(user_msg, history, retriever, year_scope=year_scope)):
400
  full_response += delta_chunk
401
  # Gửi SSE event với delta chunk
402
  sse_data = json.dumps({"delta": delta_chunk, "done": False}, ensure_ascii=False)
 
14
  #Import các model và các hàm cần thiết từ core
15
  from core.config import (
16
  COLLECTION_ROUTER_TOP_N,
 
17
  DATABASE_URL,
18
  QDRANT_API_KEY,
19
  QDRANT_URL,
 
343
  user_id = payload.user_id # Lấy user_id từ request
344
  cohort_key = payload.cohort_key # Lấy cohort_key từ request
345
 
346
+ if cohort_key:
347
+ logger.info(f"Sử dụng cohort: {cohort_key}")
 
 
 
 
 
348
 
349
  history = await get_history_async(db_pool, session_id)
350
 
351
  # Tập hợp toàn bộ response từ generator
352
  full_response = ""
353
  try:
354
+ async for chunk in iterate_in_threadpool(ask_ai_improved(user_msg, history, retriever, cohort_key=cohort_key)):
355
  full_response = chunk
356
  except Exception:
357
  logger.exception("Lỗi khi xử lý phản hồi từ AI:", exc_info=True)
 
375
  user_id = payload.user_id # Lấy user_id từ request
376
  cohort_key = payload.cohort_key # Lấy cohort_key từ request
377
 
378
+ if cohort_key:
379
+ logger.info(f"Sử dụng cohort: {cohort_key}")
 
 
 
 
 
380
 
381
  history = await get_history_async(db_pool, session_id)
382
 
 
385
  full_response = ""
386
  try:
387
  # ask_ai_stream_delta yield từng delta chunk (không cumulative)
388
+ async for delta_chunk in iterate_in_threadpool(ask_ai_stream_delta(user_msg, history, retriever, cohort_key=cohort_key)):
389
  full_response += delta_chunk
390
  # Gửi SSE event với delta chunk
391
  sse_data = json.dumps({"delta": delta_chunk, "done": False}, ensure_ascii=False)