Spaces:
Sleeping
Sleeping
name khoa : supabase -> qdrant
Browse files- core/collection_router_retriever.py +11 -11
- core/collection_utils.py +31 -0
- core/config.py +0 -7
- core/qa_pipeline.py +9 -38
- main.py +6 -17
core/collection_router_retriever.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import List
|
|
| 4 |
|
| 5 |
from langchain_core.documents import Document as LangChainDocument
|
| 6 |
|
| 7 |
-
from .collection_utils import
|
| 8 |
from .document_db import SessionLocal, list_active_collection_names
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
|
@@ -45,18 +45,18 @@ class CollectionRouterRetriever:
|
|
| 45 |
finally:
|
| 46 |
db.close()
|
| 47 |
|
| 48 |
-
def _select_target_collections(self,
|
| 49 |
fetch_limit = max(self.top_n_collections * 4, 12)
|
| 50 |
active_collections = self._get_active_collections(limit=fetch_limit)
|
| 51 |
if not active_collections:
|
| 52 |
return []
|
| 53 |
|
| 54 |
-
|
| 55 |
-
if
|
| 56 |
return [
|
| 57 |
collection_name
|
| 58 |
for collection_name in active_collections
|
| 59 |
-
if
|
| 60 |
]
|
| 61 |
|
| 62 |
return active_collections[: self.top_n_collections]
|
|
@@ -111,15 +111,15 @@ class CollectionRouterRetriever:
|
|
| 111 |
scored_docs.sort(key=lambda row: row[0], reverse=True)
|
| 112 |
return [doc for _, doc in scored_docs]
|
| 113 |
|
| 114 |
-
def search(self, query: str, k: int = 10, alpha: float = 0.6,
|
| 115 |
if k <= 0:
|
| 116 |
return []
|
| 117 |
|
| 118 |
candidate_k = max(k * 4, k)
|
| 119 |
-
|
| 120 |
-
target_collections = self._select_target_collections(
|
| 121 |
|
| 122 |
-
if
|
| 123 |
return []
|
| 124 |
|
| 125 |
routed_docs = self._search_target_collections(
|
|
@@ -128,7 +128,7 @@ class CollectionRouterRetriever:
|
|
| 128 |
limit=candidate_k,
|
| 129 |
)
|
| 130 |
|
| 131 |
-
if
|
| 132 |
deduplicated = []
|
| 133 |
seen = set()
|
| 134 |
for doc in routed_docs:
|
|
@@ -148,7 +148,7 @@ class CollectionRouterRetriever:
|
|
| 148 |
query,
|
| 149 |
k=candidate_k,
|
| 150 |
alpha=alpha,
|
| 151 |
-
|
| 152 |
)
|
| 153 |
except TypeError:
|
| 154 |
fallback_docs = self.base_retriever.search(
|
|
|
|
| 4 |
|
| 5 |
from langchain_core.documents import Document as LangChainDocument
|
| 6 |
|
| 7 |
+
from .collection_utils import collection_matches_cohort
|
| 8 |
from .document_db import SessionLocal, list_active_collection_names
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
|
|
|
| 45 |
finally:
|
| 46 |
db.close()
|
| 47 |
|
| 48 |
+
def _select_target_collections(self, cohort_key: str | None) -> List[str]:
|
| 49 |
fetch_limit = max(self.top_n_collections * 4, 12)
|
| 50 |
active_collections = self._get_active_collections(limit=fetch_limit)
|
| 51 |
if not active_collections:
|
| 52 |
return []
|
| 53 |
|
| 54 |
+
normalized_cohort = (cohort_key or "").strip()
|
| 55 |
+
if normalized_cohort:
|
| 56 |
return [
|
| 57 |
collection_name
|
| 58 |
for collection_name in active_collections
|
| 59 |
+
if collection_matches_cohort(collection_name, normalized_cohort)
|
| 60 |
]
|
| 61 |
|
| 62 |
return active_collections[: self.top_n_collections]
|
|
|
|
| 111 |
scored_docs.sort(key=lambda row: row[0], reverse=True)
|
| 112 |
return [doc for _, doc in scored_docs]
|
| 113 |
|
| 114 |
+
def search(self, query: str, k: int = 10, alpha: float = 0.6, cohort_key: str | None = None) -> List:
|
| 115 |
if k <= 0:
|
| 116 |
return []
|
| 117 |
|
| 118 |
candidate_k = max(k * 4, k)
|
| 119 |
+
cohort_scoped = bool((cohort_key or "").strip())
|
| 120 |
+
target_collections = self._select_target_collections(cohort_key)
|
| 121 |
|
| 122 |
+
if cohort_scoped and not target_collections:
|
| 123 |
return []
|
| 124 |
|
| 125 |
routed_docs = self._search_target_collections(
|
|
|
|
| 128 |
limit=candidate_k,
|
| 129 |
)
|
| 130 |
|
| 131 |
+
if cohort_scoped:
|
| 132 |
deduplicated = []
|
| 133 |
seen = set()
|
| 134 |
for doc in routed_docs:
|
|
|
|
| 148 |
query,
|
| 149 |
k=candidate_k,
|
| 150 |
alpha=alpha,
|
| 151 |
+
cohort_key=cohort_key,
|
| 152 |
)
|
| 153 |
except TypeError:
|
| 154 |
fallback_docs = self.base_retriever.search(
|
core/collection_utils.py
CHANGED
|
@@ -24,6 +24,37 @@ def extract_year_tokens(value: str) -> Set[str]:
|
|
| 24 |
return {token for token in _YEAR_PATTERN.findall(value or "")}
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def collection_matches_year(collection_name: str, year_scope: str) -> bool:
|
| 28 |
if not year_scope:
|
| 29 |
return False
|
|
|
|
| 24 |
return {token for token in _YEAR_PATTERN.findall(value or "")}
|
| 25 |
|
| 26 |
|
| 27 |
+
def extract_folder_key_from_collection_name(collection_name: str, prefix: str = "rag") -> str | None:
|
| 28 |
+
"""
|
| 29 |
+
Extract folder_key from collection name.
|
| 30 |
+
E.g., 'rag_k63' -> 'k63', 'rag_2023_2024' -> '2023_2024'
|
| 31 |
+
Returns None if collection_name doesn't match the expected pattern.
|
| 32 |
+
"""
|
| 33 |
+
if not collection_name:
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
prefix_with_underscore = f"{prefix}_"
|
| 37 |
+
if collection_name.startswith(prefix_with_underscore):
|
| 38 |
+
return collection_name[len(prefix_with_underscore):]
|
| 39 |
+
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def collection_matches_cohort(collection_name: str, cohort_key: str, prefix: str = "rag") -> bool:
|
| 44 |
+
"""
|
| 45 |
+
Check if collection matches the given cohort_key.
|
| 46 |
+
E.g., collection='rag_k63', cohort_key='k63' -> True
|
| 47 |
+
"""
|
| 48 |
+
if not cohort_key:
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
extracted = extract_folder_key_from_collection_name(collection_name, prefix)
|
| 52 |
+
if not extracted:
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
return extracted.lower() == cohort_key.lower()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
def collection_matches_year(collection_name: str, year_scope: str) -> bool:
|
| 59 |
if not year_scope:
|
| 60 |
return False
|
core/config.py
CHANGED
|
@@ -67,13 +67,6 @@ SUPABASE_SYNC_ALLOWED_IPS = [ip.strip() for ip in os.getenv('SUPABASE_SYNC_ALLOW
|
|
| 67 |
SUPABASE_SYNC_ALLOW_PRIVATE_NETWORK = os.getenv('SUPABASE_SYNC_ALLOW_PRIVATE_NETWORK', 'true').strip().lower() in {'1', 'true', 'yes', 'on'}
|
| 68 |
COLLECTION_ROUTER_TOP_N = _bounded_int_from_env('COLLECTION_ROUTER_TOP_N', 3, 1, 20)
|
| 69 |
|
| 70 |
-
# Cohort to academic year mapping
|
| 71 |
-
COHORT_TO_YEAR = {
|
| 72 |
-
'k65': '2023-2024',
|
| 73 |
-
'k64': '2022-2023',
|
| 74 |
-
'k63': '2021-2022',
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
# - Context and output limits
|
| 78 |
MAX_CONTEXT_CHARS = int(os.getenv('MAX_CONTEXT_CHARS', '12000'))
|
| 79 |
MAX_OUT_CHARS = int(os.getenv('MAX_OUT_CHARS', '3000'))
|
|
|
|
| 67 |
SUPABASE_SYNC_ALLOW_PRIVATE_NETWORK = os.getenv('SUPABASE_SYNC_ALLOW_PRIVATE_NETWORK', 'true').strip().lower() in {'1', 'true', 'yes', 'on'}
|
| 68 |
COLLECTION_ROUTER_TOP_N = _bounded_int_from_env('COLLECTION_ROUTER_TOP_N', 3, 1, 20)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
# - Context and output limits
|
| 71 |
MAX_CONTEXT_CHARS = int(os.getenv('MAX_CONTEXT_CHARS', '12000'))
|
| 72 |
MAX_OUT_CHARS = int(os.getenv('MAX_OUT_CHARS', '3000'))
|
core/qa_pipeline.py
CHANGED
|
@@ -221,16 +221,16 @@ def generate_standalone_query(message: str, history: List) -> str:
|
|
| 221 |
|
| 222 |
return message
|
| 223 |
|
| 224 |
-
def ask_ai_improved(message: str, history: List, hybrid_retriever,
|
| 225 |
full_response = ""
|
| 226 |
-
for delta in ask_ai_stream_delta(message, history, hybrid_retriever,
|
| 227 |
full_response += delta
|
| 228 |
if len(full_response) > MAX_OUT_CHARS:
|
| 229 |
yield full_response[:MAX_OUT_CHARS] + "\n\n[Đã cắt bớt nội dung dài]"
|
| 230 |
return
|
| 231 |
yield full_response
|
| 232 |
|
| 233 |
-
def ask_ai_stream_delta(message: str, history: List, hybrid_retriever,
|
| 234 |
if not message.strip():
|
| 235 |
yield " Bạn chưa nhập câu hỏi."
|
| 236 |
return
|
|
@@ -241,12 +241,6 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
|
|
| 241 |
|
| 242 |
logger.info(f" CÂU HỎI GỐC: {message}")
|
| 243 |
question = generate_standalone_query(message, history)
|
| 244 |
-
# [YEAR-AWARE CHANGE] Xac dinh pham vi nam ma nguoi dung yeu cau.
|
| 245 |
-
requested_year_range, mentioned_years = detect_requested_year(f"{message}\n{question}")
|
| 246 |
-
if requested_year_range:
|
| 247 |
-
logger.info(f"Lọc theo năm học yêu cầu: {requested_year_range}")
|
| 248 |
-
elif mentioned_years:
|
| 249 |
-
logger.info(f"Lọc theo năm được nhắc tới: {sorted(mentioned_years)}")
|
| 250 |
|
| 251 |
processed_data = analyze_and_expand_query(question)
|
| 252 |
|
|
@@ -261,12 +255,9 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
|
|
| 261 |
|
| 262 |
all_docs: List = []
|
| 263 |
seen = set()
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
logger.info(f"Sử dụng year_scope từ cohort: {year_scope_hint}")
|
| 268 |
-
else:
|
| 269 |
-
year_scope_hint = requested_year_range or (", ".join(sorted(mentioned_years)) if mentioned_years else None)
|
| 270 |
for query in queries:
|
| 271 |
#Giữ nguyên logic alpha ngành CNTT của Minh
|
| 272 |
current_alpha = 0.4 if "CNTT" in query.upper() else 0.5
|
|
@@ -274,7 +265,7 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
|
|
| 274 |
query,
|
| 275 |
k=TOP_K_RESULTS,
|
| 276 |
alpha=current_alpha,
|
| 277 |
-
|
| 278 |
)
|
| 279 |
for doc in docs:
|
| 280 |
content_hash = hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest()
|
|
@@ -287,23 +278,6 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
|
|
| 287 |
yield "Không tìm thấy thông tin liên quan trong tài liệu."
|
| 288 |
return
|
| 289 |
|
| 290 |
-
# [YEAR-AWARE CHANGE] Lọc theo năm nhưng vẫn fallback nếu không có tài liệu đúng năm.
|
| 291 |
-
year_scope = None
|
| 292 |
-
year_filter_requested = bool(requested_year_range or mentioned_years)
|
| 293 |
-
year_filtered_docs = filter_docs_by_year(all_docs, requested_year_range, mentioned_years)
|
| 294 |
-
|
| 295 |
-
if year_filter_requested:
|
| 296 |
-
if year_filtered_docs:
|
| 297 |
-
if len(year_filtered_docs) != len(all_docs):
|
| 298 |
-
logger.info(f"Đã lọc theo năm: còn {len(year_filtered_docs)}/{len(all_docs)} documents")
|
| 299 |
-
all_docs = year_filtered_docs
|
| 300 |
-
if requested_year_range:
|
| 301 |
-
year_scope = requested_year_range
|
| 302 |
-
elif mentioned_years:
|
| 303 |
-
year_scope = ", ".join(sorted(mentioned_years))
|
| 304 |
-
else:
|
| 305 |
-
logger.warning("Không tìm thấy tài liệu đúng năm yêu cầu, fallback sang tập tài liệu tổng quát")
|
| 306 |
-
|
| 307 |
final_docs = advanced_rerank(question, all_docs, top_k=FINAL_TOP_K)
|
| 308 |
|
| 309 |
context_parts = []
|
|
@@ -311,10 +285,7 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
|
|
| 311 |
for doc in final_docs:
|
| 312 |
page = doc.metadata.get('page_number', 'N/A')
|
| 313 |
file_name = doc.metadata.get('source_file') or doc.metadata.get('source')
|
| 314 |
-
|
| 315 |
-
doc_year = infer_doc_academic_year(doc)
|
| 316 |
-
year_label = f"Năm {doc_year}" if doc_year != "ALL" else "Áp dụng nhiều năm"
|
| 317 |
-
source = f"[{year_label} | {os.path.basename(file_name)} | Trang {page}]" if file_name else f"[{year_label} | Trang {page}]"
|
| 318 |
block = f"{source}\n{doc.page_content}"
|
| 319 |
if total_chars + len(block) > MAX_CONTEXT_CHARS:
|
| 320 |
break
|
|
@@ -324,7 +295,7 @@ def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, year_scop
|
|
| 324 |
context = "\n\n---\n\n".join(context_parts)
|
| 325 |
topic_hint = processed_data.get('topic') or processed_data.get('root_question') or question
|
| 326 |
|
| 327 |
-
prompt = create_advanced_prompt(question, context, question_type, topic_hint
|
| 328 |
|
| 329 |
logger.info("Đang tạo câu trả lời cuối cùng ...")
|
| 330 |
|
|
|
|
| 221 |
|
| 222 |
return message
|
| 223 |
|
| 224 |
+
def ask_ai_improved(message: str, history: List, hybrid_retriever, cohort_key: str | None = None) -> Generator[str, None, None]:
|
| 225 |
full_response = ""
|
| 226 |
+
for delta in ask_ai_stream_delta(message, history, hybrid_retriever, cohort_key=cohort_key):
|
| 227 |
full_response += delta
|
| 228 |
if len(full_response) > MAX_OUT_CHARS:
|
| 229 |
yield full_response[:MAX_OUT_CHARS] + "\n\n[Đã cắt bớt nội dung dài]"
|
| 230 |
return
|
| 231 |
yield full_response
|
| 232 |
|
| 233 |
+
def ask_ai_stream_delta(message: str, history: List, hybrid_retriever, cohort_key: str | None = None) -> Generator[str, None, None]:
|
| 234 |
if not message.strip():
|
| 235 |
yield " Bạn chưa nhập câu hỏi."
|
| 236 |
return
|
|
|
|
| 241 |
|
| 242 |
logger.info(f" CÂU HỎI GỐC: {message}")
|
| 243 |
question = generate_standalone_query(message, history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
processed_data = analyze_and_expand_query(question)
|
| 246 |
|
|
|
|
| 255 |
|
| 256 |
all_docs: List = []
|
| 257 |
seen = set()
|
| 258 |
+
if cohort_key:
|
| 259 |
+
logger.info(f"Sử dụng cohort_key: {cohort_key}")
|
| 260 |
+
|
|
|
|
|
|
|
|
|
|
| 261 |
for query in queries:
|
| 262 |
#Giữ nguyên logic alpha ngành CNTT của Minh
|
| 263 |
current_alpha = 0.4 if "CNTT" in query.upper() else 0.5
|
|
|
|
| 265 |
query,
|
| 266 |
k=TOP_K_RESULTS,
|
| 267 |
alpha=current_alpha,
|
| 268 |
+
cohort_key=cohort_key,
|
| 269 |
)
|
| 270 |
for doc in docs:
|
| 271 |
content_hash = hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest()
|
|
|
|
| 278 |
yield "Không tìm thấy thông tin liên quan trong tài liệu."
|
| 279 |
return
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
final_docs = advanced_rerank(question, all_docs, top_k=FINAL_TOP_K)
|
| 282 |
|
| 283 |
context_parts = []
|
|
|
|
| 285 |
for doc in final_docs:
|
| 286 |
page = doc.metadata.get('page_number', 'N/A')
|
| 287 |
file_name = doc.metadata.get('source_file') or doc.metadata.get('source')
|
| 288 |
+
source = f"[{os.path.basename(file_name)} | Trang {page}]" if file_name else f"[Trang {page}]"
|
|
|
|
|
|
|
|
|
|
| 289 |
block = f"{source}\n{doc.page_content}"
|
| 290 |
if total_chars + len(block) > MAX_CONTEXT_CHARS:
|
| 291 |
break
|
|
|
|
| 295 |
context = "\n\n---\n\n".join(context_parts)
|
| 296 |
topic_hint = processed_data.get('topic') or processed_data.get('root_question') or question
|
| 297 |
|
| 298 |
+
prompt = create_advanced_prompt(question, context, question_type, topic_hint)
|
| 299 |
|
| 300 |
logger.info("Đang tạo câu trả lời cuối cùng ...")
|
| 301 |
|
main.py
CHANGED
|
@@ -14,7 +14,6 @@ from qdrant_client import QdrantClient
|
|
| 14 |
#Import các model và các hàm cần thiết từ core
|
| 15 |
from core.config import (
|
| 16 |
COLLECTION_ROUTER_TOP_N,
|
| 17 |
-
COHORT_TO_YEAR,
|
| 18 |
DATABASE_URL,
|
| 19 |
QDRANT_API_KEY,
|
| 20 |
QDRANT_URL,
|
|
@@ -344,20 +343,15 @@ async def chat_endpoint(payload: ChatRequest, request: Request):
|
|
| 344 |
user_id = payload.user_id # Lấy user_id từ request
|
| 345 |
cohort_key = payload.cohort_key # Lấy cohort_key từ request
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
if cohort_key and cohort_key in COHORT_TO_YEAR:
|
| 350 |
-
year_scope = COHORT_TO_YEAR[cohort_key]
|
| 351 |
-
logger.info(f"Sử dụng cohort: {cohort_key} -> năm học: {year_scope}")
|
| 352 |
-
elif cohort_key:
|
| 353 |
-
logger.warning(f"Cohort không hợp lệ: {cohort_key}")
|
| 354 |
|
| 355 |
history = await get_history_async(db_pool, session_id)
|
| 356 |
|
| 357 |
# Tập hợp toàn bộ response từ generator
|
| 358 |
full_response = ""
|
| 359 |
try:
|
| 360 |
-
async for chunk in iterate_in_threadpool(ask_ai_improved(user_msg, history, retriever,
|
| 361 |
full_response = chunk
|
| 362 |
except Exception:
|
| 363 |
logger.exception("Lỗi khi xử lý phản hồi từ AI:", exc_info=True)
|
|
@@ -381,13 +375,8 @@ async def chat_stream_endpoint(payload: ChatRequest, request: Request):
|
|
| 381 |
user_id = payload.user_id # Lấy user_id từ request
|
| 382 |
cohort_key = payload.cohort_key # Lấy cohort_key từ request
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
if cohort_key and cohort_key in COHORT_TO_YEAR:
|
| 387 |
-
year_scope = COHORT_TO_YEAR[cohort_key]
|
| 388 |
-
logger.info(f"Sử dụng cohort: {cohort_key} -> năm học: {year_scope}")
|
| 389 |
-
elif cohort_key:
|
| 390 |
-
logger.warning(f"Cohort không hợp lệ: {cohort_key}")
|
| 391 |
|
| 392 |
history = await get_history_async(db_pool, session_id)
|
| 393 |
|
|
@@ -396,7 +385,7 @@ async def chat_stream_endpoint(payload: ChatRequest, request: Request):
|
|
| 396 |
full_response = ""
|
| 397 |
try:
|
| 398 |
# ask_ai_stream_delta yield từng delta chunk (không cumulative)
|
| 399 |
-
async for delta_chunk in iterate_in_threadpool(ask_ai_stream_delta(user_msg, history, retriever,
|
| 400 |
full_response += delta_chunk
|
| 401 |
# Gửi SSE event với delta chunk
|
| 402 |
sse_data = json.dumps({"delta": delta_chunk, "done": False}, ensure_ascii=False)
|
|
|
|
| 14 |
#Import các model và các hàm cần thiết từ core
|
| 15 |
from core.config import (
|
| 16 |
COLLECTION_ROUTER_TOP_N,
|
|
|
|
| 17 |
DATABASE_URL,
|
| 18 |
QDRANT_API_KEY,
|
| 19 |
QDRANT_URL,
|
|
|
|
| 343 |
user_id = payload.user_id # Lấy user_id từ request
|
| 344 |
cohort_key = payload.cohort_key # Lấy cohort_key từ request
|
| 345 |
|
| 346 |
+
if cohort_key:
|
| 347 |
+
logger.info(f"Sử dụng cohort: {cohort_key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
history = await get_history_async(db_pool, session_id)
|
| 350 |
|
| 351 |
# Tập hợp toàn bộ response từ generator
|
| 352 |
full_response = ""
|
| 353 |
try:
|
| 354 |
+
async for chunk in iterate_in_threadpool(ask_ai_improved(user_msg, history, retriever, cohort_key=cohort_key)):
|
| 355 |
full_response = chunk
|
| 356 |
except Exception:
|
| 357 |
logger.exception("Lỗi khi xử lý phản hồi từ AI:", exc_info=True)
|
|
|
|
| 375 |
user_id = payload.user_id # Lấy user_id từ request
|
| 376 |
cohort_key = payload.cohort_key # Lấy cohort_key từ request
|
| 377 |
|
| 378 |
+
if cohort_key:
|
| 379 |
+
logger.info(f"Sử dụng cohort: {cohort_key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
history = await get_history_async(db_pool, session_id)
|
| 382 |
|
|
|
|
| 385 |
full_response = ""
|
| 386 |
try:
|
| 387 |
# ask_ai_stream_delta yield từng delta chunk (không cumulative)
|
| 388 |
+
async for delta_chunk in iterate_in_threadpool(ask_ai_stream_delta(user_msg, history, retriever, cohort_key=cohort_key)):
|
| 389 |
full_response += delta_chunk
|
| 390 |
# Gửi SSE event với delta chunk
|
| 391 |
sse_data = json.dumps({"delta": delta_chunk, "done": False}, ensure_ascii=False)
|