Spaces:
Sleeping
Sleeping
| __import__("pysqlite3") | |
| import sys | |
| sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") | |
| import os | |
| import logging | |
| import traceback | |
| import gradio as gr | |
| import pandas as pd | |
| import docx2txt | |
| import chromadb | |
| from chromadb.config import Settings | |
| from shutil import rmtree | |
| # --- CÁC THƯ VIỆN LANGCHAIN --- | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_chroma import Chroma | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers.ensemble import EnsembleRetriever | |
| from langchain.chains import create_retrieval_chain, create_history_aware_retriever | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_core.documents import Document | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| # --- THƯ VIỆN TỐI ƯU TỐC ĐỘ (CACHE & RERANK) --- | |
| from langchain.retrievers.document_compressors import FlashrankRerank | |
| from langchain.globals import set_llm_cache | |
| from langchain_community.cache import SQLiteCache | |
| # --- CẤU HÌNH HỆ THỐNG --- | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| DATA_PATH = "medical_data" | |
| DB_PATH = "chroma_db" | |
| CACHE_DB_PATH = "llm_cache.db" # File lưu bộ nhớ đệm | |
| MAX_HISTORY_TURNS = 6 | |
| FORCE_REBUILD_DB = False | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| # --- KÍCH HOẠT CACHING --- | |
| # Hệ thống sẽ lưu câu trả lời vào file .db. | |
| # Lần sau gặp câu hỏi y hệt, nó sẽ lấy từ đệm ra ngay lập tức. | |
| if not os.path.exists(CACHE_DB_PATH): | |
| logging.info("Khởi tạo file cache mới.") | |
| set_llm_cache(SQLiteCache(database_path=CACHE_DB_PATH)) | |
| def process_excel_file(file_path: str, filename: str) -> list[Document]: | |
| """Xử lý Excel: Biến mỗi dòng thành một Document.""" | |
| docs = [] | |
| try: | |
| if file_path.endswith(".csv"): | |
| df = pd.read_csv(file_path) | |
| else: | |
| df = pd.read_excel(file_path) | |
| df.dropna(how='all', inplace=True) | |
| df.fillna("Không có thông tin", inplace=True) | |
| for idx, row in df.iterrows(): | |
| content_parts = [] | |
| for col_name, val in row.items(): | |
| clean_val = str(val).strip() | |
| if clean_val and clean_val.lower() != "nan": | |
| content_parts.append(f"{col_name}: {clean_val}") | |
| if content_parts: | |
| page_content = f"Dữ liệu từ file {filename} (Dòng {idx+1}):\n" + "\n".join(content_parts) | |
| metadata = {"source": filename, "row": idx+1, "type": "excel_record"} | |
| docs.append(Document(page_content=page_content, metadata=metadata)) | |
| except Exception as e: | |
| logging.error(f"Lỗi xử lý Excel {filename}: {e}") | |
| return docs | |
| def load_documents_from_folder(folder_path: str) -> list[Document]: | |
| logging.info(f"--- Bắt đầu quét thư mục: {folder_path} ---") | |
| documents: list[Document] = [] | |
| if not os.path.exists(folder_path): | |
| os.makedirs(folder_path, exist_ok=True) | |
| return [] | |
| for root, _, files in os.walk(folder_path): | |
| for filename in files: | |
| file_path = os.path.join(root, filename) | |
| filename_lower = filename.lower() | |
| try: | |
| if filename_lower.endswith(".pdf"): | |
| loader = PyPDFLoader(file_path) | |
| docs = loader.load() | |
| for d in docs: d.metadata["source"] = filename | |
| documents.extend(docs) | |
| elif filename_lower.endswith(".docx"): | |
| text = docx2txt.process(file_path) | |
| if text.strip(): | |
| documents.append(Document(page_content=text, metadata={"source": filename})) | |
| elif filename_lower.endswith((".xlsx", ".xls", ".csv")): | |
| excel_docs = process_excel_file(file_path, filename) | |
| documents.extend(excel_docs) | |
| elif filename_lower.endswith((".txt", ".md")): | |
| with open(file_path, "r", encoding="utf-8") as f: text = f.read() | |
| if text.strip(): | |
| documents.append(Document(page_content=text, metadata={"source": filename})) | |
| except Exception as e: | |
| logging.error(f"Lỗi đọc file {filename}: {e}") | |
| logging.info(f"Tổng cộng đã load: {len(documents)} tài liệu gốc.") | |
| return documents | |
| def get_retriever_chain(): | |
| logging.info("--- Tải Embedding Model ---") | |
| # Chạy trên CPU để tiết kiệm resource, đổi 'cpu' thành 'cuda' nếu có GPU | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| vectorstore = None | |
| chroma_settings = Settings(anonymized_telemetry=False) | |
| if FORCE_REBUILD_DB and os.path.exists(DB_PATH): | |
| rmtree(DB_PATH, ignore_errors=True) | |
| # 1. TỐI ƯU: Kiểm tra nhanh DB bằng count() thay vì load toàn bộ | |
| if os.path.exists(DB_PATH) and os.listdir(DB_PATH): | |
| try: | |
| vectorstore = Chroma( | |
| persist_directory=DB_PATH, | |
| embedding_function=embedding_model, | |
| client_settings=chroma_settings | |
| ) | |
| if vectorstore._collection.count() > 0: | |
| logging.info(f"Đã kết nối DB cũ. Size: {vectorstore._collection.count()}") | |
| else: | |
| vectorstore = None | |
| except Exception as e: | |
| logging.error(f"DB lỗi: {e}. Reset DB...") | |
| rmtree(DB_PATH, ignore_errors=True) | |
| vectorstore = None | |
| if not vectorstore: | |
| logging.info("--- Tạo Index dữ liệu mới ---") | |
| raw_docs = load_documents_from_folder(DATA_PATH) | |
| if not raw_docs: | |
| logging.warning("Không có dữ liệu trong thư mục medical_data.") | |
| return None | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| splits = text_splitter.split_documents(raw_docs) | |
| vectorstore = Chroma.from_documents( | |
| documents=splits, | |
| embedding=embedding_model, | |
| persist_directory=DB_PATH, | |
| client_settings=chroma_settings | |
| ) | |
| logging.info("Đã lưu VectorStore thành công.") | |
| # 2. TỐI ƯU: Giảm k ban đầu xuống 6 để bớt tính toán | |
| vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 6}) | |
| # 3. TỐI ƯU: Sử dụng FlashRank (Siêu nhẹ & Nhanh) thay vì CrossEncoder | |
| logging.info("--- Tải Reranker Model (FlashRank) ---") | |
| compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2") # Model ~40MB | |
| final_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=vector_retriever | |
| ) | |
| return final_retriever | |
| class DeepMedBot: | |
| def __init__(self): | |
| self.rag_chain = None | |
| self.ready = False | |
| if not GOOGLE_API_KEY: | |
| logging.error("⚠️ Thiếu GOOGLE_API_KEY!") | |
| return | |
| try: | |
| self.retr2.5-flash", | |
| temperature=0.3, | |
| google_api_key=GOOGLE_API_KEY | |
| ) | |
| self._build_chains() | |
| self.ready = True | |
| logging.info("✅ Bot DeepMed đã sẵn sàng!") | |
| except Exception as e: | |
| logging.error(f"🔥 Lỗi khởi tạo bot: {e}") | |
| logging.debug(traceback.format_exc()) | |
| def _build_chains(self): | |
| context_system_prompt = ( | |
| "Viết lại câu hỏi của người dùng thành câu đầy đủ ngữ cảnh. " | |
| "KHÔNG trả lời, chỉ viết lại." | |
| ) | |
| context_prompt = ChatPromptTemplate.from_messages([ | |
| Ba)") | |
| chat_interface = gr.ChatInterface( | |
| fn=gradio_chat_stream, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |