Corin1998 commited on
Commit
b8d8524
·
verified ·
1 Parent(s): ef1019a

Update irpr/deps.py

Browse files
Files changed (1) hide show
  1. irpr/deps.py +89 -131
irpr/deps.py CHANGED
@@ -1,95 +1,103 @@
 
 
1
  import os
2
- import pickle
3
  import numpy as np
4
-
5
- # ========= キャッシュ/データディレクトリ(import時は作成と権限確認だけ) =========
6
- CACHE_DIR = os.environ.get("HF_HOME") or "/tmp/hf-cache"
7
- for k in ["HF_HOME", "TRANSFORMERS_CACHE", "SENTENCE_TRANSFORMERS_HOME", "TORCH_HOME"]:
8
- os.environ.setdefault(k, CACHE_DIR)
9
- os.makedirs(CACHE_DIR, exist_ok=True)
10
-
11
  from irpr.config import settings
12
 
13
- DATA_DIR = settings.DATA_DIR or "data"
14
- os.makedirs(DATA_DIR, exist_ok=True)
15
-
16
- VEC_PATH = os.path.join(DATA_DIR, "vectors.npy")
17
- STORE_PATH = os.path.join(DATA_DIR, "store.pkl")
18
-
19
- # ========= グローバル状態(import時にモデルは作らない!) =========
20
- from typing import Optional, List, Dict
21
-
22
- _EMB = None # SentenceTransformer
 
 
 
23
  _EMB_DIM: Optional[int] = None
24
- _TOK = None # AutoTokenizer
25
- _MODEL = None # AutoModelForCausalLM
26
- _GEN = None # pipeline("text-generation")
27
 
28
- _VECTORS: Optional[np.ndarray] = None # shape [N, D] float32
29
- _STORE: Optional[List[Dict]] = None
30
-
31
- # ========= 永続インデックスの入出力 =========
32
- def _load_index():
33
- """ベクトル/メタの遅延ロード(モデルは触らない)"""
34
- global _VECTORS, _STORE
35
- if _VECTORS is None:
36
- if os.path.exists(VEC_PATH):
37
- try:
38
- arr = np.load(VEC_PATH)
39
- _VECTORS = arr.astype(np.float32, copy=False)
40
- except Exception:
41
- _VECTORS = np.empty((0, 0), dtype=np.float32)
42
- else:
43
- _VECTORS = np.empty((0, 0), dtype=np.float32)
44
- if _STORE is None:
45
- if os.path.exists(STORE_PATH):
46
- try:
47
- with open(STORE_PATH, "rb") as f:
48
- s = pickle.load(f)
49
- _STORE = s if isinstance(s, list) else []
50
- except Exception:
51
- _STORE = []
52
- else:
53
- _STORE = []
54
-
55
- def _save_index():
56
- global _VECTORS, _STORE
57
- if _VECTORS is None or _STORE is None:
58
- return
59
- os.makedirs(os.path.dirname(VEC_PATH), exist_ok=True)
60
- np.save(VEC_PATH, _VECTORS)
61
- with open(STORE_PATH, "wb") as f:
62
- pickle.dump(_STORE, f)
63
-
64
- # ========= モデル(初回呼び出し時にだけロード) =========
65
- def _get_emb_model():
66
- """SentenceTransformer を初めて必要になったときだけロード"""
67
  global _EMB, _EMB_DIM
68
  if _EMB is None:
69
- from sentence_transformers import SentenceTransformer # ← ここで初めてimport
70
- model_name = settings.EMB_MODEL or "intfloat/multilingual-e5-base"
71
- _EMB = SentenceTransformer(model_name, cache_folder=CACHE_DIR)
72
  _EMB_DIM = _EMB.get_sentence_embedding_dimension()
73
-
74
- # 既存ベクトル配列の次元整合
75
- _load_index()
76
- global _VECTORS
77
- if _VECTORS.size == 0 or (_VECTORS.ndim == 2 and _VECTORS.shape[1] != _EMB_DIM):
78
- _VECTORS = np.empty((0, _EMB_DIM), dtype=np.float32)
79
  return _EMB
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def _get_gen_pipeline():
82
- """text-generation pipeline を初回だけロード"""
83
- global _TOK, _MODEL, _GEN
 
 
84
  if _GEN is None:
85
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # ← 遅延import
86
- import torch # ← 遅延import
87
-
88
- gen_name = settings.GEN_MODEL or "Qwen/Qwen2.5-3B-Instruct"
89
- _TOK = AutoTokenizer.from_pretrained(gen_name, cache_dir=CACHE_DIR)
90
  _MODEL = AutoModelForCausalLM.from_pretrained(
91
- gen_name,
92
- cache_dir=CACHE_DIR,
93
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
94
  device_map="auto",
95
  low_cpu_mem_usage=True,
@@ -97,58 +105,8 @@ def _get_gen_pipeline():
97
  _GEN = pipeline("text-generation", model=_MODEL, tokenizer=_TOK)
98
  return _GEN, _TOK
99
 
100
- # ========= 埋め込みと検索 =========
101
- def embed_texts(texts: List[str]) -> np.ndarray:
102
- emb = _get_emb_model()
103
- v = emb.encode(
104
- texts,
105
- normalize_embeddings=True,
106
- convert_to_numpy=True,
107
- show_progress_bar=False,
108
- )
109
- return v.astype(np.float32, copy=False)
110
-
111
- def add_to_index(records: List[Dict]):
112
- """
113
- records: [{"text":..., "source_url":..., "title":..., "doc_id":..., "chunk_id":...}]
114
- """
115
- if not records:
116
- return
117
- _load_index()
118
- vecs = embed_texts([r["text"] for r in records]) # [M, D]
119
- global _VECTORS, _STORE
120
- if _VECTORS.size == 0:
121
- _VECTORS = vecs
122
- else:
123
- _VECTORS = np.vstack([_VECTORS, vecs])
124
- _STORE.extend(records)
125
- _save_index()
126
-
127
- def search(query: str, top_k=8):
128
- _load_index()
129
- if _VECTORS.size == 0 or not _STORE:
130
- return []
131
- qv = embed_texts([query])[0]
132
- sims = _VECTORS @ qv
133
- top_k = min(top_k, sims.shape[0])
134
- idx = np.argpartition(-sims, top_k - 1)[:top_k]
135
- idx = idx[np.argsort(-sims[idx])]
136
- hits = []
137
- for i in idx.tolist():
138
- rec = _STORE[i].copy()
139
- rec["score"] = float(sims[i])
140
- hits.append(rec)
141
- return hits
142
-
143
- # ========= 生成ユーティリティ =========
144
- def generate_chat(messages: list[dict], max_new_tokens=800, temperature=0.2) -> str:
145
  gen, tok = _get_gen_pipeline()
146
  prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
147
- out = gen(
148
- prompt,
149
- do_sample=(temperature > 0.0),
150
- temperature=temperature,
151
- max_new_tokens=max_new_tokens,
152
- )[0]["generated_text"]
153
- generated = out[len(prompt):].strip()
154
- return generated or out
 
1
+ # irpr/deps.py --- Chromadb版(faiss不使用・LLMなしでも動く)
2
+ from __future__ import annotations
3
  import os
4
+ from typing import List, Dict, Optional
5
  import numpy as np
 
 
 
 
 
 
 
6
  from irpr.config import settings
7
 
8
+ # 書き込み先確保
9
+ for d in [
10
+ os.environ.get("HF_HOME", "/data/.hf-home"),
11
+ os.environ.get("TRANSFORMERS_CACHE", "/data/.hf-cache"),
12
+ os.environ.get("SENTENCE_TRANSFORMERS_HOME", "/data/.hf-cache"),
13
+ os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/.hf-cache"),
14
+ settings.DATA_DIR,
15
+ settings.CHROMA_PATH,
16
+ ]:
17
+ if d: os.makedirs(d, exist_ok=True)
18
+
19
+ # 遅延ロード
20
+ _EMB = None
21
  _EMB_DIM: Optional[int] = None
22
+ _CHROMA_COLLECTION = None
23
+ _GEN = None
24
+ _TOK = None
25
 
26
+ def _get_embedder():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  global _EMB, _EMB_DIM
28
  if _EMB is None:
29
+ from sentence_transformers import SentenceTransformer
30
+ _EMB = SentenceTransformer(settings.EMB_MODEL, cache_folder=os.environ.get("HF_HOME", "/data/.hf-cache"))
 
31
  _EMB_DIM = _EMB.get_sentence_embedding_dimension()
 
 
 
 
 
 
32
  return _EMB
33
 
34
+ def embed_texts(texts: List[str]) -> np.ndarray:
35
+ emb = _get_embedder()
36
+ arr = emb.encode(texts, batch_size=16, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=False)
37
+ return arr.astype(np.float32, copy=False)
38
+
39
+ def _get_chroma():
40
+ global _CHROMA_COLLECTION
41
+ if _CHROMA_COLLECTION is None:
42
+ import chromadb
43
+ from chromadb.config import Settings as CS
44
+ client = chromadb.PersistentClient(path=settings.CHROMA_PATH, settings=CS(allow_reset=True))
45
+ _CHROMA_COLLECTION = client.get_or_create_collection(name="irpr_docs")
46
+ return _CHROMA_COLLECTION
47
+
48
+ def add_to_index(records: List[Dict]):
49
+ if not records: return
50
+ col = _get_chroma()
51
+ texts = [r["text"] for r in records]
52
+ embs = embed_texts(texts)
53
+ ids, metas = [], []
54
+ for r in records:
55
+ doc_id = r.get("doc_id") or "doc"
56
+ chunk_id = r.get("chunk_id") or ""
57
+ rid = f"{doc_id}:{chunk_id}" if chunk_id else doc_id
58
+ ids.append(rid)
59
+ metas.append({
60
+ "source_url": r.get("source_url"),
61
+ "title": r.get("title"),
62
+ "doc_id": doc_id,
63
+ "chunk_id": chunk_id,
64
+ })
65
+ col.add(ids=ids, documents=texts, embeddings=embs, metadatas=metas)
66
+
67
+ def search(query: str, top_k=8) -> List[Dict]:
68
+ col = _get_chroma()
69
+ q_emb = embed_texts([query])
70
+ res = col.query(query_embeddings=q_emb, n_results=top_k, include=["documents","metadatas","distances","ids"])
71
+ docs = res.get("documents", [[]])[0]
72
+ metas = res.get("metadatas", [[]])[0]
73
+ dists = res.get("distances", [[]])[0]
74
+ out: List[Dict] = []
75
+ for text, meta, dist in zip(docs, metas, dists):
76
+ score = 1.0 - float(dist)/2.0 if dist is not None else None
77
+ out.append({
78
+ "text": text,
79
+ "source_url": (meta or {}).get("source_url"),
80
+ "title": (meta or {}).get("title"),
81
+ "doc_id": (meta or {}).get("doc_id"),
82
+ "chunk_id": (meta or {}).get("chunk_id"),
83
+ "score": score,
84
+ })
85
+ return out
86
+
87
+ # ==== 生成(任意) ====
88
  def _get_gen_pipeline():
89
+ """GEN_MODEL が空なら LLM 無効の合図として例外を投げる"""
90
+ if not settings.GEN_MODEL:
91
+ raise RuntimeError("GEN_MODEL is empty (LLM disabled).")
92
+ global _GEN, _TOK
93
  if _GEN is None:
94
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
95
+ import torch
96
+ name = settings.GEN_MODEL
97
+ _TOK = AutoTokenizer.from_pretrained(name, cache_dir=os.environ.get("HF_HOME", "/data/.hf-cache"))
 
98
  _MODEL = AutoModelForCausalLM.from_pretrained(
99
+ name,
100
+ cache_dir=os.environ.get("HF_HOME", "/data/.hf-cache"),
101
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
102
  device_map="auto",
103
  low_cpu_mem_usage=True,
 
105
  _GEN = pipeline("text-generation", model=_MODEL, tokenizer=_TOK)
106
  return _GEN, _TOK
107
 
108
+ def generate_chat(messages: List[Dict], max_new_tokens=600, temperature=0.2) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  gen, tok = _get_gen_pipeline()
110
  prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
111
+ out = gen(prompt, do_sample=(temperature>0.0), temperature=temperature, max_new_tokens=max_new_tokens)[0]["generated_text"]
112
+ return out[len(prompt):].strip()