Files changed (1) hide show
  1. app.py +38 -152
app.py CHANGED
@@ -12,6 +12,7 @@ import chromadb
12
  from chromadb.config import Settings
13
  from shutil import rmtree
14
 
 
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
  from langchain_chroma import Chroma
17
  from langchain_community.document_loaders import PyPDFLoader
@@ -25,22 +26,31 @@ from langchain_core.messages import HumanMessage, AIMessage
25
  from langchain_core.documents import Document
26
  from langchain_huggingface import HuggingFaceEmbeddings
27
  from langchain.retrievers import ContextualCompressionRetriever
28
- from langchain.retrievers.document_compressors import CrossEncoderReranker
29
- from langchain_community.cross_encoders import HuggingFaceCrossEncoder
30
 
 
 
 
 
 
 
31
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
32
  DATA_PATH = "medical_data"
33
  DB_PATH = "chroma_db"
 
34
  MAX_HISTORY_TURNS = 6
35
  FORCE_REBUILD_DB = False
36
 
37
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
38
 
 
 
 
 
 
 
 
39
  def process_excel_file(file_path: str, filename: str) -> list[Document]:
40
- """
41
- Xử lý Excel thông minh: Biến mỗi dòng thành một Document riêng biệt
42
- giúp tìm kiếm chính xác từng bản ghi thuốc/bệnh nhân.
43
- """
44
  docs = []
45
  try:
46
  if file_path.endswith(".csv"):
@@ -108,35 +118,32 @@ def load_documents_from_folder(folder_path: str) -> list[Document]:
108
 
109
  def get_retriever_chain():
110
  logging.info("--- Tải Embedding Model ---")
111
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
 
 
 
 
112
 
113
  vectorstore = None
114
- splits = []
115
-
116
  chroma_settings = Settings(anonymized_telemetry=False)
117
 
118
  if FORCE_REBUILD_DB and os.path.exists(DB_PATH):
119
- logging.warning("Đang xóa DB cũ theo yêu cầu FORCE_REBUILD...")
120
  rmtree(DB_PATH, ignore_errors=True)
121
 
 
122
  if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
123
  try:
124
  vectorstore = Chroma(
125
  persist_directory=DB_PATH,
126
  embedding_function=embedding_model,
127
- client_settings=chroma_settings
128
  )
129
-
130
- existing_data = vectorstore.get()
131
- if existing_data['documents']:
132
- for text, meta in zip(existing_data['documents'], existing_data['metadatas']):
133
- splits.append(Document(page_content=text, metadata=meta))
134
- logging.info(f"Đã khôi phục {len(splits)} chunks từ DB.")
135
  else:
136
- logging.warning("DB rỗng, sẽ tạo mới.")
137
  vectorstore = None
138
  except Exception as e:
139
- logging.error(f"DB lỗi: {e}. Đang reset...")
140
  rmtree(DB_PATH, ignore_errors=True)
141
  vectorstore = None
142
 
@@ -158,25 +165,16 @@ def get_retriever_chain():
158
  )
159
  logging.info("Đã lưu VectorStore thành công.")
160
 
161
- vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
 
162
 
163
- if splits:
164
- bm25_retriever = BM25Retriever.from_documents(splits)
165
- bm25_retriever.k = 10
166
- ensemble_retriever = EnsembleRetriever(
167
- retrievers=[bm25_retriever, vector_retriever],
168
- weights=[0.4, 0.6]
169
- )
170
- else:
171
- ensemble_retriever = vector_retriever
172
 
173
- logging.info("--- Tải Reranker Model (BGE-M3) ---")
174
- reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
175
- compressor = CrossEncoderReranker(model=reranker_model, top_n=5)
176
-
177
  final_retriever = ContextualCompressionRetriever(
178
  base_compressor=compressor,
179
- base_retriever=ensemble_retriever
180
  )
181
 
182
  return final_retriever
@@ -187,140 +185,28 @@ class DeepMedBot:
187
  self.ready = False
188
 
189
  if not GOOGLE_API_KEY:
190
- logging.error("⚠️ Thiếu GOOGLE_API_KEY! Vui lòng thiết lập biến môi trường.")
191
  return
192
 
193
  try:
194
- self.retriever = get_retriever_chain()
195
- if not self.retriever:
196
- logging.warning("⚠️ Chưa có dữ liệu để Retreive. Bot sẽ chỉ trả lời bằng kiến thức nền.")
197
-
198
- self.llm = ChatGoogleGenerativeAI(
199
- model="gemini-2.5-flash",
200
  temperature=0.3,
201
  google_api_key=GOOGLE_API_KEY
202
  )
203
  self._build_chains()
204
  self.ready = True
205
- logging.info("✅ Bot DeepMed đã sẵn sàng phục vụ!")
206
  except Exception as e:
207
  logging.error(f"🔥 Lỗi khởi tạo bot: {e}")
208
  logging.debug(traceback.format_exc())
209
 
210
  def _build_chains(self):
211
  context_system_prompt = (
212
- "Dựa trên lịch sử chat câu hỏi mới nhất của người dùng, "
213
- "hãy viết lại câu hỏi đó thành một câu đầy đủ ngữ cảnh để hệ thống có thể hiểu được. "
214
- "KHÔNG trả lời câu hỏi, chỉ viết lại nó."
215
  )
216
  context_prompt = ChatPromptTemplate.from_messages([
217
- ("system", context_system_prompt),
218
- MessagesPlaceholder("chat_history"),
219
- ("human", "{input}"),
220
- ])
221
-
222
- if self.retriever:
223
- history_aware_retriever = create_history_aware_retriever(
224
- self.llm, self.retriever, context_prompt
225
- )
226
-
227
- qa_system_prompt = (
228
- "Bạn là 'DeepMed-AI' - Trợ lý Dược lâm sàng tại Trung Tâm Y Tế. "
229
- "Sử dụng các thông tin được cung cấp trong phần Context dưới đây để trả lời câu hỏi về thuốc, bệnh học và y lệnh.\n"
230
- "Nếu Context có dữ liệu từ Excel, hãy trình bày dạng bảng hoặc gạch đầu dòng rõ ràng.\n"
231
- "Nếu không tìm thấy thông tin trong Context, hãy nói 'Tôi không tìm thấy thông tin trong dữ liệu nội bộ' và gợi ý dựa trên kiến thức y khoa chung của bạn.\n\n"
232
- "Context:\n{context}"
233
- )
234
-
235
- qa_prompt = ChatPromptTemplate.from_messages([
236
- ("system", qa_system_prompt),
237
- MessagesPlaceholder("chat_history"),
238
- ("human", "{input}"),
239
- ])
240
-
241
- question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
242
-
243
- if self.retriever:
244
- self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
245
- else:
246
- self.rag_chain = qa_prompt | self.llm
247
-
248
- def chat_stream(self, message: str, history: list):
249
- if not self.ready:
250
- yield "Hệ thống đang khởi động hoặc gặp lỗi cấu hình."
251
- return
252
-
253
- chat_history = []
254
- for u, b in history[-MAX_HISTORY_TURNS:]:
255
- chat_history.append(HumanMessage(content=str(u)))
256
- chat_history.append(AIMessage(content=str(b)))
257
-
258
- full_response = ""
259
- retrieved_docs = []
260
-
261
- try:
262
- stream_input = {"input": message, "chat_history": chat_history} if self.retriever else {"input": message, "chat_history": chat_history}
263
-
264
- if self.rag_chain:
265
- for chunk in self.rag_chain.stream(stream_input):
266
-
267
- if isinstance(chunk, dict):
268
- if "answer" in chunk:
269
- full_response += chunk["answer"]
270
- yield full_response
271
-
272
- if "context" in chunk:
273
- retrieved_docs = chunk["context"]
274
-
275
- elif hasattr(chunk, 'content'):
276
- full_response += chunk.content
277
- yield full_response
278
-
279
- elif isinstance(chunk, str):
280
- full_response += chunk
281
- yield full_response
282
-
283
- if retrieved_docs:
284
- refs = self._build_references_text(retrieved_docs)
285
- if refs:
286
- full_response += f"\n\n---\n📚 **Nguồn tham khảo:**\n{refs}"
287
- yield full_response
288
-
289
- except Exception as e:
290
- logging.error(f"Lỗi khi chat: {e}")
291
- logging.debug(traceback.format_exc())
292
- yield f"Đã xảy ra lỗi: {str(e)}"
293
-
294
- @staticmethod
295
- def _build_references_text(docs) -> str:
296
- lines = []
297
- seen = set()
298
- for doc in docs:
299
- src = doc.metadata.get("source", "Tài liệu")
300
- row_info = ""
301
- if "row" in doc.metadata:
302
- row_info = f"(Dòng {doc.metadata['row']})"
303
-
304
- ref_str = f"- {src} {row_info}"
305
-
306
- if ref_str not in seen:
307
- lines.append(ref_str)
308
- seen.add(ref_str)
309
- return "\n".join(lines)
310
-
311
- bot = DeepMedBot()
312
-
313
- def gradio_chat_stream(message, history):
314
- yield from bot.chat_stream(message, history)
315
-
316
- css = """
317
- .gradio-container {min_height: 600px !important;}
318
- h1 {text-align: center; color: #2E86C1;}
319
- """
320
-
321
- with gr.Blocks(css=css, title="DeepMed AI") as demo:
322
- gr.Markdown("# 🏥 DeepMed AI - Trợ lý Lâm Sàng")
323
- gr.Markdown("Hệ thống hỗ trợ lâm sàng tại Trung Tâm Y Tế Khu Vực Thanh Ba.")
324
 
325
  chat_interface = gr.ChatInterface(
326
  fn=gradio_chat_stream,
 
12
  from chromadb.config import Settings
13
  from shutil import rmtree
14
 
15
+ # --- CÁC THƯ VIỆN LANGCHAIN ---
16
  from langchain_google_genai import ChatGoogleGenerativeAI
17
  from langchain_chroma import Chroma
18
  from langchain_community.document_loaders import PyPDFLoader
 
26
  from langchain_core.documents import Document
27
  from langchain_huggingface import HuggingFaceEmbeddings
28
  from langchain.retrievers import ContextualCompressionRetriever
 
 
29
 
30
+ # --- THƯ VIỆN TỐI ƯU TỐC ĐỘ (CACHE & RERANK) ---
31
+ from langchain.retrievers.document_compressors import FlashrankRerank
32
+ from langchain.globals import set_llm_cache
33
+ from langchain_community.cache import SQLiteCache
34
+
35
+ # --- CẤU HÌNH HỆ THỐNG ---
36
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
37
  DATA_PATH = "medical_data"
38
  DB_PATH = "chroma_db"
39
+ CACHE_DB_PATH = "llm_cache.db" # File lưu bộ nhớ đệm
40
  MAX_HISTORY_TURNS = 6
41
  FORCE_REBUILD_DB = False
42
 
43
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
44
 
45
+ # --- KÍCH HOẠT CACHING ---
46
+ # Hệ thống sẽ lưu câu trả lời vào file .db.
47
+ # Lần sau gặp câu hỏi y hệt, nó sẽ lấy từ đệm ra ngay lập tức.
48
+ if not os.path.exists(CACHE_DB_PATH):
49
+ logging.info("Khởi tạo file cache mới.")
50
+ set_llm_cache(SQLiteCache(database_path=CACHE_DB_PATH))
51
+
52
  def process_excel_file(file_path: str, filename: str) -> list[Document]:
53
+ """Xử lý Excel: Biến mỗi dòng thành một Document."""
 
 
 
54
  docs = []
55
  try:
56
  if file_path.endswith(".csv"):
 
118
 
119
  def get_retriever_chain():
120
  logging.info("--- Tải Embedding Model ---")
121
+ # Chạy trên CPU để tiết kiệm resource, đổi 'cpu' thành 'cuda' nếu có GPU
122
+ embedding_model = HuggingFaceEmbeddings(
123
+ model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
124
+ model_kwargs={'device': 'cpu'}
125
+ )
126
 
127
  vectorstore = None
 
 
128
  chroma_settings = Settings(anonymized_telemetry=False)
129
 
130
  if FORCE_REBUILD_DB and os.path.exists(DB_PATH):
 
131
  rmtree(DB_PATH, ignore_errors=True)
132
 
133
+ # 1. TỐI ƯU: Kiểm tra nhanh DB bằng count() thay vì load toàn bộ
134
  if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
135
  try:
136
  vectorstore = Chroma(
137
  persist_directory=DB_PATH,
138
  embedding_function=embedding_model,
139
+ client_settings=chroma_settings
140
  )
141
+ if vectorstore._collection.count() > 0:
142
+ logging.info(f"Đã kết nối DB cũ. Size: {vectorstore._collection.count()}")
 
 
 
 
143
  else:
 
144
  vectorstore = None
145
  except Exception as e:
146
+ logging.error(f"DB lỗi: {e}. Reset DB...")
147
  rmtree(DB_PATH, ignore_errors=True)
148
  vectorstore = None
149
 
 
165
  )
166
  logging.info("Đã lưu VectorStore thành công.")
167
 
168
+ # 2. TỐI ƯU: Giảm k ban đầu xuống 6 để bớt tính toán
169
+ vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
170
 
171
+ # 3. TỐI ƯU: Sử dụng FlashRank (Siêu nhẹ & Nhanh) thay vì CrossEncoder
172
+ logging.info("--- Tải Reranker Model (FlashRank) ---")
173
+ compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2") # Model ~40MB
 
 
 
 
 
 
174
 
 
 
 
 
175
  final_retriever = ContextualCompressionRetriever(
176
  base_compressor=compressor,
177
+ base_retriever=vector_retriever
178
  )
179
 
180
  return final_retriever
 
185
  self.ready = False
186
 
187
  if not GOOGLE_API_KEY:
188
+ logging.error("⚠️ Thiếu GOOGLE_API_KEY!")
189
  return
190
 
191
  try:
192
+ self.retr2.5-flash",
 
 
 
 
 
193
  temperature=0.3,
194
  google_api_key=GOOGLE_API_KEY
195
  )
196
  self._build_chains()
197
  self.ready = True
198
+ logging.info("✅ Bot DeepMed đã sẵn sàng!")
199
  except Exception as e:
200
  logging.error(f"🔥 Lỗi khởi tạo bot: {e}")
201
  logging.debug(traceback.format_exc())
202
 
203
  def _build_chains(self):
204
  context_system_prompt = (
205
+ "Viết lại câu hỏi của người dùng thành câu đầy đủ ngữ cảnh. "
206
+ "KHÔNG trả lời, chỉ viết lại."
 
207
  )
208
  context_prompt = ChatPromptTemplate.from_messages([
209
+ Ba)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  chat_interface = gr.ChatInterface(
212
  fn=gradio_chat_stream,