Spaces:
Sleeping
Sleeping
Update modules/rag_retiever.py
Browse files- modules/rag_retiever.py +10 -53
modules/rag_retiever.py
CHANGED
|
@@ -1,42 +1,32 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import time
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import List, Tuple, Dict, Any
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
-
# 依存は遅延ロード(Space 起動を速く&環境差での ImportError を回避)
|
| 10 |
def _lazy_imports():
|
| 11 |
from sentence_transformers import SentenceTransformer
|
| 12 |
-
import numpy as _np # noqa
|
| 13 |
return SentenceTransformer
|
| 14 |
|
| 15 |
-
# 内部ユーティリティ
|
| 16 |
def _now() -> int:
|
| 17 |
return int(time.time())
|
| 18 |
|
| 19 |
-
# === ストレージ場所は utils の auto-pick に従う ===
|
| 20 |
try:
|
| 21 |
from modules.utils import ensure_dirs, data_dir
|
| 22 |
except Exception:
|
| 23 |
-
# 非常時フォールバック
|
| 24 |
def ensure_dirs() -> None:
|
| 25 |
Path("/tmp/agent_studio").mkdir(parents=True, exist_ok=True)
|
| 26 |
def data_dir() -> Path:
|
| 27 |
ensure_dirs()
|
| 28 |
return Path("/tmp/agent_studio")
|
| 29 |
|
| 30 |
-
# ========= チャンク読み込み =========
|
| 31 |
def _chunks_path() -> Path:
|
| 32 |
return data_dir() / "chunks.jsonl"
|
| 33 |
|
| 34 |
def _load_chunks() -> List[Dict[str, Any]]:
|
| 35 |
-
"""
|
| 36 |
-
rag_indexer が書き出した想定の簡易フォーマット:
|
| 37 |
-
each line: {"text": "...", "source": "path_or_url", "meta": {...}}
|
| 38 |
-
無ければ空リスト。
|
| 39 |
-
"""
|
| 40 |
p = _chunks_path()
|
| 41 |
if not p.exists():
|
| 42 |
return []
|
|
@@ -54,9 +44,7 @@ def _load_chunks() -> List[Dict[str, Any]]:
|
|
| 54 |
continue
|
| 55 |
return rows
|
| 56 |
|
| 57 |
-
# ========= Embedding モデルとキャッシュ =========
|
| 58 |
def _emb_model_name() -> str:
|
| 59 |
-
# indexer と揃える前提。未指定なら軽量モデル
|
| 60 |
return os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 61 |
|
| 62 |
def _emb_cache_dir() -> Path:
|
|
@@ -67,12 +55,6 @@ def _emb_cache_paths() -> Tuple[Path, Path]:
|
|
| 67 |
return d / "embeddings.npy", d / "meta.json"
|
| 68 |
|
| 69 |
def _load_or_build_embeddings(chunks: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[int]]:
|
| 70 |
-
"""
|
| 71 |
-
- 既存キャッシュ(embeddings.npy, meta.json)が chunks と一致すればそれをロード
|
| 72 |
-
- 不一致または欠損なら再計算して保存
|
| 73 |
-
Returns:
|
| 74 |
-
(emb_matrix: np.ndarray [N, D], indices: List[int] mapping to chunks)
|
| 75 |
-
"""
|
| 76 |
ensure_dirs()
|
| 77 |
_emb_cache_dir().mkdir(parents=True, exist_ok=True)
|
| 78 |
npy_path, meta_path = _emb_cache_paths()
|
|
@@ -81,55 +63,36 @@ def _load_or_build_embeddings(chunks: List[Dict[str, Any]]) -> Tuple[np.ndarray,
|
|
| 81 |
try:
|
| 82 |
with open(meta_path, "r", encoding="utf-8") as f:
|
| 83 |
meta = json.load(f)
|
| 84 |
-
|
| 85 |
-
if N_meta == len(chunks) and meta.get("model") == _emb_model_name():
|
| 86 |
emb = np.load(npy_path)
|
| 87 |
-
idx = list(range(len(chunks)))
|
| 88 |
-
# 次元不整合チェック
|
| 89 |
if emb.shape[0] == len(chunks):
|
| 90 |
-
return emb,
|
| 91 |
except Exception:
|
| 92 |
-
pass
|
| 93 |
|
| 94 |
-
# 再計算
|
| 95 |
SentenceTransformer = _lazy_imports()
|
| 96 |
model = SentenceTransformer(_emb_model_name())
|
| 97 |
texts = [str(c.get("text", "")) for c in chunks]
|
| 98 |
if not texts:
|
| 99 |
return np.zeros((0, 384), dtype="float32"), []
|
| 100 |
-
emb = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
|
| 101 |
np.save(npy_path, emb)
|
| 102 |
with open(meta_path, "w", encoding="utf-8") as f:
|
| 103 |
json.dump({"n": len(chunks), "model": _emb_model_name(), "ts": _now()}, f)
|
| 104 |
return emb, list(range(len(chunks)))
|
| 105 |
|
| 106 |
-
# ========= 類似度計算 =========
|
| 107 |
def _cosine_topk(matrix: np.ndarray, query_vec: np.ndarray, top_k: int) -> List[int]:
|
| 108 |
-
"""
|
| 109 |
-
行列: [N, D](正規化済み想定), query: [D]
|
| 110 |
-
返り値: 上位インデックス
|
| 111 |
-
"""
|
| 112 |
if matrix.size == 0:
|
| 113 |
return []
|
| 114 |
-
# dot がそのまま cos 類似度(normalize_embeddings=True を前提)
|
| 115 |
sims = matrix @ query_vec
|
| 116 |
-
# np.argpartition で高速 topk
|
| 117 |
k = min(top_k, matrix.shape[0])
|
| 118 |
part = np.argpartition(-sims, k - 1)[:k]
|
| 119 |
-
# 類似度で並べ替え
|
| 120 |
part_sorted = part[np.argsort(-sims[part])]
|
| 121 |
return part_sorted.tolist()
|
| 122 |
|
| 123 |
-
# ========= 公開 API =========
|
| 124 |
def retrieve_contexts(query: str, top_k: int = 5) -> List[str]:
|
| 125 |
-
"""
|
| 126 |
-
クエリに対して、保存済みのチャンク(chunks.jsonl)から上位コンテキストを返す。
|
| 127 |
-
- 埋め込みは emb_cache にキャッシュ
|
| 128 |
-
- モデル: EMBEDDING_MODEL(未指定時 all-MiniLM-L6-v2)
|
| 129 |
-
"""
|
| 130 |
chunks = _load_chunks()
|
| 131 |
if not chunks:
|
| 132 |
-
# インデックス未構築
|
| 133 |
return []
|
| 134 |
|
| 135 |
SentenceTransformer = _lazy_imports()
|
|
@@ -143,14 +106,8 @@ def retrieve_contexts(query: str, top_k: int = 5) -> List[str]:
|
|
| 143 |
top_idx = _cosine_topk(emb_matrix, q_vec, top_k)
|
| 144 |
results: List[str] = []
|
| 145 |
for i in top_idx:
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
if src:
|
| 151 |
-
results.append(f"{txt}\n[source] {src}")
|
| 152 |
-
else:
|
| 153 |
-
results.append(txt)
|
| 154 |
-
except Exception:
|
| 155 |
-
continue
|
| 156 |
return results
|
|
|
|
| 1 |
+
# modules/rag_retriever.py
|
| 2 |
import os
|
| 3 |
import json
|
| 4 |
import time
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import List, Tuple, Dict, Any
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
|
|
|
|
| 10 |
def _lazy_imports():
|
| 11 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 12 |
return SentenceTransformer
|
| 13 |
|
|
|
|
| 14 |
def _now() -> int:
|
| 15 |
return int(time.time())
|
| 16 |
|
|
|
|
| 17 |
try:
|
| 18 |
from modules.utils import ensure_dirs, data_dir
|
| 19 |
except Exception:
|
|
|
|
| 20 |
def ensure_dirs() -> None:
|
| 21 |
Path("/tmp/agent_studio").mkdir(parents=True, exist_ok=True)
|
| 22 |
def data_dir() -> Path:
|
| 23 |
ensure_dirs()
|
| 24 |
return Path("/tmp/agent_studio")
|
| 25 |
|
|
|
|
| 26 |
def _chunks_path() -> Path:
|
| 27 |
return data_dir() / "chunks.jsonl"
|
| 28 |
|
| 29 |
def _load_chunks() -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
p = _chunks_path()
|
| 31 |
if not p.exists():
|
| 32 |
return []
|
|
|
|
| 44 |
continue
|
| 45 |
return rows
|
| 46 |
|
|
|
|
| 47 |
def _emb_model_name() -> str:
|
|
|
|
| 48 |
return os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 49 |
|
| 50 |
def _emb_cache_dir() -> Path:
|
|
|
|
| 55 |
return d / "embeddings.npy", d / "meta.json"
|
| 56 |
|
| 57 |
def _load_or_build_embeddings(chunks: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[int]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
ensure_dirs()
|
| 59 |
_emb_cache_dir().mkdir(parents=True, exist_ok=True)
|
| 60 |
npy_path, meta_path = _emb_cache_paths()
|
|
|
|
| 63 |
try:
|
| 64 |
with open(meta_path, "r", encoding="utf-8") as f:
|
| 65 |
meta = json.load(f)
|
| 66 |
+
if int(meta.get("n", -1)) == len(chunks) and meta.get("model") == _emb_model_name():
|
|
|
|
| 67 |
emb = np.load(npy_path)
|
|
|
|
|
|
|
| 68 |
if emb.shape[0] == len(chunks):
|
| 69 |
+
return emb, list(range(len(chunks)))
|
| 70 |
except Exception:
|
| 71 |
+
pass
|
| 72 |
|
|
|
|
| 73 |
SentenceTransformer = _lazy_imports()
|
| 74 |
model = SentenceTransformer(_emb_model_name())
|
| 75 |
texts = [str(c.get("text", "")) for c in chunks]
|
| 76 |
if not texts:
|
| 77 |
return np.zeros((0, 384), dtype="float32"), []
|
| 78 |
+
emb = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
|
| 79 |
np.save(npy_path, emb)
|
| 80 |
with open(meta_path, "w", encoding="utf-8") as f:
|
| 81 |
json.dump({"n": len(chunks), "model": _emb_model_name(), "ts": _now()}, f)
|
| 82 |
return emb, list(range(len(chunks)))
|
| 83 |
|
|
|
|
| 84 |
def _cosine_topk(matrix: np.ndarray, query_vec: np.ndarray, top_k: int) -> List[int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
if matrix.size == 0:
|
| 86 |
return []
|
|
|
|
| 87 |
sims = matrix @ query_vec
|
|
|
|
| 88 |
k = min(top_k, matrix.shape[0])
|
| 89 |
part = np.argpartition(-sims, k - 1)[:k]
|
|
|
|
| 90 |
part_sorted = part[np.argsort(-sims[part])]
|
| 91 |
return part_sorted.tolist()
|
| 92 |
|
|
|
|
| 93 |
def retrieve_contexts(query: str, top_k: int = 5) -> List[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
chunks = _load_chunks()
|
| 95 |
if not chunks:
|
|
|
|
| 96 |
return []
|
| 97 |
|
| 98 |
SentenceTransformer = _lazy_imports()
|
|
|
|
| 106 |
top_idx = _cosine_topk(emb_matrix, q_vec, top_k)
|
| 107 |
results: List[str] = []
|
| 108 |
for i in top_idx:
|
| 109 |
+
ch = chunks[idx_map[i]]
|
| 110 |
+
txt = str(ch.get("text", "")).strip()
|
| 111 |
+
src = ch.get("source")
|
| 112 |
+
results.append(f"{txt}\n[source] {src}" if src else txt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
return results
|