Charles Chan
commited on
Commit
·
1e06430
1
Parent(s):
a6c563b
coding
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ from langchain_community.vectorstores import FAISS
|
|
| 6 |
from datasets import load_dataset
|
| 7 |
|
| 8 |
# Streamlit 界面
|
| 9 |
-
st.title("
|
| 10 |
|
| 11 |
# 使用 假知识 数据集
|
| 12 |
if "data_list" not in st.session_state:
|
|
@@ -17,6 +17,9 @@ if not st.session_state.data_list:
|
|
| 17 |
try:
|
| 18 |
with st.spinner("正在读取数据库..."):
|
| 19 |
dataset = load_dataset("zeerd/fake_knowledge")
|
|
|
|
|
|
|
|
|
|
| 20 |
data_list = []
|
| 21 |
answer_list = []
|
| 22 |
for example in dataset["train"]:
|
|
@@ -66,6 +69,10 @@ def get_answer(prompt):
|
|
| 66 |
# 问答函数
|
| 67 |
def answer_question(repo_id, temperature, max_length, question):
|
| 68 |
# 初始化 Gemma 模型
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
| 70 |
try:
|
| 71 |
with st.spinner("正在初始化 Gemma 模型..."):
|
|
@@ -86,9 +93,9 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
| 86 |
st.success("答案生成完毕(基于模型自身)!")
|
| 87 |
print("答案生成完毕(基于模型自身)!")
|
| 88 |
with st.spinner("正在筛选本地数据集..."):
|
| 89 |
-
|
| 90 |
# question_embedding_str = " ".join(map(str, question_embedding))
|
| 91 |
-
docs_and_scores = st.session_state.db.similarity_search_with_relevance_scores(question)
|
| 92 |
|
| 93 |
context_list = []
|
| 94 |
for doc, score in docs_and_scores:
|
|
|
|
| 6 |
from datasets import load_dataset
|
| 7 |
|
| 8 |
# Streamlit 界面
|
| 9 |
+
st.title("外挂知识库问答系统")
|
| 10 |
|
| 11 |
# 使用 假知识 数据集
|
| 12 |
if "data_list" not in st.session_state:
|
|
|
|
| 17 |
try:
|
| 18 |
with st.spinner("正在读取数据库..."):
|
| 19 |
dataset = load_dataset("zeerd/fake_knowledge")
|
| 20 |
+
# 输出前五条数据
|
| 21 |
+
print(dataset["train"][:5])
|
| 22 |
+
|
| 23 |
data_list = []
|
| 24 |
answer_list = []
|
| 25 |
for example in dataset["train"]:
|
|
|
|
| 69 |
# 问答函数
|
| 70 |
def answer_question(repo_id, temperature, max_length, question):
|
| 71 |
# 初始化 Gemma 模型
|
| 72 |
+
print('repo_id: ' + repo_id)
|
| 73 |
+
print('temperature: ' + str(temperature))
|
| 74 |
+
print('max_length: ' + str(max_length))
|
| 75 |
+
|
| 76 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
| 77 |
try:
|
| 78 |
with st.spinner("正在初始化 Gemma 模型..."):
|
|
|
|
| 93 |
st.success("答案生成完毕(基于模型自身)!")
|
| 94 |
print("答案生成完毕(基于模型自身)!")
|
| 95 |
with st.spinner("正在筛选本地数据集..."):
|
| 96 |
+
question_embedding = st.session_state.embeddings.embed_query(question)
|
| 97 |
# question_embedding_str = " ".join(map(str, question_embedding))
|
| 98 |
+
docs_and_scores = st.session_state.db.similarity_search_with_relevance_scores(question, 8, question_embedding)
|
| 99 |
|
| 100 |
context_list = []
|
| 101 |
for doc, score in docs_and_scores:
|