Spaces:
Sleeping
Sleeping
Update app.py
#3
by
PBThuong
- opened
app.py
CHANGED
|
@@ -12,6 +12,7 @@ import chromadb
|
|
| 12 |
from chromadb.config import Settings
|
| 13 |
from shutil import rmtree
|
| 14 |
|
|
|
|
| 15 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 16 |
from langchain_chroma import Chroma
|
| 17 |
from langchain_community.document_loaders import PyPDFLoader
|
|
@@ -25,22 +26,31 @@ from langchain_core.messages import HumanMessage, AIMessage
|
|
| 25 |
from langchain_core.documents import Document
|
| 26 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 27 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 28 |
-
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 29 |
-
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 32 |
DATA_PATH = "medical_data"
|
| 33 |
DB_PATH = "chroma_db"
|
|
|
|
| 34 |
MAX_HISTORY_TURNS = 6
|
| 35 |
FORCE_REBUILD_DB = False
|
| 36 |
|
| 37 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def process_excel_file(file_path: str, filename: str) -> list[Document]:
|
| 40 |
-
"""
|
| 41 |
-
Xử lý Excel thông minh: Biến mỗi dòng thành một Document riêng biệt
|
| 42 |
-
giúp tìm kiếm chính xác từng bản ghi thuốc/bệnh nhân.
|
| 43 |
-
"""
|
| 44 |
docs = []
|
| 45 |
try:
|
| 46 |
if file_path.endswith(".csv"):
|
|
@@ -108,35 +118,32 @@ def load_documents_from_folder(folder_path: str) -> list[Document]:
|
|
| 108 |
|
| 109 |
def get_retriever_chain():
|
| 110 |
logging.info("--- Tải Embedding Model ---")
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
vectorstore = None
|
| 114 |
-
splits = []
|
| 115 |
-
|
| 116 |
chroma_settings = Settings(anonymized_telemetry=False)
|
| 117 |
|
| 118 |
if FORCE_REBUILD_DB and os.path.exists(DB_PATH):
|
| 119 |
-
logging.warning("Đang xóa DB cũ theo yêu cầu FORCE_REBUILD...")
|
| 120 |
rmtree(DB_PATH, ignore_errors=True)
|
| 121 |
|
|
|
|
| 122 |
if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
|
| 123 |
try:
|
| 124 |
vectorstore = Chroma(
|
| 125 |
persist_directory=DB_PATH,
|
| 126 |
embedding_function=embedding_model,
|
| 127 |
-
client_settings=chroma_settings
|
| 128 |
)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
if existing_data['documents']:
|
| 132 |
-
for text, meta in zip(existing_data['documents'], existing_data['metadatas']):
|
| 133 |
-
splits.append(Document(page_content=text, metadata=meta))
|
| 134 |
-
logging.info(f"Đã khôi phục {len(splits)} chunks từ DB.")
|
| 135 |
else:
|
| 136 |
-
logging.warning("DB rỗng, sẽ tạo mới.")
|
| 137 |
vectorstore = None
|
| 138 |
except Exception as e:
|
| 139 |
-
logging.error(f"DB lỗi: {e}.
|
| 140 |
rmtree(DB_PATH, ignore_errors=True)
|
| 141 |
vectorstore = None
|
| 142 |
|
|
@@ -158,25 +165,16 @@ def get_retriever_chain():
|
|
| 158 |
)
|
| 159 |
logging.info("Đã lưu VectorStore thành công.")
|
| 160 |
|
| 161 |
-
|
|
|
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
ensemble_retriever = EnsembleRetriever(
|
| 167 |
-
retrievers=[bm25_retriever, vector_retriever],
|
| 168 |
-
weights=[0.4, 0.6]
|
| 169 |
-
)
|
| 170 |
-
else:
|
| 171 |
-
ensemble_retriever = vector_retriever
|
| 172 |
|
| 173 |
-
logging.info("--- Tải Reranker Model (BGE-M3) ---")
|
| 174 |
-
reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
|
| 175 |
-
compressor = CrossEncoderReranker(model=reranker_model, top_n=5)
|
| 176 |
-
|
| 177 |
final_retriever = ContextualCompressionRetriever(
|
| 178 |
base_compressor=compressor,
|
| 179 |
-
base_retriever=
|
| 180 |
)
|
| 181 |
|
| 182 |
return final_retriever
|
|
@@ -187,140 +185,28 @@ class DeepMedBot:
|
|
| 187 |
self.ready = False
|
| 188 |
|
| 189 |
if not GOOGLE_API_KEY:
|
| 190 |
-
logging.error("⚠️ Thiếu GOOGLE_API_KEY!
|
| 191 |
return
|
| 192 |
|
| 193 |
try:
|
| 194 |
-
self.
|
| 195 |
-
if not self.retriever:
|
| 196 |
-
logging.warning("⚠️ Chưa có dữ liệu để Retreive. Bot sẽ chỉ trả lời bằng kiến thức nền.")
|
| 197 |
-
|
| 198 |
-
self.llm = ChatGoogleGenerativeAI(
|
| 199 |
-
model="gemini-2.5-flash",
|
| 200 |
temperature=0.3,
|
| 201 |
google_api_key=GOOGLE_API_KEY
|
| 202 |
)
|
| 203 |
self._build_chains()
|
| 204 |
self.ready = True
|
| 205 |
-
logging.info("✅ Bot DeepMed đã sẵn sàng
|
| 206 |
except Exception as e:
|
| 207 |
logging.error(f"🔥 Lỗi khởi tạo bot: {e}")
|
| 208 |
logging.debug(traceback.format_exc())
|
| 209 |
|
| 210 |
def _build_chains(self):
|
| 211 |
context_system_prompt = (
|
| 212 |
-
"
|
| 213 |
-
"
|
| 214 |
-
"KHÔNG trả lời câu hỏi, chỉ viết lại nó."
|
| 215 |
)
|
| 216 |
context_prompt = ChatPromptTemplate.from_messages([
|
| 217 |
-
|
| 218 |
-
MessagesPlaceholder("chat_history"),
|
| 219 |
-
("human", "{input}"),
|
| 220 |
-
])
|
| 221 |
-
|
| 222 |
-
if self.retriever:
|
| 223 |
-
history_aware_retriever = create_history_aware_retriever(
|
| 224 |
-
self.llm, self.retriever, context_prompt
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
qa_system_prompt = (
|
| 228 |
-
"Bạn là 'DeepMed-AI' - Trợ lý Dược lâm sàng tại Trung Tâm Y Tế. "
|
| 229 |
-
"Sử dụng các thông tin được cung cấp trong phần Context dưới đây để trả lời câu hỏi về thuốc, bệnh học và y lệnh.\n"
|
| 230 |
-
"Nếu Context có dữ liệu từ Excel, hãy trình bày dạng bảng hoặc gạch đầu dòng rõ ràng.\n"
|
| 231 |
-
"Nếu không tìm thấy thông tin trong Context, hãy nói 'Tôi không tìm thấy thông tin trong dữ liệu nội bộ' và gợi ý dựa trên kiến thức y khoa chung của bạn.\n\n"
|
| 232 |
-
"Context:\n{context}"
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
qa_prompt = ChatPromptTemplate.from_messages([
|
| 236 |
-
("system", qa_system_prompt),
|
| 237 |
-
MessagesPlaceholder("chat_history"),
|
| 238 |
-
("human", "{input}"),
|
| 239 |
-
])
|
| 240 |
-
|
| 241 |
-
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
|
| 242 |
-
|
| 243 |
-
if self.retriever:
|
| 244 |
-
self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
| 245 |
-
else:
|
| 246 |
-
self.rag_chain = qa_prompt | self.llm
|
| 247 |
-
|
| 248 |
-
def chat_stream(self, message: str, history: list):
|
| 249 |
-
if not self.ready:
|
| 250 |
-
yield "Hệ thống đang khởi động hoặc gặp lỗi cấu hình."
|
| 251 |
-
return
|
| 252 |
-
|
| 253 |
-
chat_history = []
|
| 254 |
-
for u, b in history[-MAX_HISTORY_TURNS:]:
|
| 255 |
-
chat_history.append(HumanMessage(content=str(u)))
|
| 256 |
-
chat_history.append(AIMessage(content=str(b)))
|
| 257 |
-
|
| 258 |
-
full_response = ""
|
| 259 |
-
retrieved_docs = []
|
| 260 |
-
|
| 261 |
-
try:
|
| 262 |
-
stream_input = {"input": message, "chat_history": chat_history} if self.retriever else {"input": message, "chat_history": chat_history}
|
| 263 |
-
|
| 264 |
-
if self.rag_chain:
|
| 265 |
-
for chunk in self.rag_chain.stream(stream_input):
|
| 266 |
-
|
| 267 |
-
if isinstance(chunk, dict):
|
| 268 |
-
if "answer" in chunk:
|
| 269 |
-
full_response += chunk["answer"]
|
| 270 |
-
yield full_response
|
| 271 |
-
|
| 272 |
-
if "context" in chunk:
|
| 273 |
-
retrieved_docs = chunk["context"]
|
| 274 |
-
|
| 275 |
-
elif hasattr(chunk, 'content'):
|
| 276 |
-
full_response += chunk.content
|
| 277 |
-
yield full_response
|
| 278 |
-
|
| 279 |
-
elif isinstance(chunk, str):
|
| 280 |
-
full_response += chunk
|
| 281 |
-
yield full_response
|
| 282 |
-
|
| 283 |
-
if retrieved_docs:
|
| 284 |
-
refs = self._build_references_text(retrieved_docs)
|
| 285 |
-
if refs:
|
| 286 |
-
full_response += f"\n\n---\n📚 **Nguồn tham khảo:**\n{refs}"
|
| 287 |
-
yield full_response
|
| 288 |
-
|
| 289 |
-
except Exception as e:
|
| 290 |
-
logging.error(f"Lỗi khi chat: {e}")
|
| 291 |
-
logging.debug(traceback.format_exc())
|
| 292 |
-
yield f"Đã xảy ra lỗi: {str(e)}"
|
| 293 |
-
|
| 294 |
-
@staticmethod
|
| 295 |
-
def _build_references_text(docs) -> str:
|
| 296 |
-
lines = []
|
| 297 |
-
seen = set()
|
| 298 |
-
for doc in docs:
|
| 299 |
-
src = doc.metadata.get("source", "Tài liệu")
|
| 300 |
-
row_info = ""
|
| 301 |
-
if "row" in doc.metadata:
|
| 302 |
-
row_info = f"(Dòng {doc.metadata['row']})"
|
| 303 |
-
|
| 304 |
-
ref_str = f"- {src} {row_info}"
|
| 305 |
-
|
| 306 |
-
if ref_str not in seen:
|
| 307 |
-
lines.append(ref_str)
|
| 308 |
-
seen.add(ref_str)
|
| 309 |
-
return "\n".join(lines)
|
| 310 |
-
|
| 311 |
-
bot = DeepMedBot()
|
| 312 |
-
|
| 313 |
-
def gradio_chat_stream(message, history):
|
| 314 |
-
yield from bot.chat_stream(message, history)
|
| 315 |
-
|
| 316 |
-
css = """
|
| 317 |
-
.gradio-container {min_height: 600px !important;}
|
| 318 |
-
h1 {text-align: center; color: #2E86C1;}
|
| 319 |
-
"""
|
| 320 |
-
|
| 321 |
-
with gr.Blocks(css=css, title="DeepMed AI") as demo:
|
| 322 |
-
gr.Markdown("# 🏥 DeepMed AI - Trợ lý Lâm Sàng")
|
| 323 |
-
gr.Markdown("Hệ thống hỗ trợ lâm sàng tại Trung Tâm Y Tế Khu Vực Thanh Ba.")
|
| 324 |
|
| 325 |
chat_interface = gr.ChatInterface(
|
| 326 |
fn=gradio_chat_stream,
|
|
|
|
| 12 |
from chromadb.config import Settings
|
| 13 |
from shutil import rmtree
|
| 14 |
|
| 15 |
+
# --- CÁC THƯ VIỆN LANGCHAIN ---
|
| 16 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 17 |
from langchain_chroma import Chroma
|
| 18 |
from langchain_community.document_loaders import PyPDFLoader
|
|
|
|
| 26 |
from langchain_core.documents import Document
|
| 27 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 28 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
# --- THƯ VIỆN TỐI ƯU TỐC ĐỘ (CACHE & RERANK) ---
|
| 31 |
+
from langchain.retrievers.document_compressors import FlashrankRerank
|
| 32 |
+
from langchain.globals import set_llm_cache
|
| 33 |
+
from langchain_community.cache import SQLiteCache
|
| 34 |
+
|
| 35 |
+
# --- CẤU HÌNH HỆ THỐNG ---
|
| 36 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 37 |
DATA_PATH = "medical_data"
|
| 38 |
DB_PATH = "chroma_db"
|
| 39 |
+
CACHE_DB_PATH = "llm_cache.db" # File lưu bộ nhớ đệm
|
| 40 |
MAX_HISTORY_TURNS = 6
|
| 41 |
FORCE_REBUILD_DB = False
|
| 42 |
|
| 43 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 44 |
|
| 45 |
+
# --- KÍCH HOẠT CACHING ---
|
| 46 |
+
# Hệ thống sẽ lưu câu trả lời vào file .db.
|
| 47 |
+
# Lần sau gặp câu hỏi y hệt, nó sẽ lấy từ đệm ra ngay lập tức.
|
| 48 |
+
if not os.path.exists(CACHE_DB_PATH):
|
| 49 |
+
logging.info("Khởi tạo file cache mới.")
|
| 50 |
+
set_llm_cache(SQLiteCache(database_path=CACHE_DB_PATH))
|
| 51 |
+
|
| 52 |
def process_excel_file(file_path: str, filename: str) -> list[Document]:
|
| 53 |
+
"""Xử lý Excel: Biến mỗi dòng thành một Document."""
|
|
|
|
|
|
|
|
|
|
| 54 |
docs = []
|
| 55 |
try:
|
| 56 |
if file_path.endswith(".csv"):
|
|
|
|
| 118 |
|
| 119 |
def get_retriever_chain():
|
| 120 |
logging.info("--- Tải Embedding Model ---")
|
| 121 |
+
# Chạy trên CPU để tiết kiệm resource, đổi 'cpu' thành 'cuda' nếu có GPU
|
| 122 |
+
embedding_model = HuggingFaceEmbeddings(
|
| 123 |
+
model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 124 |
+
model_kwargs={'device': 'cpu'}
|
| 125 |
+
)
|
| 126 |
|
| 127 |
vectorstore = None
|
|
|
|
|
|
|
| 128 |
chroma_settings = Settings(anonymized_telemetry=False)
|
| 129 |
|
| 130 |
if FORCE_REBUILD_DB and os.path.exists(DB_PATH):
|
|
|
|
| 131 |
rmtree(DB_PATH, ignore_errors=True)
|
| 132 |
|
| 133 |
+
# 1. TỐI ƯU: Kiểm tra nhanh DB bằng count() thay vì load toàn bộ
|
| 134 |
if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
|
| 135 |
try:
|
| 136 |
vectorstore = Chroma(
|
| 137 |
persist_directory=DB_PATH,
|
| 138 |
embedding_function=embedding_model,
|
| 139 |
+
client_settings=chroma_settings
|
| 140 |
)
|
| 141 |
+
if vectorstore._collection.count() > 0:
|
| 142 |
+
logging.info(f"Đã kết nối DB cũ. Size: {vectorstore._collection.count()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
else:
|
|
|
|
| 144 |
vectorstore = None
|
| 145 |
except Exception as e:
|
| 146 |
+
logging.error(f"DB lỗi: {e}. Reset DB...")
|
| 147 |
rmtree(DB_PATH, ignore_errors=True)
|
| 148 |
vectorstore = None
|
| 149 |
|
|
|
|
| 165 |
)
|
| 166 |
logging.info("Đã lưu VectorStore thành công.")
|
| 167 |
|
| 168 |
+
# 2. TỐI ƯU: Giảm k ban đầu xuống 6 để bớt tính toán
|
| 169 |
+
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
|
| 170 |
|
| 171 |
+
# 3. TỐI ƯU: Sử dụng FlashRank (Siêu nhẹ & Nhanh) thay vì CrossEncoder
|
| 172 |
+
logging.info("--- Tải Reranker Model (FlashRank) ---")
|
| 173 |
+
compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2") # Model ~40MB
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
final_retriever = ContextualCompressionRetriever(
|
| 176 |
base_compressor=compressor,
|
| 177 |
+
base_retriever=vector_retriever
|
| 178 |
)
|
| 179 |
|
| 180 |
return final_retriever
|
|
|
|
| 185 |
self.ready = False
|
| 186 |
|
| 187 |
if not GOOGLE_API_KEY:
|
| 188 |
+
logging.error("⚠️ Thiếu GOOGLE_API_KEY!")
|
| 189 |
return
|
| 190 |
|
| 191 |
try:
|
| 192 |
+
self.retr2.5-flash",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
temperature=0.3,
|
| 194 |
google_api_key=GOOGLE_API_KEY
|
| 195 |
)
|
| 196 |
self._build_chains()
|
| 197 |
self.ready = True
|
| 198 |
+
logging.info("✅ Bot DeepMed đã sẵn sàng!")
|
| 199 |
except Exception as e:
|
| 200 |
logging.error(f"🔥 Lỗi khởi tạo bot: {e}")
|
| 201 |
logging.debug(traceback.format_exc())
|
| 202 |
|
| 203 |
def _build_chains(self):
|
| 204 |
context_system_prompt = (
|
| 205 |
+
"Viết lại câu hỏi của người dùng thành câu đầy đủ ngữ cảnh. "
|
| 206 |
+
"KHÔNG trả lời, chỉ viết lại."
|
|
|
|
| 207 |
)
|
| 208 |
context_prompt = ChatPromptTemplate.from_messages([
|
| 209 |
+
Ba)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
chat_interface = gr.ChatInterface(
|
| 212 |
fn=gradio_chat_stream,
|