Corin1998 commited on
Commit
c52f5bf
·
verified ·
1 Parent(s): 1716018

Update modules/rag_retiever.py

Browse files
Files changed (1) hide show
  1. modules/rag_retiever.py +10 -53
modules/rag_retiever.py CHANGED
@@ -1,42 +1,32 @@
 
1
  import os
2
  import json
3
  import time
4
  from pathlib import Path
5
- from typing import List, Tuple, Dict, Any, Optional
6
 
7
  import numpy as np
8
 
9
- # 依存は遅延ロード(Space 起動を速く&環境差での ImportError を回避)
10
  def _lazy_imports():
11
  from sentence_transformers import SentenceTransformer
12
- import numpy as _np # noqa
13
  return SentenceTransformer
14
 
15
- # 内部ユーティリティ
16
  def _now() -> int:
17
  return int(time.time())
18
 
19
- # === ストレージ場所は utils の auto-pick に従う ===
20
  try:
21
  from modules.utils import ensure_dirs, data_dir
22
  except Exception:
23
- # 非常時フォールバック
24
  def ensure_dirs() -> None:
25
  Path("/tmp/agent_studio").mkdir(parents=True, exist_ok=True)
26
  def data_dir() -> Path:
27
  ensure_dirs()
28
  return Path("/tmp/agent_studio")
29
 
30
- # ========= チャンク読み込み =========
31
  def _chunks_path() -> Path:
32
  return data_dir() / "chunks.jsonl"
33
 
34
  def _load_chunks() -> List[Dict[str, Any]]:
35
- """
36
- rag_indexer が書き出した想定の簡易フォーマット:
37
- each line: {"text": "...", "source": "path_or_url", "meta": {...}}
38
- 無ければ空リスト。
39
- """
40
  p = _chunks_path()
41
  if not p.exists():
42
  return []
@@ -54,9 +44,7 @@ def _load_chunks() -> List[Dict[str, Any]]:
54
  continue
55
  return rows
56
 
57
- # ========= Embedding モデルとキャッシュ =========
58
  def _emb_model_name() -> str:
59
- # indexer と揃える前提。未指定なら軽量モデル
60
  return os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
61
 
62
  def _emb_cache_dir() -> Path:
@@ -67,12 +55,6 @@ def _emb_cache_paths() -> Tuple[Path, Path]:
67
  return d / "embeddings.npy", d / "meta.json"
68
 
69
  def _load_or_build_embeddings(chunks: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[int]]:
70
- """
71
- - 既存キャッシュ(embeddings.npy, meta.json)が chunks と一致すればそれをロード
72
- - 不一致または欠損なら再計算して保存
73
- Returns:
74
- (emb_matrix: np.ndarray [N, D], indices: List[int] mapping to chunks)
75
- """
76
  ensure_dirs()
77
  _emb_cache_dir().mkdir(parents=True, exist_ok=True)
78
  npy_path, meta_path = _emb_cache_paths()
@@ -81,55 +63,36 @@ def _load_or_build_embeddings(chunks: List[Dict[str, Any]]) -> Tuple[np.ndarray,
81
  try:
82
  with open(meta_path, "r", encoding="utf-8") as f:
83
  meta = json.load(f)
84
- N_meta = int(meta.get("n", -1))
85
- if N_meta == len(chunks) and meta.get("model") == _emb_model_name():
86
  emb = np.load(npy_path)
87
- idx = list(range(len(chunks)))
88
- # 次元不整合チェック
89
  if emb.shape[0] == len(chunks):
90
- return emb, idx
91
  except Exception:
92
- pass # 再計算へ
93
 
94
- # 再計算
95
  SentenceTransformer = _lazy_imports()
96
  model = SentenceTransformer(_emb_model_name())
97
  texts = [str(c.get("text", "")) for c in chunks]
98
  if not texts:
99
  return np.zeros((0, 384), dtype="float32"), []
100
- emb = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True) # shape: [N, D]
101
  np.save(npy_path, emb)
102
  with open(meta_path, "w", encoding="utf-8") as f:
103
  json.dump({"n": len(chunks), "model": _emb_model_name(), "ts": _now()}, f)
104
  return emb, list(range(len(chunks)))
105
 
106
- # ========= 類似度計算 =========
107
  def _cosine_topk(matrix: np.ndarray, query_vec: np.ndarray, top_k: int) -> List[int]:
108
- """
109
- 行列: [N, D](正規化済み想定), query: [D]
110
- 返り値: 上位インデックス
111
- """
112
  if matrix.size == 0:
113
  return []
114
- # dot がそのまま cos 類似度(normalize_embeddings=True を前提)
115
  sims = matrix @ query_vec
116
- # np.argpartition で高速 topk
117
  k = min(top_k, matrix.shape[0])
118
  part = np.argpartition(-sims, k - 1)[:k]
119
- # 類似度で並べ替え
120
  part_sorted = part[np.argsort(-sims[part])]
121
  return part_sorted.tolist()
122
 
123
- # ========= 公開 API =========
124
  def retrieve_contexts(query: str, top_k: int = 5) -> List[str]:
125
- """
126
- クエリに対して、保存済みのチャンク(chunks.jsonl)から上位コンテキストを返す。
127
- - 埋め込みは emb_cache にキャッシュ
128
- - モデル: EMBEDDING_MODEL(未指定時 all-MiniLM-L6-v2)
129
- """
130
  chunks = _load_chunks()
131
  if not chunks:
132
- # インデックス未構築
133
  return []
134
 
135
  SentenceTransformer = _lazy_imports()
@@ -143,14 +106,8 @@ def retrieve_contexts(query: str, top_k: int = 5) -> List[str]:
143
  top_idx = _cosine_topk(emb_matrix, q_vec, top_k)
144
  results: List[str] = []
145
  for i in top_idx:
146
- try:
147
- ch = chunks[idx_map[i]]
148
- txt = str(ch.get("text", "")).strip()
149
- src = ch.get("source")
150
- if src:
151
- results.append(f"{txt}\n[source] {src}")
152
- else:
153
- results.append(txt)
154
- except Exception:
155
- continue
156
  return results
 
1
+ # modules/rag_retriever.py
2
  import os
3
  import json
4
  import time
5
  from pathlib import Path
6
+ from typing import List, Tuple, Dict, Any
7
 
8
  import numpy as np
9
 
 
10
  def _lazy_imports():
11
  from sentence_transformers import SentenceTransformer
 
12
  return SentenceTransformer
13
 
 
14
  def _now() -> int:
15
  return int(time.time())
16
 
 
17
  try:
18
  from modules.utils import ensure_dirs, data_dir
19
  except Exception:
 
20
  def ensure_dirs() -> None:
21
  Path("/tmp/agent_studio").mkdir(parents=True, exist_ok=True)
22
  def data_dir() -> Path:
23
  ensure_dirs()
24
  return Path("/tmp/agent_studio")
25
 
 
26
  def _chunks_path() -> Path:
27
  return data_dir() / "chunks.jsonl"
28
 
29
  def _load_chunks() -> List[Dict[str, Any]]:
 
 
 
 
 
30
  p = _chunks_path()
31
  if not p.exists():
32
  return []
 
44
  continue
45
  return rows
46
 
 
47
  def _emb_model_name() -> str:
 
48
  return os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
49
 
50
  def _emb_cache_dir() -> Path:
 
55
  return d / "embeddings.npy", d / "meta.json"
56
 
57
  def _load_or_build_embeddings(chunks: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[int]]:
 
 
 
 
 
 
58
  ensure_dirs()
59
  _emb_cache_dir().mkdir(parents=True, exist_ok=True)
60
  npy_path, meta_path = _emb_cache_paths()
 
63
  try:
64
  with open(meta_path, "r", encoding="utf-8") as f:
65
  meta = json.load(f)
66
+ if int(meta.get("n", -1)) == len(chunks) and meta.get("model") == _emb_model_name():
 
67
  emb = np.load(npy_path)
 
 
68
  if emb.shape[0] == len(chunks):
69
+ return emb, list(range(len(chunks)))
70
  except Exception:
71
+ pass
72
 
 
73
  SentenceTransformer = _lazy_imports()
74
  model = SentenceTransformer(_emb_model_name())
75
  texts = [str(c.get("text", "")) for c in chunks]
76
  if not texts:
77
  return np.zeros((0, 384), dtype="float32"), []
78
+ emb = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
79
  np.save(npy_path, emb)
80
  with open(meta_path, "w", encoding="utf-8") as f:
81
  json.dump({"n": len(chunks), "model": _emb_model_name(), "ts": _now()}, f)
82
  return emb, list(range(len(chunks)))
83
 
 
84
  def _cosine_topk(matrix: np.ndarray, query_vec: np.ndarray, top_k: int) -> List[int]:
 
 
 
 
85
  if matrix.size == 0:
86
  return []
 
87
  sims = matrix @ query_vec
 
88
  k = min(top_k, matrix.shape[0])
89
  part = np.argpartition(-sims, k - 1)[:k]
 
90
  part_sorted = part[np.argsort(-sims[part])]
91
  return part_sorted.tolist()
92
 
 
93
  def retrieve_contexts(query: str, top_k: int = 5) -> List[str]:
 
 
 
 
 
94
  chunks = _load_chunks()
95
  if not chunks:
 
96
  return []
97
 
98
  SentenceTransformer = _lazy_imports()
 
106
  top_idx = _cosine_topk(emb_matrix, q_vec, top_k)
107
  results: List[str] = []
108
  for i in top_idx:
109
+ ch = chunks[idx_map[i]]
110
+ txt = str(ch.get("text", "")).strip()
111
+ src = ch.get("source")
112
+ results.append(f"{txt}\n[source] {src}" if src else txt)
 
 
 
 
 
 
113
  return results