agents / main.py
rain1024's picture
Initial commit: RAG Agent chat UI for Vietnamese law
0fb0b85
import os
import time
from pathlib import Path
from datasets import load_dataset
from dotenv import load_dotenv
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
CHROMA_DIR = Path(__file__).parent / "chroma_db"
DATASET_NAME = "undertheseanlp/UTS_VLC"
DATASET_SPLIT = "2026"
MAX_CHUNKS = 1000
SYSTEM_PROMPT = """\
Bạn là trợ lý pháp luật Việt Nam. Hãy trả lời câu hỏi dựa trên các văn bản luật được cung cấp bên dưới.
Quy tắc:
- Chỉ trả lời dựa trên nội dung trong context được cung cấp.
- Trích dẫn tên văn bản luật, điều khoản cụ thể khi có thể.
- Nếu context không đủ thông tin, hãy nói rõ rằng không tìm thấy thông tin liên quan.
- Trả lời bằng tiếng Việt.
Context:
{context}"""
def load_env():
load_dotenv(Path(__file__).parent / ".env")
def load_documents():
print("Đang tải dữ liệu từ HuggingFace...")
ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
documents = []
for row in ds:
content = row.get("content", "")
if not content:
continue
metadata = {k: v for k, v in row.items() if k != "content" and v is not None}
documents.append(Document(page_content=content, metadata=metadata))
print(f"Đã tải {len(documents)} văn bản luật.")
return documents
def chunk_documents(documents):
print("Đang chia nhỏ văn bản...")
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
chunks = splitter.split_documents(documents)
print(f"Đã tạo {len(chunks)} chunks.")
return chunks
def get_embeddings():
return AzureOpenAIEmbeddings(
azure_deployment=os.environ["AZURE_OPENAI_EMBEDDING_MODEL"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)
def build_vectorstore():
embeddings = get_embeddings()
if CHROMA_DIR.exists() and any(CHROMA_DIR.iterdir()):
print("Đang tải vector store từ ổ đĩa...")
vectorstore = Chroma(
persist_directory=str(CHROMA_DIR),
embedding_function=embeddings,
)
count = vectorstore._collection.count()
print(f"Đã tải vector store với {count} chunks.")
return vectorstore
documents = load_documents()
chunks = chunk_documents(documents)
chunks = chunks[:MAX_CHUNKS]
print(f"Sử dụng {len(chunks)} chunks (sample).")
print("Đang tạo vector store...")
batch_size = 100
vectorstore = Chroma(
persist_directory=str(CHROMA_DIR),
embedding_function=embeddings,
)
for i in range(0, len(chunks), batch_size):
batch = chunks[i : i + batch_size]
for attempt in range(5):
try:
vectorstore.add_documents(batch)
break
except Exception as e:
if "429" in str(e) or "RateLimit" in str(e):
wait = 60 * (attempt + 1)
print(f" Rate limit, đợi {wait}s...")
time.sleep(wait)
else:
raise
print(f" Đã index {min(i + batch_size, len(chunks))}/{len(chunks)} chunks")
print("Đã tạo vector store thành công.")
return vectorstore
def build_rag_chain(vectorstore):
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
llm = AzureChatOpenAI(
azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("human", "{question}"),
])
def format_docs(docs):
return "\n\n---\n\n".join(
f"[{doc.metadata.get('title', 'Không rõ')}]\n{doc.page_content}"
for doc in docs
)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return chain
def main():
load_env()
print("=" * 60)
print(" RAG Agent - Hỏi đáp Pháp luật Việt Nam")
print("=" * 60)
vectorstore = build_vectorstore()
chain = build_rag_chain(vectorstore)
print("\nSẵn sàng! Nhập câu hỏi về pháp luật Việt Nam (gõ 'quit' để thoát).\n")
while True:
question = input("Câu hỏi: ").strip()
if not question:
continue
if question.lower() in ("quit", "exit", "q"):
print("Tạm biệt!")
break
print("\nĐang tìm kiếm và trả lời...\n")
answer = chain.invoke(question)
print(answer)
print()
if __name__ == "__main__":
main()