|
|
import os |
|
|
try: |
|
|
__import__("pysqlite3") |
|
|
import sys |
|
|
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
import logging |
|
|
import traceback |
|
|
import pandas as pd |
|
|
import docx2txt |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
import chainlit as cl |
|
|
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langchain_chroma import Chroma |
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from langchain.retrievers.ensemble import EnsembleRetriever |
|
|
from langchain.chains import create_retrieval_chain, create_history_aware_retriever |
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.messages import HumanMessage, AIMessage |
|
|
from langchain_core.documents import Document |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain.retrievers import ContextualCompressionRetriever |
|
|
from langchain.retrievers.document_compressors import CrossEncoderReranker |
|
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder |
|
|
|
|
|
|
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
DATA_PATH = "medical_data" |
|
|
DB_PATH = "chroma_db" |
|
|
MAX_HISTORY_TURNS = 4 |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepMedBot: |
|
|
def __init__(self): |
|
|
self.fast_chain = None |
|
|
self.deep_chain = None |
|
|
self.ready = False |
|
|
|
|
|
if not GOOGLE_API_KEY: |
|
|
return |
|
|
|
|
|
try: |
|
|
self.fast_retriever, self.deep_retriever = get_retrievers() |
|
|
|
|
|
self.llm = ChatGoogleGenerativeAI( |
|
|
model="gemini-2.5-flash", |
|
|
temperature=0.2, |
|
|
google_api_key=GOOGLE_API_KEY, |
|
|
convert_system_message_to_human=True |
|
|
) |
|
|
|
|
|
if self.fast_retriever and self.deep_retriever: |
|
|
self._build_chains() |
|
|
self.ready = True |
|
|
else: |
|
|
self.ready = True |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Lỗi khởi tạo bot: {e}") |
|
|
|
|
|
def _build_chains(self): |
|
|
context_system_prompt = ( |
|
|
"Dựa trên lịch sử chat và câu hỏi mới nhất, hãy viết lại câu hỏi " |
|
|
"thành một câu hoàn chỉnh để tìm kiếm thông tin. " |
|
|
"CHỈ TRẢ VỀ CÂU HỎI ĐÃ VIẾT LẠI, KHÔNG TRẢ LỜI." |
|
|
) |
|
|
context_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", context_system_prompt), |
|
|
MessagesPlaceholder("chat_history"), |
|
|
("human", "{input}"), |
|
|
]) |
|
|
|
|
|
qa_system_prompt = ( |
|
|
"Bạn là 'DeepMed-AI' - Trợ lý Dược lâm sàng chuyên nghiệp.\n" |
|
|
"Nhiệm vụ: Tư vấn điều trị CHỈ DỰA TRÊN Dữ liệu nội bộ (Context).\n" |
|
|
"QUY TẮC: Trung thực tuyệt đối, Kiểm tra thuốc có trong kho, Trích dẫn nguồn.\n" |
|
|
"Context:\n{context}" |
|
|
) |
|
|
qa_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", qa_system_prompt), |
|
|
MessagesPlaceholder("chat_history"), |
|
|
("human", "{input}"), |
|
|
]) |
|
|
|
|
|
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt) |
|
|
|
|
|
history_aware_fast = create_history_aware_retriever(self.llm, self.fast_retriever, context_prompt) |
|
|
self.fast_chain = create_retrieval_chain(history_aware_fast, question_answer_chain) |
|
|
|
|
|
history_aware_deep = create_history_aware_retriever(self.llm, self.deep_retriever, context_prompt) |
|
|
self.deep_chain = create_retrieval_chain(history_aware_deep, question_answer_chain) |
|
|
|
|
|
def _build_references_text(self, docs) -> str: |
|
|
lines = [] |
|
|
seen = set() |
|
|
for doc in docs: |
|
|
src = doc.metadata.get("source", "Tài liệu") |
|
|
row_info = f"(Dòng {doc.metadata['row']})" if "row" in doc.metadata else "" |
|
|
ref_str = f"- {src} {row_info}" |
|
|
if ref_str not in seen: |
|
|
lines.append(ref_str) |
|
|
seen.add(ref_str) |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
@cl.on_chat_start |
|
|
async def start(): |
|
|
|
|
|
msg = cl.Message(content="Đang khởi động hệ thống DeepMed AI...") |
|
|
await msg.send() |
|
|
|
|
|
bot = DeepMedBot() |
|
|
cl.user_session.set("bot", bot) |
|
|
|
|
|
|
|
|
actions = [ |
|
|
cl.Action(name="fast_mode", value="fast", label="⚡ Tốc độ (Nhanh)", description="Trả lời nhanh"), |
|
|
cl.Action(name="deep_mode", value="deep", label="🧠 Chuyên sâu (Kỹ)", description="Phân tích kỹ dữ liệu") |
|
|
] |
|
|
cl.user_session.set("mode", "fast") |
|
|
|
|
|
await msg.update(content="👋 Xin chào! Tôi là **DeepMed AI**. \n\nTôi có thể hỗ trợ tra cứu thuốc, phác đồ điều trị và thông tin dược lâm sàng. \n\n*Mặc định đang ở chế độ: Tốc độ.*", actions=actions) |
|
|
|
|
|
@cl.action_callback("fast_mode") |
|
|
async def on_fast_mode(action): |
|
|
cl.user_session.set("mode", "fast") |
|
|
await cl.Message(content="✅ Đã chuyển sang chế độ: **Tốc độ (Nhanh)**").send() |
|
|
|
|
|
@cl.action_callback("deep_mode") |
|
|
async def on_deep_mode(action): |
|
|
cl.user_session.set("mode", "deep") |
|
|
await cl.Message(content="✅ Đã chuyển sang chế độ: **Chuyên sâu (Chậm & Chính xác)**").send() |
|
|
|
|
|
@cl.on_message |
|
|
async def main(message: cl.Message): |
|
|
bot = cl.user_session.get("bot") |
|
|
mode_setting = cl.user_session.get("mode") |
|
|
|
|
|
if not bot or not bot.ready: |
|
|
await cl.Message(content="⚠️ Hệ thống chưa sẵn sàng hoặc gặp lỗi khởi tạo.").send() |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
active_chain = bot.deep_chain if mode_setting == "deep" else bot.fast_chain |
|
|
|
|
|
msg = cl.Message(content="") |
|
|
await msg.send() |
|
|
|
|
|
|
|
|
history = [] |
|
|
|
|
|
full_response = "" |
|
|
retrieved_docs = [] |
|
|
|
|
|
|
|
|
try: |
|
|
async for chunk in active_chain.astream({"input": message.content, "chat_history": history}): |
|
|
if "answer" in chunk: |
|
|
token = chunk["answer"] |
|
|
full_response += token |
|
|
await msg.stream_token(token) |
|
|
elif "context" in chunk: |
|
|
retrieved_docs = chunk["context"] |
|
|
|
|
|
|
|
|
if retrieved_docs: |
|
|
refs = bot._build_references_text(retrieved_docs) |
|
|
if refs: |
|
|
ref_block = f"\n\n---\n📚 **Nguồn tham khảo:**\n{refs}" |
|
|
await msg.stream_token(ref_block) |
|
|
|
|
|
await msg.update() |
|
|
|
|
|
except Exception as e: |
|
|
await cl.Message(content=f"⚠️ Lỗi: {str(e)}").send() |