|
|
try: |
|
|
__import__("pysqlite3") |
|
|
import sys |
|
|
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
import os |
|
|
import logging |
|
|
import traceback |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import docx2txt |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
from shutil import rmtree |
|
|
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langchain_chroma import Chroma |
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from langchain.retrievers.ensemble import EnsembleRetriever |
|
|
from langchain.chains import create_retrieval_chain, create_history_aware_retriever |
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.messages import HumanMessage, AIMessage |
|
|
from langchain_core.documents import Document |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain.retrievers import ContextualCompressionRetriever |
|
|
from langchain.retrievers.document_compressors import CrossEncoderReranker |
|
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder |
|
|
|
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
DATA_PATH = "medical_data" |
|
|
DB_PATH = "chroma_db" |
|
|
MAX_HISTORY_TURNS = 4 |
|
|
FORCE_REBUILD_DB = False |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
|
|
|
|
|
|
|
|
def process_excel_file(file_path: str, filename: str) -> list[Document]: |
|
|
docs = [] |
|
|
try: |
|
|
if file_path.endswith(".csv"): |
|
|
df = pd.read_csv(file_path) |
|
|
else: |
|
|
df = pd.read_excel(file_path) |
|
|
|
|
|
df.dropna(how='all', inplace=True) |
|
|
df.fillna("Không có thông tin", inplace=True) |
|
|
|
|
|
for idx, row in df.iterrows(): |
|
|
content_parts = [] |
|
|
for col_name, val in row.items(): |
|
|
clean_val = str(val).strip() |
|
|
if clean_val and clean_val.lower() != "nan": |
|
|
content_parts.append(f"{col_name}: {clean_val}") |
|
|
|
|
|
if content_parts: |
|
|
page_content = f"Dữ liệu từ file {filename} (Dòng {idx+1}):\n" + "\n".join(content_parts) |
|
|
metadata = {"source": filename, "row": idx+1, "type": "excel_record"} |
|
|
docs.append(Document(page_content=page_content, metadata=metadata)) |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Lỗi xử lý Excel {filename}: {e}") |
|
|
|
|
|
return docs |
|
|
|
|
|
def load_documents_from_folder(folder_path: str) -> list[Document]: |
|
|
logging.info(f"--- Bắt đầu quét thư mục: {folder_path} ---") |
|
|
documents: list[Document] = [] |
|
|
if not os.path.exists(folder_path): |
|
|
os.makedirs(folder_path, exist_ok=True) |
|
|
return [] |
|
|
|
|
|
for root, _, files in os.walk(folder_path): |
|
|
for filename in files: |
|
|
file_path = os.path.join(root, filename) |
|
|
filename_lower = filename.lower() |
|
|
try: |
|
|
if filename_lower.endswith(".pdf"): |
|
|
loader = PyPDFLoader(file_path) |
|
|
docs = loader.load() |
|
|
for d in docs: d.metadata["source"] = filename |
|
|
documents.extend(docs) |
|
|
|
|
|
elif filename_lower.endswith(".docx"): |
|
|
text = docx2txt.process(file_path) |
|
|
if text.strip(): |
|
|
documents.append(Document(page_content=text, metadata={"source": filename})) |
|
|
|
|
|
elif filename_lower.endswith((".xlsx", ".xls", ".csv")): |
|
|
excel_docs = process_excel_file(file_path, filename) |
|
|
documents.extend(excel_docs) |
|
|
|
|
|
elif filename_lower.endswith((".txt", ".md")): |
|
|
with open(file_path, "r", encoding="utf-8") as f: text = f.read() |
|
|
if text.strip(): |
|
|
documents.append(Document(page_content=text, metadata={"source": filename})) |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Lỗi đọc file {filename}: {e}") |
|
|
|
|
|
logging.info(f"Tổng cộng đã load: {len(documents)} tài liệu gốc.") |
|
|
return documents |
|
|
|
|
|
|
|
|
def get_retrievers(): |
|
|
logging.info("--- Tải Embedding Model ---") |
|
|
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") |
|
|
|
|
|
vectorstore = None |
|
|
splits = [] |
|
|
|
|
|
chroma_settings = Settings(anonymized_telemetry=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not vectorstore: |
|
|
logging.info("--- Tạo Index dữ liệu mới ---") |
|
|
raw_docs = load_documents_from_folder(DATA_PATH) |
|
|
if not raw_docs: return None, None |
|
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=200) |
|
|
splits = text_splitter.split_documents(raw_docs) |
|
|
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model, persist_directory=DB_PATH, client_settings=chroma_settings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 15}) |
|
|
|
|
|
ensemble_retriever = vector_retriever |
|
|
if splits: |
|
|
bm25_retriever = BM25Retriever.from_documents(splits) |
|
|
bm25_retriever.k = 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
|
retrievers=[bm25_retriever, vector_retriever], |
|
|
weights=[0.5, 0.5] |
|
|
) |
|
|
|
|
|
fast_retriever = ensemble_retriever |
|
|
|
|
|
|
|
|
vector_retriever_deep = vectorstore.as_retriever(search_kwargs={"k": 25}) |
|
|
ensemble_retriever_deep = vector_retriever_deep |
|
|
if splits: |
|
|
bm25_retriever_deep = BM25Retriever.from_documents(splits) |
|
|
bm25_retriever_deep.k = 25 |
|
|
ensemble_retriever_deep = EnsembleRetriever( |
|
|
retrievers=[bm25_retriever_deep, vector_retriever_deep], |
|
|
weights=[0.5, 0.5] |
|
|
) |
|
|
|
|
|
logging.info("--- Tải Reranker Model (BGE-M3) ---") |
|
|
reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3") |
|
|
compressor = CrossEncoderReranker(model=reranker_model, top_n=5) |
|
|
|
|
|
deep_retriever = ContextualCompressionRetriever( |
|
|
base_compressor=compressor, |
|
|
base_retriever=ensemble_retriever_deep |
|
|
) |
|
|
return fast_retriever, deep_retriever |
|
|
|
|
|
class DeepMedBot: |
|
|
def __init__(self): |
|
|
self.fast_chain = None |
|
|
self.deep_chain = None |
|
|
self.ready = False |
|
|
self.fallback_llm = None |
|
|
|
|
|
if not GOOGLE_API_KEY: |
|
|
logging.error("⚠️ Thiếu GOOGLE_API_KEY!") |
|
|
return |
|
|
|
|
|
try: |
|
|
self.fast_retriever, self.deep_retriever = get_retrievers() |
|
|
|
|
|
self.llm = ChatGoogleGenerativeAI( |
|
|
model="gemini-2.5-flash", |
|
|
temperature=0.2, |
|
|
google_api_key=GOOGLE_API_KEY, |
|
|
convert_system_message_to_human=True |
|
|
) |
|
|
self.fallback_llm = self.llm |
|
|
|
|
|
if self.fast_retriever and self.deep_retriever: |
|
|
self._build_chains() |
|
|
self.ready = True |
|
|
logging.info("✅ Bot DeepMed đã sẵn sàng với 2 chế độ!") |
|
|
else: |
|
|
logging.warning("⚠️ Không có dữ liệu. Bot sẽ chỉ dùng kiến thức nền.") |
|
|
self.ready = True |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"🔥 Lỗi khởi tạo bot: {e}") |
|
|
logging.debug(traceback.format_exc()) |
|
|
|
|
|
def _build_chains(self): |
|
|
context_system_prompt = ( |
|
|
"Dựa trên lịch sử chat và câu hỏi mới nhất, hãy viết lại câu hỏi " |
|
|
"thành một câu hoàn chỉnh để tìm kiếm thông tin. " |
|
|
"Ví dụ: Nếu user hỏi 'thuốc đó dùng sao', hãy viết lại thành 'cách dùng thuốc [tên thuốc trước đó]'. " |
|
|
"CHỈ TRẢ VỀ CÂU HỎI ĐÃ VIẾT LẠI, KHÔNG TRẢ LỜI." |
|
|
) |
|
|
context_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", context_system_prompt), |
|
|
MessagesPlaceholder("chat_history"), |
|
|
("human", "{input}"), |
|
|
]) |
|
|
|
|
|
qa_system_prompt = ( |
|
|
"Bạn là 'DeepMed-AI' - Trợ lý Dược lâm sàng chuyên nghiệp.\n" |
|
|
"Nhiệm vụ của bạn là tư vấn điều trị CHỈ DỰA TRÊN Dữ liệu nội bộ (Context) được cung cấp bên dưới.\n\n" |
|
|
"QUY TẮC AN TOÀN (BẮT BUỘC):\n" |
|
|
"1. **Trung thực tuyệt đối:** Nếu thông tin không có trong Context, hãy trả lời: 'Xin lỗi, tôi không tìm thấy thông tin này trong dữ liệu nội bộ'. KHÔNG tự bịa ra phác đồ.\n" |
|
|
"2. **Kiểm tra thuốc:** Khi đề xuất thuốc, chỉ nêu tên các thuốc có trong danh sách Context (có kèm giá/số lượng là tốt nhất).\n" |
|
|
"3. **Trích dẫn:** Mọi khẳng định y khoa phải được trích dẫn từ Context.\n\n" |
|
|
"Context:\n{context}" |
|
|
) |
|
|
qa_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", qa_system_prompt), |
|
|
MessagesPlaceholder("chat_history"), |
|
|
("human", "{input}"), |
|
|
]) |
|
|
|
|
|
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt) |
|
|
|
|
|
history_aware_fast = create_history_aware_retriever(self.llm, self.fast_retriever, context_prompt) |
|
|
self.fast_chain = create_retrieval_chain(history_aware_fast, question_answer_chain) |
|
|
|
|
|
history_aware_deep = create_history_aware_retriever(self.llm, self.deep_retriever, context_prompt) |
|
|
self.deep_chain = create_retrieval_chain(history_aware_deep, question_answer_chain) |
|
|
|
|
|
def chat_stream(self, message: str, history: list, mode: str): |
|
|
if not self.ready: |
|
|
yield "Hệ thống đang khởi động hoặc gặp lỗi cấu hình..." |
|
|
return |
|
|
|
|
|
chat_history = [] |
|
|
if history: |
|
|
for turn in history[-MAX_HISTORY_TURNS:]: |
|
|
if isinstance(turn, (list, tuple)) and len(turn) == 2: |
|
|
u, b = turn |
|
|
if u and b and str(u).strip() and str(b).strip(): |
|
|
chat_history.append(HumanMessage(content=str(u))) |
|
|
chat_history.append(AIMessage(content=str(b))) |
|
|
|
|
|
active_chain = self.deep_chain if mode == "Chuyên sâu (Chậm & Chính xác)" else self.fast_chain |
|
|
|
|
|
if not active_chain: |
|
|
try: |
|
|
resp = self.llm.invoke([HumanMessage(content=message)]) |
|
|
yield f"⚠️ (Chế độ kiến thức chung) {resp.content}" |
|
|
return |
|
|
except: |
|
|
yield "Lỗi: Không thể kết nối với AI. Vui lòng kiểm tra API Key." |
|
|
return |
|
|
|
|
|
full_response = "" |
|
|
retrieved_docs = [] |
|
|
|
|
|
try: |
|
|
for chunk in active_chain.stream({"input": message, "chat_history": chat_history}): |
|
|
if "answer" in chunk: |
|
|
full_response += chunk["answer"] |
|
|
yield full_response |
|
|
elif "context" in chunk: |
|
|
retrieved_docs = chunk["context"] |
|
|
|
|
|
if retrieved_docs: |
|
|
refs = self._build_references_text(retrieved_docs) |
|
|
if refs: |
|
|
full_response += f"\n\n---\n📚 **Nguồn tham khảo ({mode}):**\n{refs}" |
|
|
yield full_response |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Lỗi khi chat: {e}") |
|
|
logging.error(traceback.format_exc()) |
|
|
|
|
|
if not full_response: |
|
|
try: |
|
|
yield "⚠️ Gặp lỗi khi truy xuất dữ liệu. Đang chuyển sang chế độ trả lời nhanh...\n\n" |
|
|
fallback_resp = self.llm.invoke([HumanMessage(content=message)]) |
|
|
yield fallback_resp.content |
|
|
except: |
|
|
yield f"Đã xảy ra lỗi hệ thống. Vui lòng nhấn nút 'Clear' để xóa lịch sử chat và thử lại. (Lỗi: {str(e)})" |
|
|
else: |
|
|
yield full_response + f"\n\n[Lỗi ngắt kết nối: {str(e)}]" |
|
|
|
|
|
@staticmethod |
|
|
def _build_references_text(docs) -> str: |
|
|
lines = [] |
|
|
seen = set() |
|
|
for doc in docs: |
|
|
src = doc.metadata.get("source", "Tài liệu") |
|
|
row_info = f"(Dòng {doc.metadata['row']})" if "row" in doc.metadata else "" |
|
|
type_info = " [Kho thuốc]" if doc.metadata.get("type") == "excel_record" else "" |
|
|
|
|
|
ref_str = f"- {src}{type_info} {row_info}" |
|
|
if ref_str not in seen: |
|
|
lines.append(ref_str) |
|
|
seen.add(ref_str) |
|
|
return "\n".join(lines) |
|
|
|
|
|
bot = DeepMedBot() |
|
|
|
|
|
def gradio_chat_handler(message, history, mode): |
|
|
yield from bot.chat_stream(message, history, mode) |
|
|
|
|
|
css = """ |
|
|
.gradio-container {min_height: 600px !important;} |
|
|
h1 {text-align: center; color: #2E86C1;} |
|
|
.ref-box {font-size: 0.8em; color: gray;} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="DeepMed AI") as demo: |
|
|
gr.HTML(f"<style>{css}</style>") |
|
|
|
|
|
gr.Markdown("# 🏥 DeepMed AI - Hệ Thống Hỗ Trợ Lâm Sàng") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=4): |
|
|
gr.Markdown("Nhập câu hỏi về phác đồ, thuốc hoặc tra cứu bệnh án.") |
|
|
with gr.Column(scale=1): |
|
|
mode_select = gr.Radio( |
|
|
choices=["Tốc độ (Nhanh)", "Chuyên sâu (Chậm & Chính xác)"], |
|
|
value="Tốc độ (Nhanh)", |
|
|
label="Chế độ hoạt động", |
|
|
info="Chọn 'Tốc độ' để trả lời nhanh, 'Chuyên sâu' để lọc kỹ dữ liệu." |
|
|
) |
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
|
fn=gradio_chat_handler, |
|
|
additional_inputs=[mode_select], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |