Kimty commited on
Commit
1f1a99e
·
1 Parent(s): 9fe8208

Add Gradio app with RAG pipeline for chemical regulation lookup

Browse files
Files changed (10) hide show
  1. .gitignore +3 -0
  2. app.py +101 -0
  3. core/__init__.py +3 -0
  4. core/param.py +28 -0
  5. core/prompt.py +45 -0
  6. core/rag.py +69 -0
  7. core/utils.py +25 -0
  8. core/vectorstore.py +239 -0
  9. requirements.txt +10 -0
  10. scripts/upload_vectordb.py +44 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ vector_db/
2
+ .env
3
+ __pycache__/
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core import RAG, parse_cas_input
3
+
4
+ rag = RAG()
5
+
6
+ EXAMPLE_INPUTS = [
7
+ "7664-39-3",
8
+ "7664-39-3, 128-37-0",
9
+ "50-00-0, 7664-93-9, 67-64-1",
10
+ ]
11
+
12
+
13
+ def analyse(raw_input: str, provider: str):
14
+ if not raw_input or not raw_input.strip():
15
+ raise gr.Error("Vui lòng nhập ít nhất một mã CAS.")
16
+
17
+ cas_numbers = parse_cas_input(raw_input)
18
+ if not cas_numbers:
19
+ raise gr.Error(
20
+ f"Không tìm thấy mã CAS hợp lệ trong '{raw_input}'. "
21
+ "Định dạng đúng: 7664-39-3"
22
+ )
23
+
24
+ provider_key = "google" if "Gemini" in provider else "openai"
25
+ result = rag.pipeline(cas_numbers, provider=provider_key)
26
+
27
+ rows = []
28
+ for r in result.results:
29
+ rows.append([r.casNumber, r.chemicalName, r.status, r.reason])
30
+
31
+ summary_parts = []
32
+ for r in result.results:
33
+ summary_parts.append(
34
+ f"### {r.casNumber} — {r.chemicalName}\n"
35
+ f"**Yêu cầu pháp lý:** {r.status}\n\n"
36
+ f"**Cơ sở:** {r.reason}\n"
37
+ )
38
+ summary_md = "\n---\n".join(summary_parts)
39
+
40
+ return rows, summary_md
41
+
42
+
43
+ with gr.Blocks(
44
+ title="Chemical & Precursor Declaration Checker",
45
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
46
+ css="""
47
+ .main-header {text-align:center; margin-bottom:4px}
48
+ .sub-header {text-align:center; color:#555; margin-top:0}
49
+ """,
50
+ ) as demo:
51
+ gr.Markdown(
52
+ "<h1 class='main-header'>Chemical & Precursor Declaration Checker</h1>"
53
+ "<p class='sub-header'>"
54
+ "Tra cứu nghĩa vụ khai báo hóa chất & tiền chất theo quy định Việt Nam"
55
+ "</p>"
56
+ )
57
+
58
+ with gr.Row():
59
+ with gr.Column(scale=3):
60
+ cas_input = gr.Textbox(
61
+ label="Nhập mã CAS",
62
+ placeholder="VD: 7664-39-3, 128-37-0 (phân cách bằng dấu phẩy hoặc xuống dòng)",
63
+ lines=2,
64
+ )
65
+ with gr.Column(scale=1, min_width=180):
66
+ provider_dd = gr.Dropdown(
67
+ choices=["OpenAI GPT-4.1-mini", "Google Gemini 2.5 Flash"],
68
+ value="OpenAI GPT-4.1-mini",
69
+ label="LLM Provider",
70
+ )
71
+ submit_btn = gr.Button("Tra cứu", variant="primary", size="lg")
72
+
73
+ gr.Examples(
74
+ examples=EXAMPLE_INPUTS,
75
+ inputs=cas_input,
76
+ label="Ví dụ",
77
+ )
78
+
79
+ gr.Markdown("### Kết quả")
80
+
81
+ results_table = gr.Dataframe(
82
+ headers=["CAS", "Tên hóa chất", "Trạng thái", "Kết quả tra cứu"],
83
+ datatype=["str", "str", "str", "str"],
84
+ interactive=False,
85
+ wrap=True,
86
+ )
87
+ detail_md = gr.Markdown(label="Chi tiết")
88
+
89
+ submit_btn.click(
90
+ fn=analyse,
91
+ inputs=[cas_input, provider_dd],
92
+ outputs=[results_table, detail_md],
93
+ )
94
+ cas_input.submit(
95
+ fn=analyse,
96
+ inputs=[cas_input, provider_dd],
97
+ outputs=[results_table, detail_md],
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ demo.launch()
core/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .rag import RAG, parse_cas_input
2
+
3
+ __all__ = ["RAG", "parse_cas_input"]
core/param.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List
3
+
4
+
5
+ class AnalysisResult(BaseModel):
6
+ """Single CAS analysis result returned by the LLM."""
7
+ casNumber: str = Field(..., description="Mã CAS (e.g. 128-37-0)")
8
+ chemicalName: str = Field(..., description="The only name of the chemical")
9
+ status: str = Field(
10
+ ...,
11
+ description=(
12
+ "Yêu cầu pháp lý cụ thể — lấy nguyên văn phần sau dấu '=>' trong QUY ĐỊNH ÁP DỤNG. "
13
+ "Nếu CAS xuất hiện ở nhiều phụ lục, đánh số từng yêu cầu. "
14
+ "VD: 'Chỉ được Nhập khẩu hoặc Xuất khẩu nếu có Giấy chứng nhận...'"
15
+ ),
16
+ )
17
+ reason: str = Field(
18
+ ...,
19
+ description=(
20
+ "Kết quả tra cứu: ghi rõ vị trí trong văn bản pháp luật. "
21
+ "VD: 'STT 15 của Phụ lục II Nghị định 24/2026/NĐ-CP: Danh mục hoá chất sản xuất, kinh doanh có điều kiện'"
22
+ ),
23
+ )
24
+
25
+
26
+ class CASAnalysisOutput(BaseModel):
27
+ """Structured output schema for LLM CAS analysis."""
28
+ results: List[AnalysisResult] = Field(..., description="Danh sách kết quả phân tích cho từng mã CAS")
core/prompt.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SYSTEM_PROMPT3 = """Bạn là trợ lý phân tích pháp luật hóa chất.
2
+ ========================
3
+ QUY TẮC TRÍCH XUẤT
4
+ ========================
5
+ 1. casNumber
6
+ - Ghi đúng mã CAS đang phân tích.
7
+ 2. chemicalName
8
+ - Chỉ lấy "Tên chất" (cột tên chính trong bảng).
9
+ - Không lấy tên khoa học (IUPAC).
10
+ - Không lấy nội dung trong ngoặc.
11
+ - Không tự suy đoán.
12
+ 3. status (CHỈ 1 GIÁ TRỊ DUY NHẤT)
13
+ ƯU TIÊN THEO THỨ TỰ:
14
+
15
+ A. Nếu có nội dung nghĩa vụ hành chính cụ thể
16
+ (có "=>", hoặc chứa từ như: "Phải", "Chỉ được", "Cấm", "Xin giấy phép"...)
17
+ → Lấy nguyên văn nội dung nghĩa vụ đó làm status.
18
+
19
+ B. Nếu không có nghĩa vụ cụ thể nhưng CAS thuộc DANH MỤC (I, II, III, IV…)
20
+ → Lấy nguyên văn mô tả quản lý của mục con (ví dụ IVA, IVB…).
21
+ → Format:
22
+ "Thuộc danh mục IVA: [nguyên văn mô tả]"
23
+
24
+ C. Nếu không tìm thấy trong tài liệu
25
+ → "Không có trong danh sách"
26
+
27
+ - Không gộp nhiều status.
28
+ - Không lấy nhiều nguồn cho status.
29
+ - Chỉ được có 1 status chính.
30
+
31
+ 4. reason (LIỆT KÊ ĐẦY ĐỦ VỊ TRÍ)
32
+
33
+ - Liệt kê TẤT CẢ vị trí CAS xuất hiện.
34
+ - Format mỗi vị trí:
35
+ "STT [số] của [TÊN ĐẦY ĐỦ PHỤ LỤC/DANH MỤC CHÍNH]"
36
+ - Luôn dùng tiêu đề chính (ví dụ: DANH MỤC IV), không dùng IVA trong reason.
37
+ - Ngăn cách các vị trí bằng dấu chấm và khoảng trắng.
38
+ ========================
39
+ LƯU Ý QUAN TRỌNG
40
+ ========================
41
+ - status = nghĩa vụ phải làm (nếu có).
42
+ - reason = truy vết đầy đủ tất cả văn bản.
43
+ - Nếu có cả nghĩa vụ hành chính và danh mục:
44
+ → status lấy nghĩa vụ.
45
+ → reason ghi cả hai vị trí."""
core/rag.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import List, Dict
4
+
5
+ from dotenv import load_dotenv
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_openai import ChatOpenAI
9
+
10
+ from .vectorstore import VectorStore
11
+ from .utils import build_grouped_context, _build_user_prompt
12
+ from .param import CASAnalysisOutput
13
+ from .prompt import SYSTEM_PROMPT3
14
+
15
+ load_dotenv()
16
+
17
+ CAS_PATTERN = re.compile(r"\d{2,7}-\d{2}-\d")
18
+
19
+
20
+ def parse_cas_input(raw: str) -> List[str]:
21
+ """Extract valid CAS numbers from free-form user input."""
22
+ return CAS_PATTERN.findall(raw.strip())
23
+
24
+
25
+ class RAG:
26
+ def __init__(self):
27
+ self.vector_store = VectorStore()
28
+
29
+ google_key = os.getenv("GOOGLE_API_KEY")
30
+ openai_key = os.getenv("OPENAI_API_KEY")
31
+
32
+ self.llm_google = (
33
+ ChatGoogleGenerativeAI(model="gemini-2.5-flash", api_key=google_key, temperature=0)
34
+ if google_key else None
35
+ )
36
+ self.llm_openai = (
37
+ ChatOpenAI(model="gpt-4.1-mini", api_key=openai_key, temperature=0)
38
+ if openai_key else None
39
+ )
40
+
41
+ def _get_llm(self, provider: str = "openai"):
42
+ if provider == "google" and self.llm_google:
43
+ return self.llm_google
44
+ if provider == "openai" and self.llm_openai:
45
+ return self.llm_openai
46
+ available = self.llm_openai or self.llm_google
47
+ if available is None:
48
+ raise RuntimeError("No LLM API key configured. Set OPENAI_API_KEY or GOOGLE_API_KEY.")
49
+ return available
50
+
51
+ def retrieve_build_context(self, cas_numbers: List[str]) -> str:
52
+ chunks_by_cas = self.vector_store.search_per_cas(cas_numbers)
53
+ context_chunking = build_grouped_context(chunks_by_cas)
54
+ user_prompt = _build_user_prompt(cas_numbers, context_chunking)
55
+ return user_prompt
56
+
57
+ def pipeline(self, cas_numbers: List[str], provider: str = "openai") -> CASAnalysisOutput:
58
+ user_prompt = self.retrieve_build_context(cas_numbers)
59
+
60
+ prompt = ChatPromptTemplate.from_messages([
61
+ ("system", SYSTEM_PROMPT3),
62
+ ("user", "{user_prompt}")
63
+ ])
64
+
65
+ llm = self._get_llm(provider)
66
+ llm_structured = llm.with_structured_output(CASAnalysisOutput, method="function_calling")
67
+ chain = prompt | llm_structured
68
+ result = chain.invoke({"user_prompt": user_prompt})
69
+ return result
core/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+
4
+ def build_grouped_context(results: dict[str, list[dict]]) -> str:
5
+ blocks = []
6
+
7
+ for cas, chunks in results.items():
8
+ cas_block = [f"========================\n[CAS: {cas}]\n"]
9
+
10
+ for i, item in enumerate(chunks, 1):
11
+ cas_block.append(f"\nCHUNK {i}:\n")
12
+ cas_block.append(item["content"])
13
+ cas_block.append("\n------------------------")
14
+
15
+ blocks.append("".join(cas_block))
16
+
17
+ return "\n\n".join(blocks)
18
+
19
+
20
+ def _build_user_prompt(cas_numbers: List[str], context: str) -> str:
21
+ cas_list = ", ".join(cas_numbers)
22
+ return (
23
+ f"Phân tích các mã CAS sau: {cas_list}\n\n"
24
+ f"CONTEXT:\n{context}"
25
+ )
core/vectorstore.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from typing import List, Dict
4
+ import chromadb
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+ DEFAULT_DB_PATH = "./vector_db/chroma_db"
10
+
11
+
12
+ def _ensure_vectordb(db_path: str) -> str:
13
+ """Download ChromaDB from HF Dataset if not present locally."""
14
+ if os.path.isdir(db_path) and os.listdir(db_path):
15
+ return db_path
16
+
17
+ repo_id = os.getenv("HF_DATASET_REPO")
18
+ if not repo_id:
19
+ raise RuntimeError(
20
+ f"Vector DB not found at '{db_path}' and HF_DATASET_REPO is not set. "
21
+ "Either place the DB locally or set HF_DATASET_REPO=your-username/chemical-vectordb"
22
+ )
23
+
24
+ from huggingface_hub import snapshot_download
25
+
26
+ print(f" Downloading vector DB from HF dataset: {repo_id} ...")
27
+ downloaded = snapshot_download(
28
+ repo_id=repo_id,
29
+ repo_type="dataset",
30
+ token=os.getenv("HF_TOKEN"),
31
+ )
32
+
33
+ os.makedirs(db_path, exist_ok=True)
34
+ for item in os.listdir(downloaded):
35
+ src = os.path.join(downloaded, item)
36
+ dst = os.path.join(db_path, item)
37
+ if item.startswith("."):
38
+ continue
39
+ if os.path.isdir(src):
40
+ shutil.copytree(src, dst, dirs_exist_ok=True)
41
+ else:
42
+ shutil.copy2(src, dst)
43
+
44
+ print(f" Vector DB ready at: {db_path}")
45
+ return db_path
46
+
47
+
48
+ class VectorStore:
49
+ """ChromaDB vector store with hybrid search and E5 prefix support."""
50
+
51
+ EXACT_MATCH_BOOST = 0.3
52
+
53
+ def __init__(self):
54
+ self.collection_name = os.getenv('COLLECTION_NAME', 'chemical_regulations')
55
+ self.vector_db_path = os.getenv('VECTOR_DB_PATH', DEFAULT_DB_PATH)
56
+ self.lightweight = True
57
+
58
+ self.model_name = os.getenv('EMBEDDING_MODEL', 'intfloat/multilingual-e5-base')
59
+ self._use_prefix = 'e5' in self.model_name.lower()
60
+ self.model = None
61
+
62
+ if self.lightweight:
63
+ print(f" Lightweight mode — exact match only, no embedding model")
64
+ else:
65
+ from sentence_transformers import SentenceTransformer
66
+ print(f" Loading embedding model: {self.model_name}")
67
+ self.model = SentenceTransformer(self.model_name)
68
+
69
+ self.vector_db_path = _ensure_vectordb(self.vector_db_path)
70
+ self.client = chromadb.PersistentClient(path=self.vector_db_path)
71
+
72
+ try:
73
+ self.collection = self.client.get_collection(self.collection_name)
74
+ print(f" Loaded collection: {self.collection_name} "
75
+ f"({self.collection.count()} vectors)")
76
+ except Exception as e:
77
+ raise RuntimeError(
78
+ f"Cannot load collection '{self.collection_name}'. "
79
+ f"Run scripts/2_create_embeddings.py first. Error: {e}"
80
+ )
81
+
82
+ def _encode_query(self, text: str):
83
+ """Encode with 'query: ' prefix for E5 models."""
84
+ if self.model is None:
85
+ raise RuntimeError(
86
+ "Semantic search unavailable in lightweight mode. "
87
+ "Set LIGHTWEIGHT_MODE=false or unset it to enable."
88
+ )
89
+ if self._use_prefix:
90
+ text = f"query: {text}"
91
+ return self.model.encode(text, normalize_embeddings=True)
92
+
93
+ def semantic_search(self, query: str, top_k: int = 10) -> List[Dict]:
94
+ """Embedding-based semantic search."""
95
+ query_embedding = self._encode_query(query)
96
+
97
+ results = self.collection.query(
98
+ query_embeddings=[query_embedding.tolist()],
99
+ n_results=top_k,
100
+ include=["documents", "metadatas", "distances"],
101
+ )
102
+
103
+ formatted = []
104
+ if results['ids'] and results['ids'][0]:
105
+ for i in range(len(results['ids'][0])):
106
+ formatted.append({
107
+ "id": results['ids'][0][i],
108
+ "content": results['documents'][0][i],
109
+ "metadata": results['metadatas'][0][i],
110
+ "similarity": 1 - results['distances'][0][i],
111
+ "match_type": "semantic",
112
+ })
113
+ return formatted
114
+
115
+ def exact_match_search(self, cas_number: str, top_k: int = 5) -> List[Dict]:
116
+ """
117
+ True lexical match — find chunks whose document text literally
118
+ contains the CAS number string (ChromaDB where_document $contains).
119
+ """
120
+ try:
121
+ results = self.collection.get(
122
+ where_document={"$contains": cas_number},
123
+ include=["documents", "metadatas"],
124
+ limit=top_k,
125
+ )
126
+ except Exception:
127
+ return []
128
+
129
+ formatted = []
130
+ if results['ids']:
131
+ for i in range(len(results['ids'])):
132
+ formatted.append({
133
+ "id": results['ids'][i],
134
+ "content": results['documents'][i],
135
+ "metadata": results['metadatas'][i],
136
+ "similarity": 1.0,
137
+ "match_type": "exact",
138
+ })
139
+ return formatted
140
+
141
+ def hybrid_search(self, cas_numbers: List[str], top_k: int = 10) -> List[Dict]:
142
+ """
143
+ Hybrid search strategy:
144
+ 1. Exact-match per CAS number (highest fidelity)
145
+ 2. Semantic search with expanded query (broader context) — skipped in lightweight mode
146
+ 3. Deduplicate -> rerank -> top_k
147
+ """
148
+ all_results: List[Dict] = []
149
+ seen_ids: set = set()
150
+
151
+ for cas in cas_numbers:
152
+ for r in self.exact_match_search(cas, top_k=5):
153
+ if r['id'] not in seen_ids:
154
+ all_results.append(r)
155
+ seen_ids.add(r['id'])
156
+
157
+ if not self.lightweight:
158
+ expanded = self._expand_query(cas_numbers)
159
+ for r in self.semantic_search(expanded, top_k=top_k):
160
+ if r['id'] not in seen_ids:
161
+ all_results.append(r)
162
+ seen_ids.add(r['id'])
163
+
164
+ all_results = self._rerank(all_results, cas_numbers)
165
+ return all_results[:top_k]
166
+
167
+ def search_per_cas(
168
+ self, cas_numbers: List[str], top_k_per_cas: int = 5,
169
+ ) -> Dict[str, List[Dict]]:
170
+ """
171
+ Search for each CAS number individually so every CAS gets its own
172
+ dedicated chunk slots (no competition between CAS numbers).
173
+ """
174
+ results_by_cas: Dict[str, List[Dict]] = {cas: [] for cas in cas_numbers}
175
+ seen_per_cas: Dict[str, set] = {cas: set() for cas in cas_numbers}
176
+
177
+ for cas in cas_numbers:
178
+ for r in self.exact_match_search(cas, top_k=top_k_per_cas):
179
+ if r['id'] not in seen_per_cas[cas]:
180
+ results_by_cas[cas].append(r)
181
+ seen_per_cas[cas].add(r['id'])
182
+
183
+ if not self.lightweight:
184
+ expanded = self._expand_query(cas_numbers)
185
+ for r in self.semantic_search(expanded, top_k=top_k_per_cas * len(cas_numbers)):
186
+ for cas in cas_numbers:
187
+ in_content = cas in r['content']
188
+ in_meta = cas in r['metadata'].get('cas_numbers', '')
189
+ in_section_cas = cas in r['metadata'].get('section_cas_numbers', '')
190
+ if (in_content or in_meta or in_section_cas) and r['id'] not in seen_per_cas[cas]:
191
+ results_by_cas[cas].append(r)
192
+ seen_per_cas[cas].add(r['id'])
193
+
194
+ for cas in cas_numbers:
195
+ results_by_cas[cas] = self._rerank(results_by_cas[cas], [cas])
196
+ results_by_cas[cas] = results_by_cas[cas][:top_k_per_cas]
197
+
198
+ return results_by_cas
199
+
200
+ def _expand_query(self, cas_numbers: List[str]) -> str:
201
+ """Build a richer query so semantic search captures regulatory context."""
202
+ cas_list = ', '.join(cas_numbers)
203
+ return (
204
+ f"Thông tin quy định pháp luật về hóa chất có mã CAS: {cas_list}. "
205
+ f"Giấy phép nhập khẩu, khai báo hóa chất, tiền chất, danh mục quản lý."
206
+ )
207
+
208
+ def _rerank(self, results: List[Dict], cas_numbers: List[str]) -> List[Dict]:
209
+ """Boost chunks that literally contain one of the queried CAS numbers."""
210
+ for r in results:
211
+ boost = 0.0
212
+ content = r['content']
213
+ cas_meta = r['metadata'].get('cas_numbers', '')
214
+ section_cas_meta = r['metadata'].get('section_cas_numbers', '')
215
+
216
+ for cas in cas_numbers:
217
+ if cas in content:
218
+ boost += self.EXACT_MATCH_BOOST
219
+ if cas in cas_meta:
220
+ boost += self.EXACT_MATCH_BOOST * 0.5
221
+ elif cas in section_cas_meta:
222
+ boost += self.EXACT_MATCH_BOOST * 0.25
223
+
224
+ r['similarity'] = min(1.0, r['similarity'] + boost)
225
+
226
+ results.sort(key=lambda x: x['similarity'], reverse=True)
227
+ return results
228
+
229
+ def get_stats(self) -> Dict:
230
+ stats = {
231
+ "total_chunks": self.collection.count(),
232
+ "embedding_model": self.model_name,
233
+ "collection_name": self.collection_name,
234
+ "e5_prefix_enabled": self._use_prefix,
235
+ "lightweight_mode": self.lightweight,
236
+ }
237
+ if self.model is not None:
238
+ stats["embedding_dimension"] = self.model.get_sentence_embedding_dimension()
239
+ return stats
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pypdf
2
+ langchain_google_genai
3
+ langchain_openai
4
+ langchain_community
5
+ langchain_core
6
+ chromadb
7
+ python-dotenv
8
+ pydantic
9
+ gradio
10
+ huggingface_hub
scripts/upload_vectordb.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload ChromaDB vector database to a Hugging Face Dataset repo.
3
+
4
+ Usage:
5
+ python scripts/upload_vectordb.py --repo-id YOUR_USERNAME/chemical-vectordb
6
+
7
+ Prerequisites:
8
+ pip install huggingface_hub
9
+ huggingface-cli login
10
+ """
11
+
12
+ import argparse
13
+ from huggingface_hub import HfApi
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser(description="Upload ChromaDB to HF Dataset")
18
+ parser.add_argument(
19
+ "--repo-id", required=True,
20
+ help="HF dataset repo id, e.g. your-username/chemical-vectordb",
21
+ )
22
+ parser.add_argument(
23
+ "--local-dir",
24
+ default="./vector_db/content/vector_db/chroma_db",
25
+ help="Local path to the chroma_db folder",
26
+ )
27
+ parser.add_argument("--private", action="store_true", help="Make the dataset private")
28
+ args = parser.parse_args()
29
+
30
+ api = HfApi()
31
+
32
+ api.create_repo(repo_id=args.repo_id, repo_type="dataset", private=args.private, exist_ok=True)
33
+ print(f"Repo ready: https://huggingface.co/datasets/{args.repo_id}")
34
+
35
+ api.upload_folder(
36
+ folder_path=args.local_dir,
37
+ repo_id=args.repo_id,
38
+ repo_type="dataset",
39
+ )
40
+ print(f"Upload complete! Files pushed to datasets/{args.repo_id}")
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()