sửa rerank
Browse files- core/embeddings/retrival.py +22 -10
- core/gradio/gradio_rag_qwen.py +3 -25
core/embeddings/retrival.py
CHANGED
|
@@ -34,12 +34,13 @@ class RetrievalConfig:
|
|
| 34 |
rerank_api_base_url: str = "https://api.siliconflow.com/v1"
|
| 35 |
rerank_model: str = "Qwen/Qwen3-Reranker-4B"
|
| 36 |
rerank_top_n: int = 10
|
| 37 |
-
initial_k: int =
|
| 38 |
top_k: int = 5
|
| 39 |
vector_weight: float = 0.5
|
| 40 |
bm25_weight: float = 0.5
|
| 41 |
|
| 42 |
|
|
|
|
| 43 |
_retrieval_config: RetrievalConfig | None = None
|
| 44 |
|
| 45 |
|
|
@@ -82,7 +83,7 @@ class SiliconFlowReranker(BaseDocumentCompressor):
|
|
| 82 |
"documents": [doc.page_content for doc in documents],
|
| 83 |
"top_n": self.top_n or len(documents),
|
| 84 |
},
|
| 85 |
-
timeout=
|
| 86 |
)
|
| 87 |
response.raise_for_status()
|
| 88 |
data = response.json()
|
|
@@ -109,11 +110,10 @@ class SiliconFlowReranker(BaseDocumentCompressor):
|
|
| 109 |
return list(documents)
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
| 112 |
class Retriever:
|
| 113 |
def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
|
| 114 |
-
import time
|
| 115 |
-
start = time.time()
|
| 116 |
-
|
| 117 |
self._vector_db = vector_db
|
| 118 |
self._config = get_retrieval_config()
|
| 119 |
self._reranker: Optional[SiliconFlowReranker] = None
|
|
@@ -138,7 +138,8 @@ class Retriever:
|
|
| 138 |
if use_reranker:
|
| 139 |
self._reranker = self._init_reranker()
|
| 140 |
|
| 141 |
-
logger.info(
|
|
|
|
| 142 |
|
| 143 |
def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
|
| 144 |
"""Save BM25 retriever to disk for fast loading."""
|
|
@@ -153,9 +154,9 @@ class Retriever:
|
|
| 153 |
logger.warning(f"Failed to save BM25 cache: {e}")
|
| 154 |
|
| 155 |
def _load_bm25_cache(self) -> Optional[BM25Retriever]:
|
| 156 |
-
"""Load BM25 retriever from disk cache."""
|
| 157 |
if not self._bm25_cache_path or not self._bm25_cache_path.exists():
|
| 158 |
return None
|
|
|
|
| 159 |
try:
|
| 160 |
import pickle
|
| 161 |
import time
|
|
@@ -168,9 +169,10 @@ class Retriever:
|
|
| 168 |
except Exception as e:
|
| 169 |
logger.warning(f"Failed to load BM25 cache: {e}")
|
| 170 |
return None
|
|
|
|
|
|
|
| 171 |
|
| 172 |
def _init_bm25(self) -> Optional[BM25Retriever]:
|
| 173 |
-
"""Initialize BM25 retriever (lazy-loaded, with disk cache)."""
|
| 174 |
if self._bm25_initialized:
|
| 175 |
return self._bm25_retriever
|
| 176 |
|
|
@@ -330,6 +332,8 @@ class Retriever:
|
|
| 330 |
where: Optional[Dict[str, Any]] = None,
|
| 331 |
initial_k: int | None = None,
|
| 332 |
) -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
| 333 |
if not text.strip():
|
| 334 |
return []
|
| 335 |
|
|
@@ -353,12 +357,20 @@ class Retriever:
|
|
| 353 |
if bm25:
|
| 354 |
bm25.k = initial_k
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
return [
|
| 359 |
self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score"))
|
| 360 |
for i, doc in enumerate(results[:k])
|
| 361 |
]
|
|
|
|
|
|
|
| 362 |
|
| 363 |
def flexible_search(
|
| 364 |
self,
|
|
|
|
| 34 |
rerank_api_base_url: str = "https://api.siliconflow.com/v1"
|
| 35 |
rerank_model: str = "Qwen/Qwen3-Reranker-4B"
|
| 36 |
rerank_top_n: int = 10
|
| 37 |
+
initial_k: int = 25 # Reduced to minimize reranker time
|
| 38 |
top_k: int = 5
|
| 39 |
vector_weight: float = 0.5
|
| 40 |
bm25_weight: float = 0.5
|
| 41 |
|
| 42 |
|
| 43 |
+
|
| 44 |
_retrieval_config: RetrievalConfig | None = None
|
| 45 |
|
| 46 |
|
|
|
|
| 83 |
"documents": [doc.page_content for doc in documents],
|
| 84 |
"top_n": self.top_n or len(documents),
|
| 85 |
},
|
| 86 |
+
timeout=120,
|
| 87 |
)
|
| 88 |
response.raise_for_status()
|
| 89 |
data = response.json()
|
|
|
|
| 110 |
return list(documents)
|
| 111 |
|
| 112 |
|
| 113 |
+
|
| 114 |
+
|
| 115 |
class Retriever:
|
| 116 |
def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
|
|
|
|
|
|
|
|
|
|
| 117 |
self._vector_db = vector_db
|
| 118 |
self._config = get_retrieval_config()
|
| 119 |
self._reranker: Optional[SiliconFlowReranker] = None
|
|
|
|
| 138 |
if use_reranker:
|
| 139 |
self._reranker = self._init_reranker()
|
| 140 |
|
| 141 |
+
logger.info("Retriever initialized")
|
| 142 |
+
|
| 143 |
|
| 144 |
def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
|
| 145 |
"""Save BM25 retriever to disk for fast loading."""
|
|
|
|
| 154 |
logger.warning(f"Failed to save BM25 cache: {e}")
|
| 155 |
|
| 156 |
def _load_bm25_cache(self) -> Optional[BM25Retriever]:
|
|
|
|
| 157 |
if not self._bm25_cache_path or not self._bm25_cache_path.exists():
|
| 158 |
return None
|
| 159 |
+
|
| 160 |
try:
|
| 161 |
import pickle
|
| 162 |
import time
|
|
|
|
| 169 |
except Exception as e:
|
| 170 |
logger.warning(f"Failed to load BM25 cache: {e}")
|
| 171 |
return None
|
| 172 |
+
|
| 173 |
+
|
| 174 |
|
| 175 |
def _init_bm25(self) -> Optional[BM25Retriever]:
|
|
|
|
| 176 |
if self._bm25_initialized:
|
| 177 |
return self._bm25_retriever
|
| 178 |
|
|
|
|
| 332 |
where: Optional[Dict[str, Any]] = None,
|
| 333 |
initial_k: int | None = None,
|
| 334 |
) -> List[Dict[str, Any]]:
|
| 335 |
+
import time
|
| 336 |
+
|
| 337 |
if not text.strip():
|
| 338 |
return []
|
| 339 |
|
|
|
|
| 357 |
if bm25:
|
| 358 |
bm25.k = initial_k
|
| 359 |
|
| 360 |
+
ensemble = self._get_ensemble_retriever()
|
| 361 |
+
ensemble_results = ensemble.invoke(text)
|
| 362 |
+
|
| 363 |
+
if self._reranker:
|
| 364 |
+
results = self._reranker.compress_documents(ensemble_results, text)
|
| 365 |
+
else:
|
| 366 |
+
results = ensemble_results
|
| 367 |
+
|
| 368 |
return [
|
| 369 |
self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score"))
|
| 370 |
for i, doc in enumerate(results[:k])
|
| 371 |
]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
|
| 375 |
def flexible_search(
|
| 376 |
self,
|
core/gradio/gradio_rag_qwen.py
CHANGED
|
@@ -30,7 +30,7 @@ from core.embeddings.generator import RAGContextBuilder, build_context, build_pr
|
|
| 30 |
|
| 31 |
_load_env()
|
| 32 |
|
| 33 |
-
RETRIEVAL_MODE = RetrievalMode.
|
| 34 |
|
| 35 |
# LLM Config (hardcoded sau khi xóa LLMConfig từ generator)
|
| 36 |
LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
|
|
@@ -83,12 +83,7 @@ def _init_resources() -> None:
|
|
| 83 |
|
| 84 |
|
| 85 |
def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
| 86 |
-
import time
|
| 87 |
-
total_start = time.time()
|
| 88 |
-
|
| 89 |
-
init_start = time.time()
|
| 90 |
_init_resources()
|
| 91 |
-
init_time = time.time() - init_start
|
| 92 |
|
| 93 |
assert STATE.db is not None
|
| 94 |
assert STATE.client is not None
|
|
@@ -96,14 +91,12 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 96 |
assert STATE.rag_builder is not None
|
| 97 |
|
| 98 |
# Bước 1: Retrieve và prepare context
|
| 99 |
-
retrieval_start = time.time()
|
| 100 |
prepared = STATE.rag_builder.retrieve_and_prepare(
|
| 101 |
message,
|
| 102 |
k=RETRIEVAL_CFG.top_k,
|
| 103 |
initial_k=RETRIEVAL_CFG.initial_k,
|
| 104 |
mode=RETRIEVAL_MODE.value,
|
| 105 |
)
|
| 106 |
-
retrieval_time = time.time() - retrieval_start
|
| 107 |
results = prepared["results"]
|
| 108 |
|
| 109 |
if not results:
|
|
@@ -111,7 +104,6 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 111 |
return
|
| 112 |
|
| 113 |
# Bước 2: Gọi LLM streaming để generate answer
|
| 114 |
-
llm_start = time.time()
|
| 115 |
completion = STATE.client.chat.completions.create(
|
| 116 |
model=LLM_MODEL,
|
| 117 |
messages=[{"role": "user", "content": prepared["prompt"]}],
|
|
@@ -121,26 +113,11 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 121 |
)
|
| 122 |
|
| 123 |
acc = ""
|
| 124 |
-
first_token_time = None
|
| 125 |
for chunk in completion:
|
| 126 |
-
if first_token_time is None:
|
| 127 |
-
first_token_time = time.time() - llm_start
|
| 128 |
delta = getattr(chunk.choices[0].delta, "content", "") or ""
|
| 129 |
if delta:
|
| 130 |
acc += delta
|
| 131 |
yield acc
|
| 132 |
-
|
| 133 |
-
llm_time = time.time() - llm_start
|
| 134 |
-
total_time = time.time() - total_start
|
| 135 |
-
|
| 136 |
-
# Timing info
|
| 137 |
-
timing_info = f"\n\n---\n**⏱️ Timing:**\n"
|
| 138 |
-
timing_info += f"- Init: {init_time:.2f}s\n"
|
| 139 |
-
timing_info += f"- Retrieval: {retrieval_time:.2f}s\n"
|
| 140 |
-
timing_info += f"- LLM (first token): {first_token_time:.2f}s\n" if first_token_time else ""
|
| 141 |
-
timing_info += f"- LLM (total): {llm_time:.2f}s\n"
|
| 142 |
-
timing_info += f"- **Total: {total_time:.2f}s**\n"
|
| 143 |
-
|
| 144 |
|
| 145 |
# Debug info with mode indicator
|
| 146 |
debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
|
|
@@ -182,7 +159,8 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 182 |
debug_info += f" - **Mục:** {header[:80]}{'...' if len(header) > 80 else ''}\n"
|
| 183 |
debug_info += f" - **Content:** {content[:200]}{'...' if len(content) > 200 else ''}\n\n"
|
| 184 |
|
| 185 |
-
yield acc +
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
|
|
|
|
| 30 |
|
| 31 |
_load_env()
|
| 32 |
|
| 33 |
+
RETRIEVAL_MODE = RetrievalMode.HYBRID_RERANK # Test with debug logs
|
| 34 |
|
| 35 |
# LLM Config (hardcoded sau khi xóa LLMConfig từ generator)
|
| 36 |
LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
_init_resources()
|
|
|
|
| 87 |
|
| 88 |
assert STATE.db is not None
|
| 89 |
assert STATE.client is not None
|
|
|
|
| 91 |
assert STATE.rag_builder is not None
|
| 92 |
|
| 93 |
# Bước 1: Retrieve và prepare context
|
|
|
|
| 94 |
prepared = STATE.rag_builder.retrieve_and_prepare(
|
| 95 |
message,
|
| 96 |
k=RETRIEVAL_CFG.top_k,
|
| 97 |
initial_k=RETRIEVAL_CFG.initial_k,
|
| 98 |
mode=RETRIEVAL_MODE.value,
|
| 99 |
)
|
|
|
|
| 100 |
results = prepared["results"]
|
| 101 |
|
| 102 |
if not results:
|
|
|
|
| 104 |
return
|
| 105 |
|
| 106 |
# Bước 2: Gọi LLM streaming để generate answer
|
|
|
|
| 107 |
completion = STATE.client.chat.completions.create(
|
| 108 |
model=LLM_MODEL,
|
| 109 |
messages=[{"role": "user", "content": prepared["prompt"]}],
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
acc = ""
|
|
|
|
| 116 |
for chunk in completion:
|
|
|
|
|
|
|
| 117 |
delta = getattr(chunk.choices[0].delta, "content", "") or ""
|
| 118 |
if delta:
|
| 119 |
acc += delta
|
| 120 |
yield acc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
# Debug info with mode indicator
|
| 123 |
debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
|
|
|
|
| 159 |
debug_info += f" - **Mục:** {header[:80]}{'...' if len(header) > 80 else ''}\n"
|
| 160 |
debug_info += f" - **Content:** {content[:200]}{'...' if len(content) > 200 else ''}\n\n"
|
| 161 |
|
| 162 |
+
yield acc + debug_info
|
| 163 |
+
|
| 164 |
|
| 165 |
|
| 166 |
|