Corin1998 commited on
Commit
1bfa547
·
verified ·
1 Parent(s): 460c48d

Update irpr/deps.py

Browse files
Files changed (1) hide show
  1. irpr/deps.py +41 -31
irpr/deps.py CHANGED
@@ -1,28 +1,39 @@
1
  # irpr/deps.py --- OpenAI埋め込み + 自前ベクタストア(numpy)/LLM生成
2
  from __future__ import annotations
3
  import os, json, uuid
4
- from typing import List, Dict, Optional, Tuple
5
  import numpy as np
6
  from irpr.config import settings
7
 
8
- # ==== 書き込み可能ディレクトリ定 ====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def _pick_writable_dir() -> str:
10
- candidates = [settings.DATA_DIR, "/data", "./var", "/tmp/irpr", "."]
 
 
 
11
  for base in candidates:
12
- try:
13
- if not base: continue
14
- os.makedirs(base, exist_ok=True)
15
- p = os.path.join(base, ".write_test")
16
- with open(p, "w") as w: w.write("ok")
17
- os.remove(p)
18
  return base
19
- except Exception:
20
- continue
21
  return "."
22
 
23
  BASE_DIR = _pick_writable_dir()
24
  INDEX_DIR = settings.INDEX_DIR or os.path.join(BASE_DIR, "simple_index")
25
- os.makedirs(INDEX_DIR, exist_ok=True)
26
 
27
  VECS_PATH = os.path.join(INDEX_DIR, "vectors.npy") # np.float32 [N,D](正規化済)
28
  META_PATH = os.path.join(INDEX_DIR, "meta.jsonl") # 1行1メタ
@@ -40,13 +51,17 @@ def _openai_client():
40
  return OpenAI(api_key=key)
41
 
42
  # ==== 収納・ロード ====
43
- def _load_index() -> Tuple[np.ndarray, List[dict], List[str]]:
44
  if os.path.exists(VECS_PATH):
45
- vecs = np.load(VECS_PATH).astype(np.float32, copy=False)
 
 
 
46
  else:
47
  vecs = np.zeros((0, 0), dtype=np.float32)
48
- metas: List[dict] = []
49
- texts: List[str] = []
 
50
  if os.path.exists(META_PATH):
51
  with open(META_PATH, "r", encoding="utf-8") as f:
52
  for line in f:
@@ -57,30 +72,29 @@ def _load_index() -> Tuple[np.ndarray, List[dict], List[str]]:
57
  with open(TEXT_PATH, "r", encoding="utf-8") as f:
58
  for line in f:
59
  texts.append(line.rstrip("\n"))
60
- # 整合性チェック
61
  if vecs.size == 0:
62
  return np.zeros((0, 0), dtype=np.float32), [], []
63
  n = vecs.shape[0]
64
  if len(metas) != n or len(texts) != n:
65
- # ているなら初期化
66
  return np.zeros((0, 0), dtype=np.float32), [], []
67
  return vecs, metas, texts
68
 
69
- def _save_index(vecs: np.ndarray, metas: List[dict], texts: List[str]) -> None:
70
- os.makedirs(INDEX_DIR, exist_ok=True)
71
  np.save(VECS_PATH, vecs.astype(np.float32, copy=False))
72
  with open(META_PATH, "w", encoding="utf-8") as f:
73
  for m in metas:
74
  f.write(json.dumps(m, ensure_ascii=False) + "\n")
75
  with open(TEXT_PATH, "w", encoding="utf-8") as f:
76
  for t in texts:
77
- f.write((t or "").replace("\n", "\\n") + "\n") # 1行1テキストに正規化
78
 
79
  # ==== Embedding ====
80
  def embed_texts(texts: List[str]) -> np.ndarray:
81
  client = _openai_client()
82
- model = settings.OPENAI_EMBED_MODEL
83
- # バッチで呼ぶ
84
  B = 128
85
  out = []
86
  for i in range(0, len(texts), B):
@@ -88,15 +102,11 @@ def embed_texts(texts: List[str]) -> np.ndarray:
88
  resp = client.embeddings.create(model=model, input=batch)
89
  out.extend([d.embedding for d in resp.data])
90
  arr = np.array(out, dtype=np.float32)
91
- # 正規化(コサイン類似度用)
92
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
93
  return arr / norms
94
 
95
  # ==== 追加 ====
96
  def add_to_index(records: List[Dict]) -> int:
97
- """
98
- records: [{text, title, source_url, doc_id, chunk_id}]
99
- """
100
  if not records:
101
  return 0
102
  texts = [r["text"] for r in records]
@@ -109,7 +119,7 @@ def add_to_index(records: List[Dict]) -> int:
109
  old_texts = []
110
  else:
111
  if vecs.shape[1] != vecs_new.shape[1]:
112
- # 埋め込み次元が違う(モデルえた等)→作り直し
113
  vecs = vecs_new
114
  metas = []
115
  old_texts = []
@@ -136,8 +146,8 @@ def search(query: str, top_k=8) -> List[Dict]:
136
  vecs, metas, texts = _load_index()
137
  if vecs.size == 0:
138
  return []
139
- q = embed_texts([query])[0] # (D,)
140
- scores = vecs @ q # cosine (正規化済み)
141
  idx = np.argsort(-scores)[:max(1, top_k)]
142
  out: List[Dict] = []
143
  for i in idx.tolist():
@@ -155,7 +165,7 @@ def search(query: str, top_k=8) -> List[Dict]:
155
  # ==== 生成 ====
156
  def generate_chat(messages: List[Dict], max_new_tokens=600, temperature=0.2) -> str:
157
  client = _openai_client()
158
- model = settings.OPENAI_CHAT_MODEL
159
  resp = client.chat.completions.create(
160
  model=model,
161
  messages=messages,
 
1
  # irpr/deps.py --- OpenAI埋め込み + 自前ベクタストア(numpy)/LLM生成
2
  from __future__ import annotations
3
  import os, json, uuid
4
+ from typing import List, Dict, Tuple
5
  import numpy as np
6
  from irpr.config import settings
7
 
8
+ # ==== 書き込み(/mnt/data を最優先) ====
9
+ def _ensure_dir_writable(path: str) -> bool:
10
+ try:
11
+ os.makedirs(path, exist_ok=True)
12
+ try:
13
+ os.chmod(path, 0o777)
14
+ except Exception:
15
+ pass
16
+ testfile = os.path.join(path, ".write_test")
17
+ with open(testfile, "wb") as f:
18
+ f.write(b"ok")
19
+ os.remove(testfile)
20
+ return True
21
+ except Exception:
22
+ return False
23
+
24
  def _pick_writable_dir() -> str:
25
+ candidates = []
26
+ if settings.DATA_DIR:
27
+ candidates.append(settings.DATA_DIR)
28
+ candidates += ["/mnt/data", "/data", "./var", "/tmp/irpr", "."]
29
  for base in candidates:
30
+ if _ensure_dir_writable(base):
 
 
 
 
 
31
  return base
 
 
32
  return "."
33
 
34
  BASE_DIR = _pick_writable_dir()
35
  INDEX_DIR = settings.INDEX_DIR or os.path.join(BASE_DIR, "simple_index")
36
+ _ensure_dir_writable(INDEX_DIR)
37
 
38
  VECS_PATH = os.path.join(INDEX_DIR, "vectors.npy") # np.float32 [N,D](正規化済)
39
  META_PATH = os.path.join(INDEX_DIR, "meta.jsonl") # 1行1メタ
 
51
  return OpenAI(api_key=key)
52
 
53
  # ==== 収納・ロード ====
54
+ def _load_index() -> Tuple[np.ndarray, list, list]:
55
  if os.path.exists(VECS_PATH):
56
+ try:
57
+ vecs = np.load(VECS_PATH).astype(np.float32, copy=False)
58
+ except Exception:
59
+ vecs = np.zeros((0, 0), dtype=np.float32)
60
  else:
61
  vecs = np.zeros((0, 0), dtype=np.float32)
62
+
63
+ metas = []
64
+ texts = []
65
  if os.path.exists(META_PATH):
66
  with open(META_PATH, "r", encoding="utf-8") as f:
67
  for line in f:
 
72
  with open(TEXT_PATH, "r", encoding="utf-8") as f:
73
  for line in f:
74
  texts.append(line.rstrip("\n"))
75
+
76
  if vecs.size == 0:
77
  return np.zeros((0, 0), dtype=np.float32), [], []
78
  n = vecs.shape[0]
79
  if len(metas) != n or len(texts) != n:
80
+ # 整合性が崩ら初期化
81
  return np.zeros((0, 0), dtype=np.float32), [], []
82
  return vecs, metas, texts
83
 
84
+ def _save_index(vecs: np.ndarray, metas: list, texts: list) -> None:
85
+ _ensure_dir_writable(INDEX_DIR)
86
  np.save(VECS_PATH, vecs.astype(np.float32, copy=False))
87
  with open(META_PATH, "w", encoding="utf-8") as f:
88
  for m in metas:
89
  f.write(json.dumps(m, ensure_ascii=False) + "\n")
90
  with open(TEXT_PATH, "w", encoding="utf-8") as f:
91
  for t in texts:
92
+ f.write((t or "").replace("\n", "\\n") + "\n")
93
 
94
  # ==== Embedding ====
95
  def embed_texts(texts: List[str]) -> np.ndarray:
96
  client = _openai_client()
97
+ model = os.environ.get("OPENAI_EMBED_MODEL", settings.OPENAI_EMBED_MODEL)
 
98
  B = 128
99
  out = []
100
  for i in range(0, len(texts), B):
 
102
  resp = client.embeddings.create(model=model, input=batch)
103
  out.extend([d.embedding for d in resp.data])
104
  arr = np.array(out, dtype=np.float32)
 
105
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
106
  return arr / norms
107
 
108
  # ==== 追加 ====
109
  def add_to_index(records: List[Dict]) -> int:
 
 
 
110
  if not records:
111
  return 0
112
  texts = [r["text"] for r in records]
 
119
  old_texts = []
120
  else:
121
  if vecs.shape[1] != vecs_new.shape[1]:
122
+ # 埋め込みモデル変 既存を捨てて作り直し
123
  vecs = vecs_new
124
  metas = []
125
  old_texts = []
 
146
  vecs, metas, texts = _load_index()
147
  if vecs.size == 0:
148
  return []
149
+ q = embed_texts([query])[0]
150
+ scores = vecs @ q
151
  idx = np.argsort(-scores)[:max(1, top_k)]
152
  out: List[Dict] = []
153
  for i in idx.tolist():
 
165
  # ==== 生成 ====
166
  def generate_chat(messages: List[Dict], max_new_tokens=600, temperature=0.2) -> str:
167
  client = _openai_client()
168
+ model = os.environ.get("OPENAI_CHAT_MODEL", settings.OPENAI_CHAT_MODEL)
169
  resp = client.chat.completions.create(
170
  model=model,
171
  messages=messages,