Spaces:
Build error
Build error
improve cache model
Browse files- app.py +47 -21
- data/answer_embeddings.npy +3 -0
- data/faiss_answer.index +0 -0
- data/faiss_question.index +0 -0
- data/question_embeddings.npy +3 -0
- preprocess.py +6 -14
app.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
import faiss
|
| 3 |
import numpy as np
|
| 4 |
import json
|
| 5 |
-
from sentence_transformers import SentenceTransformer
|
| 6 |
import time
|
| 7 |
|
|
|
|
| 8 |
# データを読み込む
|
| 9 |
with open("data/qa_data.json", "r", encoding="utf-8") as f:
|
| 10 |
data = json.load(f)
|
|
@@ -12,15 +12,27 @@ with open("data/qa_data.json", "r", encoding="utf-8") as f:
|
|
| 12 |
questions = [item["question"] for item in data]
|
| 13 |
answers = [item["answer"] for item in data]
|
| 14 |
|
| 15 |
-
# 埋め込みモデルをロード
|
| 16 |
-
model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
|
| 17 |
|
| 18 |
-
#
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# サイドバー設定
|
| 23 |
-
st.set_page_config(initial_sidebar_state="collapsed")
|
| 24 |
with st.sidebar.expander("⚙️ 設定", expanded=False):
|
| 25 |
threshold_q = st.slider("質問の類似度しきい値", 0.0, 1.0, 0.7, 0.01)
|
| 26 |
threshold_a = st.slider("回答の類似度しきい値", 0.0, 1.0, 0.65, 0.01)
|
|
@@ -31,24 +43,37 @@ with st.sidebar.expander("⚙️ 設定", expanded=False):
|
|
| 31 |
|
| 32 |
|
| 33 |
def search_answer(user_input):
|
| 34 |
-
"""
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
if score_q >= threshold_q:
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
if score_a >= threshold_a:
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
| 53 |
return "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし"
|
| 54 |
|
|
@@ -81,6 +106,7 @@ if user_input := st.chat_input("質問を入力してください:"):
|
|
| 81 |
|
| 82 |
with st.spinner("考え中... お待ちください。"):
|
| 83 |
answer, info = search_answer(user_input)
|
|
|
|
| 84 |
|
| 85 |
with st.chat_message("assistant"):
|
| 86 |
response_placeholder = st.empty()
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import json
|
| 4 |
+
from sentence_transformers import SentenceTransformer, util
|
| 5 |
import time
|
| 6 |
|
| 7 |
+
st.set_page_config(initial_sidebar_state="collapsed")
|
| 8 |
# データを読み込む
|
| 9 |
with open("data/qa_data.json", "r", encoding="utf-8") as f:
|
| 10 |
data = json.load(f)
|
|
|
|
| 12 |
questions = [item["question"] for item in data]
|
| 13 |
answers = [item["answer"] for item in data]
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# Cache model ở level app
|
| 17 |
+
@st.cache_resource
|
| 18 |
+
def load_model():
|
| 19 |
+
return SentenceTransformer("pkshatech/GLuCoSE-base-ja")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Cache embeddings data
|
| 23 |
+
@st.cache_data
|
| 24 |
+
def load_embeddings():
|
| 25 |
+
return (
|
| 26 |
+
np.load("data/question_embeddings.npy"),
|
| 27 |
+
np.load("data/answer_embeddings.npy"),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Load model và embeddings một lần
|
| 32 |
+
model = load_model()
|
| 33 |
+
question_embeddings, answer_embeddings = load_embeddings()
|
| 34 |
|
| 35 |
# サイドバー設定
|
|
|
|
| 36 |
with st.sidebar.expander("⚙️ 設定", expanded=False):
|
| 37 |
threshold_q = st.slider("質問の類似度しきい値", 0.0, 1.0, 0.7, 0.01)
|
| 38 |
threshold_a = st.slider("回答の類似度しきい値", 0.0, 1.0, 0.65, 0.01)
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def search_answer(user_input):
|
| 46 |
+
"""Tìm kiếm câu trả lời sử dụng cosine similarity"""
|
| 47 |
+
# Encode với batch_size và show_progress_bar=False để tăng tốc
|
| 48 |
+
user_embedding = model.encode(
|
| 49 |
+
[user_input],
|
| 50 |
+
convert_to_numpy=True,
|
| 51 |
+
batch_size=1,
|
| 52 |
+
show_progress_bar=False,
|
| 53 |
+
normalize_embeddings=True, # Pre-normalize để tăng tốc cosine similarity
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Tính cosine similarity với câu hỏi
|
| 57 |
+
cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0]
|
| 58 |
+
best_q_idx = np.argmax(cos_scores_q)
|
| 59 |
+
score_q = cos_scores_q[best_q_idx]
|
| 60 |
|
| 61 |
if score_q >= threshold_q:
|
| 62 |
+
return (
|
| 63 |
+
answers[best_q_idx].replace("\n", " \n"),
|
| 64 |
+
f"質問にマッチ ({score_q:.2f})",
|
| 65 |
+
)
|
| 66 |
|
| 67 |
+
# Tính cosine similarity với câu trả lời
|
| 68 |
+
cos_scores_a = model.util.cos_sim(user_embedding, answer_embeddings)[0]
|
| 69 |
+
best_a_idx = np.argmax(cos_scores_a)
|
| 70 |
+
score_a = cos_scores_a[best_a_idx]
|
| 71 |
|
| 72 |
if score_a >= threshold_a:
|
| 73 |
+
return (
|
| 74 |
+
answers[best_a_idx].replace("\n", " \n"),
|
| 75 |
+
f"回答にマッチ ({score_a:.2f})",
|
| 76 |
+
)
|
| 77 |
|
| 78 |
return "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし"
|
| 79 |
|
|
|
|
| 106 |
|
| 107 |
with st.spinner("考え中... お待ちください。"):
|
| 108 |
answer, info = search_answer(user_input)
|
| 109 |
+
print(info)
|
| 110 |
|
| 111 |
with st.chat_message("assistant"):
|
| 112 |
response_placeholder = st.empty()
|
data/answer_embeddings.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:925632dc69ab4df0223970df60cc9054dd46e2958e597e6998514bd3b33fc703
|
| 3 |
+
size 67712
|
data/faiss_answer.index
DELETED
|
Binary file (67.6 kB)
|
|
|
data/faiss_question.index
DELETED
|
Binary file (67.6 kB)
|
|
|
data/question_embeddings.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:45127b6e8615f93324b2debb37305b93d3963c5f91f054f9c56def8cd00c1ca5
|
| 3 |
+
size 67712
|
preprocess.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import json
|
| 2 |
-
import faiss
|
| 3 |
import numpy as np
|
| 4 |
from sentence_transformers import SentenceTransformer
|
| 5 |
|
|
@@ -14,20 +13,13 @@ answers = [item["answer"] for item in data]
|
|
| 14 |
model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
|
| 15 |
|
| 16 |
# Tạo embedding cho câu hỏi và câu trả lời
|
| 17 |
-
question_embeddings = model.encode(questions)
|
| 18 |
-
answer_embeddings = model.encode(answers)
|
| 19 |
|
| 20 |
-
# Lưu
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
index_a = faiss.IndexFlatL2(dim)
|
| 24 |
-
|
| 25 |
-
index_q.add(np.array(question_embeddings).astype(np.float32))
|
| 26 |
-
index_a.add(np.array(answer_embeddings).astype(np.float32))
|
| 27 |
-
|
| 28 |
-
faiss.write_index(index_q, "faiss_question.index")
|
| 29 |
-
faiss.write_index(index_a, "faiss_answer.index")
|
| 30 |
|
| 31 |
# Lưu dữ liệu gốc
|
| 32 |
-
with open("qa_data.json", "w", encoding="utf-8") as f:
|
| 33 |
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
|
|
|
|
| 13 |
model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
|
| 14 |
|
| 15 |
# Tạo embedding cho câu hỏi và câu trả lời
|
| 16 |
+
question_embeddings = model.encode(questions, convert_to_numpy=True)
|
| 17 |
+
answer_embeddings = model.encode(answers, convert_to_numpy=True)
|
| 18 |
|
| 19 |
+
# Lưu embedding dưới dạng numpy array
|
| 20 |
+
np.save("data/question_embeddings.npy", question_embeddings)
|
| 21 |
+
np.save("data/answer_embeddings.npy", answer_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Lưu dữ liệu gốc
|
| 24 |
+
with open("data/qa_data.json", "w", encoding="utf-8") as f:
|
| 25 |
json.dump(data, f, ensure_ascii=False, indent=2)
|