Spaces:
Sleeping
Sleeping
RollBack
Browse files- core/analyze_and_expand.py +2 -6
- core/chunking.py +5 -5
- core/collection_router_retriever.py +4 -56
- core/collection_utils.py +4 -7
- core/config.py +4 -4
- core/document_ingest_service.py +41 -224
- core/prompting.py +8 -6
- core/qa_pipeline.py +54 -145
- core/rerank.py +6 -13
- core/retriever.py +2 -60
- core/text_utils.py +7 -17
- main.py +0 -4
core/analyze_and_expand.py
CHANGED
|
@@ -126,15 +126,11 @@ def analyze_and_expand_query(question: str) -> Dict[str, Any]:
|
|
| 126 |
"expanded_queries": queries
|
| 127 |
}
|
| 128 |
|
| 129 |
-
|
| 130 |
-
"Phân loại: %s | Số truy vấn: %s",
|
| 131 |
-
final_result["question_type"],
|
| 132 |
-
len(final_result["expanded_queries"]),
|
| 133 |
-
)
|
| 134 |
return final_result
|
| 135 |
|
| 136 |
except Exception as e:
|
| 137 |
-
|
| 138 |
return {
|
| 139 |
"question_type": "simple",
|
| 140 |
"answer": None,
|
|
|
|
| 126 |
"expanded_queries": queries
|
| 127 |
}
|
| 128 |
|
| 129 |
+
print(f"Phân loại: {final_result['question_type']} | Queries: {len(final_result['expanded_queries'])}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
return final_result
|
| 131 |
|
| 132 |
except Exception as e:
|
| 133 |
+
print(f" Lỗi phân tích ({e}). Mặc định chuyển sang tìm kiếm.")
|
| 134 |
return {
|
| 135 |
"question_type": "simple",
|
| 136 |
"answer": None,
|
core/chunking.py
CHANGED
|
@@ -28,7 +28,7 @@ LIST_PATTERNS = [
|
|
| 28 |
(r"(?m)^\s*•\s+", "<LIST_BULLET>"),
|
| 29 |
]
|
| 30 |
|
| 31 |
-
|
| 32 |
def extract_and_protect_tables(text: str) -> Tuple[str, dict]:
|
| 33 |
table_pattern = re.compile(r"(?:\|.*\|[\r\n]+)+")
|
| 34 |
tables = {}
|
|
@@ -41,7 +41,7 @@ def extract_and_protect_tables(text: str) -> Tuple[str, dict]:
|
|
| 41 |
protected_text = re.sub(table_pattern, replace_table, text)
|
| 42 |
return protected_text, tables
|
| 43 |
|
| 44 |
-
|
| 45 |
def protect_lists(text: str) -> Tuple[str, dict]:
|
| 46 |
placeholders = {}
|
| 47 |
protected = text
|
|
@@ -55,14 +55,14 @@ def protect_lists(text: str) -> Tuple[str, dict]:
|
|
| 55 |
|
| 56 |
return protected, placeholders
|
| 57 |
|
| 58 |
-
|
| 59 |
def restore_placeholders(text: str, placeholders: dict) -> str:
|
| 60 |
restored = text
|
| 61 |
for placeholder, original in placeholders.items():
|
| 62 |
restored = restored.replace(placeholder, original)
|
| 63 |
return restored
|
| 64 |
|
| 65 |
-
|
| 66 |
def split_by_structure(text: str) -> List[str]:
|
| 67 |
parts = [text]
|
| 68 |
|
|
@@ -91,7 +91,7 @@ def split_by_structure(text: str) -> List[str]:
|
|
| 91 |
|
| 92 |
return [part for part in parts if part.strip()]
|
| 93 |
|
| 94 |
-
|
| 95 |
def smart_chunking(docs: List) -> List:
|
| 96 |
logger.info("Chunking theo cau truc + do dai...")
|
| 97 |
length_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
| 28 |
(r"(?m)^\s*•\s+", "<LIST_BULLET>"),
|
| 29 |
]
|
| 30 |
|
| 31 |
+
|
| 32 |
def extract_and_protect_tables(text: str) -> Tuple[str, dict]:
|
| 33 |
table_pattern = re.compile(r"(?:\|.*\|[\r\n]+)+")
|
| 34 |
tables = {}
|
|
|
|
| 41 |
protected_text = re.sub(table_pattern, replace_table, text)
|
| 42 |
return protected_text, tables
|
| 43 |
|
| 44 |
+
|
| 45 |
def protect_lists(text: str) -> Tuple[str, dict]:
|
| 46 |
placeholders = {}
|
| 47 |
protected = text
|
|
|
|
| 55 |
|
| 56 |
return protected, placeholders
|
| 57 |
|
| 58 |
+
|
| 59 |
def restore_placeholders(text: str, placeholders: dict) -> str:
|
| 60 |
restored = text
|
| 61 |
for placeholder, original in placeholders.items():
|
| 62 |
restored = restored.replace(placeholder, original)
|
| 63 |
return restored
|
| 64 |
|
| 65 |
+
|
| 66 |
def split_by_structure(text: str) -> List[str]:
|
| 67 |
parts = [text]
|
| 68 |
|
|
|
|
| 91 |
|
| 92 |
return [part for part in parts if part.strip()]
|
| 93 |
|
| 94 |
+
|
| 95 |
def smart_chunking(docs: List) -> List:
|
| 96 |
logger.info("Chunking theo cau truc + do dai...")
|
| 97 |
length_splitter = RecursiveCharacterTextSplitter(
|
core/collection_router_retriever.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
| 1 |
import hashlib
|
| 2 |
import logging
|
| 3 |
-
import
|
| 4 |
-
from typing import List, Optional
|
| 5 |
|
| 6 |
from langchain_core.documents import Document as LangChainDocument
|
| 7 |
-
from qdrant_client.models import Filter, FieldCondition, HasIdCondition, MatchAny
|
| 8 |
|
| 9 |
from .collection_utils import collection_matches_year
|
| 10 |
from .document_db import SessionLocal, list_active_collection_names
|
|
@@ -12,47 +10,6 @@ from .document_db import SessionLocal, list_active_collection_names
|
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
|
| 15 |
-
def _build_year_filter(year_scope: Optional[str]) -> Optional[Filter]:
|
| 16 |
-
"""Tạo Qdrant Filter từ year_scope (ví dụ: '2023-2024' hoặc '2023')."""
|
| 17 |
-
if not year_scope:
|
| 18 |
-
return None
|
| 19 |
-
|
| 20 |
-
year_targets = []
|
| 21 |
-
year_scope = year_scope.strip()
|
| 22 |
-
|
| 23 |
-
# Parse year_scope: có thể là "2023-2024" hoặc "2023"
|
| 24 |
-
if "-" in year_scope:
|
| 25 |
-
parts = year_scope.split("-")
|
| 26 |
-
for p in parts:
|
| 27 |
-
try:
|
| 28 |
-
year_targets.append(int(p.strip()))
|
| 29 |
-
except ValueError:
|
| 30 |
-
pass
|
| 31 |
-
else:
|
| 32 |
-
try:
|
| 33 |
-
year_targets.append(int(year_scope))
|
| 34 |
-
except ValueError:
|
| 35 |
-
pass
|
| 36 |
-
|
| 37 |
-
if not year_targets:
|
| 38 |
-
return None
|
| 39 |
-
|
| 40 |
-
# Sử dụng MatchAny để filter theo danh sách years
|
| 41 |
-
from qdrant_client.models import HasIdCondition as QdrantHasId
|
| 42 |
-
try:
|
| 43 |
-
return Filter(
|
| 44 |
-
must=[
|
| 45 |
-
FieldCondition(
|
| 46 |
-
key="years",
|
| 47 |
-
match=MatchAny(any=year_targets),
|
| 48 |
-
)
|
| 49 |
-
]
|
| 50 |
-
)
|
| 51 |
-
except Exception:
|
| 52 |
-
# Fallback nếu MatchAny không work
|
| 53 |
-
return None
|
| 54 |
-
|
| 55 |
-
|
| 56 |
class CollectionRouterRetriever:
|
| 57 |
def __init__(
|
| 58 |
self,
|
|
@@ -104,7 +61,7 @@ class CollectionRouterRetriever:
|
|
| 104 |
|
| 105 |
return active_collections[: self.top_n_collections]
|
| 106 |
|
| 107 |
-
def _search_target_collections(self, query: str, collections: List[str], limit: int
|
| 108 |
if not collections:
|
| 109 |
return []
|
| 110 |
|
|
@@ -114,11 +71,6 @@ class CollectionRouterRetriever:
|
|
| 114 |
logger.exception("Failed to embed query for collection routing")
|
| 115 |
return []
|
| 116 |
|
| 117 |
-
# Tạo filter Qdrant nếu có year_scope
|
| 118 |
-
year_filter = _build_year_filter(year_scope)
|
| 119 |
-
if year_filter:
|
| 120 |
-
logger.info(f"Áp dụng Qdrant Filter cho year_scope: {year_scope}")
|
| 121 |
-
|
| 122 |
scored_docs = []
|
| 123 |
for collection_name in collections:
|
| 124 |
try:
|
|
@@ -127,10 +79,9 @@ class CollectionRouterRetriever:
|
|
| 127 |
query_vector=query_vector,
|
| 128 |
limit=limit,
|
| 129 |
with_payload=True,
|
| 130 |
-
query_filter=year_filter, # NEW: Áp dụng Qdrant Filter native
|
| 131 |
)
|
| 132 |
-
except Exception
|
| 133 |
-
logger.exception(
|
| 134 |
continue
|
| 135 |
|
| 136 |
for point in points:
|
|
@@ -144,11 +95,9 @@ class CollectionRouterRetriever:
|
|
| 144 |
"source_file": payload.get("filename") or payload.get("stored_name") or "",
|
| 145 |
"source_relpath": payload.get("object_path") or payload.get("path") or "",
|
| 146 |
"object_path": payload.get("object_path") or "",
|
| 147 |
-
"source_url": payload.get("source_url") or "", # NEW: Thêm source_url
|
| 148 |
"folder_key": payload.get("folder_key") or "",
|
| 149 |
"collection_name": collection_name,
|
| 150 |
"academic_year": payload.get("academic_year") or "",
|
| 151 |
-
"years": payload.get("years") or [], # NEW: Thêm years array
|
| 152 |
"chunk_index": payload.get("chunk_index"),
|
| 153 |
"page_number": payload.get("page_number"),
|
| 154 |
}
|
|
@@ -177,7 +126,6 @@ class CollectionRouterRetriever:
|
|
| 177 |
query=query,
|
| 178 |
collections=target_collections,
|
| 179 |
limit=candidate_k,
|
| 180 |
-
year_scope=year_scope, # NEW: Pass year_scope để Qdrant Filter
|
| 181 |
)
|
| 182 |
|
| 183 |
if year_scoped:
|
|
|
|
| 1 |
import hashlib
|
| 2 |
import logging
|
| 3 |
+
from typing import List
|
|
|
|
| 4 |
|
| 5 |
from langchain_core.documents import Document as LangChainDocument
|
|
|
|
| 6 |
|
| 7 |
from .collection_utils import collection_matches_year
|
| 8 |
from .document_db import SessionLocal, list_active_collection_names
|
|
|
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
class CollectionRouterRetriever:
|
| 14 |
def __init__(
|
| 15 |
self,
|
|
|
|
| 61 |
|
| 62 |
return active_collections[: self.top_n_collections]
|
| 63 |
|
| 64 |
+
def _search_target_collections(self, query: str, collections: List[str], limit: int) -> List:
|
| 65 |
if not collections:
|
| 66 |
return []
|
| 67 |
|
|
|
|
| 71 |
logger.exception("Failed to embed query for collection routing")
|
| 72 |
return []
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
scored_docs = []
|
| 75 |
for collection_name in collections:
|
| 76 |
try:
|
|
|
|
| 79 |
query_vector=query_vector,
|
| 80 |
limit=limit,
|
| 81 |
with_payload=True,
|
|
|
|
| 82 |
)
|
| 83 |
+
except Exception:
|
| 84 |
+
logger.exception("Qdrant search failed for collection=%s", collection_name)
|
| 85 |
continue
|
| 86 |
|
| 87 |
for point in points:
|
|
|
|
| 95 |
"source_file": payload.get("filename") or payload.get("stored_name") or "",
|
| 96 |
"source_relpath": payload.get("object_path") or payload.get("path") or "",
|
| 97 |
"object_path": payload.get("object_path") or "",
|
|
|
|
| 98 |
"folder_key": payload.get("folder_key") or "",
|
| 99 |
"collection_name": collection_name,
|
| 100 |
"academic_year": payload.get("academic_year") or "",
|
|
|
|
| 101 |
"chunk_index": payload.get("chunk_index"),
|
| 102 |
"page_number": payload.get("page_number"),
|
| 103 |
}
|
|
|
|
| 126 |
query=query,
|
| 127 |
collections=target_collections,
|
| 128 |
limit=candidate_k,
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
if year_scoped:
|
core/collection_utils.py
CHANGED
|
@@ -14,13 +14,10 @@ def normalize_folder_key(folder_key: str) -> str:
|
|
| 14 |
|
| 15 |
|
| 16 |
def build_collection_name(folder_key: str, prefix: str = "rag") -> str:
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
"""
|
| 22 |
-
# ✅ Force single collection: always return "rag_docs"
|
| 23 |
-
return f"{prefix}_docs"
|
| 24 |
|
| 25 |
|
| 26 |
def extract_year_tokens(value: str) -> Set[str]:
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def build_collection_name(folder_key: str, prefix: str = "rag") -> str:
|
| 17 |
+
normalized = normalize_folder_key(folder_key)
|
| 18 |
+
base = f"{prefix}_{normalized}"
|
| 19 |
+
# Qdrant collection names should stay short and simple.
|
| 20 |
+
return base[:63]
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def extract_year_tokens(value: str) -> Set[str]:
|
core/config.py
CHANGED
|
@@ -39,14 +39,14 @@ GEMINI_API_KEYS = os.getenv('GEMINI_API_KEYS', '').strip()
|
|
| 39 |
# Name models
|
| 40 |
LLM_MODEL = os.getenv('LLM_MODEL', 'llama-3.1-70b-versatile')
|
| 41 |
FAST_LLM_MODEL = os.getenv('FAST_LLM_MODEL', 'llama-3.1-8b-instant')
|
| 42 |
-
EMBED_MODEL = os.getenv('EMBED_MODEL', '
|
| 43 |
-
CROSS_ENCODER_MODEL = os.getenv('CROSS_ENCODER_MODEL', '
|
| 44 |
|
| 45 |
# Chunking and retrieval settings
|
| 46 |
CHUNK_SIZE = int(os.getenv('CHUNK_SIZE', '800'))
|
| 47 |
CHUNK_OVERLAP = int(os.getenv('CHUNK_OVERLAP', '150'))
|
| 48 |
-
TOP_K_RESULTS = int(os.getenv('TOP_K_RESULTS', '
|
| 49 |
-
FINAL_TOP_K = int(os.getenv('FINAL_TOP_K', '
|
| 50 |
|
| 51 |
QDRANT_COLLECTION = os.getenv('QDRANT_COLLECTION', 'rag_docs')
|
| 52 |
DOCUMENTS_DATABASE_URL = os.getenv('DOCUMENTS_DATABASE_URL', _default_documents_db_url())
|
|
|
|
| 39 |
# Name models
|
| 40 |
LLM_MODEL = os.getenv('LLM_MODEL', 'llama-3.1-70b-versatile')
|
| 41 |
FAST_LLM_MODEL = os.getenv('FAST_LLM_MODEL', 'llama-3.1-8b-instant')
|
| 42 |
+
EMBED_MODEL = os.getenv('EMBED_MODEL', 'BAAI/bge-m3')
|
| 43 |
+
CROSS_ENCODER_MODEL = os.getenv('CROSS_ENCODER_MODEL', 'BAAI/bge-reranker-v2-m3')
|
| 44 |
|
| 45 |
# Chunking and retrieval settings
|
| 46 |
CHUNK_SIZE = int(os.getenv('CHUNK_SIZE', '800'))
|
| 47 |
CHUNK_OVERLAP = int(os.getenv('CHUNK_OVERLAP', '150'))
|
| 48 |
+
TOP_K_RESULTS = int(os.getenv('TOP_K_RESULTS', '10'))
|
| 49 |
+
FINAL_TOP_K = int(os.getenv('FINAL_TOP_K', '5'))
|
| 50 |
|
| 51 |
QDRANT_COLLECTION = os.getenv('QDRANT_COLLECTION', 'rag_docs')
|
| 52 |
DOCUMENTS_DATABASE_URL = os.getenv('DOCUMENTS_DATABASE_URL', _default_documents_db_url())
|
core/document_ingest_service.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
-
import hashlib
|
| 2 |
import logging
|
| 3 |
import os
|
| 4 |
-
import re
|
| 5 |
import uuid
|
| 6 |
from datetime import datetime, timezone
|
| 7 |
from typing import List, Optional
|
|
@@ -20,7 +18,7 @@ from qdrant_client.models import (
|
|
| 20 |
)
|
| 21 |
|
| 22 |
from .chunking import smart_chunking
|
| 23 |
-
from .config import QDRANT_API_KEY, QDRANT_COLLECTION, QDRANT_URL
|
| 24 |
from .document_db import Document, DocumentChunk, SessionLocal
|
| 25 |
from .models import embeddings
|
| 26 |
from .text_utils import clean_text
|
|
@@ -28,36 +26,7 @@ from .vectorstore import extract_academic_year, load_documents_from_file
|
|
| 28 |
|
| 29 |
logger = logging.getLogger(__name__)
|
| 30 |
|
| 31 |
-
ACTIVE_CODE_PATTERN = re.compile(r"(20\d{2})\s*[-_/]\s*(20\d{2})")
|
| 32 |
_ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt"}
|
| 33 |
-
_ENSURED_PAYLOAD_INDEX_COLLECTIONS = set()
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def _build_supabase_file_url(object_path: str) -> str:
|
| 37 |
-
"""Tạo URL đầy đủ cho tài liệu từ Supabase Storage."""
|
| 38 |
-
if not SUPABASE_URL or not SUPABASE_STORAGE_BUCKET or not object_path:
|
| 39 |
-
return ""
|
| 40 |
-
|
| 41 |
-
clean_path = object_path.lstrip("/")
|
| 42 |
-
return f"{SUPABASE_URL}/storage/v1/object/public/{SUPABASE_STORAGE_BUCKET}/{clean_path}"
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _extract_years_from_academic_year(academic_year: str) -> List[int]:
|
| 46 |
-
"""Trích xuất danh sách năm từ chuỗi năm học (ví dụ '2023-2024' -> [2023, 2024])."""
|
| 47 |
-
if not academic_year or academic_year == "ALL":
|
| 48 |
-
return []
|
| 49 |
-
|
| 50 |
-
years = []
|
| 51 |
-
match = ACTIVE_CODE_PATTERN.search(academic_year)
|
| 52 |
-
if match:
|
| 53 |
-
try:
|
| 54 |
-
start_year = int(match.group(1))
|
| 55 |
-
end_year = int(match.group(2))
|
| 56 |
-
years = [start_year, end_year]
|
| 57 |
-
except (ValueError, IndexError):
|
| 58 |
-
pass
|
| 59 |
-
|
| 60 |
-
return years
|
| 61 |
|
| 62 |
|
| 63 |
def _load_documents_for_ingest(path: str, extension: str) -> List[LangChainDocument]:
|
|
@@ -138,33 +107,13 @@ def _ensure_qdrant_collection(client: QdrantClient, vector_size: int, collection
|
|
| 138 |
|
| 139 |
|
| 140 |
def _ensure_payload_indexes(client: QdrantClient, collection_name: str) -> None:
|
| 141 |
-
|
| 142 |
-
return
|
| 143 |
-
|
| 144 |
-
# KEYWORD indexes cho filtering nhanh
|
| 145 |
-
for field_name in ("object_path", "document_id", "content_hash"):
|
| 146 |
-
try:
|
| 147 |
-
client.create_payload_index(
|
| 148 |
-
collection_name=collection_name,
|
| 149 |
-
field_name=field_name,
|
| 150 |
-
field_schema=PayloadSchemaType.KEYWORD,
|
| 151 |
-
wait=True,
|
| 152 |
-
)
|
| 153 |
-
except Exception as e:
|
| 154 |
-
logger.warning(f"Failed to create KEYWORD index for {field_name}: {e}")
|
| 155 |
-
|
| 156 |
-
# INTEGER array index cho years
|
| 157 |
-
try:
|
| 158 |
client.create_payload_index(
|
| 159 |
collection_name=collection_name,
|
| 160 |
-
field_name=
|
| 161 |
-
field_schema=PayloadSchemaType.
|
| 162 |
wait=True,
|
| 163 |
)
|
| 164 |
-
except Exception as e:
|
| 165 |
-
logger.warning(f"Failed to create INTEGER index for years: {e}")
|
| 166 |
-
|
| 167 |
-
_ENSURED_PAYLOAD_INDEX_COLLECTIONS.add(collection_name)
|
| 168 |
|
| 169 |
|
| 170 |
def _is_missing_payload_index_error(error: Exception) -> bool:
|
|
@@ -172,152 +121,6 @@ def _is_missing_payload_index_error(error: Exception) -> bool:
|
|
| 172 |
return "Index required but not found" in message
|
| 173 |
|
| 174 |
|
| 175 |
-
def _get_or_create_deduplicated_points(
|
| 176 |
-
client: QdrantClient,
|
| 177 |
-
collection_name: str,
|
| 178 |
-
chunk_docs: List[LangChainDocument],
|
| 179 |
-
vectors: List,
|
| 180 |
-
source_object_ref: str,
|
| 181 |
-
document: Document,
|
| 182 |
-
source_updated_at: Optional[str],
|
| 183 |
-
source_etag: Optional[str],
|
| 184 |
-
created_at: str,
|
| 185 |
-
effective_source_path: Optional[str] = None,
|
| 186 |
-
) -> tuple[List[PointStruct], List[DocumentChunk]]:
|
| 187 |
-
"""
|
| 188 |
-
Tích hợp MD5 deduplication: nếu content hash trùng, cập nhật years array thay vì tạo mới.
|
| 189 |
-
"""
|
| 190 |
-
points: List[PointStruct] = []
|
| 191 |
-
db_chunk_rows: List[DocumentChunk] = []
|
| 192 |
-
|
| 193 |
-
for index, (chunk_doc, vector) in enumerate(zip(chunk_docs, vectors)):
|
| 194 |
-
chunk_text = chunk_doc.page_content
|
| 195 |
-
metadata = chunk_doc.metadata if isinstance(chunk_doc.metadata, dict) else {}
|
| 196 |
-
|
| 197 |
-
# Tính content hash
|
| 198 |
-
content_hash = hashlib.md5(chunk_text.encode('utf-8')).hexdigest()
|
| 199 |
-
|
| 200 |
-
# Trích académie năm học
|
| 201 |
-
academic_year = metadata.get("academic_year") or "ALL"
|
| 202 |
-
years = _extract_years_from_academic_year(academic_year)
|
| 203 |
-
|
| 204 |
-
# Tạo source URL
|
| 205 |
-
source_url = _build_supabase_file_url(source_object_ref)
|
| 206 |
-
|
| 207 |
-
# Kiểm tra xem content_hash đã tồn tại
|
| 208 |
-
existing_point_id = None
|
| 209 |
-
try:
|
| 210 |
-
existing_points = client.scroll(
|
| 211 |
-
collection_name=collection_name,
|
| 212 |
-
limit=1,
|
| 213 |
-
scroll_filter=Filter(
|
| 214 |
-
must=[
|
| 215 |
-
FieldCondition(
|
| 216 |
-
key="content_hash",
|
| 217 |
-
match=MatchValue(value=content_hash),
|
| 218 |
-
)
|
| 219 |
-
]
|
| 220 |
-
),
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
if existing_points and existing_points[0]:
|
| 224 |
-
# Nếu tìm thấy point với hash trùng
|
| 225 |
-
existing_point_id = existing_points[0][0].id
|
| 226 |
-
logger.info(f"Tìm thấy content đã tồn tại hash={content_hash[:8]}..., sẽ cập nhật years")
|
| 227 |
-
except Exception as e:
|
| 228 |
-
logger.debug(f"Không thể tìm kiếm existing points: {e}")
|
| 229 |
-
|
| 230 |
-
if existing_point_id:
|
| 231 |
-
# Merge years array
|
| 232 |
-
try:
|
| 233 |
-
existing_payload = client.retrieve(collection_name, [existing_point_id])[0].payload
|
| 234 |
-
existing_years = set(existing_payload.get("years", []))
|
| 235 |
-
merged_years = sorted(list(set(years) | existing_years))
|
| 236 |
-
|
| 237 |
-
# Update payload với years mới
|
| 238 |
-
updated_payload = {
|
| 239 |
-
**existing_payload,
|
| 240 |
-
"years": merged_years,
|
| 241 |
-
"document_id": document.id, # Update document_id nếu tài liệu mới
|
| 242 |
-
"source_updated_at": source_updated_at or existing_payload.get("source_updated_at"),
|
| 243 |
-
}
|
| 244 |
-
|
| 245 |
-
# ✅ Dùng set_payload để cập nhật payload
|
| 246 |
-
client.set_payload(
|
| 247 |
-
collection_name=collection_name,
|
| 248 |
-
payload=updated_payload,
|
| 249 |
-
points=[existing_point_id],
|
| 250 |
-
)
|
| 251 |
-
logger.info(f"Đã cập nhật years cho hash {content_hash[:8]}...: {merged_years}")
|
| 252 |
-
# ✅ QUAN TRỌNG: Bỏ qua tạo point mới - vì đã cập nhật point đã tồn tại
|
| 253 |
-
continue
|
| 254 |
-
except Exception as e:
|
| 255 |
-
logger.warning(f"Lỗi cập nhật years cho point đã tồn tại: {e}, sẽ tạo point mới")
|
| 256 |
-
# Fallback: tạo point mới nếu cập nhật thất bại
|
| 257 |
-
pass
|
| 258 |
-
|
| 259 |
-
# Tạo point mới
|
| 260 |
-
point_id = str(uuid.uuid4())
|
| 261 |
-
payload = _build_payload(
|
| 262 |
-
document, source_object_ref, chunk_text, index, metadata,
|
| 263 |
-
academic_year, years, content_hash, source_url,
|
| 264 |
-
source_updated_at, source_etag, created_at, effective_source_path
|
| 265 |
-
)
|
| 266 |
-
points.append(PointStruct(id=point_id, vector=vector, payload=payload))
|
| 267 |
-
db_chunk_rows.append(
|
| 268 |
-
DocumentChunk(
|
| 269 |
-
document_id=document.id,
|
| 270 |
-
chunk_index=index,
|
| 271 |
-
content_preview=chunk_text[:200],
|
| 272 |
-
qdrant_point_id=point_id,
|
| 273 |
-
)
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
return points, db_chunk_rows
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
def _build_payload(
|
| 280 |
-
document: Document,
|
| 281 |
-
source_object_ref: str,
|
| 282 |
-
chunk_text: str,
|
| 283 |
-
index: int,
|
| 284 |
-
metadata: dict,
|
| 285 |
-
academic_year: str,
|
| 286 |
-
years: List[int],
|
| 287 |
-
content_hash: str,
|
| 288 |
-
source_url: str,
|
| 289 |
-
source_updated_at: Optional[str],
|
| 290 |
-
source_etag: Optional[str],
|
| 291 |
-
created_at: str,
|
| 292 |
-
effective_source_path: Optional[str] = None,
|
| 293 |
-
) -> dict:
|
| 294 |
-
"""Xây dựng payload dictionary cho point."""
|
| 295 |
-
source_name = os.path.basename(source_object_ref) if source_object_ref else document.stored_name
|
| 296 |
-
source_relpath = source_object_ref or source_name
|
| 297 |
-
|
| 298 |
-
return {
|
| 299 |
-
"document_id": document.id,
|
| 300 |
-
"filename": document.original_name,
|
| 301 |
-
"stored_effective_source_path or name": document.stored_name,
|
| 302 |
-
"path": document.path,
|
| 303 |
-
"object_path": source_object_ref,
|
| 304 |
-
"folder_key": document.folder_key,
|
| 305 |
-
"collection_name": document.collection_name or "",
|
| 306 |
-
"source_file": metadata.get("source_file") or source_name,
|
| 307 |
-
"source_relpath": metadata.get("source_relpath") or source_relpath,
|
| 308 |
-
"source_url": source_url,
|
| 309 |
-
"academic_year": academic_year,
|
| 310 |
-
"years": years,
|
| 311 |
-
"content_hash": content_hash,
|
| 312 |
-
"page_number": metadata.get("page_number"),
|
| 313 |
-
"source_updated_at": source_updated_at,
|
| 314 |
-
"source_etag": source_etag,
|
| 315 |
-
"chunk_index": index,
|
| 316 |
-
"created_at": created_at,
|
| 317 |
-
"content": chunk_text,
|
| 318 |
-
}
|
| 319 |
-
|
| 320 |
-
|
| 321 |
def _delete_existing_document_points(
|
| 322 |
client: QdrantClient,
|
| 323 |
collection_name: str,
|
|
@@ -357,7 +160,6 @@ def _delete_existing_document_points(
|
|
| 357 |
"Missing payload index detected while deleting old points in collection=%s. Rebuilding indexes and retrying once.",
|
| 358 |
collection_name,
|
| 359 |
)
|
| 360 |
-
_ENSURED_PAYLOAD_INDEX_COLLECTIONS.discard(collection_name)
|
| 361 |
_ensure_payload_indexes(client, collection_name)
|
| 362 |
client.delete(
|
| 363 |
collection_name=collection_name,
|
|
@@ -433,30 +235,46 @@ def process_document_ingest(
|
|
| 433 |
_delete_existing_document_points(client, target_collection, source_object_ref, document.id)
|
| 434 |
|
| 435 |
created_at = datetime.now(timezone.utc).isoformat()
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
-
|
| 452 |
-
if points:
|
| 453 |
-
client.upsert(collection_name=target_collection, points=points, wait=True)
|
| 454 |
|
| 455 |
db.query(DocumentChunk).filter(DocumentChunk.document_id == document.id).delete()
|
| 456 |
-
|
| 457 |
-
# ✅ Chỉ bulk save nếu có chunks mới
|
| 458 |
-
if db_chunk_rows:
|
| 459 |
-
db.bulk_save_objects(db_chunk_rows)
|
| 460 |
|
| 461 |
if effective_source_path:
|
| 462 |
document.path = effective_source_path
|
|
@@ -540,7 +358,6 @@ def delete_vectors_for_object_path(collection_name: str, object_path: str) -> bo
|
|
| 540 |
"Missing payload index detected while deleting object_path in collection=%s. Rebuilding indexes and retrying once.",
|
| 541 |
target_collection,
|
| 542 |
)
|
| 543 |
-
_ENSURED_PAYLOAD_INDEX_COLLECTIONS.discard(target_collection)
|
| 544 |
_ensure_payload_indexes(client, target_collection)
|
| 545 |
client.delete(
|
| 546 |
collection_name=target_collection,
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
|
|
|
| 3 |
import uuid
|
| 4 |
from datetime import datetime, timezone
|
| 5 |
from typing import List, Optional
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
from .chunking import smart_chunking
|
| 21 |
+
from .config import QDRANT_API_KEY, QDRANT_COLLECTION, QDRANT_URL
|
| 22 |
from .document_db import Document, DocumentChunk, SessionLocal
|
| 23 |
from .models import embeddings
|
| 24 |
from .text_utils import clean_text
|
|
|
|
| 26 |
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
|
|
|
| 29 |
_ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def _load_documents_for_ingest(path: str, extension: str) -> List[LangChainDocument]:
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
def _ensure_payload_indexes(client: QdrantClient, collection_name: str) -> None:
|
| 110 |
+
for field_name in ("object_path", "document_id"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
client.create_payload_index(
|
| 112 |
collection_name=collection_name,
|
| 113 |
+
field_name=field_name,
|
| 114 |
+
field_schema=PayloadSchemaType.KEYWORD,
|
| 115 |
wait=True,
|
| 116 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
def _is_missing_payload_index_error(error: Exception) -> bool:
|
|
|
|
| 121 |
return "Index required but not found" in message
|
| 122 |
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def _delete_existing_document_points(
|
| 125 |
client: QdrantClient,
|
| 126 |
collection_name: str,
|
|
|
|
| 160 |
"Missing payload index detected while deleting old points in collection=%s. Rebuilding indexes and retrying once.",
|
| 161 |
collection_name,
|
| 162 |
)
|
|
|
|
| 163 |
_ensure_payload_indexes(client, collection_name)
|
| 164 |
client.delete(
|
| 165 |
collection_name=collection_name,
|
|
|
|
| 235 |
_delete_existing_document_points(client, target_collection, source_object_ref, document.id)
|
| 236 |
|
| 237 |
created_at = datetime.now(timezone.utc).isoformat()
|
| 238 |
+
points: List[PointStruct] = []
|
| 239 |
+
db_chunk_rows: List[DocumentChunk] = []
|
| 240 |
+
|
| 241 |
+
for index, (chunk_doc, vector) in enumerate(zip(chunk_docs, vectors)):
|
| 242 |
+
chunk_text = chunk_doc.page_content
|
| 243 |
+
metadata = chunk_doc.metadata if isinstance(chunk_doc.metadata, dict) else {}
|
| 244 |
+
point_id = str(uuid.uuid4())
|
| 245 |
+
payload = {
|
| 246 |
+
"document_id": document.id,
|
| 247 |
+
"filename": document.original_name,
|
| 248 |
+
"stored_name": document.stored_name,
|
| 249 |
+
"path": effective_source_path or document.path,
|
| 250 |
+
"object_path": source_object_ref,
|
| 251 |
+
"folder_key": document.folder_key,
|
| 252 |
+
"collection_name": target_collection,
|
| 253 |
+
"source_file": metadata.get("source_file") or source_name,
|
| 254 |
+
"source_relpath": metadata.get("source_relpath") or source_relpath,
|
| 255 |
+
"academic_year": metadata.get("academic_year") or "ALL",
|
| 256 |
+
"page_number": metadata.get("page_number"),
|
| 257 |
+
"source_updated_at": source_updated_at,
|
| 258 |
+
"source_etag": source_etag,
|
| 259 |
+
"chunk_index": index,
|
| 260 |
+
"created_at": created_at,
|
| 261 |
+
"content": chunk_text,
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
points.append(PointStruct(id=point_id, vector=vector, payload=payload))
|
| 265 |
+
db_chunk_rows.append(
|
| 266 |
+
DocumentChunk(
|
| 267 |
+
document_id=document.id,
|
| 268 |
+
chunk_index=index,
|
| 269 |
+
content_preview=chunk_text[:200],
|
| 270 |
+
qdrant_point_id=point_id,
|
| 271 |
+
)
|
| 272 |
+
)
|
| 273 |
|
| 274 |
+
client.upsert(collection_name=target_collection, points=points, wait=True)
|
|
|
|
|
|
|
| 275 |
|
| 276 |
db.query(DocumentChunk).filter(DocumentChunk.document_id == document.id).delete()
|
| 277 |
+
db.bulk_save_objects(db_chunk_rows)
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
if effective_source_path:
|
| 280 |
document.path = effective_source_path
|
|
|
|
| 358 |
"Missing payload index detected while deleting object_path in collection=%s. Rebuilding indexes and retrying once.",
|
| 359 |
target_collection,
|
| 360 |
)
|
|
|
|
| 361 |
_ensure_payload_indexes(client, target_collection)
|
| 362 |
client.delete(
|
| 363 |
collection_name=target_collection,
|
core/prompting.py
CHANGED
|
@@ -85,7 +85,7 @@ Về vấn đề [Chủ đề], theo **Điều [Số]**, các trường hợp ng
|
|
| 85 |
# Lấy ví dụ phù hợp (Fallback về simple nếu không khớp)
|
| 86 |
example = examples.get(question_type, examples['simple'])
|
| 87 |
|
| 88 |
-
# TOPIC INSTRUCTION: Rào chắn ngữ cảnh (Context Guardrail)
|
| 89 |
if topic:
|
| 90 |
topic_instr = (
|
| 91 |
f"\n\n **LƯU Ý ĐẶC BIỆT VỀ CHỦ ĐỀ MỞ RỘNG:**\n"
|
|
@@ -97,17 +97,19 @@ Về vấn đề [Chủ đề], theo **Điều [Số]**, các trường hợp ng
|
|
| 97 |
else:
|
| 98 |
topic_instr = ""
|
| 99 |
|
|
|
|
| 100 |
if year_scope:
|
| 101 |
year_instr = (
|
| 102 |
-
f"\n\n **RÀNG BUỘC NĂM HỌC (
|
| 103 |
-
f"- Người dùng đang hỏi
|
| 104 |
-
f"-
|
| 105 |
-
f"- Nếu
|
|
|
|
| 106 |
)
|
| 107 |
else:
|
| 108 |
year_instr = ""
|
| 109 |
|
| 110 |
-
# Gộp Prompt
|
| 111 |
full_prompt = f"""{base_system}
|
| 112 |
----------------
|
| 113 |
{example}
|
|
|
|
| 85 |
# Lấy ví dụ phù hợp (Fallback về simple nếu không khớp)
|
| 86 |
example = examples.get(question_type, examples['simple'])
|
| 87 |
|
| 88 |
+
# 3. TOPIC INSTRUCTION: Rào chắn ngữ cảnh (Context Guardrail)
|
| 89 |
if topic:
|
| 90 |
topic_instr = (
|
| 91 |
f"\n\n **LƯU Ý ĐẶC BIỆT VỀ CHỦ ĐỀ MỞ RỘNG:**\n"
|
|
|
|
| 97 |
else:
|
| 98 |
topic_instr = ""
|
| 99 |
|
| 100 |
+
# [YEAR-AWARE CHANGE] Rang buoc cau tra loi theo nam hoc duoc hoi.
|
| 101 |
if year_scope:
|
| 102 |
year_instr = (
|
| 103 |
+
f"\n\n **RÀNG BUỘC NĂM HỌC (BẮT BUỘC):**\n"
|
| 104 |
+
f"- Người dùng đang hỏi trong phạm vi năm: **{year_scope}**.\n"
|
| 105 |
+
f"- Ưu tiên các đoạn có nhãn nguồn cùng năm trong context (ví dụ: [Năm 2022-2023 | ...]).\n"
|
| 106 |
+
f"- Nếu chưa đủ bằng chứng đúng năm, được phép dùng đoạn có nhãn 'Áp dụng nhiều năm' hoặc quy định gần nhất và phải ghi chú rõ phạm vi áp dụng.\n"
|
| 107 |
+
f"- Không kết luận 'không có dữ liệu' chỉ vì thiếu đúng nhãn năm nếu vẫn có quy định bao quát liên quan.\n"
|
| 108 |
)
|
| 109 |
else:
|
| 110 |
year_instr = ""
|
| 111 |
|
| 112 |
+
# 4. Gộp Prompt
|
| 113 |
full_prompt = f"""{base_system}
|
| 114 |
----------------
|
| 115 |
{example}
|
core/qa_pipeline.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
from typing import List, Generator
|
| 2 |
import os, re, hashlib
|
| 3 |
import logging
|
| 4 |
import groq
|
| 5 |
import google.generativeai as genai
|
| 6 |
import json
|
| 7 |
-
|
| 8 |
from .models import llm
|
| 9 |
from .config import TOP_K_RESULTS, FINAL_TOP_K
|
| 10 |
from .rerank import advanced_rerank
|
|
@@ -12,7 +12,6 @@ from .prompting import create_advanced_prompt
|
|
| 12 |
from .retriever import HybridRetriever
|
| 13 |
from .analyze_and_expand import analyze_and_expand_query
|
| 14 |
from .llm_utils import safe_invoke, safe_stream
|
| 15 |
-
import concurrent.futures
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
|
@@ -23,15 +22,6 @@ MAX_OUT_CHARS = 3000
|
|
| 23 |
# [YEAR-AWARE CHANGE] Pattern nhan dien nam hoc trong cau hoi.
|
| 24 |
ACADEMIC_YEAR_PATTERN = re.compile(r"\b(20\d{2})\s*[-_/]\s*(20\d{2})\b")
|
| 25 |
SINGLE_YEAR_PATTERN = re.compile(r"\b(20\d{2})\b")
|
| 26 |
-
_SOCIAL_KEYWORDS = {
|
| 27 |
-
"hello", "hi", "xin chao", "chao", "alo", "hey", "thanks", "cam on", "tam biet", "bye"
|
| 28 |
-
}
|
| 29 |
-
_PERSONAL_NON_DOMAIN_PATTERNS = [
|
| 30 |
-
re.compile(r"\bb(ạn|an)\s+c[oó]\s+bi[eế]t\s+t[oô]i\s+l[aà]\s+ai\b", re.IGNORECASE),
|
| 31 |
-
re.compile(r"\bb(ạn|an)\s+l[aà]\s+ai\b", re.IGNORECASE),
|
| 32 |
-
re.compile(r"\bai\s+t[aạ]o\s+ra\s+b(ạn|an)\b", re.IGNORECASE),
|
| 33 |
-
re.compile(r"\b(ăn|an)\s+c[oơ]m\s+ch(ưa|ua)\b", re.IGNORECASE),
|
| 34 |
-
]
|
| 35 |
|
| 36 |
# Quản lý API Keys cho Groq và Gemini với xoay tua tự động khi gặp lỗi hoặc hết hạn
|
| 37 |
class AIProviderManager:
|
|
@@ -140,48 +130,6 @@ def sanitize_for_prompt(text: str) -> str:
|
|
| 140 |
text = re.sub(r"\b\d{8,12}\b", "[ID]", text)
|
| 141 |
return text.strip()
|
| 142 |
|
| 143 |
-
|
| 144 |
-
def remove_accents(input_str: str) -> str:
|
| 145 |
-
s1 = unicodedata.normalize('NFKD', input_str).encode('ASCII', 'ignore').decode('utf-8')
|
| 146 |
-
return s1.lower()
|
| 147 |
-
|
| 148 |
-
def _normalize_for_router(message: str) -> str:
|
| 149 |
-
compact = remove_accents(message or "")
|
| 150 |
-
compact = re.sub(r"[^\w\s]", " ", compact, flags=re.UNICODE)
|
| 151 |
-
return re.sub(r"\s+", " ", compact).strip()
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def _quick_non_domain_reply(message: str) -> Optional[str]:
|
| 155 |
-
normalized = _normalize_for_router(message)
|
| 156 |
-
if not normalized:
|
| 157 |
-
return None
|
| 158 |
-
|
| 159 |
-
if normalized in _SOCIAL_KEYWORDS:
|
| 160 |
-
return "Chào bạn. Mình hỗ trợ tra cứu quy chế đào tạo, bạn cần hỏi nội dung nào?"
|
| 161 |
-
|
| 162 |
-
for pattern in _PERSONAL_NON_DOMAIN_PATTERNS:
|
| 163 |
-
if pattern.search(normalized):
|
| 164 |
-
return "Mình không có thông tin cá nhân của bạn. Mình chỉ hỗ trợ giải đáp về quy chế đào tạo."
|
| 165 |
-
|
| 166 |
-
return None
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def _was_recently_prompted_for_year(history: List) -> bool:
|
| 170 |
-
if not history:
|
| 171 |
-
return False
|
| 172 |
-
|
| 173 |
-
reminder_snippet = "Vui lòng nhập kèm năm học để tra cứu nhanh hơn"
|
| 174 |
-
for item in reversed(history[-6:]):
|
| 175 |
-
if not isinstance(item, dict):
|
| 176 |
-
continue
|
| 177 |
-
if str(item.get("role") or "").strip().lower() != "assistant":
|
| 178 |
-
continue
|
| 179 |
-
content = str(item.get("content") or "")
|
| 180 |
-
if reminder_snippet in content:
|
| 181 |
-
return True
|
| 182 |
-
|
| 183 |
-
return False
|
| 184 |
-
|
| 185 |
def generate_standalone_query(message: str, history: List) -> str:
|
| 186 |
"""Tái tạo câu hỏi từ lịch sử """
|
| 187 |
if not history:
|
|
@@ -283,60 +231,24 @@ def ask_ai_improved(message: str, history: List, hybrid_retriever) -> Generator[
|
|
| 283 |
yield full_response
|
| 284 |
|
| 285 |
def ask_ai_stream_delta(message: str, history: List, hybrid_retriever) -> Generator[str, None, None]:
|
| 286 |
-
# Kiểm tra rỗng
|
| 287 |
if not message.strip():
|
| 288 |
-
yield "Bạn chưa nhập câu hỏi."
|
| 289 |
return
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
if quick_reply:
|
| 294 |
-
logger.info("Bỏ qua truy xuất tài liệu cho câu hỏi giao tiếp/ngoài phạm vi")
|
| 295 |
-
yield quick_reply
|
| 296 |
return
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
try:
|
| 302 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
| 303 |
-
# Call 1: Tạo standalone question từ history
|
| 304 |
-
future_standalone = executor.submit(
|
| 305 |
-
generate_standalone_query,
|
| 306 |
-
message,
|
| 307 |
-
history
|
| 308 |
-
)
|
| 309 |
-
|
| 310 |
-
# Call 2: Phân loại & mở rộng (song parallel)
|
| 311 |
-
# Dùng message gốc luôn, LLM sẽ handle context từ message
|
| 312 |
-
future_classify = executor.submit(
|
| 313 |
-
analyze_and_expand_query,
|
| 314 |
-
message # ✅ Dùng message gốc, không chờ standalone xong
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
# Chờ cả 2 xong (timeout 15s)
|
| 318 |
-
question = future_standalone.result(timeout=15)
|
| 319 |
-
processed_data = future_classify.result(timeout=15)
|
| 320 |
-
|
| 321 |
-
except concurrent.futures.TimeoutError:
|
| 322 |
-
logger.warning("Timeout khi gọi LLM song parallel, fallback...")
|
| 323 |
-
question = message
|
| 324 |
-
processed_data = {
|
| 325 |
-
"question_type": "simple",
|
| 326 |
-
"answer": None,
|
| 327 |
-
"expanded_queries": [message]
|
| 328 |
-
}
|
| 329 |
-
except Exception as e:
|
| 330 |
-
logger.warning(f"Lỗi parallel execution: {e}, fallback...")
|
| 331 |
-
question = message
|
| 332 |
-
processed_data = {
|
| 333 |
-
"question_type": "simple",
|
| 334 |
-
"answer": None,
|
| 335 |
-
"expanded_queries": [message]
|
| 336 |
-
}
|
| 337 |
-
|
| 338 |
requested_year_range, mentioned_years = detect_requested_year(f"{message}\n{question}")
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
if processed_data.get("question_type") == "normal":
|
| 342 |
ans = processed_data.get("answer") or "Chào bạn 👋 Mình hỗ trợ tra cứu quy chế đào tạo."
|
|
@@ -347,57 +259,57 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever) -> Genera
|
|
| 347 |
queries = processed_data['expanded_queries']
|
| 348 |
logger.info(f"Các truy vấn tìm kiếm: {queries}")
|
| 349 |
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
for
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
docs_temp.append(doc)
|
| 368 |
-
seen_temp.add(content_hash)
|
| 369 |
-
return docs_temp
|
| 370 |
-
# Tìm tài liệu
|
| 371 |
-
# Cố gắng tìm tài liệu khớp chính xác với năm học người dùng nhắc đến
|
| 372 |
-
all_docs = fetch_docs(year_scope_hint)
|
| 373 |
-
|
| 374 |
-
# Nếu lớp 1 tìm không ra hoặc người dùng hoàn toàn không nhập năm, hệ thống sẽ tự động hạ chuẩn, tìm trên toàn bộ cơ sở dữ liệu chung (ALL)
|
| 375 |
-
if not all_docs and year_scope_hint:
|
| 376 |
-
logger.info(f"Bộ lọc năm '{year_scope_hint}' quá gắt không ra kết quả. Tự động Fallback tìm trên bản chung...")
|
| 377 |
-
year_scope_hint = None # Reset lại biến hint để quét toàn bộ VectorDB
|
| 378 |
-
all_docs = fetch_docs(None)
|
| 379 |
|
| 380 |
logger.info(f"Tìm thấy tổng {len(all_docs)} documents.")
|
| 381 |
-
|
| 382 |
-
# Xử lý lịch sự nếu Vector DB thực sự "bó tay"
|
| 383 |
if not all_docs:
|
| 384 |
-
yield
|
| 385 |
return
|
| 386 |
|
| 387 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
final_docs = advanced_rerank(question, all_docs, top_k=FINAL_TOP_K)
|
| 389 |
|
| 390 |
-
# Gắn nhãn năm học vào Context cho LLM đọc
|
| 391 |
context_parts = []
|
| 392 |
total_chars = 0
|
| 393 |
for doc in final_docs:
|
| 394 |
page = doc.metadata.get('page_number', 'N/A')
|
| 395 |
file_name = doc.metadata.get('source_file') or doc.metadata.get('source')
|
| 396 |
-
|
| 397 |
doc_year = infer_doc_academic_year(doc)
|
| 398 |
year_label = f"Năm {doc_year}" if doc_year != "ALL" else "Áp dụng nhiều năm"
|
| 399 |
source = f"[{year_label} | {os.path.basename(file_name)} | Trang {page}]" if file_name else f"[{year_label} | Trang {page}]"
|
| 400 |
-
|
| 401 |
block = f"{source}\n{doc.page_content}"
|
| 402 |
if total_chars + len(block) > MAX_CONTEXT_CHARS:
|
| 403 |
break
|
|
@@ -407,14 +319,12 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever) -> Genera
|
|
| 407 |
context = "\n\n---\n\n".join(context_parts)
|
| 408 |
topic_hint = processed_data.get('topic') or processed_data.get('root_question') or question
|
| 409 |
|
| 410 |
-
|
| 411 |
-
prompt = create_advanced_prompt(question, context, question_type, topic_hint, year_scope=year_scope_hint)
|
| 412 |
|
| 413 |
logger.info("Đang tạo câu trả lời cuối cùng ...")
|
| 414 |
|
| 415 |
success = False
|
| 416 |
-
|
| 417 |
-
# Streaming qua Groq (Có xoay tua khi gặp lỗi 429)
|
| 418 |
for _ in range(len(api_manager.groq_keys)):
|
| 419 |
try:
|
| 420 |
client = api_manager.get_groq_client()
|
|
@@ -436,7 +346,7 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever) -> Genera
|
|
| 436 |
logger.error(f"Lỗi Groq: {e}")
|
| 437 |
break
|
| 438 |
|
| 439 |
-
#
|
| 440 |
if not success:
|
| 441 |
logger.warning("Chuyển sang Gemini ...")
|
| 442 |
for _ in range(max(1, len(api_manager.gemini_keys))):
|
|
@@ -453,6 +363,5 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever) -> Genera
|
|
| 453 |
api_manager.rotate_gemini()
|
| 454 |
logger.error(f"Lỗi Gemini: {e}")
|
| 455 |
|
| 456 |
-
# Báo lỗi khi cả 2 API đều sập
|
| 457 |
if not success:
|
| 458 |
-
yield "Đã xảy ra lỗi hệ thống hoặc quá tải
|
|
|
|
| 1 |
+
from typing import List, Generator
|
| 2 |
import os, re, hashlib
|
| 3 |
import logging
|
| 4 |
import groq
|
| 5 |
import google.generativeai as genai
|
| 6 |
import json
|
| 7 |
+
|
| 8 |
from .models import llm
|
| 9 |
from .config import TOP_K_RESULTS, FINAL_TOP_K
|
| 10 |
from .rerank import advanced_rerank
|
|
|
|
| 12 |
from .retriever import HybridRetriever
|
| 13 |
from .analyze_and_expand import analyze_and_expand_query
|
| 14 |
from .llm_utils import safe_invoke, safe_stream
|
|
|
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
| 22 |
# [YEAR-AWARE CHANGE] Pattern nhan dien nam hoc trong cau hoi.
|
| 23 |
ACADEMIC_YEAR_PATTERN = re.compile(r"\b(20\d{2})\s*[-_/]\s*(20\d{2})\b")
|
| 24 |
SINGLE_YEAR_PATTERN = re.compile(r"\b(20\d{2})\b")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Quản lý API Keys cho Groq và Gemini với xoay tua tự động khi gặp lỗi hoặc hết hạn
|
| 27 |
class AIProviderManager:
|
|
|
|
| 130 |
text = re.sub(r"\b\d{8,12}\b", "[ID]", text)
|
| 131 |
return text.strip()
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def generate_standalone_query(message: str, history: List) -> str:
|
| 134 |
"""Tái tạo câu hỏi từ lịch sử """
|
| 135 |
if not history:
|
|
|
|
| 231 |
yield full_response
|
| 232 |
|
| 233 |
def ask_ai_stream_delta(message: str, history: List, hybrid_retriever) -> Generator[str, None, None]:
|
|
|
|
| 234 |
if not message.strip():
|
| 235 |
+
yield " Bạn chưa nhập câu hỏi."
|
| 236 |
return
|
| 237 |
|
| 238 |
+
if message.strip().lower() in {"hello", "hi", "xin chào", "chào"}:
|
| 239 |
+
yield "Chào bạn 👋 Mình hỗ trợ tra cứu quy chế đào tạo. Bạn cần hỏi điều gì?"
|
|
|
|
|
|
|
|
|
|
| 240 |
return
|
| 241 |
|
| 242 |
+
logger.info(f" CÂU HỎI GỐC: {message}")
|
| 243 |
+
question = generate_standalone_query(message, history)
|
| 244 |
+
# [YEAR-AWARE CHANGE] Xac dinh pham vi nam ma nguoi dung yeu cau.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
requested_year_range, mentioned_years = detect_requested_year(f"{message}\n{question}")
|
| 246 |
+
if requested_year_range:
|
| 247 |
+
logger.info(f"Lọc theo năm học yêu cầu: {requested_year_range}")
|
| 248 |
+
elif mentioned_years:
|
| 249 |
+
logger.info(f"Lọc theo năm được nhắc tới: {sorted(mentioned_years)}")
|
| 250 |
+
|
| 251 |
+
processed_data = analyze_and_expand_query(question)
|
| 252 |
|
| 253 |
if processed_data.get("question_type") == "normal":
|
| 254 |
ans = processed_data.get("answer") or "Chào bạn 👋 Mình hỗ trợ tra cứu quy chế đào tạo."
|
|
|
|
| 259 |
queries = processed_data['expanded_queries']
|
| 260 |
logger.info(f"Các truy vấn tìm kiếm: {queries}")
|
| 261 |
|
| 262 |
+
all_docs: List = []
|
| 263 |
+
seen = set()
|
| 264 |
+
year_scope_hint = requested_year_range or (", ".join(sorted(mentioned_years)) if mentioned_years else None)
|
| 265 |
+
for query in queries:
|
| 266 |
+
#Giữ nguyên logic alpha ngành CNTT của Minh
|
| 267 |
+
current_alpha = 0.4 if "CNTT" in query.upper() else 0.5
|
| 268 |
+
docs = hybrid_retriever.search(
|
| 269 |
+
query,
|
| 270 |
+
k=TOP_K_RESULTS,
|
| 271 |
+
alpha=current_alpha,
|
| 272 |
+
year_scope=year_scope_hint,
|
| 273 |
+
)
|
| 274 |
+
for doc in docs:
|
| 275 |
+
content_hash = hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest()
|
| 276 |
+
if content_hash not in seen:
|
| 277 |
+
all_docs.append(doc)
|
| 278 |
+
seen.add(content_hash)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
logger.info(f"Tìm thấy tổng {len(all_docs)} documents.")
|
|
|
|
|
|
|
| 281 |
if not all_docs:
|
| 282 |
+
yield "Không tìm thấy thông tin liên quan trong tài liệu."
|
| 283 |
return
|
| 284 |
|
| 285 |
+
# [YEAR-AWARE CHANGE] Lọc theo năm nhưng vẫn fallback nếu không có tài liệu đúng năm.
|
| 286 |
+
year_scope = None
|
| 287 |
+
year_filter_requested = bool(requested_year_range or mentioned_years)
|
| 288 |
+
year_filtered_docs = filter_docs_by_year(all_docs, requested_year_range, mentioned_years)
|
| 289 |
+
|
| 290 |
+
if year_filter_requested:
|
| 291 |
+
if year_filtered_docs:
|
| 292 |
+
if len(year_filtered_docs) != len(all_docs):
|
| 293 |
+
logger.info(f"Đã lọc theo năm: còn {len(year_filtered_docs)}/{len(all_docs)} documents")
|
| 294 |
+
all_docs = year_filtered_docs
|
| 295 |
+
if requested_year_range:
|
| 296 |
+
year_scope = requested_year_range
|
| 297 |
+
elif mentioned_years:
|
| 298 |
+
year_scope = ", ".join(sorted(mentioned_years))
|
| 299 |
+
else:
|
| 300 |
+
logger.warning("Không tìm thấy tài liệu đúng năm yêu cầu, fallback sang tập tài liệu tổng quát")
|
| 301 |
+
|
| 302 |
final_docs = advanced_rerank(question, all_docs, top_k=FINAL_TOP_K)
|
| 303 |
|
|
|
|
| 304 |
context_parts = []
|
| 305 |
total_chars = 0
|
| 306 |
for doc in final_docs:
|
| 307 |
page = doc.metadata.get('page_number', 'N/A')
|
| 308 |
file_name = doc.metadata.get('source_file') or doc.metadata.get('source')
|
| 309 |
+
# [YEAR-AWARE CHANGE] Gan nhan nam trong context de LLM bam dung nguon.
|
| 310 |
doc_year = infer_doc_academic_year(doc)
|
| 311 |
year_label = f"Năm {doc_year}" if doc_year != "ALL" else "Áp dụng nhiều năm"
|
| 312 |
source = f"[{year_label} | {os.path.basename(file_name)} | Trang {page}]" if file_name else f"[{year_label} | Trang {page}]"
|
|
|
|
| 313 |
block = f"{source}\n{doc.page_content}"
|
| 314 |
if total_chars + len(block) > MAX_CONTEXT_CHARS:
|
| 315 |
break
|
|
|
|
| 319 |
context = "\n\n---\n\n".join(context_parts)
|
| 320 |
topic_hint = processed_data.get('topic') or processed_data.get('root_question') or question
|
| 321 |
|
| 322 |
+
prompt = create_advanced_prompt(question, context, question_type, topic_hint, year_scope=year_scope)
|
|
|
|
| 323 |
|
| 324 |
logger.info("Đang tạo câu trả lời cuối cùng ...")
|
| 325 |
|
| 326 |
success = False
|
| 327 |
+
# Thử với Groq
|
|
|
|
| 328 |
for _ in range(len(api_manager.groq_keys)):
|
| 329 |
try:
|
| 330 |
client = api_manager.get_groq_client()
|
|
|
|
| 346 |
logger.error(f"Lỗi Groq: {e}")
|
| 347 |
break
|
| 348 |
|
| 349 |
+
# Dự phòng sang Gemini (nếu Groq lỗi hoặc hết key)
|
| 350 |
if not success:
|
| 351 |
logger.warning("Chuyển sang Gemini ...")
|
| 352 |
for _ in range(max(1, len(api_manager.gemini_keys))):
|
|
|
|
| 363 |
api_manager.rotate_gemini()
|
| 364 |
logger.error(f"Lỗi Gemini: {e}")
|
| 365 |
|
|
|
|
| 366 |
if not success:
|
| 367 |
+
yield "Đã xảy ra lỗi hệ thống hoặc quá tải. Vui lòng thử lại sau giây lát!"
|
core/rerank.py
CHANGED
|
@@ -1,22 +1,15 @@
|
|
| 1 |
from typing import List
|
| 2 |
-
import logging
|
| 3 |
from .models import cross_encoder
|
| 4 |
|
| 5 |
-
MAX_RERANK_CHARS =
|
| 6 |
-
logger = logging.getLogger(__name__)
|
| 7 |
|
| 8 |
def advanced_rerank(question: str, docs: List, top_k: int = 5) -> List:
|
| 9 |
if not docs:
|
| 10 |
return []
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
scores = cross_encoder.predict(pairs, show_progress_bar=False)
|
| 18 |
-
ranked = sorted(zip(scores, pruned_docs), key=lambda x: x[0], reverse=True)
|
| 19 |
-
|
| 20 |
-
logger.info("Top 3 điểm: %s", [f"{s:.3f}" for s, _ in ranked[:3]])
|
| 21 |
return [doc for score, doc in ranked[:top_k]]
|
| 22 |
|
|
|
|
| 1 |
from typing import List
|
|
|
|
| 2 |
from .models import cross_encoder
|
| 3 |
|
| 4 |
+
MAX_RERANK_CHARS = 1200
|
|
|
|
| 5 |
|
| 6 |
def advanced_rerank(question: str, docs: List, top_k: int = 5) -> List:
|
| 7 |
if not docs:
|
| 8 |
return []
|
| 9 |
+
print(f"Đang rerank {len(docs)} documents với Cross-Encoder...")
|
| 10 |
+
pairs = [(question, (doc.page_content or "")[:MAX_RERANK_CHARS]) for doc in docs]
|
| 11 |
+
scores = cross_encoder.predict(pairs)
|
| 12 |
+
ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
|
| 13 |
+
print(f" Top 3 scores: {[f'{s:.3f}' for s, _ in ranked[:3]]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
return [doc for score, doc in ranked[:top_k]]
|
| 15 |
|
core/retriever.py
CHANGED
|
@@ -13,57 +13,6 @@ class HybridRetriever:
|
|
| 13 |
self.rrf_c = 60
|
| 14 |
print(" BM25 sẵn sàng!")
|
| 15 |
|
| 16 |
-
@staticmethod
|
| 17 |
-
def _filter_by_year_scope(documents: List, year_scope: str | None) -> List:
|
| 18 |
-
"""Filter documents theo year_scope (ví dụ: '2023-2024' hoặc '2023')."""
|
| 19 |
-
if not year_scope:
|
| 20 |
-
return documents
|
| 21 |
-
|
| 22 |
-
filtered = []
|
| 23 |
-
year_targets = set()
|
| 24 |
-
|
| 25 |
-
# Parse year_scope: có thể là "2023-2024" hoặc "2023"
|
| 26 |
-
if "-" in year_scope:
|
| 27 |
-
parts = year_scope.split("-")
|
| 28 |
-
try:
|
| 29 |
-
year_targets = {int(p.strip()) for p in parts if p.strip()}
|
| 30 |
-
except ValueError:
|
| 31 |
-
return documents
|
| 32 |
-
else:
|
| 33 |
-
try:
|
| 34 |
-
year_targets = {int(year_scope.strip())}
|
| 35 |
-
except ValueError:
|
| 36 |
-
return documents
|
| 37 |
-
|
| 38 |
-
for doc in documents:
|
| 39 |
-
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
| 40 |
-
|
| 41 |
-
# Check years array (mới)
|
| 42 |
-
doc_years = metadata.get("years", [])
|
| 43 |
-
if isinstance(doc_years, list) and any(y in year_targets for y in doc_years):
|
| 44 |
-
filtered.append(doc)
|
| 45 |
-
continue
|
| 46 |
-
|
| 47 |
-
# Check academic_year string (cũ, để backwards compatibility)
|
| 48 |
-
academic_year = metadata.get("academic_year", "")
|
| 49 |
-
if academic_year and academic_year != "ALL":
|
| 50 |
-
doc_year_tokens = set()
|
| 51 |
-
for potential_year in academic_year.split("-"):
|
| 52 |
-
try:
|
| 53 |
-
doc_year_tokens.add(int(potential_year.strip()))
|
| 54 |
-
except ValueError:
|
| 55 |
-
pass
|
| 56 |
-
|
| 57 |
-
if doc_year_tokens.intersection(year_targets):
|
| 58 |
-
filtered.append(doc)
|
| 59 |
-
continue
|
| 60 |
-
|
| 61 |
-
# Include ALL documents không có year info
|
| 62 |
-
if not doc_years and academic_year == "ALL":
|
| 63 |
-
filtered.append(doc)
|
| 64 |
-
|
| 65 |
-
return filtered if filtered else documents
|
| 66 |
-
|
| 67 |
@staticmethod
|
| 68 |
def _doc_key(doc) -> str:
|
| 69 |
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
|
@@ -74,6 +23,7 @@ class HybridRetriever:
|
|
| 74 |
return f"{source}|{page}|{digest}"
|
| 75 |
|
| 76 |
def search(self, query: str, k: int = 10, alpha: float = 0.6, year_scope: str | None = None) -> List:
|
|
|
|
| 77 |
if not self.documents or k <= 0:
|
| 78 |
return []
|
| 79 |
|
|
@@ -84,15 +34,7 @@ class HybridRetriever:
|
|
| 84 |
# Lấy top k từ BM25
|
| 85 |
tokenized_query = query.lower().split()
|
| 86 |
candidate_k = min(max(k * 4, k), len(self.documents))
|
| 87 |
-
|
| 88 |
-
# Filter documents theo year_scope nếu có
|
| 89 |
-
docs_to_search = self.documents
|
| 90 |
-
if year_scope:
|
| 91 |
-
docs_to_search = self._filter_by_year_scope(self.documents, year_scope)
|
| 92 |
-
if not docs_to_search:
|
| 93 |
-
docs_to_search = self.documents # Fallback nếu không có doc match year
|
| 94 |
-
|
| 95 |
-
bm25_top_docs = self.bm25.get_top_n(tokenized_query, docs_to_search, n=candidate_k)
|
| 96 |
|
| 97 |
bm25_ranked = {}
|
| 98 |
all_retrieved = {}
|
|
|
|
| 13 |
self.rrf_c = 60
|
| 14 |
print(" BM25 sẵn sàng!")
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
@staticmethod
|
| 17 |
def _doc_key(doc) -> str:
|
| 18 |
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
|
|
|
| 23 |
return f"{source}|{page}|{digest}"
|
| 24 |
|
| 25 |
def search(self, query: str, k: int = 10, alpha: float = 0.6, year_scope: str | None = None) -> List:
|
| 26 |
+
del year_scope
|
| 27 |
if not self.documents or k <= 0:
|
| 28 |
return []
|
| 29 |
|
|
|
|
| 34 |
# Lấy top k từ BM25
|
| 35 |
tokenized_query = query.lower().split()
|
| 36 |
candidate_k = min(max(k * 4, k), len(self.documents))
|
| 37 |
+
bm25_top_docs = self.bm25.get_top_n(tokenized_query, self.documents, n=candidate_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
bm25_ranked = {}
|
| 40 |
all_retrieved = {}
|
core/text_utils.py
CHANGED
|
@@ -1,34 +1,24 @@
|
|
| 1 |
import re
|
| 2 |
|
| 3 |
-
#Compile regex patterns một lần toàn cục - tránh recompile mỗi lần gọi
|
| 4 |
-
_HYPHENATED_WORD_PATTERN = re.compile(r'(\w+)-\s*\n\s*(\w+)')
|
| 5 |
-
_INVALID_CHARS_PATTERN = re.compile(r'[^\w\s\.,;:!?\-$$\"\'\À-ỹ\n\|<>]')
|
| 6 |
-
_MULTIPLE_SPACES_PATTERN = re.compile(r'[ \t]+')
|
| 7 |
-
_SPACE_BEFORE_NEWLINE_PATTERN = re.compile(r' +\n')
|
| 8 |
-
_SPACE_AFTER_NEWLINE_PATTERN = re.compile(r'\n +')
|
| 9 |
-
_MULTIPLE_NEWLINES_PATTERN = re.compile(r'\n{3,}')
|
| 10 |
-
_SPACE_BEFORE_PUNCTUATION_PATTERN = re.compile(r'\s+([.,;:!?])')
|
| 11 |
-
|
| 12 |
-
|
| 13 |
def clean_text(text: str) -> str:
|
| 14 |
if not text or not text.strip():
|
| 15 |
return ""
|
| 16 |
|
| 17 |
# Nối các từ bị gãy ngang do xuống dòng
|
| 18 |
-
text =
|
| 19 |
|
| 20 |
# \| và < > vào để bảo vệ khung Bảng Markdown và các Placeholder
|
| 21 |
-
text =
|
| 22 |
|
| 23 |
# Chuẩn hóa khoảng trắng
|
| 24 |
-
text =
|
| 25 |
-
text =
|
| 26 |
-
text =
|
| 27 |
|
| 28 |
# Giới hạn tối đa 2 dòng trống liên tiếp
|
| 29 |
-
text =
|
| 30 |
|
| 31 |
# Sửa lỗi dư khoảng trắng trước dấu câu
|
| 32 |
-
text =
|
| 33 |
|
| 34 |
return text.strip()
|
|
|
|
| 1 |
import re
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
def clean_text(text: str) -> str:
|
| 4 |
if not text or not text.strip():
|
| 5 |
return ""
|
| 6 |
|
| 7 |
# Nối các từ bị gãy ngang do xuống dòng
|
| 8 |
+
text = re.sub(r'(\w+)-\s*\n\s*(\w+)', r'\1\2', text)
|
| 9 |
|
| 10 |
# \| và < > vào để bảo vệ khung Bảng Markdown và các Placeholder
|
| 11 |
+
text = re.sub(r'[^\w\s\.,;:!?\-$$\"\'\À-ỹ\n\|<>]', ' ', text)
|
| 12 |
|
| 13 |
# Chuẩn hóa khoảng trắng
|
| 14 |
+
text = re.sub(r'[ \t]+', ' ', text)
|
| 15 |
+
text = re.sub(r' +\n', '\n', text)
|
| 16 |
+
text = re.sub(r'\n +', '\n', text)
|
| 17 |
|
| 18 |
# Giới hạn tối đa 2 dòng trống liên tiếp
|
| 19 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 20 |
|
| 21 |
# Sửa lỗi dư khoảng trắng trước dấu câu
|
| 22 |
+
text = re.sub(r'\s+([.,;:!?])', r'\1', text)
|
| 23 |
|
| 24 |
return text.strip()
|
main.py
CHANGED
|
@@ -37,10 +37,6 @@ from api.admin_sync_router import router as admin_sync_router
|
|
| 37 |
# Hàm log lỗi an toàn
|
| 38 |
logging.basicConfig(level=logging.INFO)
|
| 39 |
logger = logging.getLogger(__name__)
|
| 40 |
-
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 41 |
-
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
| 42 |
-
logging.getLogger("qdrant_client").setLevel(logging.WARNING)
|
| 43 |
-
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
| 44 |
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "20"))
|
| 45 |
POOL_MIN_SIZE = int(os.getenv("DB_POOL_MIN_SIZE", "1"))
|
| 46 |
POOL_MAX_SIZE = int(os.getenv("DB_POOL_MAX_SIZE", "10"))
|
|
|
|
| 37 |
# Hàm log lỗi an toàn
|
| 38 |
logging.basicConfig(level=logging.INFO)
|
| 39 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "20"))
|
| 41 |
POOL_MIN_SIZE = int(os.getenv("DB_POOL_MIN_SIZE", "1"))
|
| 42 |
POOL_MAX_SIZE = int(os.getenv("DB_POOL_MAX_SIZE", "10"))
|