Corin1998 commited on
Commit
5373d47
·
verified ·
1 Parent(s): 1fd8f5d

Create retriever.py

Browse files
Files changed (1) hide show
  1. rag/retriever.py +54 -0
rag/retriever.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List, Optional
3
+ import os, pickle
4
+ import numpy as np
5
+ import faiss
6
+
7
+ from .embeddings import embed_texts
8
+ from .util_text import clean_text
9
+
10
+ class RagStore:
11
+ """
12
+ 極小実装:
13
+ - build(docs): ベクトル化してFAISS IndexFlatIPに投入
14
+ - search(q,k): クエリを埋め込み→内積で上位k件返す
15
+ - available(): インデックスが存在するか
16
+ """
17
+ def __init__(self, index_dir: str = "data/index"):
18
+ self.index_dir = index_dir
19
+ self.index: Optional[faiss.Index] = None
20
+ self.docs: List[dict] = []
21
+ os.makedirs(index_dir, exist_ok=True)
22
+
23
+ def available(self) -> bool:
24
+ return self.index is not None and len(self.docs) > 0
25
+
26
+ def build(self, docs: List[dict]) -> None:
27
+ texts = [clean_text(d.get("text","")) for d in docs if d.get("text")]
28
+ metas = [d.get("meta") or {} for d in docs if d.get("text")]
29
+ ids = [d.get("id") for d in docs if d.get("text")]
30
+ embs = embed_texts(texts) # (n, dim)
31
+ dim = int(embs.shape[1]) if embs.ndim == 2 and embs.size > 0 else 384
32
+ index = faiss.IndexFlatIP(dim)
33
+ if embs.shape[0] > 0:
34
+ index.add(embs)
35
+ self.index = index
36
+ # 保存(任意)
37
+ with open(os.path.join(self.index_dir, "docs.pkl"), "wb") as f:
38
+ pickle.dump({"ids": ids, "metas": metas, "texts": texts}, f)
39
+ self.docs = [{"id": i, "text": t, "meta": m} for i, t, m in zip(ids, texts, metas)]
40
+
41
+ def search(self, query: str, k: int = 10) -> List[dict]:
42
+ if not self.available():
43
+ return []
44
+ qv = embed_texts([clean_text(query)])
45
+ if qv.shape[0] == 0:
46
+ return []
47
+ D, I = self.index.search(qv, min(k, max(1, len(self.docs))))
48
+ idxs = [int(i) for i in I[0] if 0 <= int(i) < len(self.docs)]
49
+ out = []
50
+ for j in idxs:
51
+ d = dict(self.docs[j]) # copy
52
+ d["score"] = float(D[0][idxs.index(j)]) if D.size else 0.0
53
+ out.append(d)
54
+ return out