Create retriever.py
Browse files- rag/retriever.py +54 -0
rag/retriever.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
import os, pickle
|
| 4 |
+
import numpy as np
|
| 5 |
+
import faiss
|
| 6 |
+
|
| 7 |
+
from .embeddings import embed_texts
|
| 8 |
+
from .util_text import clean_text
|
| 9 |
+
|
| 10 |
+
class RagStore:
|
| 11 |
+
"""
|
| 12 |
+
極小実装:
|
| 13 |
+
- build(docs): ベクトル化してFAISS IndexFlatIPに投入
|
| 14 |
+
- search(q,k): クエリを埋め込み→内積で上位k件返す
|
| 15 |
+
- available(): インデックスが存在するか
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, index_dir: str = "data/index"):
|
| 18 |
+
self.index_dir = index_dir
|
| 19 |
+
self.index: Optional[faiss.Index] = None
|
| 20 |
+
self.docs: List[dict] = []
|
| 21 |
+
os.makedirs(index_dir, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
def available(self) -> bool:
|
| 24 |
+
return self.index is not None and len(self.docs) > 0
|
| 25 |
+
|
| 26 |
+
def build(self, docs: List[dict]) -> None:
|
| 27 |
+
texts = [clean_text(d.get("text","")) for d in docs if d.get("text")]
|
| 28 |
+
metas = [d.get("meta") or {} for d in docs if d.get("text")]
|
| 29 |
+
ids = [d.get("id") for d in docs if d.get("text")]
|
| 30 |
+
embs = embed_texts(texts) # (n, dim)
|
| 31 |
+
dim = int(embs.shape[1]) if embs.ndim == 2 and embs.size > 0 else 384
|
| 32 |
+
index = faiss.IndexFlatIP(dim)
|
| 33 |
+
if embs.shape[0] > 0:
|
| 34 |
+
index.add(embs)
|
| 35 |
+
self.index = index
|
| 36 |
+
# 保存(任意)
|
| 37 |
+
with open(os.path.join(self.index_dir, "docs.pkl"), "wb") as f:
|
| 38 |
+
pickle.dump({"ids": ids, "metas": metas, "texts": texts}, f)
|
| 39 |
+
self.docs = [{"id": i, "text": t, "meta": m} for i, t, m in zip(ids, texts, metas)]
|
| 40 |
+
|
| 41 |
+
def search(self, query: str, k: int = 10) -> List[dict]:
|
| 42 |
+
if not self.available():
|
| 43 |
+
return []
|
| 44 |
+
qv = embed_texts([clean_text(query)])
|
| 45 |
+
if qv.shape[0] == 0:
|
| 46 |
+
return []
|
| 47 |
+
D, I = self.index.search(qv, min(k, max(1, len(self.docs))))
|
| 48 |
+
idxs = [int(i) for i in I[0] if 0 <= int(i) < len(self.docs)]
|
| 49 |
+
out = []
|
| 50 |
+
for j in idxs:
|
| 51 |
+
d = dict(self.docs[j]) # copy
|
| 52 |
+
d["score"] = float(D[0][idxs.index(j)]) if D.size else 0.0
|
| 53 |
+
out.append(d)
|
| 54 |
+
return out
|