Add "Search Only" to OpenAI model options and make OpenAI API key input optional
#1
by
shinichi-a
- opened
app.py
CHANGED
|
@@ -1,23 +1,15 @@
|
|
| 1 |
-
"""
|
| 2 |
-
streamlit run app.py --server.address 0.0.0.0
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
from __future__ import annotations
|
| 6 |
|
| 7 |
-
import streamlit as st
|
| 8 |
import os
|
| 9 |
-
|
| 10 |
-
import faiss
|
| 11 |
-
from sentence_transformers import SentenceTransformer
|
| 12 |
import torch
|
| 13 |
-
|
| 14 |
import streamlit as st
|
| 15 |
-
import pandas as pd
|
| 16 |
-
import os
|
| 17 |
from time import time
|
|
|
|
|
|
|
|
|
|
| 18 |
from datasets.download import DownloadManager
|
| 19 |
-
from datasets import load_dataset # type: ignore
|
| 20 |
-
|
| 21 |
|
| 22 |
WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
|
| 23 |
WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
|
|
@@ -36,6 +28,7 @@ EMB_MODEL_NAMES = list(EMB_MODEL_PQ.keys())
|
|
| 36 |
OPENAI_MODEL_NAMES = [
|
| 37 |
"gpt-3.5-turbo-1106",
|
| 38 |
"gpt-4-1106-preview",
|
|
|
|
| 39 |
]
|
| 40 |
|
| 41 |
E5_QUERY_TYPES = [
|
|
@@ -60,7 +53,6 @@ Responses must be given in Japanese.
|
|
| 60 |
{question}
|
| 61 |
""".strip()
|
| 62 |
|
| 63 |
-
|
| 64 |
if os.getenv("SPACE_ID"):
|
| 65 |
USE_HF_SPACE = True
|
| 66 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
|
@@ -68,9 +60,7 @@ if os.getenv("SPACE_ID"):
|
|
| 68 |
else:
|
| 69 |
USE_HF_SPACE = False
|
| 70 |
|
| 71 |
-
# for tokenizer
|
| 72 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 73 |
-
|
| 74 |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
| 75 |
|
| 76 |
|
|
@@ -81,6 +71,7 @@ def get_model(name: str, max_seq_length=512):
|
|
| 81 |
device = "cuda"
|
| 82 |
elif torch.backends.mps.is_available():
|
| 83 |
device = "mps"
|
|
|
|
| 84 |
model = SentenceTransformer(name, device=device)
|
| 85 |
model.max_seq_length = max_seq_length
|
| 86 |
return model
|
|
@@ -93,9 +84,7 @@ def get_wikija_ds(name: str = WIKIPEDIA_JS_DS_NAME):
|
|
| 93 |
|
| 94 |
|
| 95 |
@st.cache_resource
|
| 96 |
-
def get_faiss_index(
|
| 97 |
-
index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME
|
| 98 |
-
):
|
| 99 |
target_path = f"faiss_indexes/{name}/{index_name}"
|
| 100 |
dm = DownloadManager()
|
| 101 |
index_local_path = dm.download(
|
|
@@ -110,9 +99,7 @@ def text_to_emb(model, text: str, prefix: str):
|
|
| 110 |
return model.encode([prefix + text], normalize_embeddings=True)
|
| 111 |
|
| 112 |
|
| 113 |
-
def search(
|
| 114 |
-
faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int
|
| 115 |
-
):
|
| 116 |
start_time = time()
|
| 117 |
emb = text_to_emb(emb_model, question, search_text_prefix)
|
| 118 |
emb_exec_time = time() - start_time
|
|
@@ -121,7 +108,7 @@ def search(
|
|
| 121 |
scores = scores[0]
|
| 122 |
indexes = indexes[0]
|
| 123 |
results = []
|
| 124 |
-
for idx, score in zip(indexes, scores):
|
| 125 |
idx = int(idx)
|
| 126 |
passage = ds[idx]
|
| 127 |
results.append((score, passage))
|
|
@@ -133,7 +120,6 @@ def to_contexts(passages):
|
|
| 133 |
for passage in passages:
|
| 134 |
title = passage["title"]
|
| 135 |
text = passage["text"]
|
| 136 |
-
# section = passage["section"]
|
| 137 |
contexts += f"- {title}: {text}\n"
|
| 138 |
return contexts
|
| 139 |
|
|
@@ -211,15 +197,13 @@ def app():
|
|
| 211 |
key="question",
|
| 212 |
value="楽曲『約束はいらない』でデビューした、声優は誰?",
|
| 213 |
)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
else:
|
| 222 |
-
st.session_state.openai_api_key = OPENAI_API_KEY
|
| 223 |
|
| 224 |
with st.expander("オプション"):
|
| 225 |
option_cols_main = st.columns(2)
|
|
@@ -229,6 +213,8 @@ def app():
|
|
| 229 |
st.selectbox(
|
| 230 |
"OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
|
| 231 |
)
|
|
|
|
|
|
|
| 232 |
emb_model_name = st.session_state.emb_model_name
|
| 233 |
option_cols_sub = st.columns(2)
|
| 234 |
with option_cols_sub[0]:
|
|
@@ -300,10 +286,10 @@ def app():
|
|
| 300 |
st.dataframe(df, hide_index=True)
|
| 301 |
|
| 302 |
openai_api_key = st.session_state.openai_api_key
|
| 303 |
-
|
|
|
|
| 304 |
openai_api_key = openai_api_key.strip()
|
| 305 |
answer_header.subheader("Answer: ")
|
| 306 |
-
openai_model_name = st.session_state.openai_model_name
|
| 307 |
temperature = st.session_state.temperature
|
| 308 |
qa_prompt = st.session_state.qa_prompt
|
| 309 |
max_tokens = st.session_state.max_tokens
|
|
@@ -320,4 +306,4 @@ def app():
|
|
| 320 |
|
| 321 |
|
| 322 |
if __name__ == "__main__":
|
| 323 |
-
app()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
import os
|
| 4 |
+
import pandas as pd
|
|
|
|
|
|
|
| 5 |
import torch
|
| 6 |
+
import faiss
|
| 7 |
import streamlit as st
|
|
|
|
|
|
|
| 8 |
from time import time
|
| 9 |
+
from openai import OpenAI
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
from datasets.download import DownloadManager
|
|
|
|
|
|
|
| 13 |
|
| 14 |
WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
|
| 15 |
WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
|
|
|
|
| 28 |
OPENAI_MODEL_NAMES = [
|
| 29 |
"gpt-3.5-turbo-1106",
|
| 30 |
"gpt-4-1106-preview",
|
| 31 |
+
"Search Only",
|
| 32 |
]
|
| 33 |
|
| 34 |
E5_QUERY_TYPES = [
|
|
|
|
| 53 |
{question}
|
| 54 |
""".strip()
|
| 55 |
|
|
|
|
| 56 |
if os.getenv("SPACE_ID"):
|
| 57 |
USE_HF_SPACE = True
|
| 58 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
|
|
|
| 60 |
else:
|
| 61 |
USE_HF_SPACE = False
|
| 62 |
|
|
|
|
| 63 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
| 64 |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
| 65 |
|
| 66 |
|
|
|
|
| 71 |
device = "cuda"
|
| 72 |
elif torch.backends.mps.is_available():
|
| 73 |
device = "mps"
|
| 74 |
+
|
| 75 |
model = SentenceTransformer(name, device=device)
|
| 76 |
model.max_seq_length = max_seq_length
|
| 77 |
return model
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@st.cache_resource
|
| 87 |
+
def get_faiss_index(index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME):
|
|
|
|
|
|
|
| 88 |
target_path = f"faiss_indexes/{name}/{index_name}"
|
| 89 |
dm = DownloadManager()
|
| 90 |
index_local_path = dm.download(
|
|
|
|
| 99 |
return model.encode([prefix + text], normalize_embeddings=True)
|
| 100 |
|
| 101 |
|
| 102 |
+
def search(faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int):
|
|
|
|
|
|
|
| 103 |
start_time = time()
|
| 104 |
emb = text_to_emb(emb_model, question, search_text_prefix)
|
| 105 |
emb_exec_time = time() - start_time
|
|
|
|
| 108 |
scores = scores[0]
|
| 109 |
indexes = indexes[0]
|
| 110 |
results = []
|
| 111 |
+
for idx, score in zip(indexes, scores):
|
| 112 |
idx = int(idx)
|
| 113 |
passage = ds[idx]
|
| 114 |
results.append((score, passage))
|
|
|
|
| 120 |
for passage in passages:
|
| 121 |
title = passage["title"]
|
| 122 |
text = passage["text"]
|
|
|
|
| 123 |
contexts += f"- {title}: {text}\n"
|
| 124 |
return contexts
|
| 125 |
|
|
|
|
| 197 |
key="question",
|
| 198 |
value="楽曲『約束はいらない』でデビューした、声優は誰?",
|
| 199 |
)
|
| 200 |
+
st.text_input(
|
| 201 |
+
"OpenAI API Key",
|
| 202 |
+
key="openai_api_key",
|
| 203 |
+
type="password",
|
| 204 |
+
value=OPENAI_API_KEY if OPENAI_API_KEY else "",
|
| 205 |
+
placeholder="※ OpenAI API Key 未入力時は回答を生成せずに、検索のみ実行します",
|
| 206 |
+
)
|
|
|
|
|
|
|
| 207 |
|
| 208 |
with st.expander("オプション"):
|
| 209 |
option_cols_main = st.columns(2)
|
|
|
|
| 213 |
st.selectbox(
|
| 214 |
"OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
|
| 215 |
)
|
| 216 |
+
if "emb_model_name" not in st.session_state:
|
| 217 |
+
st.session_state.emb_model_name = EMB_MODEL_NAMES[0] # replace with the actual default value you want to use
|
| 218 |
emb_model_name = st.session_state.emb_model_name
|
| 219 |
option_cols_sub = st.columns(2)
|
| 220 |
with option_cols_sub[0]:
|
|
|
|
| 286 |
st.dataframe(df, hide_index=True)
|
| 287 |
|
| 288 |
openai_api_key = st.session_state.openai_api_key
|
| 289 |
+
openai_model_name = st.session_state.openai_model_name
|
| 290 |
+
if openai_api_key and openai_model_name != "Search Only":
|
| 291 |
openai_api_key = openai_api_key.strip()
|
| 292 |
answer_header.subheader("Answer: ")
|
|
|
|
| 293 |
temperature = st.session_state.temperature
|
| 294 |
qa_prompt = st.session_state.qa_prompt
|
| 295 |
max_tokens = st.session_state.max_tokens
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
if __name__ == "__main__":
|
| 309 |
+
app()
|