chatbot / app.py
PBThuong's picture
Update app.py
0506d3c verified
raw
history blame
8.46 kB
__import__("pysqlite3")
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
import os
import logging
import traceback
import gradio as gr
import pandas as pd
import docx2txt
import chromadb
from chromadb.config import Settings
from shutil import rmtree
# --- CÁC THƯ VIỆN LANGCHAIN ---
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
# --- THƯ VIỆN TỐI ƯU TỐC ĐỘ (CACHE & RERANK) ---
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
# --- CẤU HÌNH HỆ THỐNG ---
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
DATA_PATH = "medical_data"
DB_PATH = "chroma_db"
CACHE_DB_PATH = "llm_cache.db" # File lưu bộ nhớ đệm
MAX_HISTORY_TURNS = 6
FORCE_REBUILD_DB = False
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
# --- KÍCH HOẠT CACHING ---
# Hệ thống sẽ lưu câu trả lời vào file .db.
# Lần sau gặp câu hỏi y hệt, nó sẽ lấy từ đệm ra ngay lập tức.
if not os.path.exists(CACHE_DB_PATH):
logging.info("Khởi tạo file cache mới.")
set_llm_cache(SQLiteCache(database_path=CACHE_DB_PATH))
def process_excel_file(file_path: str, filename: str) -> list[Document]:
"""Xử lý Excel: Biến mỗi dòng thành một Document."""
docs = []
try:
if file_path.endswith(".csv"):
df = pd.read_csv(file_path)
else:
df = pd.read_excel(file_path)
df.dropna(how='all', inplace=True)
df.fillna("Không có thông tin", inplace=True)
for idx, row in df.iterrows():
content_parts = []
for col_name, val in row.items():
clean_val = str(val).strip()
if clean_val and clean_val.lower() != "nan":
content_parts.append(f"{col_name}: {clean_val}")
if content_parts:
page_content = f"Dữ liệu từ file {filename} (Dòng {idx+1}):\n" + "\n".join(content_parts)
metadata = {"source": filename, "row": idx+1, "type": "excel_record"}
docs.append(Document(page_content=page_content, metadata=metadata))
except Exception as e:
logging.error(f"Lỗi xử lý Excel {filename}: {e}")
return docs
def load_documents_from_folder(folder_path: str) -> list[Document]:
logging.info(f"--- Bắt đầu quét thư mục: {folder_path} ---")
documents: list[Document] = []
if not os.path.exists(folder_path):
os.makedirs(folder_path, exist_ok=True)
return []
for root, _, files in os.walk(folder_path):
for filename in files:
file_path = os.path.join(root, filename)
filename_lower = filename.lower()
try:
if filename_lower.endswith(".pdf"):
loader = PyPDFLoader(file_path)
docs = loader.load()
for d in docs: d.metadata["source"] = filename
documents.extend(docs)
elif filename_lower.endswith(".docx"):
text = docx2txt.process(file_path)
if text.strip():
documents.append(Document(page_content=text, metadata={"source": filename}))
elif filename_lower.endswith((".xlsx", ".xls", ".csv")):
excel_docs = process_excel_file(file_path, filename)
documents.extend(excel_docs)
elif filename_lower.endswith((".txt", ".md")):
with open(file_path, "r", encoding="utf-8") as f: text = f.read()
if text.strip():
documents.append(Document(page_content=text, metadata={"source": filename}))
except Exception as e:
logging.error(f"Lỗi đọc file {filename}: {e}")
logging.info(f"Tổng cộng đã load: {len(documents)} tài liệu gốc.")
return documents
def get_retriever_chain():
logging.info("--- Tải Embedding Model ---")
# Chạy trên CPU để tiết kiệm resource, đổi 'cpu' thành 'cuda' nếu có GPU
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
model_kwargs={'device': 'cpu'}
)
vectorstore = None
chroma_settings = Settings(anonymized_telemetry=False)
if FORCE_REBUILD_DB and os.path.exists(DB_PATH):
rmtree(DB_PATH, ignore_errors=True)
# 1. TỐI ƯU: Kiểm tra nhanh DB bằng count() thay vì load toàn bộ
if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
try:
vectorstore = Chroma(
persist_directory=DB_PATH,
embedding_function=embedding_model,
client_settings=chroma_settings
)
if vectorstore._collection.count() > 0:
logging.info(f"Đã kết nối DB cũ. Size: {vectorstore._collection.count()}")
else:
vectorstore = None
except Exception as e:
logging.error(f"DB lỗi: {e}. Reset DB...")
rmtree(DB_PATH, ignore_errors=True)
vectorstore = None
if not vectorstore:
logging.info("--- Tạo Index dữ liệu mới ---")
raw_docs = load_documents_from_folder(DATA_PATH)
if not raw_docs:
logging.warning("Không có dữ liệu trong thư mục medical_data.")
return None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(raw_docs)
vectorstore = Chroma.from_documents(
documents=splits,
embedding=embedding_model,
persist_directory=DB_PATH,
client_settings=chroma_settings
)
logging.info("Đã lưu VectorStore thành công.")
# 2. TỐI ƯU: Giảm k ban đầu xuống 6 để bớt tính toán
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
# 3. TỐI ƯU: Sử dụng FlashRank (Siêu nhẹ & Nhanh) thay vì CrossEncoder
logging.info("--- Tải Reranker Model (FlashRank) ---")
compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2") # Model ~40MB
final_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=vector_retriever
)
return final_retriever
class DeepMedBot:
def __init__(self):
self.rag_chain = None
self.ready = False
if not GOOGLE_API_KEY:
logging.error("⚠️ Thiếu GOOGLE_API_KEY!")
return
try:
self.retr2.5-flash",
temperature=0.3,
google_api_key=GOOGLE_API_KEY
)
self._build_chains()
self.ready = True
logging.info("✅ Bot DeepMed đã sẵn sàng!")
except Exception as e:
logging.error(f"🔥 Lỗi khởi tạo bot: {e}")
logging.debug(traceback.format_exc())
def _build_chains(self):
context_system_prompt = (
"Viết lại câu hỏi của người dùng thành câu đầy đủ ngữ cảnh. "
"KHÔNG trả lời, chỉ viết lại."
)
context_prompt = ChatPromptTemplate.from_messages([
Ba)")
chat_interface = gr.ChatInterface(
fn=gradio_chat_stream,
)
if __name__ == "__main__":
demo.launch()