Corin1998 commited on
Commit
3f322c7
·
verified ·
1 Parent(s): 04baeca

Update app/deps.py

Browse files
Files changed (1) hide show
  1. app/deps.py +16 -9
app/deps.py CHANGED
@@ -4,10 +4,10 @@ import faiss, os, pickle, torch
4
 
5
  from app.config import settings
6
 
7
- # --- data ディレクトリの存在を保証(初回起動での書き込み失敗を防止) ---
8
  os.makedirs(settings.DATA_DIR, exist_ok=True)
9
 
10
- # Embeddings (multilingual-e5)
11
  emb_model = SentenceTransformer(settings.EMB_MODEL)
12
 
13
  # FAISS index (create or load)
@@ -20,9 +20,9 @@ if os.path.exists(index_path) and os.path.exists(store_path):
20
  store = pickle.load(f)
21
  else:
22
  index = faiss.IndexFlatIP(emb_model.get_sentence_embedding_dimension())
23
- store = [] # list[dict]: {text, source_url, title, doc_id, chunk_id, score...}
24
 
25
- # Generation model (Japanese-capable small model by default)
26
  tok = AutoTokenizer.from_pretrained(settings.GEN_MODEL)
27
  model = AutoModelForCausalLM.from_pretrained(
28
  settings.GEN_MODEL,
@@ -30,15 +30,12 @@ model = AutoModelForCausalLM.from_pretrained(
30
  device_map="auto",
31
  low_cpu_mem_usage=True
32
  )
33
-
34
- gen = pipeline("text-generation", model=model, tokenizer=tok, max_new_tokens=800)
35
 
36
  def embed_texts(texts: list[str]):
37
- v = emb_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=False)
38
- return v
39
 
40
  def add_to_index(records: list[dict]):
41
- # records: [{"text":..., "source_url":..., "title":..., "doc_id":..., "chunk_id":...}]
42
  vecs = embed_texts([r["text"] for r in records])
43
  index.add(vecs)
44
  store.extend(records)
@@ -57,3 +54,13 @@ def search(query: str, top_k=8):
57
  rec["score"] = float(scores[0][rank])
58
  hits.append(rec)
59
  return hits
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  from app.config import settings
6
 
7
+ # data ディレクトリの存在を保証
8
  os.makedirs(settings.DATA_DIR, exist_ok=True)
9
 
10
+ # Embeddings
11
  emb_model = SentenceTransformer(settings.EMB_MODEL)
12
 
13
  # FAISS index (create or load)
 
20
  store = pickle.load(f)
21
  else:
22
  index = faiss.IndexFlatIP(emb_model.get_sentence_embedding_dimension())
23
+ store = [] # [{text, source_url, title, doc_id, chunk_id, ...}]
24
 
25
+ # Text-generation model
26
  tok = AutoTokenizer.from_pretrained(settings.GEN_MODEL)
27
  model = AutoModelForCausalLM.from_pretrained(
28
  settings.GEN_MODEL,
 
30
  device_map="auto",
31
  low_cpu_mem_usage=True
32
  )
33
+ gen = pipeline("text-generation", model=model, tokenizer=tok)
 
34
 
35
  def embed_texts(texts: list[str]):
36
+ return emb_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=False)
 
37
 
38
  def add_to_index(records: list[dict]):
 
39
  vecs = embed_texts([r["text"] for r in records])
40
  index.add(vecs)
41
  store.extend(records)
 
54
  rec["score"] = float(scores[0][rank])
55
  hits.append(rec)
56
  return hits
57
+
58
+ def generate_chat(messages: list[dict], max_new_tokens=800, temperature=0.2):
59
+ """
60
+ messages: [{"role":"system"/"user"/"assistant", "content":"..."}]
61
+ """
62
+ prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
63
+ out = gen(prompt, do_sample=(temperature > 0.0), temperature=temperature, max_new_tokens=max_new_tokens)[0]["generated_text"]
64
+ # pipelineは入力+生成を返すことがあるため、プロンプトを取り除く
65
+ generated = out[len(prompt):].strip()
66
+ return generated or out