File size: 5,322 Bytes
0fb0b85 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import os
import time
from pathlib import Path
from datasets import load_dataset
from dotenv import load_dotenv
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
CHROMA_DIR = Path(__file__).parent / "chroma_db"
DATASET_NAME = "undertheseanlp/UTS_VLC"
DATASET_SPLIT = "2026"
MAX_CHUNKS = 1000
SYSTEM_PROMPT = """\
Bạn là trợ lý pháp luật Việt Nam. Hãy trả lời câu hỏi dựa trên các văn bản luật được cung cấp bên dưới.
Quy tắc:
- Chỉ trả lời dựa trên nội dung trong context được cung cấp.
- Trích dẫn tên văn bản luật, điều khoản cụ thể khi có thể.
- Nếu context không đủ thông tin, hãy nói rõ rằng không tìm thấy thông tin liên quan.
- Trả lời bằng tiếng Việt.
Context:
{context}"""
def load_env():
load_dotenv(Path(__file__).parent / ".env")
def load_documents():
print("Đang tải dữ liệu từ HuggingFace...")
ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
documents = []
for row in ds:
content = row.get("content", "")
if not content:
continue
metadata = {k: v for k, v in row.items() if k != "content" and v is not None}
documents.append(Document(page_content=content, metadata=metadata))
print(f"Đã tải {len(documents)} văn bản luật.")
return documents
def chunk_documents(documents):
print("Đang chia nhỏ văn bản...")
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
chunks = splitter.split_documents(documents)
print(f"Đã tạo {len(chunks)} chunks.")
return chunks
def get_embeddings():
return AzureOpenAIEmbeddings(
azure_deployment=os.environ["AZURE_OPENAI_EMBEDDING_MODEL"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)
def build_vectorstore():
embeddings = get_embeddings()
if CHROMA_DIR.exists() and any(CHROMA_DIR.iterdir()):
print("Đang tải vector store từ ổ đĩa...")
vectorstore = Chroma(
persist_directory=str(CHROMA_DIR),
embedding_function=embeddings,
)
count = vectorstore._collection.count()
print(f"Đã tải vector store với {count} chunks.")
return vectorstore
documents = load_documents()
chunks = chunk_documents(documents)
chunks = chunks[:MAX_CHUNKS]
print(f"Sử dụng {len(chunks)} chunks (sample).")
print("Đang tạo vector store...")
batch_size = 100
vectorstore = Chroma(
persist_directory=str(CHROMA_DIR),
embedding_function=embeddings,
)
for i in range(0, len(chunks), batch_size):
batch = chunks[i : i + batch_size]
for attempt in range(5):
try:
vectorstore.add_documents(batch)
break
except Exception as e:
if "429" in str(e) or "RateLimit" in str(e):
wait = 60 * (attempt + 1)
print(f" Rate limit, đợi {wait}s...")
time.sleep(wait)
else:
raise
print(f" Đã index {min(i + batch_size, len(chunks))}/{len(chunks)} chunks")
print("Đã tạo vector store thành công.")
return vectorstore
def build_rag_chain(vectorstore):
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
llm = AzureChatOpenAI(
azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("human", "{question}"),
])
def format_docs(docs):
return "\n\n---\n\n".join(
f"[{doc.metadata.get('title', 'Không rõ')}]\n{doc.page_content}"
for doc in docs
)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return chain
def main():
load_env()
print("=" * 60)
print(" RAG Agent - Hỏi đáp Pháp luật Việt Nam")
print("=" * 60)
vectorstore = build_vectorstore()
chain = build_rag_chain(vectorstore)
print("\nSẵn sàng! Nhập câu hỏi về pháp luật Việt Nam (gõ 'quit' để thoát).\n")
while True:
question = input("Câu hỏi: ").strip()
if not question:
continue
if question.lower() in ("quit", "exit", "q"):
print("Tạm biệt!")
break
print("\nĐang tìm kiếm và trả lời...\n")
answer = chain.invoke(question)
print(answer)
print()
if __name__ == "__main__":
main()
|