Spaces:
Running
Running
| # ===================================================== | |
| # 1. FIX sqlite3 CHO CHROMA (pysqlite3 hack) | |
| # ===================================================== | |
| __import__("pysqlite3") | |
| import sys | |
| sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") | |
| # ===================================================== | |
| # 2. IMPORTS | |
| # ===================================================== | |
| import os | |
| import logging | |
| import traceback | |
| import gradio as gr | |
| import pandas as pd | |
| import docx2txt | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_chroma import Chroma | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers.ensemble import EnsembleRetriever | |
| from langchain.chains import create_retrieval_chain, create_history_aware_retriever | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_core.documents import Document | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # ===================================================== | |
| # 3. CẤU HÌNH CHUNG | |
| # ===================================================== | |
| # Lấy API key từ biến môi trường | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| DATA_PATH = "medical_data" # thư mục chứa tài liệu | |
| DB_PATH = "chroma_db" # thư mục chứa database Chroma | |
| # Số lượt đối thoại gần nhất gửi vào LLM (mỗi lượt gồm 1 user + 1 bot) | |
| MAX_HISTORY_TURNS = 6 | |
| # Chọn chiến lược truy vấn: | |
| USE_BM25 = True # True = dùng hybrid (BM25 + Vector), False = chỉ dùng Vector | |
| USE_MMR = True # True = dùng MMR cho retriever vector | |
| # Logging cơ bản | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| ) | |
| # ===================================================== | |
| # 4. HÀM LOAD TÀI LIỆU ĐA ĐỊNH DẠNG (PDF, DOCX, EXCEL, CSV, TXT, MD) | |
| # ===================================================== | |
| def load_documents_from_folder(folder_path: str) -> list[Document]: | |
| """ | |
| Quét thư mục (kể cả subfolder), đọc các file hỗ trợ và trả về list Document. | |
| """ | |
| logging.info(f"--- Bắt đầu quét thư mục: {folder_path} ---") | |
| documents: list[Document] = [] | |
| if not os.path.exists(folder_path): | |
| os.makedirs(folder_path, exist_ok=True) | |
| logging.warning(f"Thư mục {folder_path} chưa tồn tại. Đã tạo mới, tạm thời chưa có tài liệu.") | |
| return [] | |
| for root, _, files in os.walk(folder_path): | |
| for filename in files: | |
| file_path = os.path.join(root, filename) | |
| filename_lower = filename.lower() | |
| try: | |
| # 1. PDF | |
| if filename_lower.endswith(".pdf"): | |
| logging.info(f"-> Đang xử lý PDF: {file_path}") | |
| loader = PyPDFLoader(file_path) | |
| docs = loader.load() # mỗi trang = 1 Document | |
| # Thêm metadata source gọn (chỉ tên file) | |
| for d in docs: | |
| d.metadata["source"] = filename | |
| documents.extend(docs) | |
| # 2. DOCX | |
| elif filename_lower.endswith(".docx"): | |
| logging.info(f"-> Đang xử lý Word: {file_path}") | |
| text = docx2txt.process(file_path) | |
| if text and text.strip(): | |
| documents.append( | |
| Document( | |
| page_content=text, | |
| metadata={"source": filename} | |
| ) | |
| ) | |
| else: | |
| logging.warning(f"File Word rỗng: {filename}") | |
| # 3. EXCEL (XLS, XLSX) | |
| elif filename_lower.endswith((".xlsx", ".xls")): | |
| logging.info(f"-> Đang xử lý Excel: {file_path}") | |
| try: | |
| df = pd.read_excel(file_path) | |
| text_data = "" | |
| for _, row in df.iterrows(): | |
| row_str = " | ".join( | |
| f"{col}: {val}" | |
| for col, val in row.items() | |
| if pd.notna(val) | |
| ) | |
| if row_str: | |
| text_data += row_str + "\n" | |
| if text_data.strip(): | |
| documents.append( | |
| Document( | |
| page_content=text_data, | |
| metadata={"source": filename} | |
| ) | |
| ) | |
| else: | |
| logging.warning(f"File Excel rỗng: {filename}") | |
| except Exception as e: | |
| logging.error(f"Lỗi đọc Excel {filename}: {e}") | |
| # 4. CSV | |
| elif filename_lower.endswith(".csv"): | |
| logging.info(f"-> Đang xử lý CSV: {file_path}") | |
| try: | |
| df = pd.read_csv(file_path) | |
| text_data = "" | |
| for _, row in df.iterrows(): | |
| row_str = " | ".join( | |
| f"{col}: {val}" | |
| for col, val in row.items() | |
| if pd.notna(val) | |
| ) | |
| if row_str: | |
| text_data += row_str + "\n" | |
| if text_data.strip(): | |
| documents.append( | |
| Document( | |
| page_content=text_data, | |
| metadata={"source": filename} | |
| ) | |
| ) | |
| else: | |
| logging.warning(f"File CSV rỗng: {filename}") | |
| except Exception as e: | |
| logging.error(f"Lỗi đọc CSV {filename}: {e}") | |
| # 5. TEXT / MARKDOWN | |
| elif filename_lower.endswith((".txt", ".md")): | |
| logging.info(f"-> Đang xử lý Text/Markdown: {file_path}") | |
| text = "" | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| except UnicodeDecodeError: | |
| logging.warning(f"Encoding UTF-8 thất bại, thử Latin-1 cho {filename}") | |
| with open(file_path, "r", encoding="latin-1") as f: | |
| text = f.read() | |
| if text and text.strip(): | |
| documents.append( | |
| Document( | |
| page_content=text, | |
| metadata={"source": filename} | |
| ) | |
| ) | |
| else: | |
| logging.info(f"-> Bỏ qua file không hỗ trợ: {file_path}") | |
| except Exception as e: | |
| logging.error(f"❌ LỖI khi đọc file {filename}: {e}") | |
| logging.debug(traceback.format_exc()) | |
| logging.info(f"--- Hoàn tất load tài liệu. Tổng số Document: {len(documents)} ---") | |
| return documents | |
| # ===================================================== | |
| # 5. XÂY DỰNG VECTORSTORE + RETRIEVER | |
| # ===================================================== | |
| def build_vectorstore_and_corpus(embedding_model): | |
| """ | |
| - Nếu đã có DB Chroma: load lên. | |
| - Nếu chưa: đọc folder, split, tạo DB mới. | |
| Trả về (vectorstore, splits). | |
| """ | |
| from shutil import rmtree | |
| splits: list[Document] = [] | |
| vectorstore = None | |
| # TH1: có DB cũ | |
| if os.path.exists(DB_PATH) and os.listdir(DB_PATH): | |
| try: | |
| logging.info("--- Tìm thấy ChromaDB cũ, đang load... ---") | |
| vectorstore = Chroma( | |
| persist_directory=DB_PATH, | |
| embedding_function=embedding_model | |
| ) | |
| existing = vectorstore.get() | |
| if existing.get("documents"): | |
| for text, meta in zip(existing["documents"], existing["metadatas"]): | |
| splits.append(Document(page_content=text, metadata=meta)) | |
| logging.info(f"Tải lại corpus từ DB, tổng số chunk: {len(splits)}") | |
| else: | |
| logging.warning("ChromaDB không chứa documents. Sẽ rebuild.") | |
| splits = [] | |
| except Exception as e: | |
| logging.error(f"Lỗi đọc DB cũ: {e}. Xóa và rebuild lại.") | |
| logging.debug(traceback.format_exc()) | |
| rmtree(DB_PATH, ignore_errors=True) | |
| splits = [] | |
| vectorstore = None | |
| # TH2: chưa có dữ liệu trong splits → đọc file gốc, split và tạo DB mới | |
| if not splits: | |
| logging.info("--- Đang đọc tài liệu gốc để tạo index mới... ---") | |
| documents = load_documents_from_folder(DATA_PATH) | |
| if not documents: | |
| logging.error("Không tìm thấy tài liệu nào trong thư mục medical_data.") | |
| return None, [] | |
| # Chunking | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, # nhỏ hơn cho câu trả lời cụ thể hơn | |
| chunk_overlap=150 | |
| ) | |
| splits = text_splitter.split_documents(documents) | |
| logging.info(f"Đã split thành {len(splits)} chunks.") | |
| # Tạo Chroma mới | |
| logging.info("--- Đang mã hoá embedding vào ChromaDB... ---") | |
| vectorstore = Chroma.from_documents( | |
| documents=splits, | |
| embedding=embedding_model, | |
| persist_directory=DB_PATH | |
| ) | |
| return vectorstore, splits | |
| def get_retriever(): | |
| """ | |
| Khởi tạo retriever: | |
| - Load hoặc tạo ChromaDB. | |
| - Tuỳ cấu hình: dùng hybrid (BM25 + Vector) hoặc chỉ Vector. | |
| - Có thể dùng MMR để tăng đa dạng context. | |
| """ | |
| logging.info("--- Đang tải model Embedding (HuggingFace) ---") | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| ) | |
| vectorstore, splits = build_vectorstore_and_corpus(embedding_model) | |
| if vectorstore is None or not splits: | |
| logging.error("Không thể khởi tạo retriever vì không có dữ liệu.") | |
| return None | |
| # Vector retriever (có thể dùng MMR) | |
| if USE_MMR: | |
| logging.info("Sử dụng Vector Retriever với MMR.") | |
| chroma_retriever = vectorstore.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": 8, "lambda_mult": 0.7} | |
| ) | |
| else: | |
| logging.info("Sử dụng Vector Retriever với similarity search.") | |
| chroma_retriever = vectorstore.as_retriever( | |
| search_kwargs={"k": 8} | |
| ) | |
| # Nếu không muốn dùng BM25 -> trả về luôn retriever vector | |
| if not USE_BM25: | |
| logging.info("Chỉ dùng retriever Vector (không dùng BM25).") | |
| return chroma_retriever | |
| # Tạo BM25 retriever từ corpus | |
| logging.info("Khởi tạo BM25 Retriever từ corpus (hybrid search).") | |
| bm25_retriever = BM25Retriever.from_documents(splits) | |
| bm25_retriever.k = 8 | |
| # Ensemble retriever: BM25 + Vector | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[bm25_retriever, chroma_retriever], | |
| weights=[0.4, 0.6] | |
| ) | |
| logging.info("Đã khởi tạo Ensemble Retriever (BM25 + Vector).") | |
| return ensemble_retriever | |
| # ===================================================== | |
| # 6. LỚP CHATBOT DEEPMED | |
| # ===================================================== | |
| class DeepMedBot: | |
| def __init__(self): | |
| self.retriever = None | |
| self.rag_chain = None | |
| self.ready = False | |
| if not GOOGLE_API_KEY: | |
| logging.error("GOOGLE_API_KEY chưa được thiết lập trong biến môi trường.") | |
| return | |
| try: | |
| self.retriever = get_retriever() | |
| if self.retriever is None: | |
| logging.error("Không khởi tạo được retriever.") | |
| return | |
| self.llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash", # có thể đổi sang model mạnh hơn nếu muốn | |
| temperature=0.3, | |
| google_api_key=GOOGLE_API_KEY | |
| ) | |
| self._build_chains() | |
| self.ready = True | |
| logging.info("--- DeepMedBot đã sẵn sàng ---") | |
| except Exception as e: | |
| logging.error(f"Lỗi khởi tạo DeepMedBot: {e}") | |
| logging.debug(traceback.format_exc()) | |
| self.ready = False | |
| def _build_chains(self): | |
| # 1. Prompt contextualize question | |
| contextualize_q_system_prompt = ( | |
| "Dựa trên lịch sử trò chuyện và câu hỏi mới nhất của người dùng, " | |
| "nếu câu hỏi liên quan đến ngữ cảnh trước đó, hãy viết lại nó thành một câu hỏi độc lập đầy đủ ý nghĩa. " | |
| "Nếu không liên quan, giữ nguyên câu hỏi gốc. KHÔNG trả lời câu hỏi, chỉ viết lại." | |
| ) | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", contextualize_q_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ]) | |
| history_aware_retriever = create_history_aware_retriever( | |
| self.llm, self.retriever, contextualize_q_prompt | |
| ) | |
| # 2. Prompt trả lời y khoa an toàn | |
| qa_system_prompt = ( | |
| "Bạn là trợ lý y tế DeepMed. Bạn hỗ trợ giải thích thông tin y khoa dựa trên tài liệu được cung cấp (Context) " | |
| "và kiến thức y khoa chung của bạn.\n\n" | |
| "QUY TẮC AN TOÀN:\n" | |
| "- Nếu Context không đủ để trả lời chính xác, hãy nói bạn không chắc chắn và khuyên người dùng hỏi bác sĩ.\n" | |
| "- Ưu tiên sử dụng thông tin trong Context. Nếu có mâu thuẫn giữa Context và hiểu biết của bạn, " | |
| "hãy nói rõ và khuyên tham khảo nguồn chính thống mới nhất.\n\n" | |
| "CÁCH TRẢ LỜI:\n" | |
| "- Trả lời ngắn gọn, dễ hiểu.\n" | |
| "- Nếu phù hợp, chia câu trả lời thành các mục: Tóm tắt, Chi tiết, Lưu ý.\n\n" | |
| "Context:\n{context}" | |
| ) | |
| qa_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", qa_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ]) | |
| question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt) | |
| # 3. Gộp thành RAG chain | |
| self.rag_chain = create_retrieval_chain( | |
| history_aware_retriever, | |
| question_answer_chain | |
| ) | |
| def chat(self, message: str, history: list[list[str]]): | |
| """ | |
| Hàm dùng cho Gradio: | |
| - history: list[[user_msg, bot_msg], ...] | |
| """ | |
| if not self.ready: | |
| return "❗ Hệ thống chưa sẵn sàng. Vui lòng kiểm tra lại API key và dữ liệu (medical_data)." | |
| # Giới hạn lịch sử gửi vào LLM cho đỡ nặng | |
| if len(history) > MAX_HISTORY_TURNS: | |
| history = history[-MAX_HISTORY_TURNS:] | |
| # Chuyển lịch sử Gradio sang LangChain messages | |
| chat_history = [] | |
| for user_msg, bot_msg in history: | |
| chat_history.append(HumanMessage(content=user_msg)) | |
| chat_history.append(AIMessage(content=bot_msg)) | |
| try: | |
| response = self.rag_chain.invoke({ | |
| "input": message, | |
| "chat_history": chat_history | |
| }) | |
| answer = response.get("answer", "") | |
| # Xử lý trích dẫn nguồn | |
| references_text = self._build_references_text(response) | |
| if references_text: | |
| answer += "\n\n---\n📚 **Tài liệu tham khảo:**\n" + references_text | |
| return answer | |
| except Exception as e: | |
| logging.error(f"Lỗi khi xử lý chat: {e}") | |
| logging.debug(traceback.format_exc()) | |
| return "🤖 Xin lỗi, hệ thống gặp lỗi nội bộ. Bạn hãy thử lại sau ít phút." | |
| def _build_references_text(response) -> str: | |
| """ | |
| Gom nguồn trích dẫn: | |
| - Gộp theo tên file. | |
| - Nếu có số trang, liệt kê danh sách trang. | |
| """ | |
| from collections import defaultdict | |
| if "context" not in response: | |
| return "" | |
| source_pages = defaultdict(set) # {source_name: {page1, page2, ...}} | |
| for doc in response["context"]: | |
| src = os.path.basename(doc.metadata.get("source", "Tài liệu không tên")) | |
| page = doc.metadata.get("page", None) | |
| if page is not None: | |
| source_pages[src].add(page + 1) | |
| else: | |
| # Đảm bảo vẫn tạo key nếu chưa có trang | |
| _ = source_pages[src] | |
| lines = [] | |
| for src, pages in source_pages.items(): | |
| if pages: | |
| pages_str = ", ".join(str(p) for p in sorted(pages)) | |
| lines.append(f"- {src} (Trang {pages_str})") | |
| else: | |
| lines.append(f"- {src}") | |
| return "\n".join(lines) | |
| bot = DeepMedBot() | |
| def gradio_chat(message, history): | |
| return bot.chat(message, history) | |
| demo = gr.ChatInterface( | |
| fn=gradio_chat, | |
| title="🏥 DeepMed AI Pro - Trợ lý Y tế tại Trung Tâm Y Tế khu vực Thanh Ba", | |
| description=( | |
| "Chatbot hỗ trợ tra cứu thông tin y khoa từ kho tài liệu nội bộ.\n" | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |