minh-4T commited on
Commit
42ed92c
·
1 Parent(s): 2a96248

change model embedding + encoder

Browse files
Files changed (2) hide show
  1. core/chunking.py +5 -5
  2. core/config.py +4 -4
core/chunking.py CHANGED
@@ -28,7 +28,7 @@ LIST_PATTERNS = [
28
  (r"(?m)^\s*•\s+", "<LIST_BULLET>"),
29
  ]
30
 
31
-
32
  def extract_and_protect_tables(text: str) -> Tuple[str, dict]:
33
  table_pattern = re.compile(r"(?:\|.*\|[\r\n]+)+")
34
  tables = {}
@@ -41,7 +41,7 @@ def extract_and_protect_tables(text: str) -> Tuple[str, dict]:
41
  protected_text = re.sub(table_pattern, replace_table, text)
42
  return protected_text, tables
43
 
44
-
45
  def protect_lists(text: str) -> Tuple[str, dict]:
46
  placeholders = {}
47
  protected = text
@@ -55,14 +55,14 @@ def protect_lists(text: str) -> Tuple[str, dict]:
55
 
56
  return protected, placeholders
57
 
58
-
59
  def restore_placeholders(text: str, placeholders: dict) -> str:
60
  restored = text
61
  for placeholder, original in placeholders.items():
62
  restored = restored.replace(placeholder, original)
63
  return restored
64
 
65
-
66
  def split_by_structure(text: str) -> List[str]:
67
  parts = [text]
68
 
@@ -91,7 +91,7 @@ def split_by_structure(text: str) -> List[str]:
91
 
92
  return [part for part in parts if part.strip()]
93
 
94
-
95
  def smart_chunking(docs: List) -> List:
96
  logger.info("Chunking theo cau truc + do dai...")
97
  length_splitter = RecursiveCharacterTextSplitter(
 
28
  (r"(?m)^\s*•\s+", "<LIST_BULLET>"),
29
  ]
30
 
31
+ # Tách và thêm các thẻ <table> để bảo vệ cấu trúc bảng khỏi bị chia cắt trong quá trình chunking.
32
  def extract_and_protect_tables(text: str) -> Tuple[str, dict]:
33
  table_pattern = re.compile(r"(?:\|.*\|[\r\n]+)+")
34
  tables = {}
 
41
  protected_text = re.sub(table_pattern, replace_table, text)
42
  return protected_text, tables
43
 
44
+ # Bảo vệ các phần tử của danh sách khỏi bị chia cắt trong quá trình chunking
45
  def protect_lists(text: str) -> Tuple[str, dict]:
46
  placeholders = {}
47
  protected = text
 
55
 
56
  return protected, placeholders
57
 
58
+ # Khôi phục các phần từ được bảo vệ về nội dung gốc bằng cách thay thế các placeholder
59
  def restore_placeholders(text: str, placeholders: dict) -> str:
60
  restored = text
61
  for placeholder, original in placeholders.items():
62
  restored = restored.replace(placeholder, original)
63
  return restored
64
 
65
+ # Tách văn bản dựa trên cấu trúc được xây dựng từ đầu
66
  def split_by_structure(text: str) -> List[str]:
67
  parts = [text]
68
 
 
91
 
92
  return [part for part in parts if part.strip()]
93
 
94
+ # Hàm chính thực hiện chunking thông minh
95
  def smart_chunking(docs: List) -> List:
96
  logger.info("Chunking theo cau truc + do dai...")
97
  length_splitter = RecursiveCharacterTextSplitter(
core/config.py CHANGED
@@ -39,14 +39,14 @@ GEMINI_API_KEYS = os.getenv('GEMINI_API_KEYS', '').strip()
39
  # Name models
40
  LLM_MODEL = os.getenv('LLM_MODEL', 'llama-3.1-70b-versatile')
41
  FAST_LLM_MODEL = os.getenv('FAST_LLM_MODEL', 'llama-3.1-8b-instant')
42
- EMBED_MODEL = os.getenv('EMBED_MODEL', 'BAAI/bge-m3')
43
- CROSS_ENCODER_MODEL = os.getenv('CROSS_ENCODER_MODEL', 'BAAI/bge-reranker-v2-m3')
44
 
45
  # Chunking and retrieval settings
46
  CHUNK_SIZE = int(os.getenv('CHUNK_SIZE', '800'))
47
  CHUNK_OVERLAP = int(os.getenv('CHUNK_OVERLAP', '150'))
48
- TOP_K_RESULTS = int(os.getenv('TOP_K_RESULTS', '10'))
49
- FINAL_TOP_K = int(os.getenv('FINAL_TOP_K', '5'))
50
 
51
  QDRANT_COLLECTION = os.getenv('QDRANT_COLLECTION', 'rag_docs')
52
  DOCUMENTS_DATABASE_URL = os.getenv('DOCUMENTS_DATABASE_URL', _default_documents_db_url())
 
39
  # Name models
40
  LLM_MODEL = os.getenv('LLM_MODEL', 'llama-3.1-70b-versatile')
41
  FAST_LLM_MODEL = os.getenv('FAST_LLM_MODEL', 'llama-3.1-8b-instant')
42
+ EMBED_MODEL = os.getenv('EMBED_MODEL', 'bkai-foundation-models/vietnamese-bi-encoder')
43
+ CROSS_ENCODER_MODEL = os.getenv('CROSS_ENCODER_MODEL', 'itdainb/PhoRanker')
44
 
45
  # Chunking and retrieval settings
46
  CHUNK_SIZE = int(os.getenv('CHUNK_SIZE', '800'))
47
  CHUNK_OVERLAP = int(os.getenv('CHUNK_OVERLAP', '150'))
48
+ TOP_K_RESULTS = int(os.getenv('TOP_K_RESULTS', '8'))
49
+ FINAL_TOP_K = int(os.getenv('FINAL_TOP_K', '3'))
50
 
51
  QDRANT_COLLECTION = os.getenv('QDRANT_COLLECTION', 'rag_docs')
52
  DOCUMENTS_DATABASE_URL = os.getenv('DOCUMENTS_DATABASE_URL', _default_documents_db_url())