hungnha commited on
Commit
5cc85a5
·
1 Parent(s): 794ce9a

sửa rerank

Browse files
core/embeddings/retrival.py CHANGED
@@ -34,12 +34,13 @@ class RetrievalConfig:
34
  rerank_api_base_url: str = "https://api.siliconflow.com/v1"
35
  rerank_model: str = "Qwen/Qwen3-Reranker-4B"
36
  rerank_top_n: int = 10
37
- initial_k: int = 100
38
  top_k: int = 5
39
  vector_weight: float = 0.5
40
  bm25_weight: float = 0.5
41
 
42
 
 
43
  _retrieval_config: RetrievalConfig | None = None
44
 
45
 
@@ -82,7 +83,7 @@ class SiliconFlowReranker(BaseDocumentCompressor):
82
  "documents": [doc.page_content for doc in documents],
83
  "top_n": self.top_n or len(documents),
84
  },
85
- timeout=30,
86
  )
87
  response.raise_for_status()
88
  data = response.json()
@@ -109,11 +110,10 @@ class SiliconFlowReranker(BaseDocumentCompressor):
109
  return list(documents)
110
 
111
 
 
 
112
  class Retriever:
113
  def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
114
- import time
115
- start = time.time()
116
-
117
  self._vector_db = vector_db
118
  self._config = get_retrieval_config()
119
  self._reranker: Optional[SiliconFlowReranker] = None
@@ -138,7 +138,8 @@ class Retriever:
138
  if use_reranker:
139
  self._reranker = self._init_reranker()
140
 
141
- logger.info(f"Retriever initialized in {time.time() - start:.2f}s (BM25 lazy-loaded)")
 
142
 
143
  def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
144
  """Save BM25 retriever to disk for fast loading."""
@@ -153,9 +154,9 @@ class Retriever:
153
  logger.warning(f"Failed to save BM25 cache: {e}")
154
 
155
  def _load_bm25_cache(self) -> Optional[BM25Retriever]:
156
- """Load BM25 retriever from disk cache."""
157
  if not self._bm25_cache_path or not self._bm25_cache_path.exists():
158
  return None
 
159
  try:
160
  import pickle
161
  import time
@@ -168,9 +169,10 @@ class Retriever:
168
  except Exception as e:
169
  logger.warning(f"Failed to load BM25 cache: {e}")
170
  return None
 
 
171
 
172
  def _init_bm25(self) -> Optional[BM25Retriever]:
173
- """Initialize BM25 retriever (lazy-loaded, with disk cache)."""
174
  if self._bm25_initialized:
175
  return self._bm25_retriever
176
 
@@ -330,6 +332,8 @@ class Retriever:
330
  where: Optional[Dict[str, Any]] = None,
331
  initial_k: int | None = None,
332
  ) -> List[Dict[str, Any]]:
 
 
333
  if not text.strip():
334
  return []
335
 
@@ -353,12 +357,20 @@ class Retriever:
353
  if bm25:
354
  bm25.k = initial_k
355
 
356
- final_retriever = self._build_final()
357
- results = final_retriever.invoke(text)
 
 
 
 
 
 
358
  return [
359
  self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score"))
360
  for i, doc in enumerate(results[:k])
361
  ]
 
 
362
 
363
  def flexible_search(
364
  self,
 
34
  rerank_api_base_url: str = "https://api.siliconflow.com/v1"
35
  rerank_model: str = "Qwen/Qwen3-Reranker-4B"
36
  rerank_top_n: int = 10
37
+ initial_k: int = 25 # Reduced to minimize reranker time
38
  top_k: int = 5
39
  vector_weight: float = 0.5
40
  bm25_weight: float = 0.5
41
 
42
 
43
+
44
  _retrieval_config: RetrievalConfig | None = None
45
 
46
 
 
83
  "documents": [doc.page_content for doc in documents],
84
  "top_n": self.top_n or len(documents),
85
  },
86
+ timeout=120,
87
  )
88
  response.raise_for_status()
89
  data = response.json()
 
110
  return list(documents)
111
 
112
 
113
+
114
+
115
  class Retriever:
116
  def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
 
 
 
117
  self._vector_db = vector_db
118
  self._config = get_retrieval_config()
119
  self._reranker: Optional[SiliconFlowReranker] = None
 
138
  if use_reranker:
139
  self._reranker = self._init_reranker()
140
 
141
+ logger.info("Retriever initialized")
142
+
143
 
144
  def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
145
  """Save BM25 retriever to disk for fast loading."""
 
154
  logger.warning(f"Failed to save BM25 cache: {e}")
155
 
156
  def _load_bm25_cache(self) -> Optional[BM25Retriever]:
 
157
  if not self._bm25_cache_path or not self._bm25_cache_path.exists():
158
  return None
159
+
160
  try:
161
  import pickle
162
  import time
 
169
  except Exception as e:
170
  logger.warning(f"Failed to load BM25 cache: {e}")
171
  return None
172
+
173
+
174
 
175
  def _init_bm25(self) -> Optional[BM25Retriever]:
 
176
  if self._bm25_initialized:
177
  return self._bm25_retriever
178
 
 
332
  where: Optional[Dict[str, Any]] = None,
333
  initial_k: int | None = None,
334
  ) -> List[Dict[str, Any]]:
335
+ import time
336
+
337
  if not text.strip():
338
  return []
339
 
 
357
  if bm25:
358
  bm25.k = initial_k
359
 
360
+ ensemble = self._get_ensemble_retriever()
361
+ ensemble_results = ensemble.invoke(text)
362
+
363
+ if self._reranker:
364
+ results = self._reranker.compress_documents(ensemble_results, text)
365
+ else:
366
+ results = ensemble_results
367
+
368
  return [
369
  self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score"))
370
  for i, doc in enumerate(results[:k])
371
  ]
372
+
373
+
374
 
375
  def flexible_search(
376
  self,
core/gradio/gradio_rag_qwen.py CHANGED
@@ -30,7 +30,7 @@ from core.embeddings.generator import RAGContextBuilder, build_context, build_pr
30
 
31
  _load_env()
32
 
33
- RETRIEVAL_MODE = RetrievalMode.VECTOR_ONLY # Fastest mode - no BM25/reranker
34
 
35
  # LLM Config (hardcoded sau khi xóa LLMConfig từ generator)
36
  LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
@@ -83,12 +83,7 @@ def _init_resources() -> None:
83
 
84
 
85
  def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
86
- import time
87
- total_start = time.time()
88
-
89
- init_start = time.time()
90
  _init_resources()
91
- init_time = time.time() - init_start
92
 
93
  assert STATE.db is not None
94
  assert STATE.client is not None
@@ -96,14 +91,12 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
96
  assert STATE.rag_builder is not None
97
 
98
  # Bước 1: Retrieve và prepare context
99
- retrieval_start = time.time()
100
  prepared = STATE.rag_builder.retrieve_and_prepare(
101
  message,
102
  k=RETRIEVAL_CFG.top_k,
103
  initial_k=RETRIEVAL_CFG.initial_k,
104
  mode=RETRIEVAL_MODE.value,
105
  )
106
- retrieval_time = time.time() - retrieval_start
107
  results = prepared["results"]
108
 
109
  if not results:
@@ -111,7 +104,6 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
111
  return
112
 
113
  # Bước 2: Gọi LLM streaming để generate answer
114
- llm_start = time.time()
115
  completion = STATE.client.chat.completions.create(
116
  model=LLM_MODEL,
117
  messages=[{"role": "user", "content": prepared["prompt"]}],
@@ -121,26 +113,11 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
121
  )
122
 
123
  acc = ""
124
- first_token_time = None
125
  for chunk in completion:
126
- if first_token_time is None:
127
- first_token_time = time.time() - llm_start
128
  delta = getattr(chunk.choices[0].delta, "content", "") or ""
129
  if delta:
130
  acc += delta
131
  yield acc
132
-
133
- llm_time = time.time() - llm_start
134
- total_time = time.time() - total_start
135
-
136
- # Timing info
137
- timing_info = f"\n\n---\n**⏱️ Timing:**\n"
138
- timing_info += f"- Init: {init_time:.2f}s\n"
139
- timing_info += f"- Retrieval: {retrieval_time:.2f}s\n"
140
- timing_info += f"- LLM (first token): {first_token_time:.2f}s\n" if first_token_time else ""
141
- timing_info += f"- LLM (total): {llm_time:.2f}s\n"
142
- timing_info += f"- **Total: {total_time:.2f}s**\n"
143
-
144
 
145
  # Debug info with mode indicator
146
  debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
@@ -182,7 +159,8 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
182
  debug_info += f" - **Mục:** {header[:80]}{'...' if len(header) > 80 else ''}\n"
183
  debug_info += f" - **Content:** {content[:200]}{'...' if len(content) > 200 else ''}\n\n"
184
 
185
- yield acc + timing_info + debug_info
 
186
 
187
 
188
 
 
30
 
31
  _load_env()
32
 
33
+ RETRIEVAL_MODE = RetrievalMode.HYBRID_RERANK # Test with debug logs
34
 
35
  # LLM Config (hardcoded sau khi xóa LLMConfig từ generator)
36
  LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
 
83
 
84
 
85
  def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
 
 
 
 
86
  _init_resources()
 
87
 
88
  assert STATE.db is not None
89
  assert STATE.client is not None
 
91
  assert STATE.rag_builder is not None
92
 
93
  # Bước 1: Retrieve và prepare context
 
94
  prepared = STATE.rag_builder.retrieve_and_prepare(
95
  message,
96
  k=RETRIEVAL_CFG.top_k,
97
  initial_k=RETRIEVAL_CFG.initial_k,
98
  mode=RETRIEVAL_MODE.value,
99
  )
 
100
  results = prepared["results"]
101
 
102
  if not results:
 
104
  return
105
 
106
  # Bước 2: Gọi LLM streaming để generate answer
 
107
  completion = STATE.client.chat.completions.create(
108
  model=LLM_MODEL,
109
  messages=[{"role": "user", "content": prepared["prompt"]}],
 
113
  )
114
 
115
  acc = ""
 
116
  for chunk in completion:
 
 
117
  delta = getattr(chunk.choices[0].delta, "content", "") or ""
118
  if delta:
119
  acc += delta
120
  yield acc
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # Debug info with mode indicator
123
  debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
 
159
  debug_info += f" - **Mục:** {header[:80]}{'...' if len(header) > 80 else ''}\n"
160
  debug_info += f" - **Content:** {content[:200]}{'...' if len(content) > 200 else ''}\n\n"
161
 
162
+ yield acc + debug_info
163
+
164
 
165
 
166