PBThuong96 commited on
Commit
1abbebe
·
verified ·
1 Parent(s): 4f16805

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +496 -176
app.py CHANGED
@@ -11,7 +11,10 @@ import docx2txt
11
  import chromadb
12
  from chromadb.config import Settings
13
  from shutil import rmtree
 
 
14
 
 
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
  from langchain_chroma import Chroma
17
  from langchain_community.document_loaders import PyPDFLoader
@@ -24,213 +27,319 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
24
  from langchain_core.messages import HumanMessage, AIMessage
25
  from langchain_core.documents import Document
26
  from langchain_huggingface import HuggingFaceEmbeddings
27
- from langchain.retrievers import ContextualCompressionRetriever
28
- from langchain.retrievers.document_compressors import CrossEncoderReranker
29
- from langchain_community.cross_encoders import HuggingFaceCrossEncoder
 
 
30
 
31
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
32
  DATA_PATH = "medical_data"
33
  DB_PATH = "chroma_db"
34
- MAX_HISTORY_TURNS = 6
35
  FORCE_REBUILD_DB = False
36
 
37
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
38
 
39
  def process_excel_file(file_path: str, filename: str) -> list[Document]:
40
- """
41
- Xử lý Excel thông minh: Biến mỗi dòng thành một Document riêng biệt
42
- giúp tìm kiếm chính xác từng bản ghi thuốc/bệnh nhân.
43
- """
44
  docs = []
45
  try:
46
  if file_path.endswith(".csv"):
47
- df = pd.read_csv(file_path)
48
  else:
49
  df = pd.read_excel(file_path)
50
-
51
- df.dropna(how='all', inplace=True)
52
- df.fillna("Không thông tin", inplace=True)
53
-
54
- for idx, row in df.iterrows():
55
- content_parts = []
56
- for col_name, val in row.items():
57
- clean_val = str(val).strip()
58
- if clean_val and clean_val.lower() != "nan":
59
- content_parts.append(f"{col_name}: {clean_val}")
60
-
61
- if content_parts:
62
- page_content = f"Dữ liệu từ file {filename} (Dòng {idx+1}):\n" + "\n".join(content_parts)
63
- metadata = {"source": filename, "row": idx+1, "type": "excel_record"}
64
- docs.append(Document(page_content=page_content, metadata=metadata))
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  except Exception as e:
67
  logging.error(f"Lỗi xử lý Excel {filename}: {e}")
68
 
69
  return docs
70
 
71
  def load_documents_from_folder(folder_path: str) -> list[Document]:
72
- logging.info(f"--- Bắt đầu quét thư mục: {folder_path} ---")
73
- documents: list[Document] = []
74
  if not os.path.exists(folder_path):
75
  os.makedirs(folder_path, exist_ok=True)
76
  return []
77
-
78
- for root, _, files in os.walk(folder_path):
79
- for filename in files:
80
- file_path = os.path.join(root, filename)
81
- filename_lower = filename.lower()
82
- try:
83
- if filename_lower.endswith(".pdf"):
84
- loader = PyPDFLoader(file_path)
85
- docs = loader.load()
86
- for d in docs: d.metadata["source"] = filename
87
- documents.extend(docs)
88
-
89
- elif filename_lower.endswith(".docx"):
90
- text = docx2txt.process(file_path)
91
- if text.strip():
92
- documents.append(Document(page_content=text, metadata={"source": filename}))
93
-
94
- elif filename_lower.endswith((".xlsx", ".xls", ".csv")):
95
- excel_docs = process_excel_file(file_path, filename)
96
- documents.extend(excel_docs)
97
-
98
- elif filename_lower.endswith((".txt", ".md")):
99
- with open(file_path, "r", encoding="utf-8") as f: text = f.read()
100
- if text.strip():
101
- documents.append(Document(page_content=text, metadata={"source": filename}))
102
 
103
- except Exception as e:
104
- logging.error(f"Lỗi đọc file {filename}: {e}")
105
-
106
- logging.info(f"Tổng cộng đã load: {len(documents)} tài liệu gốc.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  return documents
108
 
109
  def get_retriever_chain():
 
110
  logging.info("--- Tải Embedding Model ---")
111
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
 
 
 
 
 
 
112
 
113
  vectorstore = None
114
  splits = []
115
-
116
- chroma_settings = Settings(anonymized_telemetry=False)
117
-
 
 
 
118
  if FORCE_REBUILD_DB and os.path.exists(DB_PATH):
119
- logging.warning("Đang xóa DB cũ theo yêu cầu FORCE_REBUILD...")
120
  rmtree(DB_PATH, ignore_errors=True)
121
-
122
  if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
123
  try:
124
  vectorstore = Chroma(
125
- persist_directory=DB_PATH,
126
  embedding_function=embedding_model,
127
- client_settings=chroma_settings
128
  )
129
 
130
- existing_data = vectorstore.get()
131
- if existing_data['documents']:
132
- for text, meta in zip(existing_data['documents'], existing_data['metadatas']):
 
 
 
 
133
  splits.append(Document(page_content=text, metadata=meta))
134
- logging.info(f"Đã khôi phục {len(splits)} chunks từ DB.")
135
  else:
136
- logging.warning("DB rỗng, sẽ tạo mới.")
137
  vectorstore = None
138
  except Exception as e:
139
- logging.error(f"DB lỗi: {e}. Đang reset...")
140
  rmtree(DB_PATH, ignore_errors=True)
141
  vectorstore = None
142
-
143
  if not vectorstore:
144
- logging.info("--- Tạo Index dữ liệu mới ---")
145
  raw_docs = load_documents_from_folder(DATA_PATH)
146
  if not raw_docs:
147
- logging.warning("Không có dữ liệu trong thư mục medical_data.")
148
  return None
149
-
150
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
 
 
 
 
 
 
 
151
  splits = text_splitter.split_documents(raw_docs)
 
 
 
 
 
152
 
153
  vectorstore = Chroma.from_documents(
154
- documents=splits,
155
- embedding=embedding_model,
156
  persist_directory=DB_PATH,
157
  client_settings=chroma_settings
158
  )
159
- logging.info("Đã lưu VectorStore thành công.")
160
-
161
- vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
162
 
 
 
 
 
 
 
 
 
 
163
  if splits:
164
  bm25_retriever = BM25Retriever.from_documents(splits)
165
- bm25_retriever.k = 10
166
- ensemble_retriever = EnsembleRetriever(
167
- retrievers=[bm25_retriever, vector_retriever],
168
- weights=[0.4, 0.6]
169
- )
170
- else:
171
- ensemble_retriever = vector_retriever
172
-
173
- logging.info("--- Tải Reranker Model (BGE-M3) ---")
174
- reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
175
- compressor = CrossEncoderReranker(model=reranker_model, top_n=5)
176
 
177
- final_retriever = ContextualCompressionRetriever(
178
- base_compressor=compressor,
179
- base_retriever=ensemble_retriever
 
180
  )
181
 
182
- return final_retriever
 
 
 
183
 
184
  class DeepMedBot:
185
  def __init__(self):
186
  self.rag_chain = None
187
  self.ready = False
 
 
 
188
 
 
 
 
 
 
 
 
189
  if not GOOGLE_API_KEY:
190
- logging.error("⚠️ Thiếu GOOGLE_API_KEY! Vui lòng thiết lập biến môi trường.")
191
- return
192
-
193
  try:
 
194
  self.retriever = get_retriever_chain()
195
- if not self.retriever:
196
- logging.warning("⚠️ Chưa có dữ liệu để Retreive. Bot sẽ chỉ trả lời bằng kiến thức nền.")
197
 
 
198
  self.llm = ChatGoogleGenerativeAI(
199
- model="gemini-2.5-flash",
200
- temperature=0.11,
201
- google_api_key=GOOGLE_API_KEY
 
 
202
  )
 
203
  self._build_chains()
204
  self.ready = True
205
- logging.info("✅ Bot DeepMed đã sẵn sàng phục vụ!")
 
 
206
  except Exception as e:
207
- logging.error(f"🔥 Lỗi khởi tạo bot: {e}")
208
- logging.debug(traceback.format_exc())
209
-
210
  def _build_chains(self):
211
- context_system_prompt = (
212
- "Dựa trên lịch sử chat và câu hỏi mới nhất của người dùng, "
213
- "hãy viết lại câu hỏi đó thành một câu đầy đủ ngữ cảnh để hệ thống có thể hiểu được. "
214
- "KHÔNG trả lời câu hỏi, chỉ viết lại nó."
 
 
 
 
 
215
  )
216
- context_prompt = ChatPromptTemplate.from_messages([
217
- ("system", context_system_prompt),
 
218
  MessagesPlaceholder("chat_history"),
219
  ("human", "{input}"),
220
  ])
221
 
222
- if self.retriever:
223
- history_aware_retriever = create_history_aware_retriever(
224
- self.llm, self.retriever, context_prompt
225
- )
226
-
227
- qa_system_prompt = (
228
- "Bạn 'DeepMed-AI' - Trợ lý Dược lâm sàng tại Trung Tâm Y Tế. "
229
- "Sử dụng các thông tin được cung cấp trong phần Context dưới đây để trả lời câu hỏi về thuốc, bệnh học và y lệnh.\n"
230
- "Nếu Context dữ liệu từ Excel, hãy trình bày dạng bảng hoặc gạch đầu dòng rõ ràng.\n"
231
- "Nếu không tìm thấy thông tin trong Context, hãy nói 'Tôi không tìm thấy thông tin trong dữ liệu nội bộ' gợi ý dựa trên kiến thức y khoa chung của bạn.\n\n"
232
- "Context:\n{context}"
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  qa_prompt = ChatPromptTemplate.from_messages([
236
  ("system", qa_system_prompt),
@@ -238,93 +347,304 @@ class DeepMedBot:
238
  ("human", "{input}"),
239
  ])
240
 
241
- question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
242
-
243
  if self.retriever:
 
 
 
 
 
 
 
 
 
244
  self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
245
  else:
 
246
  self.rag_chain = qa_prompt | self.llm
247
-
248
  def chat_stream(self, message: str, history: list):
 
249
  if not self.ready:
250
- yield "Hệ thống đang khởi động hoặc gặp lỗi cấu hình."
251
- return
252
-
 
 
253
  chat_history = []
254
  for u, b in history[-MAX_HISTORY_TURNS:]:
255
  chat_history.append(HumanMessage(content=str(u)))
256
  chat_history.append(AIMessage(content=str(b)))
257
-
258
  full_response = ""
259
  retrieved_docs = []
260
 
261
  try:
262
- stream_input = {"input": message, "chat_history": chat_history} if self.retriever else {"input": message, "chat_history": chat_history}
 
263
 
264
- if self.rag_chain:
265
- for chunk in self.rag_chain.stream(stream_input):
266
-
 
 
 
267
  if isinstance(chunk, dict):
268
  if "answer" in chunk:
269
  full_response += chunk["answer"]
270
- yield full_response
271
-
272
  if "context" in chunk:
273
  retrieved_docs = chunk["context"]
274
-
275
  elif hasattr(chunk, 'content'):
276
  full_response += chunk.content
277
  yield full_response
278
-
279
- elif isinstance(chunk, str):
280
- full_response += chunk
281
- yield full_response
282
-
283
- if retrieved_docs:
284
- refs = self._build_references_text(retrieved_docs)
285
- if refs:
286
- full_response += f"\n\n---\n📚 **Nguồn tham khảo:**\n{refs}"
287
- yield full_response
288
-
 
 
 
 
 
 
 
 
 
289
  except Exception as e:
290
- logging.error(f"Lỗi khi chat: {e}")
291
- logging.debug(traceback.format_exc())
292
- yield f"Đã xảy ra lỗi: {str(e)}"
293
-
294
  @staticmethod
295
  def _build_references_text(docs) -> str:
296
- lines = []
297
- seen = set()
298
  for doc in docs:
299
- src = doc.metadata.get("source", "Tài liệu")
300
- row_info = ""
301
- if "row" in doc.metadata:
302
- row_info = f"(Dòng {doc.metadata['row']})"
303
 
304
- ref_str = f"- {src} {row_info}"
305
-
306
- if ref_str not in seen:
307
- lines.append(ref_str)
308
- seen.add(ref_str)
309
- return "\n".join(lines)
 
 
 
 
310
 
 
311
  bot = DeepMedBot()
312
 
313
  def gradio_chat_stream(message, history):
 
314
  yield from bot.chat_stream(message, history)
315
 
 
316
  css = """
317
- .gradio-container {min_height: 600px !important;}
318
- h1 {text-align: center; color: #2E86C1;}
319
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- with gr.Blocks(css=css, title="DeepMed AI") as demo:
322
- gr.Markdown("# 🏥 DeepMed AI - Trợ lý Lâm Sàng")
323
- gr.Markdown("Hệ thống hỗ trợ lâm sàng tại Trung Tâm Y Tế Khu Vực Thanh Ba.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- chat_interface = gr.ChatInterface(
326
- fn=gradio_chat_stream,
327
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  if __name__ == "__main__":
330
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
11
  import chromadb
12
  from chromadb.config import Settings
13
  from shutil import rmtree
14
+ import gc
15
+ import torch
16
 
17
+ # Optimization: Import only what's needed
18
  from langchain_google_genai import ChatGoogleGenerativeAI
19
  from langchain_chroma import Chroma
20
  from langchain_community.document_loaders import PyPDFLoader
 
27
  from langchain_core.messages import HumanMessage, AIMessage
28
  from langchain_core.documents import Document
29
  from langchain_huggingface import HuggingFaceEmbeddings
30
+
31
+ # Bỏ CrossEncoder để giảm memory, thay bằng các kỹ thuật khác
32
+ # from langchain.retrievers import ContextualCompressionRetriever
33
+ # from langchain.retrievers.document_compressors import CrossEncoderReranker
34
+ # from langchain_community.cross_encoders import HuggingFaceCrossEncoder
35
 
36
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
37
  DATA_PATH = "medical_data"
38
  DB_PATH = "chroma_db"
39
+ MAX_HISTORY_TURNS = 5 # Giảm để tăng tốc
40
  FORCE_REBUILD_DB = False
41
 
42
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
43
 
44
  def process_excel_file(file_path: str, filename: str) -> list[Document]:
45
+ """Tối ưu hóa xử lý Excel để tăng độ chính xác"""
 
 
 
46
  docs = []
47
  try:
48
  if file_path.endswith(".csv"):
49
+ df = pd.read_csv(file_path, encoding='utf-8')
50
  else:
51
  df = pd.read_excel(file_path)
52
+
53
+ # Xử lý thông minh cho dữ liệu y tế
54
+ # Phát hiện loại dữ liệu tự động
55
+ if any(col in df.columns.str.lower() for col in ['tên thuốc', 'thuốc', 'drug']):
56
+ # Dữ liệu thuốc
57
+ for idx, row in df.iterrows():
58
+ content = f"THÔNG TIN THUỐC - Dòng {idx+1}:\n"
59
+ for col in df.columns:
60
+ if pd.notna(row[col]):
61
+ content += f"{col}: {str(row[col]).strip()}\n"
62
+ docs.append(Document(
63
+ page_content=content,
64
+ metadata={"source": filename, "row": idx+1, "type": "drug_info"}
65
+ ))
66
+ elif any(col in df.columns.str.lower() for col in ['bệnh nhân', 'patient', 'mã bn']):
67
+ # Dữ liệu bệnh nhân
68
+ for idx, row in df.iterrows():
69
+ content = f"HỒ SƠ BỆNH NHÂN - Dòng {idx+1}:\n"
70
+ for col in df.columns:
71
+ if pd.notna(row[col]):
72
+ content += f"{col}: {str(row[col]).strip()}\n"
73
+ docs.append(Document(
74
+ page_content=content,
75
+ metadata={"source": filename, "row": idx+1, "type": "patient_record"}
76
+ ))
77
+ else:
78
+ # Dữ liệu chung
79
+ for idx, row in df.iterrows():
80
+ content_parts = [f"{col}: {str(row[col]).strip()}"
81
+ for col in df.columns if pd.notna(row[col])]
82
+ if content_parts:
83
+ docs.append(Document(
84
+ page_content=f"Dữ liệu từ {filename} (Dòng {idx+1}):\n" + "\n".join(content_parts),
85
+ metadata={"source": filename, "row": idx+1, "type": "general_data"}
86
+ ))
87
+
88
  except Exception as e:
89
  logging.error(f"Lỗi xử lý Excel {filename}: {e}")
90
 
91
  return docs
92
 
93
  def load_documents_from_folder(folder_path: str) -> list[Document]:
94
+ """Tải xử tài liệu với metadata phong phú"""
95
+ documents = []
96
  if not os.path.exists(folder_path):
97
  os.makedirs(folder_path, exist_ok=True)
98
  return []
99
+
100
+ # Ưu tiên xử theo thứ tự để tăng độ chính xác
101
+ file_extensions = ['.pdf', '.docx', '.xlsx', '.xls', '.csv', '.txt']
102
+
103
+ for ext in file_extensions:
104
+ for root, _, files in os.walk(folder_path):
105
+ for filename in files:
106
+ if filename.lower().endswith(ext):
107
+ file_path = os.path.join(root, filename)
108
+ try:
109
+ if filename.lower().endswith(".pdf"):
110
+ loader = PyPDFLoader(file_path)
111
+ docs = loader.load()
112
+ for i, d in enumerate(docs):
113
+ d.metadata.update({
114
+ "source": filename,
115
+ "page": i+1,
116
+ "file_type": "pdf",
117
+ "doc_id": f"{filename}_page_{i+1}"
118
+ })
119
+ documents.extend(docs)
 
 
 
 
120
 
121
+ elif filename.lower().endswith(".docx"):
122
+ text = docx2txt.process(file_path)
123
+ if text.strip():
124
+ doc = Document(
125
+ page_content=text,
126
+ metadata={
127
+ "source": filename,
128
+ "file_type": "docx",
129
+ "doc_id": filename
130
+ }
131
+ )
132
+ documents.append(doc)
133
+
134
+ elif filename.lower().endswith((".xlsx", ".xls", ".csv")):
135
+ excel_docs = process_excel_file(file_path, filename)
136
+ documents.extend(excel_docs)
137
+
138
+ elif filename.lower().endswith((".txt", ".md")):
139
+ with open(file_path, "r", encoding="utf-8") as f:
140
+ text = f.read()
141
+ if text.strip():
142
+ doc = Document(
143
+ page_content=text,
144
+ metadata={
145
+ "source": filename,
146
+ "file_type": "txt",
147
+ "doc_id": filename
148
+ }
149
+ )
150
+ documents.append(doc)
151
+
152
+ except Exception as e:
153
+ logging.error(f"Lỗi đọc file {filename}: {e}")
154
+
155
+ logging.info(f"Đã load {len(documents)} tài liệu")
156
  return documents
157
 
158
  def get_retriever_chain():
159
+ """Tạo retriever tối ưu cho Hugging Face"""
160
  logging.info("--- Tải Embedding Model ---")
161
+
162
+ # Model tối ưu cho tiếng Việt và memory
163
+ embedding_model = HuggingFaceEmbeddings(
164
+ model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L6-v2",
165
+ model_kwargs={'device': 'cpu'},
166
+ encode_kwargs={'normalize_embeddings': True}
167
+ )
168
 
169
  vectorstore = None
170
  splits = []
171
+
172
+ chroma_settings = Settings(
173
+ anonymized_telemetry=False,
174
+ allow_reset=True
175
+ )
176
+
177
  if FORCE_REBUILD_DB and os.path.exists(DB_PATH):
178
+ logging.warning("Đang xóa DB cũ...")
179
  rmtree(DB_PATH, ignore_errors=True)
180
+
181
  if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
182
  try:
183
  vectorstore = Chroma(
184
+ persist_directory=DB_PATH,
185
  embedding_function=embedding_model,
186
+ client_settings=chroma_settings
187
  )
188
 
189
+ # Kiểm tra số lượng documents
190
+ count = vectorstore._collection.count()
191
+ if count > 0:
192
+ logging.info(f"Đã khôi phục {count} documents từ DB")
193
+ # Lấy splits cho BM25
194
+ results = vectorstore._collection.get()
195
+ for text, meta in zip(results['documents'], results['metadatas']):
196
  splits.append(Document(page_content=text, metadata=meta))
 
197
  else:
198
+ logging.warning("DB rỗng, tạo mới...")
199
  vectorstore = None
200
  except Exception as e:
201
+ logging.error(f"DB lỗi: {e}")
202
  rmtree(DB_PATH, ignore_errors=True)
203
  vectorstore = None
204
+
205
  if not vectorstore:
206
+ logging.info("--- Tạo Index mới ---")
207
  raw_docs = load_documents_from_folder(DATA_PATH)
208
  if not raw_docs:
209
+ logging.warning("Không có dữ liệu")
210
  return None
211
+
212
+ # Text splitter tối ưu cho y tế
213
+ text_splitter = RecursiveCharacterTextSplitter(
214
+ chunk_size=800, # Tăng độ chính xác với chunk nhỏ hơn
215
+ chunk_overlap=150,
216
+ separators=["\n\n", "\n", "。", ".", "!", "?", ";", ";", ",", ",", " ", ""],
217
+ length_function=len,
218
+ )
219
+
220
  splits = text_splitter.split_documents(raw_docs)
221
+ logging.info(f"Đã chia thành {len(splits)} chunks")
222
+
223
+ # Giảm memory bằng cách xóa raw_docs
224
+ del raw_docs
225
+ gc.collect()
226
 
227
  vectorstore = Chroma.from_documents(
228
+ documents=splits,
229
+ embedding=embedding_model,
230
  persist_directory=DB_PATH,
231
  client_settings=chroma_settings
232
  )
 
 
 
233
 
234
+ # Tăng số lượng retrieved documents để bù đắp độ chính xác
235
+ vector_retriever = vectorstore.as_retriever(
236
+ search_kwargs={
237
+ "k": 15, # Tăng từ 10 lên 15
238
+ "score_threshold": 0.3 # Ngưỡng similarity
239
+ }
240
+ )
241
+
242
+ # BM25 Retriever cho keyword matching
243
  if splits:
244
  bm25_retriever = BM25Retriever.from_documents(splits)
245
+ bm25_retriever.k = 15 # Tăng số documents
 
 
 
 
 
 
 
 
 
 
246
 
247
+ # Ensemble Retriever với weights tối ưu
248
+ ensemble_retriever = EnsembleRetriever(
249
+ retrievers=[bm25_retriever, vector_retriever],
250
+ weights=[0.5, 0.5] # Cân bằng giữa keyword và semantic
251
  )
252
 
253
+ # Memory management
254
+ gc.collect()
255
+
256
+ return ensemble_retriever
257
 
258
  class DeepMedBot:
259
  def __init__(self):
260
  self.rag_chain = None
261
  self.ready = False
262
+ self.retriever = None
263
+ self.llm = None
264
+ self.chat_history = [] # Lưu history riêng
265
 
266
+ logging.info("Initializing DeepMedBot...")
267
+
268
+ def initialize(self):
269
+ """Khởi tạo lazy để giảm startup time"""
270
+ if self.ready:
271
+ return True
272
+
273
  if not GOOGLE_API_KEY:
274
+ logging.error("⚠️ Thiếu GOOGLE_API_KEY!")
275
+ return False
276
+
277
  try:
278
+ # Khởi tạo retriever
279
  self.retriever = get_retriever_chain()
 
 
280
 
281
+ # Khởi tạo LLM với config tối ưu
282
  self.llm = ChatGoogleGenerativeAI(
283
+ model="gemini-1.5-flash", # Dùng flash thay vì 2.5 cho ổn định
284
+ temperature=0.1,
285
+ google_api_key=GOOGLE_API_KEY,
286
+ max_output_tokens=2000,
287
+ timeout=30
288
  )
289
+
290
  self._build_chains()
291
  self.ready = True
292
+ logging.info("✅ DeepMedBot đã sẵn sàng!")
293
+ return True
294
+
295
  except Exception as e:
296
+ logging.error(f"🔥 Lỗi khởi tạo: {e}")
297
+ return False
298
+
299
  def _build_chains(self):
300
+ """Xây dựng chains với prompt tối ưu"""
301
+ # Contextualize question với medical focus
302
+ contextualize_q_system_prompt = (
303
+ "Bạn trợ lý y tế. Dựa vào lịch sử chat và câu hỏi mới, "
304
+ "hãy viết lại câu hỏi thành một phiên bản đầy đủ, rõ ràng, "
305
+ "chuyên nghiệp về y tế để tìm kiếm thông tin.\n"
306
+ "Ví dụ:\n"
307
+ "User: 'tác dụng phụ?' -> 'Thuốc này có những tác dụng phụ gì?'\n"
308
+ "KHÔNG trả lời câu hỏi, chỉ VIẾT LẠI câu hỏi."
309
  )
310
+
311
+ contextualize_q_prompt = ChatPromptTemplate.from_messages([
312
+ ("system", contextualize_q_system_prompt),
313
  MessagesPlaceholder("chat_history"),
314
  ("human", "{input}"),
315
  ])
316
 
317
+ # QA prompt tối ưu cho y tế
318
+ qa_system_prompt = """
319
+ Bạn là "DeepMed AI" - Trợ lý Dược lâm sàng thông minh tại Trung Tâm Y Tế Khu Vực Thanh Ba.
320
+
321
+ HƯỚNG DẪN TRẢ LỜI:
322
+ 1. **NGUYÊN TẮC VÀNG**: Luôn kiểm tra kỹ thông tin từ Context trước khi trả lời
323
+ 2. **ĐỊNH DẠNG RÀNG**:
324
+ - Thuốc: Tên thuốc (IN HOA), liều lượng, chống chỉ định, tác dụng phụ
325
+ - Bệnh nhân: BN, tuổi, chẩn đoán, phác đồ
326
+ - Số liệu: Trình bày dạng bảng hoặc bullet points
327
+ 3. **MỨC ĐỘ TIN CẬY**:
328
+ ✅ "Theo dữ liệu nội bộ: [thông tin]"
329
+ ⚠️ "Thông tin không đầy đủ trong dữ liệu, theo kiến thức y khoa: [thông tin]"
330
+ ❌ "Không tìm thấy trong dữ liệu, vui lòng kiểm tra lại"
331
+ 4. **AN TOÀN Y TẾ**: Luôn nhắc "Vui lòng tham khảo ý kiến bác sĩ trước khi sử dụng"
332
+
333
+ Context:
334
+ {context}
335
+
336
+ Hãy trả lời câu hỏi dựa trên Context trên. Nếu không có thông tin trong Context, hãy:
337
+ 1. Nói rõ "Không tìm thấy trong dữ liệu nội bộ"
338
+ 2. Cung cấp kiến thức y khoa chung (nếu có)
339
+ 3. Gợi ý tham khảo bác sĩ chuyên khoa
340
+
341
+ Câu hỏi: {input}
342
+ """
343
 
344
  qa_prompt = ChatPromptTemplate.from_messages([
345
  ("system", qa_system_prompt),
 
347
  ("human", "{input}"),
348
  ])
349
 
 
 
350
  if self.retriever:
351
+ # Tạo history-aware retriever
352
+ history_aware_retriever = create_history_aware_retriever(
353
+ self.llm, self.retriever, contextualize_q_prompt
354
+ )
355
+
356
+ # Tạo chain trả lời
357
+ question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
358
+
359
+ # Tạo retrieval chain hoàn chỉnh
360
  self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
361
  else:
362
+ # Fallback chain
363
  self.rag_chain = qa_prompt | self.llm
364
+
365
  def chat_stream(self, message: str, history: list):
366
+ """Stream chat với memory management"""
367
  if not self.ready:
368
+ if not self.initialize():
369
+ yield "Hệ thống đang khởi động..."
370
+ return
371
+
372
+ # Giới hạn history để tránh memory leak
373
  chat_history = []
374
  for u, b in history[-MAX_HISTORY_TURNS:]:
375
  chat_history.append(HumanMessage(content=str(u)))
376
  chat_history.append(AIMessage(content=str(b)))
377
+
378
  full_response = ""
379
  retrieved_docs = []
380
 
381
  try:
382
+ # Thêm delay nhỏ để tránh timeout
383
+ import time
384
 
385
+ # Stream response
386
+ if hasattr(self.rag_chain, 'stream'):
387
+ for chunk in self.rag_chain.stream({
388
+ "input": message,
389
+ "chat_history": chat_history
390
+ }):
391
  if isinstance(chunk, dict):
392
  if "answer" in chunk:
393
  full_response += chunk["answer"]
394
+ yield full_response
 
395
  if "context" in chunk:
396
  retrieved_docs = chunk["context"]
 
397
  elif hasattr(chunk, 'content'):
398
  full_response += chunk.content
399
  yield full_response
400
+ time.sleep(0.01) # Small delay
401
+ else:
402
+ # Fallback non-stream
403
+ response = self.rag_chain.invoke({
404
+ "input": message,
405
+ "chat_history": chat_history
406
+ })
407
+ full_response = response.content if hasattr(response, 'content') else str(response)
408
+ yield full_response
409
+
410
+ # Thêm references nếu có
411
+ if retrieved_docs:
412
+ refs = self._build_references_text(retrieved_docs)
413
+ if refs:
414
+ full_response += f"\n\n---\n📚 **Tài liệu tham khảo:**\n{refs}"
415
+ yield full_response
416
+
417
+ # Memory cleanup
418
+ gc.collect()
419
+
420
  except Exception as e:
421
+ logging.error(f"Chat error: {e}")
422
+ yield f"⚠️ Có lỗi xảy ra: {str(e)[:100]}"
423
+
 
424
  @staticmethod
425
  def _build_references_text(docs) -> str:
426
+ """Xây dựng references với format đẹp"""
427
+ references = {}
428
  for doc in docs:
429
+ source = doc.metadata.get("source", "Tài liệu")
430
+ file_type = doc.metadata.get("file_type", "")
431
+ row = doc.metadata.get("row", "")
 
432
 
433
+ key = f"{source}_{row}"
434
+ if key not in references:
435
+ ref_info = f"📄 {source}"
436
+ if file_type:
437
+ ref_info += f" ({file_type.upper()})"
438
+ if row:
439
+ ref_info += f" - Dòng {row}"
440
+ references[key] = ref_info
441
+
442
+ return "\n".join(references.values())
443
 
444
+ # Global bot instance với lazy loading
445
  bot = DeepMedBot()
446
 
447
  def gradio_chat_stream(message, history):
448
+ """Wrapper cho Gradio"""
449
  yield from bot.chat_stream(message, history)
450
 
451
+ # CSS responsive cho cả mobile và PC
452
  css = """
453
+ /* Base styles */
454
+ .gradio-container {
455
+ min-height: 100vh !important;
456
+ max-width: 100% !important;
457
+ margin: 0 auto !important;
458
+ padding: 10px !important;
459
+ }
460
+
461
+ /* Header */
462
+ h1 {
463
+ text-align: center;
464
+ color: #2E86C1;
465
+ font-size: 24px !important;
466
+ margin: 10px 0 !important;
467
+ padding: 10px !important;
468
+ }
469
+
470
+ /* Chat container */
471
+ #chatbot {
472
+ min-height: 400px !important;
473
+ max-height: 60vh !important;
474
+ overflow-y: auto !important;
475
+ border: 1px solid #e0e0e0 !important;
476
+ border-radius: 10px !important;
477
+ padding: 15px !important;
478
+ background: #f9f9f9 !important;
479
+ }
480
 
481
+ /* Messages */
482
+ .user, .assistant {
483
+ padding: 10px 15px !important;
484
+ margin: 8px 0 !important;
485
+ border-radius: 15px !important;
486
+ max-width: 85% !important;
487
+ word-wrap: break-word !important;
488
+ }
489
+
490
+ .user {
491
+ background: #E3F2FD !important;
492
+ margin-left: auto !important;
493
+ }
494
+
495
+ .assistant {
496
+ background: #F5F5F5 !important;
497
+ margin-right: auto !important;
498
+ }
499
+
500
+ /* Input area */
501
+ #text-input {
502
+ border-radius: 20px !important;
503
+ padding: 12px 20px !important;
504
+ font-size: 14px !important;
505
+ border: 2px solid #2E86C1 !important;
506
+ }
507
+
508
+ /* Buttons */
509
+ button {
510
+ border-radius: 20px !important;
511
+ padding: 10px 20px !important;
512
+ font-weight: bold !important;
513
+ transition: all 0.3s !important;
514
+ }
515
+
516
+ button:hover {
517
+ transform: translateY(-2px) !important;
518
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1) !important;
519
+ }
520
+
521
+ /* Footer */
522
+ .footer {
523
+ text-align: center;
524
+ padding: 10px;
525
+ color: #666;
526
+ font-size: 12px;
527
+ }
528
+
529
+ /* Mobile optimization */
530
+ @media screen and (max-width: 768px) {
531
+ .gradio-container {
532
+ padding: 5px !important;
533
+ }
534
 
535
+ h1 {
536
+ font-size: 20px !important;
537
+ padding: 5px !important;
538
+ }
539
+
540
+ #chatbot {
541
+ min-height: 300px !important;
542
+ max-height: 50vh !important;
543
+ padding: 10px !important;
544
+ }
545
+
546
+ .user, .assistant {
547
+ max-width: 90% !important;
548
+ padding: 8px 12px !important;
549
+ font-size: 14px !important;
550
+ }
551
+
552
+ #text-input {
553
+ padding: 10px 15px !important;
554
+ font-size: 13px !important;
555
+ }
556
+
557
+ button {
558
+ padding: 8px 16px !important;
559
+ margin: 5px !important;
560
+ }
561
+ }
562
+
563
+ /* PC optimization */
564
+ @media screen and (min-width: 1200px) {
565
+ .gradio-container {
566
+ max-width: 900px !important;
567
+ }
568
+
569
+ #chatbot {
570
+ max-height: 500px !important;
571
+ }
572
+ }
573
+
574
+ /* Loading animation */
575
+ @keyframes pulse {
576
+ 0% { opacity: 1; }
577
+ 50% { opacity: 0.5; }
578
+ 100% { opacity: 1; }
579
+ }
580
+
581
+ .typing {
582
+ animation: pulse 1.5s infinite;
583
+ }
584
+
585
+ /* Table formatting for medical data */
586
+ table {
587
+ border-collapse: collapse;
588
+ width: 100%;
589
+ margin: 10px 0;
590
+ }
591
+
592
+ th, td {
593
+ border: 1px solid #ddd;
594
+ padding: 8px;
595
+ text-align: left;
596
+ }
597
+
598
+ th {
599
+ background-color: #f2f2f2;
600
+ }
601
+
602
+ /* Scrollbar styling */
603
+ ::-webkit-scrollbar {
604
+ width: 6px;
605
+ }
606
+
607
+ ::-webkit-scrollbar-track {
608
+ background: #f1f1f1;
609
+ }
610
+
611
+ ::-webkit-scrollbar-thumb {
612
+ background: #888;
613
+ border-radius: 3px;
614
+ }
615
+
616
+ ::-webkit-scrollbar-thumb:hover {
617
+ background: #555;
618
+ }
619
+ """
620
+
621
+ # Config cho Hugging Face Spaces
622
+ def get_spaces_config():
623
+ return {
624
+ "title": "DeepMed AI - Medical Assistant",
625
+ "description": "Trợ lý lâm sàng AI cho trung tâm y tế",
626
+ "thumbnail": "https://huggingface.co/spaces/your-space/your-app/raw/main/thumbnail.png",
627
+ "theme": "light",
628
+ "sdk": "gradio",
629
+ "sdk_version": "4.0.0",
630
+ }
631
+
632
+ # Memory management
633
+ def cleanup():
634
+ """Cleanup function for Hugging Face"""
635
+ if torch.cuda.is_available():
636
+ torch.cuda.empty_cache()
637
+ gc.collect()
638
 
639
  if __name__ == "__main__":
640
+ # Hugging Face Spaces config
641
+ demo.queue(max_size=20) # Giới hạn queue
642
+ demo.launch(
643
+ server_name="0.0.0.0",
644
+ server_port=7860,
645
+ show_error=True,
646
+ debug=False,
647
+ share=False,
648
+ favicon_path=None
649
+ )
650
+ cleanup()