import os import time from langchain_community.document_loaders import ( PyPDFLoader, TextLoader, CSVLoader ) from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_community.vectorstores import Chroma # ========================= # PATH # ========================= DATA_DIR = os.path.join(os.path.dirname(__file__), "data") CHROMA_DB_DIR = os.path.join( os.path.dirname(__file__), "chroma_db" ) # ========================= # LOAD DOCUMENTS # ========================= def load_documents(): """Đọc toàn bộ tài liệu trong thư mục data.""" documents = [] if not os.path.exists(DATA_DIR): print(f"Thư mục {DATA_DIR} không tồn tại!") return documents for filename in os.listdir(DATA_DIR): file_path = os.path.join(DATA_DIR, filename) try: if filename.endswith(".pdf"): loader = PyPDFLoader(file_path) documents.extend(loader.load()) elif filename.endswith(".txt"): loader = TextLoader( file_path, encoding="utf-8" ) documents.extend(loader.load()) elif filename.endswith(".csv"): loader = CSVLoader( file_path, encoding="utf-8" ) documents.extend(loader.load()) except Exception as e: print(f"Lỗi đọc file {filename}: {e}") print(f"Đã tải {len(documents)} trang/phần từ tài liệu.") return documents # ========================= # EMBEDDING MODEL # ========================= def get_embeddings(): api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError( "Thiếu GEMINI_API_KEY trong biến môi trường." ) embeddings = GoogleGenerativeAIEmbeddings( model="models/gemini-embedding-001", google_api_key=api_key ) return embeddings # ========================= # LOAD VECTOR DB # ========================= def get_vectorstore(): """ Load ChromaDB đã build sẵn. Production KHÔNG tự build lại DB. """ if not os.path.exists(CHROMA_DB_DIR): raise Exception( "Không tìm thấy thư mục chroma_db. " "Hãy build DB trước rồi upload lên server." ) print("Đang tải Vector Database hiện có...") vectorstore = Chroma( persist_directory=CHROMA_DB_DIR, embedding_function=get_embeddings() ) return vectorstore # ========================= # CREATE RETRIEVER # ========================= def create_rag_retriever(): vectorstore = get_vectorstore() # Sử dụng thuật toán MMR thay vì Similarity thuần túy # Lấy ra 10 kết quả liên quan nhất, sau đó chọn lại 4 kết quả đa dạng nhất retriever = vectorstore.as_retriever( search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10} ) return retriever # ========================= # FILE LIST # ========================= def get_list_files(): if not os.path.exists(DATA_DIR): return [] files_list = [] for filename in os.listdir(DATA_DIR): file_path = os.path.join(DATA_DIR, filename) if os.path.isfile(file_path): size_kb = os.path.getsize(file_path) / 1024 if size_kb > 1024: size_str = f"{size_kb / 1024:.2f} MB" else: size_str = f"{size_kb:.2f} KB" files_list.append({ "name": filename, "size": size_str }) return files_list # ========================= # DELETE FILE # ========================= def delete_file(filename: str): file_path = os.path.join(DATA_DIR, filename) if os.path.exists(file_path): os.remove(file_path) return True return False # ========================= # PROCESS NEW FILE # ========================= def process_new_file(filename: str): file_path = os.path.join(DATA_DIR, filename) if not os.path.exists(file_path): return False final_chunks = [] try: if filename.endswith(".pdf"): loader = PyPDFLoader(file_path) docs = loader.load() # PDF: Chia nhỏ vừa phải, có overlap splitter = RecursiveCharacterTextSplitter( chunk_size=700, chunk_overlap=150, separators=["\n\n", "\n", ".", " ", ""] ) final_chunks.extend(splitter.split_documents(docs)) elif filename.endswith(".txt"): loader = TextLoader(file_path, encoding="utf-8") docs = loader.load() # TXT Công thức: Ưu tiên ôm trọn 1 món, cắt theo đoạn splitter = RecursiveCharacterTextSplitter( chunk_size=1500, chunk_overlap=0, separators=["\n\n", "\n"] ) final_chunks.extend(splitter.split_documents(docs)) elif filename.endswith(".csv"): loader = CSVLoader(file_path, encoding="utf-8") # CSV: Giữ nguyên từng dòng là 1 chunk, KHÔNG dùng text_splitter final_chunks.extend(loader.load()) except Exception as e: print(f"Lỗi đọc file mới: {e}") return False if not final_chunks: return False print(f"Đã xử lý {filename} thành {len(final_chunks)} chunks.") vectorstore = get_vectorstore() vectorstore.add_documents(final_chunks) print("Đã thêm dữ liệu vào ChromaDB.") return True # ========================= # BUILD DB LOCAL ONLY # ========================= def build_vector_database(): documents = load_documents() if not documents: raise Exception("Không có tài liệu để embed.") text_splitter = RecursiveCharacterTextSplitter( chunk_size=700, chunk_overlap=150, separators=["\n\n", "\n", ".", " ", ""] ) chunks = text_splitter.split_documents(documents) print(f"Đã chia thành {len(chunks)} chunks.") embeddings = get_embeddings() print("Đang tạo ChromaDB với chiến lược Auto-Retry...") vectorstore = Chroma( persist_directory=CHROMA_DB_DIR, embedding_function=embeddings ) # Đảm bảo nhỏ hơn 100 để an toàn tuyệt đối BATCH_SIZE = 95 for i in range(0, len(chunks), BATCH_SIZE): batch = chunks[i : i + BATCH_SIZE] print(f"⏳ Đang nhúng batch từ {i} đến {i + len(batch)}...") # Cơ chế tự động thử lại nếu bị Google chặn max_retries = 3 for attempt in range(max_retries): try: vectorstore.add_documents(batch) # Nếu chưa phải batch cuối, bắt buộc ngủ 60s để reset Quota 1 phút if i + BATCH_SIZE < len(chunks): print(f"✅ Xong mẻ. Đang nghỉ 60s để nạp lại Quota...") time.sleep(60) break # Thành công thì thoát vòng lặp retry except Exception as e: error_msg = str(e) if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg: print(f"⚠️ Quá tải API! Đang chờ 65 giây để thử lại (Lần {attempt + 1}/{max_retries})...") time.sleep(65) else: # Nếu là lỗi khác (như mất mạng) thì văng lỗi ra raise e print("🎉 Tạo ChromaDB thành công!") return vectorstore # ========================= # MAIN # ========================= if __name__ == "__main__": env_path = os.path.join( os.path.dirname(__file__), "..", ".env" ) if os.path.exists(env_path): from dotenv import load_dotenv load_dotenv(env_path) build_vector_database()