Corin1998's picture
Update rag/retriever.py
9916250 verified
import os
import pickle
from typing import List, Dict, Any
import numpy as np
import faiss
# 絶対インポートで安定化
from rag.embeddings import embed_texts
class RagStore:
def __init__(self, index_dir: str = "data/index"):
self.index_dir = index_dir
self.index = None # faiss index
self.meta: List[Dict[str, Any]] = [] # 原文+メタ
def _index_path(self) -> str:
return os.path.join(self.index_dir, "index.faiss")
def _meta_path(self) -> str:
return os.path.join(self.index_dir, "meta.pkl")
def available(self) -> bool:
return os.path.exists(self._index_path()) and os.path.exists(self._meta_path())
def load(self):
"""メモリにFAISSとメタデータをロード(未ロードなら)。"""
if self.index is not None and self.meta:
return
if not self.available():
return
self.index = faiss.read_index(self._index_path())
with open(self._meta_path(), "rb") as f:
self.meta = pickle.load(f)
def build(self, docs: List[Dict[str, Any]]):
"""docs = [{'text': str, 'meta': {...}}, ...] を受け取りインデックスを構築。"""
os.makedirs(self.index_dir, exist_ok=True)
texts = [d["text"] for d in docs]
embs = embed_texts(texts).astype("float32")
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs)
# 保存
faiss.write_index(index, self._index_path())
with open(self._meta_path(), "wb") as f:
pickle.dump(docs, f)
# メモリにも保持
self.index = index
self.meta = docs
def search(self, query: str, k: int = 8) -> List[Dict[str, Any]]:
"""クエリに対する上位k件を返す。各要素に 'score' を付与。"""
self.load()
if self.index is None or not self.meta:
return []
q = embed_texts([query]).astype("float32")
D, I = self.index.search(q, k)
hits: List[Dict[str, Any]] = []
for idx, score in zip(I[0], D[0]):
if idx < 0:
continue
item = dict(self.meta[idx])
item["score"] = float(score)
hits.append(item)
return hits