RAGSample / rag_core /retriever.py
VietCat's picture
update gemini
23597f5
import faiss
import numpy as np
import os
import pickle
import logging
from rag_core.utils import log_timed
# ✅ Sử dụng thư mục tương đối, tránh bị hiểu nhầm sang /data (root)
INDEX_PATH = "faiss_index/index.faiss"
META_PATH = "faiss_index/meta.pkl"
class Retriever:
def __init__(self):
if os.path.exists(INDEX_PATH):
logging.info(f"✅ Đã tìm thấy index: {INDEX_PATH}")
self.index = faiss.read_index(INDEX_PATH)
with open(META_PATH, "rb") as f:
self.texts = pickle.load(f)
else:
logging.info("⚠️ Chưa có index. Cần xây dựng mới.")
self.index = None
self.texts = []
@log_timed("xây FAISS index")
def build(self, texts: list, embed_fn):
embeddings = []
valid_texts = []
for i, t in enumerate(texts):
try:
emb = embed_fn(t)
embeddings.append(emb)
valid_texts.append(t)
except Exception as e:
logging.warning(f"❌ Lỗi embedding chunk {i}: {e}\nNội dung chunk: {t[:300]}{'...' if len(t) > 300 else ''}")
if not embeddings:
raise RuntimeError("Không có embedding nào thành công!")
dim = len(embeddings[0])
self.index = faiss.IndexFlatL2(dim)
self.index.add(np.array(embeddings).astype("float32"))
os.makedirs(os.path.dirname(INDEX_PATH), exist_ok=True)
faiss.write_index(self.index, INDEX_PATH)
with open(META_PATH, "wb") as f:
pickle.dump(valid_texts, f)
self.texts = valid_texts
@log_timed("truy vấn FAISS")
def query(self, query_text, embed_fn, k=3):
if self.index is None:
raise RuntimeError("FAISS index chưa được xây dựng. Hãy build index trước khi truy vấn.")
q_emb = np.array([embed_fn(query_text)]).astype("float32")
D, I = self.index.search(q_emb, k)
return [self.texts[i] for i in I[0]]
@log_timed("bổ sung embedding bị thiếu")
def rescan_and_append(self, full_texts, embed_fn):
if self.index is None:
raise RuntimeError("FAISS index chưa được xây dựng. Hãy build index trước khi bổ sung embedding.")
existing_set = set(self.texts)
new_texts = [t for t in full_texts if t not in existing_set]
if not new_texts:
logging.info("📭 Không có chunk mới để thêm.")
return
new_embeddings = []
for i, t in enumerate(new_texts):
try:
emb = embed_fn(t)
new_embeddings.append(emb)
self.texts.append(t)
except Exception as e:
logging.warning(f"❌ Lỗi embedding chunk mới {i}: {e}\nNội dung chunk: {t[:300]}{'...' if len(t) > 300 else ''}")
if new_embeddings:
self.index.add(np.array(new_embeddings).astype("float32"))
faiss.write_index(self.index, INDEX_PATH)
with open(META_PATH, "wb") as f:
pickle.dump(self.texts, f)