| import os |
| import time |
| from pathlib import Path |
|
|
| from datasets import load_dataset |
| from dotenv import load_dotenv |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_chroma import Chroma |
| from langchain_core.documents import Document |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings |
|
|
| CHROMA_DIR = Path(__file__).parent / "chroma_db" |
| DATASET_NAME = "undertheseanlp/UTS_VLC" |
| DATASET_SPLIT = "2026" |
| MAX_CHUNKS = 1000 |
|
|
| SYSTEM_PROMPT = """\ |
| Bạn là trợ lý pháp luật Việt Nam. Hãy trả lời câu hỏi dựa trên các văn bản luật được cung cấp bên dưới. |
| |
| Quy tắc: |
| - Chỉ trả lời dựa trên nội dung trong context được cung cấp. |
| - Trích dẫn tên văn bản luật, điều khoản cụ thể khi có thể. |
| - Nếu context không đủ thông tin, hãy nói rõ rằng không tìm thấy thông tin liên quan. |
| - Trả lời bằng tiếng Việt. |
| |
| Context: |
| {context}""" |
|
|
|
|
| def load_env(): |
| load_dotenv(Path(__file__).parent / ".env") |
|
|
|
|
| def load_documents(): |
| print("Đang tải dữ liệu từ HuggingFace...") |
| ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT) |
| documents = [] |
| for row in ds: |
| content = row.get("content", "") |
| if not content: |
| continue |
| metadata = {k: v for k, v in row.items() if k != "content" and v is not None} |
| documents.append(Document(page_content=content, metadata=metadata)) |
| print(f"Đã tải {len(documents)} văn bản luật.") |
| return documents |
|
|
|
|
| def chunk_documents(documents): |
| print("Đang chia nhỏ văn bản...") |
| splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200, |
| ) |
| chunks = splitter.split_documents(documents) |
| print(f"Đã tạo {len(chunks)} chunks.") |
| return chunks |
|
|
|
|
| def get_embeddings(): |
| return AzureOpenAIEmbeddings( |
| azure_deployment=os.environ["AZURE_OPENAI_EMBEDDING_MODEL"], |
| azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], |
| api_key=os.environ["AZURE_OPENAI_API_KEY"], |
| api_version=os.environ["AZURE_OPENAI_API_VERSION"], |
| ) |
|
|
|
|
| def build_vectorstore(): |
| embeddings = get_embeddings() |
|
|
| if CHROMA_DIR.exists() and any(CHROMA_DIR.iterdir()): |
| print("Đang tải vector store từ ổ đĩa...") |
| vectorstore = Chroma( |
| persist_directory=str(CHROMA_DIR), |
| embedding_function=embeddings, |
| ) |
| count = vectorstore._collection.count() |
| print(f"Đã tải vector store với {count} chunks.") |
| return vectorstore |
|
|
| documents = load_documents() |
| chunks = chunk_documents(documents) |
| chunks = chunks[:MAX_CHUNKS] |
| print(f"Sử dụng {len(chunks)} chunks (sample).") |
|
|
| print("Đang tạo vector store...") |
| batch_size = 100 |
| vectorstore = Chroma( |
| persist_directory=str(CHROMA_DIR), |
| embedding_function=embeddings, |
| ) |
| for i in range(0, len(chunks), batch_size): |
| batch = chunks[i : i + batch_size] |
| for attempt in range(5): |
| try: |
| vectorstore.add_documents(batch) |
| break |
| except Exception as e: |
| if "429" in str(e) or "RateLimit" in str(e): |
| wait = 60 * (attempt + 1) |
| print(f" Rate limit, đợi {wait}s...") |
| time.sleep(wait) |
| else: |
| raise |
| print(f" Đã index {min(i + batch_size, len(chunks))}/{len(chunks)} chunks") |
| print("Đã tạo vector store thành công.") |
| return vectorstore |
|
|
|
|
| def build_rag_chain(vectorstore): |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) |
|
|
| llm = AzureChatOpenAI( |
| azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT"], |
| azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], |
| api_key=os.environ["AZURE_OPENAI_API_KEY"], |
| api_version=os.environ["AZURE_OPENAI_API_VERSION"], |
| ) |
|
|
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", SYSTEM_PROMPT), |
| ("human", "{question}"), |
| ]) |
|
|
| def format_docs(docs): |
| return "\n\n---\n\n".join( |
| f"[{doc.metadata.get('title', 'Không rõ')}]\n{doc.page_content}" |
| for doc in docs |
| ) |
|
|
| chain = ( |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} |
| | prompt |
| | llm |
| | StrOutputParser() |
| ) |
| return chain |
|
|
|
|
| def main(): |
| load_env() |
|
|
| print("=" * 60) |
| print(" RAG Agent - Hỏi đáp Pháp luật Việt Nam") |
| print("=" * 60) |
|
|
| vectorstore = build_vectorstore() |
| chain = build_rag_chain(vectorstore) |
|
|
| print("\nSẵn sàng! Nhập câu hỏi về pháp luật Việt Nam (gõ 'quit' để thoát).\n") |
|
|
| while True: |
| question = input("Câu hỏi: ").strip() |
| if not question: |
| continue |
| if question.lower() in ("quit", "exit", "q"): |
| print("Tạm biệt!") |
| break |
|
|
| print("\nĐang tìm kiếm và trả lời...\n") |
| answer = chain.invoke(question) |
| print(answer) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|