Spaces:
Build error
Build error
| import streamlit as st | |
| import os | |
| import pickle | |
| import ipaddress | |
| import tiktoken | |
| from pathlib import Path | |
| from streamlit import runtime | |
| from streamlit.runtime.scriptrunner import get_script_run_ctx | |
| from streamlit.web.server.websocket_headers import _get_websocket_headers | |
| from llama_index import SimpleDirectoryReader | |
| # from llama_index import Prompt | |
| from llama_index.prompts.base import PromptTemplate | |
| from llama_index.chat_engine import CondenseQuestionChatEngine; | |
| from llama_index.response_synthesizers import get_response_synthesizer | |
| from llama_index import ServiceContext, SimpleDirectoryReader | |
| from llama_index.node_parser import SimpleNodeParser | |
| from llama_index.langchain_helpers.text_splitter import TokenTextSplitter | |
| from llama_index.constants import DEFAULT_CHUNK_OVERLAP | |
| from llama_index.response_synthesizers import get_response_synthesizer | |
| from llama_index.callbacks import CallbackManager | |
| from llama_index.llms import OpenAI | |
| from log import logger | |
| from llama_index.llms.base import ChatMessage, MessageRole | |
| from llama_index.prompts.base import ChatPromptTemplate | |
| # 接続元制御 | |
| ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"] | |
| # Azure AD app registration details | |
| CLIENT_ID = os.environ["CLIENT_ID"] | |
| CLIENT_SECRET = os.environ["CLIENT_SECRET"] | |
| TENANT_ID = os.environ["TENANT_ID"] | |
| # Azure API | |
| REDIRECT_URI = os.environ["REDIRECT_URI"] | |
| AUTHORITY = f"https://login.microsoftonline.com/{TENANT_ID}" | |
| SCOPES = ["openid", "profile", "User.Read"] | |
| # 接続元IP取得 | |
| def get_remote_ip(): | |
| ctx = get_script_run_ctx() | |
| session_info = runtime.get_instance().get_client(ctx.session_id) | |
| headers = _get_websocket_headers() | |
| return session_info.request.remote_ip, headers.get("X-Forwarded-For") | |
| # 接続元IP許可判定 | |
| def is_allow_ip_address(): | |
| remote_ip, x_forwarded_for = get_remote_ip() | |
| logger.info("remote_ip:"+remote_ip) | |
| if x_forwarded_for is not None: | |
| remote_ip = x_forwarded_for | |
| # localhost | |
| if remote_ip == "::1": | |
| return True | |
| # プライベートIP | |
| ipaddr = ipaddress.IPv4Address(remote_ip) | |
| logger.info("ipaddr:"+str(ipaddr)) | |
| if ipaddr.is_private: | |
| return True | |
| # その他(許可リスト判定) | |
| return remote_ip in ALLOW_IP_ADDRESS | |
| #ログインの確認 | |
| def check_login(): | |
| if not is_allow_ip_address(): | |
| st.title("HTTP 403 Forbidden") | |
| st.stop() | |
| if "login_token" not in st.session_state or not st.session_state.login_token: | |
| st.warning("**ログインしてください**") | |
| st.stop() | |
| INDEX_NAME = os.environ["INDEX_NAME"] | |
| PKL_NAME = os.environ["PKL_NAME"] | |
| # デバッグ用 | |
| llm = OpenAI(model='gpt-3.5-turbo', temperature=0.8, max_tokens=256) | |
| text_splitter = TokenTextSplitter(separator="。", chunk_size=1500 | |
| , chunk_overlap=DEFAULT_CHUNK_OVERLAP | |
| , tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode) | |
| node_parser = SimpleNodeParser(text_splitter=text_splitter) | |
| custom_prompt = PromptTemplate("""\ | |
| 以下はこれまでの会話履歴と、ドキュメントを検索して回答する必要がある、ユーザーからの会話文です。 | |
| 会話と新しい会話文に基づいて、検索クエリを作成します。 | |
| <Chat History> | |
| {chat_history} | |
| <Follow Up Message> | |
| {question} | |
| <Standalone question> | |
| """) | |
| TEXT_QA_SYSTEM_PROMPT = ChatMessage( | |
| content=( | |
| "あなたは世界中で信頼されているQAシステムです。\n" | |
| "事前知識ではなく、常に提供されたコンテキスト情報を使用してクエリに回答してください。\n" | |
| "従うべきいくつかのルール:\n" | |
| "1. 回答内で指定されたコンテキストを直接参照しないでください。\n" | |
| "2. 「コンテキストに基づいて、...」や「コンテキスト情報は...」、またはそれに類するような記述は避けてください。" | |
| ), | |
| role=MessageRole.SYSTEM, | |
| ) | |
| # QAプロンプトテンプレートメッセージ | |
| TEXT_QA_PROMPT_TMPL_MSGS = [ | |
| TEXT_QA_SYSTEM_PROMPT, | |
| ChatMessage( | |
| content=( | |
| "コンテキスト情報は以下のとおりです。\n" | |
| "---------------------\n" | |
| "{context_str}\n" | |
| "---------------------\n" | |
| "事前知識ではなくコンテキスト情報を考慮して、クエリに答えます。\n" | |
| "Query: {query_str}\n" | |
| "Answer: " | |
| ), | |
| role=MessageRole.USER, | |
| ), | |
| ] | |
| CHAT_TEXT_QA_PROMPT = ChatPromptTemplate(message_templates=TEXT_QA_PROMPT_TMPL_MSGS) | |
| CHAT_REFINE_PROMPT_TMPL_MSGS = [ | |
| ChatMessage( | |
| content=( | |
| "あなたは、既存の回答を改良する際に2つのモードで厳密に動作するQAシステムのエキスパートです。\n" | |
| "1. 新しいコンテキストを使用して元の回答を**書き直す**。\n" | |
| "2. 新しいコンテキストが役に立たない場合は、元の回答を**繰り返す**。\n" | |
| "回答内で元の回答やコンテキストを直接参照しないでください。\n" | |
| "疑問がある場合は、元の答えを繰り返してください。" | |
| "New Context: {context_msg}\n" | |
| "Query: {query_str}\n" | |
| "Original Answer: {existing_answer}\n" | |
| "New Answer: " | |
| ), | |
| role=MessageRole.USER, | |
| ) | |
| ] | |
| # チャットRefineプロンプト | |
| CHAT_REFINE_PROMPT = ChatPromptTemplate(message_templates=CHAT_REFINE_PROMPT_TMPL_MSGS) | |
| def setChatEngine(): | |
| callback_manager = CallbackManager([st.session_state.llama_debug_handler]) | |
| service_context = ServiceContext.from_defaults(llm=llm,node_parser=node_parser,callback_manager=callback_manager) | |
| response_synthesizer = get_response_synthesizer( | |
| response_mode='refine', | |
| text_qa_template= CHAT_TEXT_QA_PROMPT, | |
| refine_template=CHAT_REFINE_PROMPT, | |
| ) | |
| st.session_state.query_engine = st.session_state.index.as_query_engine( | |
| response_synthesizer=response_synthesizer, | |
| service_context=service_context, | |
| ) | |
| st.session_state.chat_engine = CondenseQuestionChatEngine.from_defaults( | |
| query_engine=st.session_state.query_engine, | |
| condense_question_prompt=custom_prompt, | |
| verbose=True | |
| ) | |