chatbot / app.py
PBThuong96's picture
Update app.py
df60532 verified
import os
import sys
import logging
import gradio as gr
# --- 1. SỬA LỖI SQLITE TRÊN HUGGING FACE ---
try:
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
except ImportError:
pass
import chromadb
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
# --- IMPORT ĐƠN GIẢN HÓA (LOẠI BỎ CÁC MODULE GÂY LỖI _type) ---
# Chỉ sử dụng các thành phần cốt lõi ổn định nhất
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
# --- CẤU HÌNH ---
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
DB_PATH = "chroma_db"
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
def get_category_vn_name(cat_code):
return {
"drug_info": "💊 Thuốc Nội Bộ",
"local_regimen": "🏥 Phác Đồ Thanh Ba",
"moh_regimen": "🏛️ Bộ Y Tế",
"association": "🌐 Hiệp Hội"
}.get(cat_code, "Khác")
# --- 2. LOAD DB (VECTOR SEARCH THUẦN TÚY - ỔN ĐỊNH 100%) ---
def get_retrievers():
if not os.path.exists(DB_PATH):
raise FileNotFoundError(f"❌ LỖI: Không tìm thấy thư mục '{DB_PATH}'. Bạn đã upload folder này vào phần Files chưa?")
logging.info("--- Đang tải dữ liệu từ ChromaDB... ---")
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
vectorstore = Chroma(persist_directory=DB_PATH, embedding_function=embedding)
# Kiểm tra dữ liệu
try:
all_data = vectorstore.get()
if not all_data['documents']:
raise ValueError("Database rỗng")
logging.info(f"✅ Đã tải thành công {len(all_data['documents'])} tài liệu từ Database.")
except Exception as e:
logging.error(f"Lỗi đọc dữ liệu Chroma: {e}")
raise ValueError(f"Không thể đọc dữ liệu từ ChromaDB: {e}")
# --- TẠO RETRIEVER ĐƠN GIẢN ---
# Thay vì dùng Ensemble/Reranker (dễ lỗi), ta dùng Vector Search trực tiếp.
# Mode 1: FAST (Tìm kiếm Thuốc - Lấy 5 kết quả sát nhất)
logging.info("--- Khởi tạo Fast Retriever (Vector Only) ---")
fast_retriever = vectorstore.as_retriever(
search_kwargs={
"k": 5,
"filter": {"category": "drug_info"}
}
)
# Mode 2: DEEP (Tìm kiếm Phác đồ - Lấy 15 kết quả sát nhất)
# Tăng k lên để bù đắp việc thiếu Reranker
logging.info("--- Khởi tạo Deep Retriever (Vector Only) ---")
cats = ["local_regimen", "moh_regimen", "association", "drug_info"]
deep_retriever = vectorstore.as_retriever(
search_kwargs={
"k": 15,
"filter": {"category": {"$in": cats}}
}
)
return fast_retriever, deep_retriever
# --- 3. BOT LOGIC ---
class DeepMedBot:
def __init__(self):
self.ready = False
self.init_error = "Đang khởi động..."
if not GOOGLE_API_KEY:
self.init_error = "❌ LỖI: Chưa cấu hình GOOGLE_API_KEY trong Settings."
return
try:
self.fast_retriever, self.deep_retriever = get_retrievers()
self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
self._build_chains()
self.ready = True
self.init_error = ""
logging.info("✅ BOT KHỞI ĐỘNG THÀNH CÔNG (CHẾ ĐỘ VECTOR STABLE)!")
except Exception as e:
self.init_error = f"❌ LỖI KHỞI TẠO: {str(e)}"
logging.error(self.init_error)
def _build_chains(self):
# Prompt Nhanh
fast_sys = (
"Bạn là Dược sĩ Lâm sàng.\n"
"Tra cứu [💊 Thuốc Nội Bộ] và trả lời bằng **Bảng Markdown**:\n"
"| Tên thuốc | Hoạt chất | Hàm lượng | Đơn vị | Ghi chú |\n"
"| --- | --- | --- | --- | --- |\n"
"Nếu không thấy, báo: '❌ Không tìm thấy trong kho'."
"Context:\n{context}"
)
fast_chain = create_stuff_documents_chain(self.llm, ChatPromptTemplate.from_messages([("system", fast_sys), ("human", "{input}")]))
self.fast_chain = create_retrieval_chain(self.fast_retriever, fast_chain)
# Prompt Chuyên sâu
deep_sys = (
"Bạn là Bác sĩ Trưởng khoa.\n"
"1. **Tìm phác đồ:** Ưu tiên tuyệt đối [🏥 Phác Đồ Thanh Ba]. Nếu không có mới dùng [Bộ Y Tế].\n"
"2. **Đối chiếu thuốc:** Kiểm tra thuốc trong phác đồ có trong [💊 Thuốc Nội Bộ] không.\n"
"3. **Định dạng trả lời:**\n"
" - Chẩn đoán/Nguyên tắc.\n"
" - Phác đồ (Ghi rõ nguồn).\n"
" - **Bảng kê đơn:**\n"
" | Tên thuốc | Liều dùng | Có trong kho? | Thay thế |\n"
" | --- | --- | --- | --- |\n"
"Context:\n{context}"
)
deep_chain = create_stuff_documents_chain(self.llm, ChatPromptTemplate.from_messages([("system", deep_sys), ("human", "{input}")]))
self.deep_chain = create_retrieval_chain(self.deep_retriever, deep_chain)
def chat(self, msg, history, mode):
if not self.ready:
return f"⚠️ HỆ THỐNG GẶP LỖI.\n\nChi tiết lỗi:\n{self.init_error}\n\nHãy thử Restart Space trong phần Settings."
chain = self.deep_chain if mode == "Chuyên sâu" else self.fast_chain
try:
res = chain.invoke({"input": msg})
ans = res['answer']
if 'context' in res and res['context']:
refs = list(set([f"- [{get_category_vn_name(d.metadata.get('category'))}] {d.metadata.get('source')}" for d in res['context']]))
ans += "\n\n---\n📚 **Nguồn:**\n" + "\n".join(refs)
return ans
except Exception as e:
return f"❌ Lỗi khi trả lời: {str(e)}"
bot = DeepMedBot()
def respond(message, history, mode):
return bot.chat(message, history, mode)
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[gr.Radio(["Tra cứu nhanh (Chỉ thuốc)", "Chuyên sâu"], value="Tra cứu nhanh (Chỉ thuốc)", label="Chế độ")],
title="TTYT Thanh Ba - Hỗ trợ Lâm sàng",
description="Hệ thống tra cứu Phác đồ & Thuốc nội bộ.",
css=".gradio-container {min_height: 600px}"
)
if __name__ == "__main__":
demo.launch()