chatbot / app.py
PBThuong's picture
Update app.py
2bb0244 verified
raw
history blame
18.5 kB
# =====================================================
# 1. FIX sqlite3 CHO CHROMA (pysqlite3 hack)
# =====================================================
__import__("pysqlite3")
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
# =====================================================
# 2. IMPORTS
# =====================================================
import os
import logging
import traceback
import gradio as gr
import pandas as pd
import docx2txt
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
# =====================================================
# 3. CẤU HÌNH CHUNG
# =====================================================
# Lấy API key từ biến môi trường
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
DATA_PATH = "medical_data" # thư mục chứa tài liệu
DB_PATH = "chroma_db" # thư mục chứa database Chroma
# Số lượt đối thoại gần nhất gửi vào LLM (mỗi lượt gồm 1 user + 1 bot)
MAX_HISTORY_TURNS = 6
# Chọn chiến lược truy vấn:
USE_BM25 = True # True = dùng hybrid (BM25 + Vector), False = chỉ dùng Vector
USE_MMR = True # True = dùng MMR cho retriever vector
# Logging cơ bản
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
# =====================================================
# 4. HÀM LOAD TÀI LIỆU ĐA ĐỊNH DẠNG (PDF, DOCX, EXCEL, CSV, TXT, MD)
# =====================================================
def load_documents_from_folder(folder_path: str) -> list[Document]:
"""
Quét thư mục (kể cả subfolder), đọc các file hỗ trợ và trả về list Document.
"""
logging.info(f"--- Bắt đầu quét thư mục: {folder_path} ---")
documents: list[Document] = []
if not os.path.exists(folder_path):
os.makedirs(folder_path, exist_ok=True)
logging.warning(f"Thư mục {folder_path} chưa tồn tại. Đã tạo mới, tạm thời chưa có tài liệu.")
return []
for root, _, files in os.walk(folder_path):
for filename in files:
file_path = os.path.join(root, filename)
filename_lower = filename.lower()
try:
# 1. PDF
if filename_lower.endswith(".pdf"):
logging.info(f"-> Đang xử lý PDF: {file_path}")
loader = PyPDFLoader(file_path)
docs = loader.load() # mỗi trang = 1 Document
# Thêm metadata source gọn (chỉ tên file)
for d in docs:
d.metadata["source"] = filename
documents.extend(docs)
# 2. DOCX
elif filename_lower.endswith(".docx"):
logging.info(f"-> Đang xử lý Word: {file_path}")
text = docx2txt.process(file_path)
if text and text.strip():
documents.append(
Document(
page_content=text,
metadata={"source": filename}
)
)
else:
logging.warning(f"File Word rỗng: {filename}")
# 3. EXCEL (XLS, XLSX)
elif filename_lower.endswith((".xlsx", ".xls")):
logging.info(f"-> Đang xử lý Excel: {file_path}")
try:
df = pd.read_excel(file_path)
text_data = ""
for _, row in df.iterrows():
row_str = " | ".join(
f"{col}: {val}"
for col, val in row.items()
if pd.notna(val)
)
if row_str:
text_data += row_str + "\n"
if text_data.strip():
documents.append(
Document(
page_content=text_data,
metadata={"source": filename}
)
)
else:
logging.warning(f"File Excel rỗng: {filename}")
except Exception as e:
logging.error(f"Lỗi đọc Excel {filename}: {e}")
# 4. CSV
elif filename_lower.endswith(".csv"):
logging.info(f"-> Đang xử lý CSV: {file_path}")
try:
df = pd.read_csv(file_path)
text_data = ""
for _, row in df.iterrows():
row_str = " | ".join(
f"{col}: {val}"
for col, val in row.items()
if pd.notna(val)
)
if row_str:
text_data += row_str + "\n"
if text_data.strip():
documents.append(
Document(
page_content=text_data,
metadata={"source": filename}
)
)
else:
logging.warning(f"File CSV rỗng: {filename}")
except Exception as e:
logging.error(f"Lỗi đọc CSV {filename}: {e}")
# 5. TEXT / MARKDOWN
elif filename_lower.endswith((".txt", ".md")):
logging.info(f"-> Đang xử lý Text/Markdown: {file_path}")
text = ""
try:
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
except UnicodeDecodeError:
logging.warning(f"Encoding UTF-8 thất bại, thử Latin-1 cho {filename}")
with open(file_path, "r", encoding="latin-1") as f:
text = f.read()
if text and text.strip():
documents.append(
Document(
page_content=text,
metadata={"source": filename}
)
)
else:
logging.info(f"-> Bỏ qua file không hỗ trợ: {file_path}")
except Exception as e:
logging.error(f"❌ LỖI khi đọc file {filename}: {e}")
logging.debug(traceback.format_exc())
logging.info(f"--- Hoàn tất load tài liệu. Tổng số Document: {len(documents)} ---")
return documents
# =====================================================
# 5. XÂY DỰNG VECTORSTORE + RETRIEVER
# =====================================================
def build_vectorstore_and_corpus(embedding_model):
"""
- Nếu đã có DB Chroma: load lên.
- Nếu chưa: đọc folder, split, tạo DB mới.
Trả về (vectorstore, splits).
"""
from shutil import rmtree
splits: list[Document] = []
vectorstore = None
# TH1: có DB cũ
if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
try:
logging.info("--- Tìm thấy ChromaDB cũ, đang load... ---")
vectorstore = Chroma(
persist_directory=DB_PATH,
embedding_function=embedding_model
)
existing = vectorstore.get()
if existing.get("documents"):
for text, meta in zip(existing["documents"], existing["metadatas"]):
splits.append(Document(page_content=text, metadata=meta))
logging.info(f"Tải lại corpus từ DB, tổng số chunk: {len(splits)}")
else:
logging.warning("ChromaDB không chứa documents. Sẽ rebuild.")
splits = []
except Exception as e:
logging.error(f"Lỗi đọc DB cũ: {e}. Xóa và rebuild lại.")
logging.debug(traceback.format_exc())
rmtree(DB_PATH, ignore_errors=True)
splits = []
vectorstore = None
# TH2: chưa có dữ liệu trong splits → đọc file gốc, split và tạo DB mới
if not splits:
logging.info("--- Đang đọc tài liệu gốc để tạo index mới... ---")
documents = load_documents_from_folder(DATA_PATH)
if not documents:
logging.error("Không tìm thấy tài liệu nào trong thư mục medical_data.")
return None, []
# Chunking
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800, # nhỏ hơn cho câu trả lời cụ thể hơn
chunk_overlap=150
)
splits = text_splitter.split_documents(documents)
logging.info(f"Đã split thành {len(splits)} chunks.")
# Tạo Chroma mới
logging.info("--- Đang mã hoá embedding vào ChromaDB... ---")
vectorstore = Chroma.from_documents(
documents=splits,
embedding=embedding_model,
persist_directory=DB_PATH
)
return vectorstore, splits
def get_retriever():
"""
Khởi tạo retriever:
- Load hoặc tạo ChromaDB.
- Tuỳ cấu hình: dùng hybrid (BM25 + Vector) hoặc chỉ Vector.
- Có thể dùng MMR để tăng đa dạng context.
"""
logging.info("--- Đang tải model Embedding (HuggingFace) ---")
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)
vectorstore, splits = build_vectorstore_and_corpus(embedding_model)
if vectorstore is None or not splits:
logging.error("Không thể khởi tạo retriever vì không có dữ liệu.")
return None
# Vector retriever (có thể dùng MMR)
if USE_MMR:
logging.info("Sử dụng Vector Retriever với MMR.")
chroma_retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={"k": 8, "lambda_mult": 0.7}
)
else:
logging.info("Sử dụng Vector Retriever với similarity search.")
chroma_retriever = vectorstore.as_retriever(
search_kwargs={"k": 8}
)
# Nếu không muốn dùng BM25 -> trả về luôn retriever vector
if not USE_BM25:
logging.info("Chỉ dùng retriever Vector (không dùng BM25).")
return chroma_retriever
# Tạo BM25 retriever từ corpus
logging.info("Khởi tạo BM25 Retriever từ corpus (hybrid search).")
bm25_retriever = BM25Retriever.from_documents(splits)
bm25_retriever.k = 8
# Ensemble retriever: BM25 + Vector
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever],
weights=[0.4, 0.6]
)
logging.info("Đã khởi tạo Ensemble Retriever (BM25 + Vector).")
return ensemble_retriever
# =====================================================
# 6. LỚP CHATBOT DEEPMED
# =====================================================
class DeepMedBot:
def __init__(self):
self.retriever = None
self.rag_chain = None
self.ready = False
if not GOOGLE_API_KEY:
logging.error("GOOGLE_API_KEY chưa được thiết lập trong biến môi trường.")
return
try:
self.retriever = get_retriever()
if self.retriever is None:
logging.error("Không khởi tạo được retriever.")
return
self.llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash", # có thể đổi sang model mạnh hơn nếu muốn
temperature=0.3,
google_api_key=GOOGLE_API_KEY
)
self._build_chains()
self.ready = True
logging.info("--- DeepMedBot đã sẵn sàng ---")
except Exception as e:
logging.error(f"Lỗi khởi tạo DeepMedBot: {e}")
logging.debug(traceback.format_exc())
self.ready = False
def _build_chains(self):
# 1. Prompt contextualize question
contextualize_q_system_prompt = (
"Dựa trên lịch sử trò chuyện và câu hỏi mới nhất của người dùng, "
"nếu câu hỏi liên quan đến ngữ cảnh trước đó, hãy viết lại nó thành một câu hỏi độc lập đầy đủ ý nghĩa. "
"Nếu không liên quan, giữ nguyên câu hỏi gốc. KHÔNG trả lời câu hỏi, chỉ viết lại."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages([
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
history_aware_retriever = create_history_aware_retriever(
self.llm, self.retriever, contextualize_q_prompt
)
# 2. Prompt trả lời y khoa an toàn
qa_system_prompt = (
"Bạn là trợ lý y tế DeepMed. Bạn hỗ trợ giải thích thông tin y khoa dựa trên tài liệu được cung cấp (Context) "
"và kiến thức y khoa chung của bạn.\n\n"
"QUY TẮC AN TOÀN:\n"
"- Nếu Context không đủ để trả lời chính xác, hãy nói bạn không chắc chắn và khuyên người dùng hỏi bác sĩ.\n"
"- Ưu tiên sử dụng thông tin trong Context. Nếu có mâu thuẫn giữa Context và hiểu biết của bạn, "
"hãy nói rõ và khuyên tham khảo nguồn chính thống mới nhất.\n\n"
"CÁCH TRẢ LỜI:\n"
"- Trả lời ngắn gọn, dễ hiểu.\n"
"- Nếu phù hợp, chia câu trả lời thành các mục: Tóm tắt, Chi tiết, Lưu ý.\n\n"
"Context:\n{context}"
)
qa_prompt = ChatPromptTemplate.from_messages([
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
# 3. Gộp thành RAG chain
self.rag_chain = create_retrieval_chain(
history_aware_retriever,
question_answer_chain
)
def chat(self, message: str, history: list[list[str]]):
"""
Hàm dùng cho Gradio:
- history: list[[user_msg, bot_msg], ...]
"""
if not self.ready:
return "❗ Hệ thống chưa sẵn sàng. Vui lòng kiểm tra lại API key và dữ liệu (medical_data)."
# Giới hạn lịch sử gửi vào LLM cho đỡ nặng
if len(history) > MAX_HISTORY_TURNS:
history = history[-MAX_HISTORY_TURNS:]
# Chuyển lịch sử Gradio sang LangChain messages
chat_history = []
for user_msg, bot_msg in history:
chat_history.append(HumanMessage(content=user_msg))
chat_history.append(AIMessage(content=bot_msg))
try:
response = self.rag_chain.invoke({
"input": message,
"chat_history": chat_history
})
answer = response.get("answer", "")
# Xử lý trích dẫn nguồn
references_text = self._build_references_text(response)
if references_text:
answer += "\n\n---\n📚 **Tài liệu tham khảo:**\n" + references_text
return answer
except Exception as e:
logging.error(f"Lỗi khi xử lý chat: {e}")
logging.debug(traceback.format_exc())
return "🤖 Xin lỗi, hệ thống gặp lỗi nội bộ. Bạn hãy thử lại sau ít phút."
@staticmethod
def _build_references_text(response) -> str:
"""
Gom nguồn trích dẫn:
- Gộp theo tên file.
- Nếu có số trang, liệt kê danh sách trang.
"""
from collections import defaultdict
if "context" not in response:
return ""
source_pages = defaultdict(set) # {source_name: {page1, page2, ...}}
for doc in response["context"]:
src = os.path.basename(doc.metadata.get("source", "Tài liệu không tên"))
page = doc.metadata.get("page", None)
if page is not None:
source_pages[src].add(page + 1)
else:
# Đảm bảo vẫn tạo key nếu chưa có trang
_ = source_pages[src]
lines = []
for src, pages in source_pages.items():
if pages:
pages_str = ", ".join(str(p) for p in sorted(pages))
lines.append(f"- {src} (Trang {pages_str})")
else:
lines.append(f"- {src}")
return "\n".join(lines)
bot = DeepMedBot()
def gradio_chat(message, history):
return bot.chat(message, history)
demo = gr.ChatInterface(
fn=gradio_chat,
title="🏥 DeepMed AI Pro - Trợ lý Y tế tại Trung Tâm Y Tế khu vực Thanh Ba",
description=(
"Chatbot hỗ trợ tra cứu thông tin y khoa từ kho tài liệu nội bộ.\n"
),
)
if __name__ == "__main__":
demo.launch()