Corin1998 commited on
Commit
97c7e20
·
verified ·
1 Parent(s): b785837

Update irpr/deps.py

Browse files
Files changed (1) hide show
  1. irpr/deps.py +134 -126
irpr/deps.py CHANGED
@@ -1,157 +1,165 @@
1
- # irpr/deps.py --- Chromadb版faiss不使用・LLMなしでも動く
2
  from __future__ import annotations
3
- import os
4
- from typing import List, Dict, Optional
5
  import numpy as np
6
  from irpr.config import settings
7
 
8
- # ===== 書き込み可能な場所にキャッシュを集約 =====
9
- BASE = settings.DATA_DIR or "./var"
10
- DEFAULT_CACHE = os.path.join(BASE, ".hf-cache")
11
- DEFAULT_HOME = os.path.join(BASE, ".hf-home")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # 既に設定が無ければ、ここで安全なデフォルトに設定
14
- os.environ.setdefault("HF_HOME", DEFAULT_HOME)
15
- os.environ.setdefault("TRANSFORMERS_CACHE", DEFAULT_CACHE)
16
- os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", DEFAULT_CACHE)
17
- os.environ.setdefault("HUGGINGFACE_HUB_CACHE", DEFAULT_CACHE)
18
 
19
- for d in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"],
20
- os.environ["SENTENCE_TRANSFORMERS_HOME"], os.environ["HUGGINGFACE_HUB_CACHE"],
21
- BASE, settings.CHROMA_PATH]:
22
  try:
23
- os.makedirs(d, exist_ok=True)
24
- except Exception:
25
- pass # フォルダ作成失敗しても致命ではな
 
 
 
 
26
 
27
- # 遅延ロード
28
- _EMB = None
29
- _EMB_DIM: Optional[int] = None
30
- _CHROMA_COLLECTION = None
31
- _GEN = None
32
- _TOK = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def _get_embedder():
35
- """SentenceTransformer を遅延ロード"""
36
- global _EMB, _EMB_DIM
37
- if _EMB is None:
38
- from sentence_transformers import SentenceTransformer
39
- _EMB = SentenceTransformer(
40
- settings.EMB_MODEL,
41
- cache_folder=os.environ.get("HF_HOME", DEFAULT_CACHE)
42
- )
43
- _EMB_DIM = _EMB.get_sentence_embedding_dimension()
44
- return _EMB
45
 
 
46
  def embed_texts(texts: List[str]) -> np.ndarray:
47
- emb = _get_embedder()
48
- arr = emb.encode(
49
- texts,
50
- batch_size=16,
51
- normalize_embeddings=True,
52
- convert_to_numpy=True,
53
- show_progress_bar=False
54
- )
55
- return arr.astype(np.float32, copy=False)
56
-
57
- def _get_chroma():
58
- """永続 Chromadb コレクションを取得"""
59
- global _CHROMA_COLLECTION
60
- if _CHROMA_COLLECTION is None:
61
- import chromadb
62
- from chromadb.config import Settings as CS
63
- client = chromadb.PersistentClient(
64
- path=settings.CHROMA_PATH,
65
- settings=CS(allow_reset=True)
66
- )
67
- _CHROMA_COLLECTION = client.get_or_create_collection(name="irpr_docs")
68
- return _CHROMA_COLLECTION
69
 
70
- def add_to_index(records: List[Dict]):
71
- """records: {text, title, source_url, doc_id, chunk_id} の配列"""
 
 
 
72
  if not records:
73
- return
74
- col = _get_chroma()
75
  texts = [r["text"] for r in records]
76
- embs = embed_texts(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- ids, metas = [], []
79
  for r in records:
80
- doc_id = r.get("doc_id") or "doc"
81
  chunk_id = r.get("chunk_id") or ""
82
- rid = f"{doc_id}:{chunk_id}" if chunk_id else doc_id
83
- ids.append(rid)
84
  metas.append({
85
  "source_url": r.get("source_url"),
86
  "title": r.get("title"),
87
  "doc_id": doc_id,
88
  "chunk_id": chunk_id,
 
89
  })
90
- col.add(ids=ids, documents=texts, embeddings=embs, metadatas=metas)
 
 
 
91
 
 
92
  def search(query: str, top_k=8) -> List[Dict]:
93
- col = _get_chroma()
94
- q_emb = embed_texts([query])
95
- res = col.query(
96
- query_embeddings=q_emb,
97
- n_results=top_k,
98
- include=["documents", "metadatas", "distances", "ids"]
99
- )
100
- docs = res.get("documents", [[]])[0]
101
- metas = res.get("metadatas", [[]])[0]
102
- dists = res.get("distances", [[]])[0]
103
  out: List[Dict] = []
104
- for text, meta, dist in zip(docs, metas, dists):
105
- score = 1.0 - float(dist)/2.0 if dist is not None else None # 類似度風スコア
106
  out.append({
107
- "text": text,
108
- "source_url": (meta or {}).get("source_url"),
109
- "title": (meta or {}).get("title"),
110
- "doc_id": (meta or {}).get("doc_id"),
111
- "chunk_id": (meta or {}).get("chunk_id"),
112
- "score": score,
113
  })
114
  return out
115
 
116
- # ==== 生成(任意) ====
117
- def _get_gen_pipeline():
118
- """
119
- GEN_MODEL が空なら LLM 無効の合図として例外を投げる。
120
- CPU環境でも動くように dtype/device_map は保守的に。
121
- """
122
- if not settings.GEN_MODEL:
123
- raise RuntimeError("GEN_MODEL is empty (LLM disabled).")
124
- global _GEN, _TOK
125
- if _GEN is None:
126
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
127
- # torch は任意(無ければCPU既定)
128
- try:
129
- import torch # noqa
130
- torch_dtype = getattr(torch, "bfloat16", None)
131
- except Exception:
132
- torch = None
133
- torch_dtype = None
134
-
135
- name = settings.GEN_MODEL
136
- cache_dir = os.environ.get("HF_HOME", DEFAULT_CACHE)
137
- _TOK = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
138
- # dtype/device_map はCPUでも成立する保守的な指定にする
139
- model_kwargs = dict(cache_dir=cache_dir, low_cpu_mem_usage=True)
140
- if torch and hasattr(torch, "cuda") and torch.cuda.is_available():
141
- model_kwargs["torch_dtype"] = getattr(torch, "bfloat16", None) or getattr(torch, "float16", None)
142
- model_kwargs["device_map"] = "auto"
143
-
144
- _MODEL = AutoModelForCausalLM.from_pretrained(name, **model_kwargs)
145
- _GEN = pipeline("text-generation", model=_MODEL, tokenizer=_TOK)
146
- return _GEN, _TOK
147
-
148
  def generate_chat(messages: List[Dict], max_new_tokens=600, temperature=0.2) -> str:
149
- gen, tok = _get_gen_pipeline()
150
- prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
151
- out = gen(
152
- prompt,
153
- do_sample=(temperature > 0.0),
154
- temperature=temperature,
155
- max_new_tokens=max_new_tokens
156
- )[0]["generated_text"]
157
- return out[len(prompt):].strip()
 
1
+ # irpr/deps.py --- OpenAI埋め込み + 自前ベクタストアnumpy/LLM生成
2
  from __future__ import annotations
3
+ import os, json, uuid
4
+ from typing import List, Dict, Optional, Tuple
5
  import numpy as np
6
  from irpr.config import settings
7
 
8
+ # ==== 書き込み可能ディレクトリの決定 ====
9
+ def _pick_writable_dir() -> str:
10
+ candidates = [settings.DATA_DIR, "/data", "./var", "/tmp/irpr", "."]
11
+ for base in candidates:
12
+ try:
13
+ if not base: continue
14
+ os.makedirs(base, exist_ok=True)
15
+ p = os.path.join(base, ".write_test")
16
+ with open(p, "w") as w: w.write("ok")
17
+ os.remove(p)
18
+ return base
19
+ except Exception:
20
+ continue
21
+ return "."
22
+
23
+ BASE_DIR = _pick_writable_dir()
24
+ INDEX_DIR = settings.INDEX_DIR or os.path.join(BASE_DIR, "simple_index")
25
+ os.makedirs(INDEX_DIR, exist_ok=True)
26
 
27
+ VECS_PATH = os.path.join(INDEX_DIR, "vectors.npy") # np.float32 [N,D](正規化済)
28
+ META_PATH = os.path.join(INDEX_DIR, "meta.jsonl") # 1行1メタ
29
+ TEXT_PATH = os.path.join(INDEX_DIR, "texts.jsonl") # 1行1テキスト
 
 
30
 
31
+ # ==== OpenAI クライアント ====
32
+ def _openai_client():
 
33
  try:
34
+ from openai import OpenAI
35
+ except Exception as e:
36
+ raise RuntimeError("`openai` パッケージが見つかりません。requirements.txt openai を追加してくださ。") from e
37
+ key = os.environ.get("OPENAI_API_KEY", "").strip()
38
+ if not key:
39
+ raise RuntimeError("OPENAI_API_KEY が未設定です。環境変数に設定してください。")
40
+ return OpenAI(api_key=key)
41
 
42
+ # ==== 収納・ロード ====
43
+ def _load_index() -> Tuple[np.ndarray, List[dict], List[str]]:
44
+ if os.path.exists(VECS_PATH):
45
+ vecs = np.load(VECS_PATH).astype(np.float32, copy=False)
46
+ else:
47
+ vecs = np.zeros((0, 0), dtype=np.float32)
48
+ metas: List[dict] = []
49
+ texts: List[str] = []
50
+ if os.path.exists(META_PATH):
51
+ with open(META_PATH, "r", encoding="utf-8") as f:
52
+ for line in f:
53
+ line = line.strip()
54
+ if line:
55
+ metas.append(json.loads(line))
56
+ if os.path.exists(TEXT_PATH):
57
+ with open(TEXT_PATH, "r", encoding="utf-8") as f:
58
+ for line in f:
59
+ texts.append(line.rstrip("\n"))
60
+ # 整合性チェック
61
+ if vecs.size == 0:
62
+ return np.zeros((0, 0), dtype=np.float32), [], []
63
+ n = vecs.shape[0]
64
+ if len(metas) != n or len(texts) != n:
65
+ # 壊れているなら初期化
66
+ return np.zeros((0, 0), dtype=np.float32), [], []
67
+ return vecs, metas, texts
68
 
69
+ def _save_index(vecs: np.ndarray, metas: List[dict], texts: List[str]) -> None:
70
+ os.makedirs(INDEX_DIR, exist_ok=True)
71
+ np.save(VECS_PATH, vecs.astype(np.float32, copy=False))
72
+ with open(META_PATH, "w", encoding="utf-8") as f:
73
+ for m in metas:
74
+ f.write(json.dumps(m, ensure_ascii=False) + "\n")
75
+ with open(TEXT_PATH, "w", encoding="utf-8") as f:
76
+ for t in texts:
77
+ f.write((t or "").replace("\n", "\\n") + "\n") # 1行1テキストに正規化
 
 
78
 
79
+ # ==== Embedding ====
80
  def embed_texts(texts: List[str]) -> np.ndarray:
81
+ client = _openai_client()
82
+ model = settings.OPENAI_EMBED_MODEL
83
+ # バッチで呼ぶ
84
+ B = 128
85
+ out = []
86
+ for i in range(0, len(texts), B):
87
+ batch = texts[i:i+B]
88
+ resp = client.embeddings.create(model=model, input=batch)
89
+ out.extend([d.embedding for d in resp.data])
90
+ arr = np.array(out, dtype=np.float32)
91
+ # 正規化(コサイン類似度用)
92
+ norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
93
+ return arr / norms
 
 
 
 
 
 
 
 
 
94
 
95
+ # ==== 追加 ====
96
+ def add_to_index(records: List[Dict]) -> int:
97
+ """
98
+ records: [{text, title, source_url, doc_id, chunk_id}]
99
+ """
100
  if not records:
101
+ return 0
 
102
  texts = [r["text"] for r in records]
103
+ vecs_new = embed_texts(texts)
104
+
105
+ vecs, metas, old_texts = _load_index()
106
+ if vecs.size == 0:
107
+ vecs = vecs_new
108
+ metas = []
109
+ old_texts = []
110
+ else:
111
+ if vecs.shape[1] != vecs_new.shape[1]:
112
+ # 埋め込み次元が違う(モデルを変えた等)→作り直し
113
+ vecs = vecs_new
114
+ metas = []
115
+ old_texts = []
116
+ else:
117
+ vecs = np.vstack([vecs, vecs_new])
118
 
 
119
  for r in records:
120
+ doc_id = r.get("doc_id") or str(uuid.uuid4())
121
  chunk_id = r.get("chunk_id") or ""
 
 
122
  metas.append({
123
  "source_url": r.get("source_url"),
124
  "title": r.get("title"),
125
  "doc_id": doc_id,
126
  "chunk_id": chunk_id,
127
+ "id": f"{doc_id}:{chunk_id}" if chunk_id else doc_id
128
  })
129
+ old_texts.append(r.get("text", ""))
130
+
131
+ _save_index(vecs, metas, old_texts)
132
+ return len(records)
133
 
134
+ # ==== 検索 ====
135
  def search(query: str, top_k=8) -> List[Dict]:
136
+ vecs, metas, texts = _load_index()
137
+ if vecs.size == 0:
138
+ return []
139
+ q = embed_texts([query])[0] # (D,)
140
+ scores = vecs @ q # cosine (正規化済み)
141
+ idx = np.argsort(-scores)[:max(1, top_k)]
 
 
 
 
142
  out: List[Dict] = []
143
+ for i in idx.tolist():
144
+ m = metas[i]
145
  out.append({
146
+ "text": (texts[i] or "").replace("\\n", "\n"),
147
+ "source_url": m.get("source_url"),
148
+ "title": m.get("title"),
149
+ "doc_id": m.get("doc_id"),
150
+ "chunk_id": m.get("chunk_id"),
151
+ "score": float(scores[i]),
152
  })
153
  return out
154
 
155
+ # ==== 生成 ====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def generate_chat(messages: List[Dict], max_new_tokens=600, temperature=0.2) -> str:
157
+ client = _openai_client()
158
+ model = settings.OPENAI_CHAT_MODEL
159
+ resp = client.chat.completions.create(
160
+ model=model,
161
+ messages=messages,
162
+ temperature=float(temperature),
163
+ max_tokens=int(max_new_tokens),
164
+ )
165
+ return (resp.choices[0].message.content or "").strip()