Food_RAG_AgenticAI / rag_tool.py
NTThong0710
change
eab7c1b
import os
import time
from langchain_community.document_loaders import (
PyPDFLoader,
TextLoader,
CSVLoader
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import Chroma
# =========================
# PATH
# =========================
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
CHROMA_DB_DIR = os.path.join(
os.path.dirname(__file__),
"chroma_db"
)
# =========================
# LOAD DOCUMENTS
# =========================
def load_documents():
"""Đọc toàn bộ tài liệu trong thư mục data."""
documents = []
if not os.path.exists(DATA_DIR):
print(f"Thư mục {DATA_DIR} không tồn tại!")
return documents
for filename in os.listdir(DATA_DIR):
file_path = os.path.join(DATA_DIR, filename)
try:
if filename.endswith(".pdf"):
loader = PyPDFLoader(file_path)
documents.extend(loader.load())
elif filename.endswith(".txt"):
loader = TextLoader(
file_path,
encoding="utf-8"
)
documents.extend(loader.load())
elif filename.endswith(".csv"):
loader = CSVLoader(
file_path,
encoding="utf-8"
)
documents.extend(loader.load())
except Exception as e:
print(f"Lỗi đọc file {filename}: {e}")
print(f"Đã tải {len(documents)} trang/phần từ tài liệu.")
return documents
# =========================
# EMBEDDING MODEL
# =========================
def get_embeddings():
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError(
"Thiếu GEMINI_API_KEY trong biến môi trường."
)
embeddings = GoogleGenerativeAIEmbeddings(
model="models/gemini-embedding-001",
google_api_key=api_key
)
return embeddings
# =========================
# LOAD VECTOR DB
# =========================
def get_vectorstore():
"""
Load ChromaDB đã build sẵn.
Production KHÔNG tự build lại DB.
"""
if not os.path.exists(CHROMA_DB_DIR):
raise Exception(
"Không tìm thấy thư mục chroma_db. "
"Hãy build DB trước rồi upload lên server."
)
print("Đang tải Vector Database hiện có...")
vectorstore = Chroma(
persist_directory=CHROMA_DB_DIR,
embedding_function=get_embeddings()
)
return vectorstore
# =========================
# CREATE RETRIEVER
# =========================
def create_rag_retriever():
vectorstore = get_vectorstore()
# Sử dụng thuật toán MMR thay vì Similarity thuần túy
# Lấy ra 10 kết quả liên quan nhất, sau đó chọn lại 4 kết quả đa dạng nhất
retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={"k": 4, "fetch_k": 10}
)
return retriever
# =========================
# FILE LIST
# =========================
def get_list_files():
if not os.path.exists(DATA_DIR):
return []
files_list = []
for filename in os.listdir(DATA_DIR):
file_path = os.path.join(DATA_DIR, filename)
if os.path.isfile(file_path):
size_kb = os.path.getsize(file_path) / 1024
if size_kb > 1024:
size_str = f"{size_kb / 1024:.2f} MB"
else:
size_str = f"{size_kb:.2f} KB"
files_list.append({
"name": filename,
"size": size_str
})
return files_list
# =========================
# DELETE FILE
# =========================
def delete_file(filename: str):
file_path = os.path.join(DATA_DIR, filename)
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
# =========================
# PROCESS NEW FILE
# =========================
def process_new_file(filename: str):
file_path = os.path.join(DATA_DIR, filename)
if not os.path.exists(file_path):
return False
final_chunks = []
try:
if filename.endswith(".pdf"):
loader = PyPDFLoader(file_path)
docs = loader.load()
# PDF: Chia nhỏ vừa phải, có overlap
splitter = RecursiveCharacterTextSplitter(
chunk_size=700, chunk_overlap=150,
separators=["\n\n", "\n", ".", " ", ""]
)
final_chunks.extend(splitter.split_documents(docs))
elif filename.endswith(".txt"):
loader = TextLoader(file_path, encoding="utf-8")
docs = loader.load()
# TXT Công thức: Ưu tiên ôm trọn 1 món, cắt theo đoạn
splitter = RecursiveCharacterTextSplitter(
chunk_size=1500, chunk_overlap=0,
separators=["\n\n", "\n"]
)
final_chunks.extend(splitter.split_documents(docs))
elif filename.endswith(".csv"):
loader = CSVLoader(file_path, encoding="utf-8")
# CSV: Giữ nguyên từng dòng là 1 chunk, KHÔNG dùng text_splitter
final_chunks.extend(loader.load())
except Exception as e:
print(f"Lỗi đọc file mới: {e}")
return False
if not final_chunks:
return False
print(f"Đã xử lý {filename} thành {len(final_chunks)} chunks.")
vectorstore = get_vectorstore()
vectorstore.add_documents(final_chunks)
print("Đã thêm dữ liệu vào ChromaDB.")
return True
# =========================
# BUILD DB LOCAL ONLY
# =========================
def build_vector_database():
documents = load_documents()
if not documents:
raise Exception("Không có tài liệu để embed.")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=700,
chunk_overlap=150,
separators=["\n\n", "\n", ".", " ", ""]
)
chunks = text_splitter.split_documents(documents)
print(f"Đã chia thành {len(chunks)} chunks.")
embeddings = get_embeddings()
print("Đang tạo ChromaDB với chiến lược Auto-Retry...")
vectorstore = Chroma(
persist_directory=CHROMA_DB_DIR,
embedding_function=embeddings
)
# Đảm bảo nhỏ hơn 100 để an toàn tuyệt đối
BATCH_SIZE = 95
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i : i + BATCH_SIZE]
print(f"⏳ Đang nhúng batch từ {i} đến {i + len(batch)}...")
# Cơ chế tự động thử lại nếu bị Google chặn
max_retries = 3
for attempt in range(max_retries):
try:
vectorstore.add_documents(batch)
# Nếu chưa phải batch cuối, bắt buộc ngủ 60s để reset Quota 1 phút
if i + BATCH_SIZE < len(chunks):
print(f"✅ Xong mẻ. Đang nghỉ 60s để nạp lại Quota...")
time.sleep(60)
break # Thành công thì thoát vòng lặp retry
except Exception as e:
error_msg = str(e)
if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
print(f"⚠️ Quá tải API! Đang chờ 65 giây để thử lại (Lần {attempt + 1}/{max_retries})...")
time.sleep(65)
else:
# Nếu là lỗi khác (như mất mạng) thì văng lỗi ra
raise e
print("🎉 Tạo ChromaDB thành công!")
return vectorstore
# =========================
# MAIN
# =========================
if __name__ == "__main__":
env_path = os.path.join(
os.path.dirname(__file__),
"..",
".env"
)
if os.path.exists(env_path):
from dotenv import load_dotenv
load_dotenv(env_path)
build_vector_database()