Spaces:
Sleeping
Sleeping
Update app/deps.py
Browse files- app/deps.py +16 -9
app/deps.py
CHANGED
|
@@ -4,10 +4,10 @@ import faiss, os, pickle, torch
|
|
| 4 |
|
| 5 |
from app.config import settings
|
| 6 |
|
| 7 |
-
#
|
| 8 |
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
| 9 |
|
| 10 |
-
# Embeddings
|
| 11 |
emb_model = SentenceTransformer(settings.EMB_MODEL)
|
| 12 |
|
| 13 |
# FAISS index (create or load)
|
|
@@ -20,9 +20,9 @@ if os.path.exists(index_path) and os.path.exists(store_path):
|
|
| 20 |
store = pickle.load(f)
|
| 21 |
else:
|
| 22 |
index = faiss.IndexFlatIP(emb_model.get_sentence_embedding_dimension())
|
| 23 |
-
store = [] #
|
| 24 |
|
| 25 |
-
#
|
| 26 |
tok = AutoTokenizer.from_pretrained(settings.GEN_MODEL)
|
| 27 |
model = AutoModelForCausalLM.from_pretrained(
|
| 28 |
settings.GEN_MODEL,
|
|
@@ -30,15 +30,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 30 |
device_map="auto",
|
| 31 |
low_cpu_mem_usage=True
|
| 32 |
)
|
| 33 |
-
|
| 34 |
-
gen = pipeline("text-generation", model=model, tokenizer=tok, max_new_tokens=800)
|
| 35 |
|
| 36 |
def embed_texts(texts: list[str]):
|
| 37 |
-
|
| 38 |
-
return v
|
| 39 |
|
| 40 |
def add_to_index(records: list[dict]):
|
| 41 |
-
# records: [{"text":..., "source_url":..., "title":..., "doc_id":..., "chunk_id":...}]
|
| 42 |
vecs = embed_texts([r["text"] for r in records])
|
| 43 |
index.add(vecs)
|
| 44 |
store.extend(records)
|
|
@@ -57,3 +54,13 @@ def search(query: str, top_k=8):
|
|
| 57 |
rec["score"] = float(scores[0][rank])
|
| 58 |
hits.append(rec)
|
| 59 |
return hits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from app.config import settings
|
| 6 |
|
| 7 |
+
# data ディレクトリの存在を保証
|
| 8 |
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
| 9 |
|
| 10 |
+
# Embeddings
|
| 11 |
emb_model = SentenceTransformer(settings.EMB_MODEL)
|
| 12 |
|
| 13 |
# FAISS index (create or load)
|
|
|
|
| 20 |
store = pickle.load(f)
|
| 21 |
else:
|
| 22 |
index = faiss.IndexFlatIP(emb_model.get_sentence_embedding_dimension())
|
| 23 |
+
store = [] # [{text, source_url, title, doc_id, chunk_id, ...}]
|
| 24 |
|
| 25 |
+
# Text-generation model
|
| 26 |
tok = AutoTokenizer.from_pretrained(settings.GEN_MODEL)
|
| 27 |
model = AutoModelForCausalLM.from_pretrained(
|
| 28 |
settings.GEN_MODEL,
|
|
|
|
| 30 |
device_map="auto",
|
| 31 |
low_cpu_mem_usage=True
|
| 32 |
)
|
| 33 |
+
gen = pipeline("text-generation", model=model, tokenizer=tok)
|
|
|
|
| 34 |
|
| 35 |
def embed_texts(texts: list[str]):
|
| 36 |
+
return emb_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=False)
|
|
|
|
| 37 |
|
| 38 |
def add_to_index(records: list[dict]):
|
|
|
|
| 39 |
vecs = embed_texts([r["text"] for r in records])
|
| 40 |
index.add(vecs)
|
| 41 |
store.extend(records)
|
|
|
|
| 54 |
rec["score"] = float(scores[0][rank])
|
| 55 |
hits.append(rec)
|
| 56 |
return hits
|
| 57 |
+
|
| 58 |
+
def generate_chat(messages: list[dict], max_new_tokens=800, temperature=0.2):
|
| 59 |
+
"""
|
| 60 |
+
messages: [{"role":"system"/"user"/"assistant", "content":"..."}]
|
| 61 |
+
"""
|
| 62 |
+
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 63 |
+
out = gen(prompt, do_sample=(temperature > 0.0), temperature=temperature, max_new_tokens=max_new_tokens)[0]["generated_text"]
|
| 64 |
+
# pipelineは入力+生成を返すことがあるため、プロンプトを取り除く
|
| 65 |
+
generated = out[len(prompt):].strip()
|
| 66 |
+
return generated or out
|