__import__("pysqlite3") import sys sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") import os import logging import traceback import gradio as gr import pandas as pd import docx2txt import chromadb from chromadb.config import Settings from shutil import rmtree # --- CÁC THƯ VIỆN LANGCHAIN --- from langchain_google_genai import ChatGoogleGenerativeAI from langchain_chroma import Chroma from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.retrievers import BM25Retriever from langchain.retrievers.ensemble import EnsembleRetriever from langchain.chains import create_retrieval_chain, create_history_aware_retriever from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import HumanMessage, AIMessage from langchain_core.documents import Document from langchain_huggingface import HuggingFaceEmbeddings from langchain.retrievers import ContextualCompressionRetriever # --- THƯ VIỆN TỐI ƯU TỐC ĐỘ (CACHE & RERANK) --- from langchain.retrievers.document_compressors import FlashrankRerank from langchain.globals import set_llm_cache from langchain_community.cache import SQLiteCache # --- CẤU HÌNH HỆ THỐNG --- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") DATA_PATH = "medical_data" DB_PATH = "chroma_db" CACHE_DB_PATH = "llm_cache.db" # File lưu bộ nhớ đệm MAX_HISTORY_TURNS = 6 FORCE_REBUILD_DB = False logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") # --- KÍCH HOẠT CACHING --- # Hệ thống sẽ lưu câu trả lời vào file .db. # Lần sau gặp câu hỏi y hệt, nó sẽ lấy từ đệm ra ngay lập tức. if not os.path.exists(CACHE_DB_PATH): logging.info("Khởi tạo file cache mới.") set_llm_cache(SQLiteCache(database_path=CACHE_DB_PATH)) def process_excel_file(file_path: str, filename: str) -> list[Document]: """Xử lý Excel: Biến mỗi dòng thành một Document.""" docs = [] try: if file_path.endswith(".csv"): df = pd.read_csv(file_path) else: df = pd.read_excel(file_path) df.dropna(how='all', inplace=True) df.fillna("Không có thông tin", inplace=True) for idx, row in df.iterrows(): content_parts = [] for col_name, val in row.items(): clean_val = str(val).strip() if clean_val and clean_val.lower() != "nan": content_parts.append(f"{col_name}: {clean_val}") if content_parts: page_content = f"Dữ liệu từ file {filename} (Dòng {idx+1}):\n" + "\n".join(content_parts) metadata = {"source": filename, "row": idx+1, "type": "excel_record"} docs.append(Document(page_content=page_content, metadata=metadata)) except Exception as e: logging.error(f"Lỗi xử lý Excel {filename}: {e}") return docs def load_documents_from_folder(folder_path: str) -> list[Document]: logging.info(f"--- Bắt đầu quét thư mục: {folder_path} ---") documents: list[Document] = [] if not os.path.exists(folder_path): os.makedirs(folder_path, exist_ok=True) return [] for root, _, files in os.walk(folder_path): for filename in files: file_path = os.path.join(root, filename) filename_lower = filename.lower() try: if filename_lower.endswith(".pdf"): loader = PyPDFLoader(file_path) docs = loader.load() for d in docs: d.metadata["source"] = filename documents.extend(docs) elif filename_lower.endswith(".docx"): text = docx2txt.process(file_path) if text.strip(): documents.append(Document(page_content=text, metadata={"source": filename})) elif filename_lower.endswith((".xlsx", ".xls", ".csv")): excel_docs = process_excel_file(file_path, filename) documents.extend(excel_docs) elif filename_lower.endswith((".txt", ".md")): with open(file_path, "r", encoding="utf-8") as f: text = f.read() if text.strip(): documents.append(Document(page_content=text, metadata={"source": filename})) except Exception as e: logging.error(f"Lỗi đọc file {filename}: {e}") logging.info(f"Tổng cộng đã load: {len(documents)} tài liệu gốc.") return documents def get_retriever_chain(): logging.info("--- Tải Embedding Model ---") # Chạy trên CPU để tiết kiệm resource, đổi 'cpu' thành 'cuda' nếu có GPU embedding_model = HuggingFaceEmbeddings( model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", model_kwargs={'device': 'cpu'} ) vectorstore = None chroma_settings = Settings(anonymized_telemetry=False) if FORCE_REBUILD_DB and os.path.exists(DB_PATH): rmtree(DB_PATH, ignore_errors=True) # 1. TỐI ƯU: Kiểm tra nhanh DB bằng count() thay vì load toàn bộ if os.path.exists(DB_PATH) and os.listdir(DB_PATH): try: vectorstore = Chroma( persist_directory=DB_PATH, embedding_function=embedding_model, client_settings=chroma_settings ) if vectorstore._collection.count() > 0: logging.info(f"Đã kết nối DB cũ. Size: {vectorstore._collection.count()}") else: vectorstore = None except Exception as e: logging.error(f"DB lỗi: {e}. Reset DB...") rmtree(DB_PATH, ignore_errors=True) vectorstore = None if not vectorstore: logging.info("--- Tạo Index dữ liệu mới ---") raw_docs = load_documents_from_folder(DATA_PATH) if not raw_docs: logging.warning("Không có dữ liệu trong thư mục medical_data.") return None text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) splits = text_splitter.split_documents(raw_docs) vectorstore = Chroma.from_documents( documents=splits, embedding=embedding_model, persist_directory=DB_PATH, client_settings=chroma_settings ) logging.info("Đã lưu VectorStore thành công.") # 2. TỐI ƯU: Giảm k ban đầu xuống 6 để bớt tính toán vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 6}) # 3. TỐI ƯU: Sử dụng FlashRank (Siêu nhẹ & Nhanh) thay vì CrossEncoder logging.info("--- Tải Reranker Model (FlashRank) ---") compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2") # Model ~40MB final_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=vector_retriever ) return final_retriever class DeepMedBot: def __init__(self): self.rag_chain = None self.ready = False if not GOOGLE_API_KEY: logging.error("⚠️ Thiếu GOOGLE_API_KEY!") return try: self.retr2.5-flash", temperature=0.3, google_api_key=GOOGLE_API_KEY ) self._build_chains() self.ready = True logging.info("✅ Bot DeepMed đã sẵn sàng!") except Exception as e: logging.error(f"🔥 Lỗi khởi tạo bot: {e}") logging.debug(traceback.format_exc()) def _build_chains(self): context_system_prompt = ( "Viết lại câu hỏi của người dùng thành câu đầy đủ ngữ cảnh. " "KHÔNG trả lời, chỉ viết lại." ) context_prompt = ChatPromptTemplate.from_messages([ Ba)") chat_interface = gr.ChatInterface( fn=gradio_chat_stream, ) if __name__ == "__main__": demo.launch()