File size: 18,452 Bytes
2bb0244
 
 
 
1469b0c
2bb0244
1469b0c
2bb0244
 
 
1469b0c
2bb0244
 
1469b0c
2bb0244
87c3fc4
 
 
1469b0c
 
2bb0244
 
1469b0c
 
8ba3a76
2bb0244
1469b0c
 
2bb0244
1469b0c
 
 
2bb0244
1469b0c
 
2bb0244
 
 
 
 
 
1469b0c
2bb0244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1469b0c
2bb0244
 
1469b0c
 
2bb0244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c3fc4
2bb0244
87c3fc4
2bb0244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c3fc4
2bb0244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c3fc4
2bb0244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1469b0c
 
 
2bb0244
 
 
 
 
 
 
 
 
 
 
 
 
1469b0c
2bb0244
 
1469b0c
 
2bb0244
 
 
 
 
 
 
 
1469b0c
2bb0244
1469b0c
2bb0244
 
1469b0c
2bb0244
 
 
 
 
1469b0c
2bb0244
1469b0c
2bb0244
1469b0c
 
2bb0244
 
1469b0c
2bb0244
 
 
 
 
1469b0c
2bb0244
1469b0c
2bb0244
 
1469b0c
2bb0244
 
1469b0c
 
 
2bb0244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1469b0c
2bb0244
 
 
 
 
 
 
1469b0c
2bb0244
1469b0c
2bb0244
1469b0c
 
 
 
2bb0244
1469b0c
 
 
2bb0244
 
 
1469b0c
2bb0244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1469b0c
 
 
2bb0244
 
 
 
 
1469b0c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
# =====================================================
# 1. FIX sqlite3 CHO CHROMA (pysqlite3 hack)
# =====================================================
__import__("pysqlite3")
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")

# =====================================================
# 2. IMPORTS
# =====================================================
import os
import logging
import traceback

import gradio as gr
import pandas as pd
import docx2txt

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_chroma import Chroma

from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever

from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.documents import Document

from langchain_huggingface import HuggingFaceEmbeddings


# =====================================================
# 3. CẤU HÌNH CHUNG
# =====================================================

# Lấy API key từ biến môi trường
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

DATA_PATH = "medical_data"   # thư mục chứa tài liệu
DB_PATH = "chroma_db"        # thư mục chứa database Chroma

# Số lượt đối thoại gần nhất gửi vào LLM (mỗi lượt gồm 1 user + 1 bot)
MAX_HISTORY_TURNS = 6

# Chọn chiến lược truy vấn:
USE_BM25 = True   # True = dùng hybrid (BM25 + Vector), False = chỉ dùng Vector
USE_MMR = True    # True = dùng MMR cho retriever vector

# Logging cơ bản
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)


# =====================================================
# 4. HÀM LOAD TÀI LIỆU ĐA ĐỊNH DẠNG (PDF, DOCX, EXCEL, CSV, TXT, MD)
# =====================================================

def load_documents_from_folder(folder_path: str) -> list[Document]:
    """
    Quét thư mục (kể cả subfolder), đọc các file hỗ trợ và trả về list Document.
    """
    logging.info(f"--- Bắt đầu quét thư mục: {folder_path} ---")
    documents: list[Document] = []

    if not os.path.exists(folder_path):
        os.makedirs(folder_path, exist_ok=True)
        logging.warning(f"Thư mục {folder_path} chưa tồn tại. Đã tạo mới, tạm thời chưa có tài liệu.")
        return []

    for root, _, files in os.walk(folder_path):
        for filename in files:
            file_path = os.path.join(root, filename)
            filename_lower = filename.lower()

            try:
                # 1. PDF
                if filename_lower.endswith(".pdf"):
                    logging.info(f"-> Đang xử lý PDF: {file_path}")
                    loader = PyPDFLoader(file_path)
                    docs = loader.load()  # mỗi trang = 1 Document
                    # Thêm metadata source gọn (chỉ tên file)
                    for d in docs:
                        d.metadata["source"] = filename
                    documents.extend(docs)

                # 2. DOCX
                elif filename_lower.endswith(".docx"):
                    logging.info(f"-> Đang xử lý Word: {file_path}")
                    text = docx2txt.process(file_path)
                    if text and text.strip():
                        documents.append(
                            Document(
                                page_content=text,
                                metadata={"source": filename}
                            )
                        )
                    else:
                        logging.warning(f"File Word rỗng: {filename}")

                # 3. EXCEL (XLS, XLSX)
                elif filename_lower.endswith((".xlsx", ".xls")):
                    logging.info(f"-> Đang xử lý Excel: {file_path}")
                    try:
                        df = pd.read_excel(file_path)
                        text_data = ""
                        for _, row in df.iterrows():
                            row_str = " | ".join(
                                f"{col}: {val}"
                                for col, val in row.items()
                                if pd.notna(val)
                            )
                            if row_str:
                                text_data += row_str + "\n"
                        if text_data.strip():
                            documents.append(
                                Document(
                                    page_content=text_data,
                                    metadata={"source": filename}
                                )
                            )
                        else:
                            logging.warning(f"File Excel rỗng: {filename}")
                    except Exception as e:
                        logging.error(f"Lỗi đọc Excel {filename}: {e}")

                # 4. CSV
                elif filename_lower.endswith(".csv"):
                    logging.info(f"-> Đang xử lý CSV: {file_path}")
                    try:
                        df = pd.read_csv(file_path)
                        text_data = ""
                        for _, row in df.iterrows():
                            row_str = " | ".join(
                                f"{col}: {val}"
                                for col, val in row.items()
                                if pd.notna(val)
                            )
                            if row_str:
                                text_data += row_str + "\n"
                        if text_data.strip():
                            documents.append(
                                Document(
                                    page_content=text_data,
                                    metadata={"source": filename}
                                )
                            )
                        else:
                            logging.warning(f"File CSV rỗng: {filename}")
                    except Exception as e:
                        logging.error(f"Lỗi đọc CSV {filename}: {e}")

                # 5. TEXT / MARKDOWN
                elif filename_lower.endswith((".txt", ".md")):
                    logging.info(f"-> Đang xử lý Text/Markdown: {file_path}")
                    text = ""
                    try:
                        with open(file_path, "r", encoding="utf-8") as f:
                            text = f.read()
                    except UnicodeDecodeError:
                        logging.warning(f"Encoding UTF-8 thất bại, thử Latin-1 cho {filename}")
                        with open(file_path, "r", encoding="latin-1") as f:
                            text = f.read()

                    if text and text.strip():
                        documents.append(
                            Document(
                                page_content=text,
                                metadata={"source": filename}
                            )
                        )

                else:
                    logging.info(f"-> Bỏ qua file không hỗ trợ: {file_path}")

            except Exception as e:
                logging.error(f"❌ LỖI khi đọc file {filename}: {e}")
                logging.debug(traceback.format_exc())

    logging.info(f"--- Hoàn tất load tài liệu. Tổng số Document: {len(documents)} ---")
    return documents


# =====================================================
# 5. XÂY DỰNG VECTORSTORE + RETRIEVER
# =====================================================

def build_vectorstore_and_corpus(embedding_model):
    """
    - Nếu đã có DB Chroma: load lên.
    - Nếu chưa: đọc folder, split, tạo DB mới.
    Trả về (vectorstore, splits).
    """
    from shutil import rmtree

    splits: list[Document] = []
    vectorstore = None

    # TH1: có DB cũ
    if os.path.exists(DB_PATH) and os.listdir(DB_PATH):
        try:
            logging.info("--- Tìm thấy ChromaDB cũ, đang load... ---")
            vectorstore = Chroma(
                persist_directory=DB_PATH,
                embedding_function=embedding_model
            )
            existing = vectorstore.get()
            if existing.get("documents"):
                for text, meta in zip(existing["documents"], existing["metadatas"]):
                    splits.append(Document(page_content=text, metadata=meta))
                logging.info(f"Tải lại corpus từ DB, tổng số chunk: {len(splits)}")
            else:
                logging.warning("ChromaDB không chứa documents. Sẽ rebuild.")
                splits = []
        except Exception as e:
            logging.error(f"Lỗi đọc DB cũ: {e}. Xóa và rebuild lại.")
            logging.debug(traceback.format_exc())
            rmtree(DB_PATH, ignore_errors=True)
            splits = []
            vectorstore = None

    # TH2: chưa có dữ liệu trong splits → đọc file gốc, split và tạo DB mới
    if not splits:
        logging.info("--- Đang đọc tài liệu gốc để tạo index mới... ---")
        documents = load_documents_from_folder(DATA_PATH)
        if not documents:
            logging.error("Không tìm thấy tài liệu nào trong thư mục medical_data.")
            return None, []

        # Chunking
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=800,     # nhỏ hơn cho câu trả lời cụ thể hơn
            chunk_overlap=150
        )
        splits = text_splitter.split_documents(documents)
        logging.info(f"Đã split thành {len(splits)} chunks.")

        # Tạo Chroma mới
        logging.info("--- Đang mã hoá embedding vào ChromaDB... ---")
        vectorstore = Chroma.from_documents(
            documents=splits,
            embedding=embedding_model,
            persist_directory=DB_PATH
        )

    return vectorstore, splits


def get_retriever():
    """
    Khởi tạo retriever:
    - Load hoặc tạo ChromaDB.
    - Tuỳ cấu hình: dùng hybrid (BM25 + Vector) hoặc chỉ Vector.
    - Có thể dùng MMR để tăng đa dạng context.
    """
    logging.info("--- Đang tải model Embedding (HuggingFace) ---")
    embedding_model = HuggingFaceEmbeddings(
        model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
    )

    vectorstore, splits = build_vectorstore_and_corpus(embedding_model)
    if vectorstore is None or not splits:
        logging.error("Không thể khởi tạo retriever vì không có dữ liệu.")
        return None

    # Vector retriever (có thể dùng MMR)
    if USE_MMR:
        logging.info("Sử dụng Vector Retriever với MMR.")
        chroma_retriever = vectorstore.as_retriever(
            search_type="mmr",
            search_kwargs={"k": 8, "lambda_mult": 0.7}
        )
    else:
        logging.info("Sử dụng Vector Retriever với similarity search.")
        chroma_retriever = vectorstore.as_retriever(
            search_kwargs={"k": 8}
        )

    # Nếu không muốn dùng BM25 -> trả về luôn retriever vector
    if not USE_BM25:
        logging.info("Chỉ dùng retriever Vector (không dùng BM25).")
        return chroma_retriever

    # Tạo BM25 retriever từ corpus
    logging.info("Khởi tạo BM25 Retriever từ corpus (hybrid search).")
    bm25_retriever = BM25Retriever.from_documents(splits)
    bm25_retriever.k = 8

    # Ensemble retriever: BM25 + Vector
    ensemble_retriever = EnsembleRetriever(
        retrievers=[bm25_retriever, chroma_retriever],
        weights=[0.4, 0.6]
    )
    logging.info("Đã khởi tạo Ensemble Retriever (BM25 + Vector).")
    return ensemble_retriever


# =====================================================
# 6. LỚP CHATBOT DEEPMED
# =====================================================

class DeepMedBot:
    def __init__(self):
        self.retriever = None
        self.rag_chain = None
        self.ready = False

        if not GOOGLE_API_KEY:
            logging.error("GOOGLE_API_KEY chưa được thiết lập trong biến môi trường.")
            return

        try:
            self.retriever = get_retriever()
            if self.retriever is None:
                logging.error("Không khởi tạo được retriever.")
                return

            self.llm = ChatGoogleGenerativeAI(
                model="gemini-2.5-flash",   # có thể đổi sang model mạnh hơn nếu muốn
                temperature=0.3,
                google_api_key=GOOGLE_API_KEY
            )
            self._build_chains()
            self.ready = True
            logging.info("--- DeepMedBot đã sẵn sàng ---")
        except Exception as e:
            logging.error(f"Lỗi khởi tạo DeepMedBot: {e}")
            logging.debug(traceback.format_exc())
            self.ready = False

    def _build_chains(self):
        # 1. Prompt contextualize question
        contextualize_q_system_prompt = (
            "Dựa trên lịch sử trò chuyện và câu hỏi mới nhất của người dùng, "
            "nếu câu hỏi liên quan đến ngữ cảnh trước đó, hãy viết lại nó thành một câu hỏi độc lập đầy đủ ý nghĩa. "
            "Nếu không liên quan, giữ nguyên câu hỏi gốc. KHÔNG trả lời câu hỏi, chỉ viết lại."
        )
        contextualize_q_prompt = ChatPromptTemplate.from_messages([
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ])

        history_aware_retriever = create_history_aware_retriever(
            self.llm, self.retriever, contextualize_q_prompt
        )

        # 2. Prompt trả lời y khoa an toàn
        qa_system_prompt = (
            "Bạn là trợ lý y tế DeepMed. Bạn hỗ trợ giải thích thông tin y khoa dựa trên tài liệu được cung cấp (Context) "
            "và kiến thức y khoa chung của bạn.\n\n"
            "QUY TẮC AN TOÀN:\n"
            "- Nếu Context không đủ để trả lời chính xác, hãy nói bạn không chắc chắn và khuyên người dùng hỏi bác sĩ.\n"
            "- Ưu tiên sử dụng thông tin trong Context. Nếu có mâu thuẫn giữa Context và hiểu biết của bạn, "
            "hãy nói rõ và khuyên tham khảo nguồn chính thống mới nhất.\n\n"
            "CÁCH TRẢ LỜI:\n"
            "- Trả lời ngắn gọn, dễ hiểu.\n"
            "- Nếu phù hợp, chia câu trả lời thành các mục: Tóm tắt, Chi tiết, Lưu ý.\n\n"
            "Context:\n{context}"
        )

        qa_prompt = ChatPromptTemplate.from_messages([
            ("system", qa_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ])

        question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)

        # 3. Gộp thành RAG chain
        self.rag_chain = create_retrieval_chain(
            history_aware_retriever,
            question_answer_chain
        )

    def chat(self, message: str, history: list[list[str]]):
        """
        Hàm dùng cho Gradio:
        - history: list[[user_msg, bot_msg], ...]
        """
        if not self.ready:
            return "❗ Hệ thống chưa sẵn sàng. Vui lòng kiểm tra lại API key và dữ liệu (medical_data)."

        # Giới hạn lịch sử gửi vào LLM cho đỡ nặng
        if len(history) > MAX_HISTORY_TURNS:
            history = history[-MAX_HISTORY_TURNS:]

        # Chuyển lịch sử Gradio sang LangChain messages
        chat_history = []
        for user_msg, bot_msg in history:
            chat_history.append(HumanMessage(content=user_msg))
            chat_history.append(AIMessage(content=bot_msg))

        try:
            response = self.rag_chain.invoke({
                "input": message,
                "chat_history": chat_history
            })

            answer = response.get("answer", "")

            # Xử lý trích dẫn nguồn
            references_text = self._build_references_text(response)
            if references_text:
                answer += "\n\n---\n📚 **Tài liệu tham khảo:**\n" + references_text

            return answer

        except Exception as e:
            logging.error(f"Lỗi khi xử lý chat: {e}")
            logging.debug(traceback.format_exc())
            return "🤖 Xin lỗi, hệ thống gặp lỗi nội bộ. Bạn hãy thử lại sau ít phút."

    @staticmethod
    def _build_references_text(response) -> str:
        """
        Gom nguồn trích dẫn:
        - Gộp theo tên file.
        - Nếu có số trang, liệt kê danh sách trang.
        """
        from collections import defaultdict
        if "context" not in response:
            return ""

        source_pages = defaultdict(set)  # {source_name: {page1, page2, ...}}

        for doc in response["context"]:
            src = os.path.basename(doc.metadata.get("source", "Tài liệu không tên"))
            page = doc.metadata.get("page", None)
            if page is not None:
                source_pages[src].add(page + 1)
            else:
                # Đảm bảo vẫn tạo key nếu chưa có trang
                _ = source_pages[src]

        lines = []
        for src, pages in source_pages.items():
            if pages:
                pages_str = ", ".join(str(p) for p in sorted(pages))
                lines.append(f"- {src} (Trang {pages_str})")
            else:
                lines.append(f"- {src}")

        return "\n".join(lines)



bot = DeepMedBot()

def gradio_chat(message, history):
    return bot.chat(message, history)


demo = gr.ChatInterface(
    fn=gradio_chat,
    title="🏥 DeepMed AI Pro - Trợ lý Y tế tại Trung Tâm Y Tế khu vực Thanh Ba",
    description=(
        "Chatbot hỗ trợ tra cứu thông tin y khoa từ kho tài liệu nội bộ.\n"
    ),
)

if __name__ == "__main__":
    demo.launch()