File size: 5,322 Bytes
0fb0b85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import time
from pathlib import Path

from datasets import load_dataset
from dotenv import load_dotenv
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings

CHROMA_DIR = Path(__file__).parent / "chroma_db"
DATASET_NAME = "undertheseanlp/UTS_VLC"
DATASET_SPLIT = "2026"
MAX_CHUNKS = 1000

SYSTEM_PROMPT = """\
Bạn là trợ lý pháp luật Việt Nam. Hãy trả lời câu hỏi dựa trên các văn bản luật được cung cấp bên dưới.

Quy tắc:
- Chỉ trả lời dựa trên nội dung trong context được cung cấp.
- Trích dẫn tên văn bản luật, điều khoản cụ thể khi có thể.
- Nếu context không đủ thông tin, hãy nói rõ rằng không tìm thấy thông tin liên quan.
- Trả lời bằng tiếng Việt.

Context:
{context}"""


def load_env():
    load_dotenv(Path(__file__).parent / ".env")


def load_documents():
    print("Đang tải dữ liệu từ HuggingFace...")
    ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
    documents = []
    for row in ds:
        content = row.get("content", "")
        if not content:
            continue
        metadata = {k: v for k, v in row.items() if k != "content" and v is not None}
        documents.append(Document(page_content=content, metadata=metadata))
    print(f"Đã tải {len(documents)} văn bản luật.")
    return documents


def chunk_documents(documents):
    print("Đang chia nhỏ văn bản...")
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
    )
    chunks = splitter.split_documents(documents)
    print(f"Đã tạo {len(chunks)} chunks.")
    return chunks


def get_embeddings():
    return AzureOpenAIEmbeddings(
        azure_deployment=os.environ["AZURE_OPENAI_EMBEDDING_MODEL"],
        azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
        api_key=os.environ["AZURE_OPENAI_API_KEY"],
        api_version=os.environ["AZURE_OPENAI_API_VERSION"],
    )


def build_vectorstore():
    embeddings = get_embeddings()

    if CHROMA_DIR.exists() and any(CHROMA_DIR.iterdir()):
        print("Đang tải vector store từ ổ đĩa...")
        vectorstore = Chroma(
            persist_directory=str(CHROMA_DIR),
            embedding_function=embeddings,
        )
        count = vectorstore._collection.count()
        print(f"Đã tải vector store với {count} chunks.")
        return vectorstore

    documents = load_documents()
    chunks = chunk_documents(documents)
    chunks = chunks[:MAX_CHUNKS]
    print(f"Sử dụng {len(chunks)} chunks (sample).")

    print("Đang tạo vector store...")
    batch_size = 100
    vectorstore = Chroma(
        persist_directory=str(CHROMA_DIR),
        embedding_function=embeddings,
    )
    for i in range(0, len(chunks), batch_size):
        batch = chunks[i : i + batch_size]
        for attempt in range(5):
            try:
                vectorstore.add_documents(batch)
                break
            except Exception as e:
                if "429" in str(e) or "RateLimit" in str(e):
                    wait = 60 * (attempt + 1)
                    print(f"  Rate limit, đợi {wait}s...")
                    time.sleep(wait)
                else:
                    raise
        print(f"  Đã index {min(i + batch_size, len(chunks))}/{len(chunks)} chunks")
    print("Đã tạo vector store thành công.")
    return vectorstore


def build_rag_chain(vectorstore):
    retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

    llm = AzureChatOpenAI(
        azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT"],
        azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
        api_key=os.environ["AZURE_OPENAI_API_KEY"],
        api_version=os.environ["AZURE_OPENAI_API_VERSION"],
    )

    prompt = ChatPromptTemplate.from_messages([
        ("system", SYSTEM_PROMPT),
        ("human", "{question}"),
    ])

    def format_docs(docs):
        return "\n\n---\n\n".join(
            f"[{doc.metadata.get('title', 'Không rõ')}]\n{doc.page_content}"
            for doc in docs
        )

    chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    return chain


def main():
    load_env()

    print("=" * 60)
    print("  RAG Agent - Hỏi đáp Pháp luật Việt Nam")
    print("=" * 60)

    vectorstore = build_vectorstore()
    chain = build_rag_chain(vectorstore)

    print("\nSẵn sàng! Nhập câu hỏi về pháp luật Việt Nam (gõ 'quit' để thoát).\n")

    while True:
        question = input("Câu hỏi: ").strip()
        if not question:
            continue
        if question.lower() in ("quit", "exit", "q"):
            print("Tạm biệt!")
            break

        print("\nĐang tìm kiếm và trả lời...\n")
        answer = chain.invoke(question)
        print(answer)
        print()


if __name__ == "__main__":
    main()