VietCat commited on
Commit
f6c9376
·
1 Parent(s): 79753dc

add reranker

Browse files
Files changed (4) hide show
  1. app/config.py +3 -0
  2. app/main.py +10 -0
  3. app/reranker.py +47 -0
  4. app/supabase_db.py +1 -1
app/config.py CHANGED
@@ -42,6 +42,9 @@ class Settings(BaseSettings):
42
  embedding_provider: str = os.getenv("EMBEDDING_PROVIDER", "gemini") or ""
43
  embedding_model: str = os.getenv("EMBEDDING_MODEL", "models/embedding-001") or ""
44
 
 
 
 
45
  class Config:
46
  env_file = ".env"
47
 
 
42
  embedding_provider: str = os.getenv("EMBEDDING_PROVIDER", "gemini") or ""
43
  embedding_model: str = os.getenv("EMBEDDING_MODEL", "models/embedding-001") or ""
44
 
45
+ rerank_provider: str = os.getenv("RERANK_PROVIDER", "") or llm_provider
46
+ rerank_model: str = os.getenv("RERANK_MODEL", "") or llm_model
47
+
48
  class Config:
49
  env_file = ".env"
50
 
app/main.py CHANGED
@@ -18,6 +18,7 @@ from .utils import setup_logging, extract_command, extract_keywords, timing_deco
18
  from .constants import VEHICLE_KEYWORDS, SHEET_RANGE, VEHICLE_KEYWORD_TO_COLUMN
19
  from .health import router as health_router
20
  from .llm import create_llm_client
 
21
 
22
  app = FastAPI(title="WeBot Facebook Messenger API")
23
 
@@ -67,6 +68,8 @@ llm_client = create_llm_client(
67
  model=settings.gemini_model
68
  )
69
 
 
 
70
  logger.info("[STARTUP] Mount health router...")
71
  app.include_router(health_router)
72
 
@@ -395,6 +398,13 @@ async def process_business_logic(log_kwargs: Dict[str, Any], page_token: str) ->
395
  async def format_search_results(question: str, matches: List[Dict[str, Any]]) -> str:
396
  if not matches:
397
  return "Không tìm thấy kết quả phù hợp."
 
 
 
 
 
 
 
398
  # Tìm item có similarity cao nhất
399
  top = None
400
  top_result_text = ""
 
18
  from .constants import VEHICLE_KEYWORDS, SHEET_RANGE, VEHICLE_KEYWORD_TO_COLUMN
19
  from .health import router as health_router
20
  from .llm import create_llm_client
21
+ from .reranker import Reranker
22
 
23
  app = FastAPI(title="WeBot Facebook Messenger API")
24
 
 
68
  model=settings.gemini_model
69
  )
70
 
71
+ reranker = Reranker()
72
+
73
  logger.info("[STARTUP] Mount health router...")
74
  app.include_router(health_router)
75
 
 
398
  async def format_search_results(question: str, matches: List[Dict[str, Any]]) -> str:
399
  if not matches:
400
  return "Không tìm thấy kết quả phù hợp."
401
+ # Rerank matches trước khi format cho LLM
402
+ try:
403
+ reranked = await reranker.rerank(question, matches, top_k=5)
404
+ if reranked:
405
+ matches = reranked
406
+ except Exception as e:
407
+ logger.error(f"[RERANK] Lỗi khi rerank: {e}")
408
  # Tìm item có similarity cao nhất
409
  top = None
410
  top_result_text = ""
app/reranker.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from .config import get_settings
3
+ from .gemini_client import GeminiClient
4
+ from loguru import logger
5
+ import asyncio
6
+
7
+ class Reranker:
8
+ def __init__(self):
9
+ settings = get_settings()
10
+ self.provider = getattr(settings, 'rerank_provider', settings.llm_provider)
11
+ self.model = getattr(settings, 'rerank_model', settings.llm_model)
12
+ if self.provider == 'gemini':
13
+ self.client = GeminiClient(settings.gemini_api_key, model=self.model)
14
+ # elif self.provider == 'openai':
15
+ # self.client = OpenAIClient(settings.openai_api_key, model=self.model)
16
+ # elif self.provider == 'cohere':
17
+ # self.client = CohereClient(settings.cohere_api_key, model=self.model)
18
+ else:
19
+ raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
20
+
21
+ async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]:
22
+ """
23
+ Rerank docs theo độ liên quan với query, trả về top_k docs.
24
+ """
25
+ scored = []
26
+ for doc in docs:
27
+ content = (doc.get('tieude', '') or '') + ' ' + (doc.get('noidung', '') or '')
28
+ prompt = (
29
+ f"Đoạn luật: {content}\n"
30
+ f"Câu hỏi: {query}\n"
31
+ "Hãy đánh giá mức độ liên quan giữa đoạn luật và câu hỏi trên thang điểm 0-10. "
32
+ "Chỉ trả về một số duy nhất."
33
+ )
34
+ try:
35
+ if self.provider == 'gemini':
36
+ loop = asyncio.get_event_loop()
37
+ score = await loop.run_in_executor(None, self.client.generate_text, prompt)
38
+ else:
39
+ raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
40
+ score = float(str(score).strip().split()[0])
41
+ except Exception as e:
42
+ logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
43
+ score = 0
44
+ doc['rerank_score'] = score
45
+ scored.append(doc)
46
+ scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True)
47
+ return scored[:top_k]
app/supabase_db.py CHANGED
@@ -31,7 +31,7 @@ class SupabaseClient:
31
  return None
32
 
33
  @timing_decorator_sync
34
- def match_documents(self, embedding: List[float], match_count: int = 10, vehicle_keywords: Optional[List[str]] = None):
35
  """
36
  Truy vấn vector similarity search qua RPC match_documents.
37
  Input: embedding (list[float]), match_count (int), vehicle_keywords (list[str] hoặc None)
 
31
  return None
32
 
33
  @timing_decorator_sync
34
+ def match_documents(self, embedding: List[float], match_count: int = 20, vehicle_keywords: Optional[List[str]] = None):
35
  """
36
  Truy vấn vector similarity search qua RPC match_documents.
37
  Input: embedding (list[float]), match_count (int), vehicle_keywords (list[str] hoặc None)