|
|
""" |
|
|
EmpathemeBot - Hugging Face Spaces用統合版Streamlitアプリ(セキュア版v2) |
|
|
APIキーをセッション単位で管理し、ユーザー間で共有されないようにする |
|
|
「新しいチャット」ボタンでAPIキーを維持する改良版 |
|
|
""" |
|
|
|
|
|
import html |
|
|
import logging |
|
|
import re |
|
|
import time |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
from typing import List, Dict, Optional |
|
|
from pathlib import Path |
|
|
import sys |
|
|
|
|
|
import streamlit as st |
|
|
from dotenv import load_dotenv |
|
|
import os |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
|
|
|
|
|
|
from src.qa.chain import QAChain |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="EmpathemeBot QA System", |
|
|
page_icon="", |
|
|
layout="wide", |
|
|
initial_sidebar_state="collapsed", |
|
|
menu_items={} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class EmpathemeBotUI: |
|
|
"""Hugging Face Spaces用セキュア版EmpathemeBot UIクラス""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
if 'session_id' not in st.session_state: |
|
|
st.session_state.session_id = str(uuid.uuid4()) |
|
|
if 'messages' not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
if 'qa_chain' not in st.session_state: |
|
|
st.session_state.qa_chain = None |
|
|
if 'last_activity' not in st.session_state: |
|
|
st.session_state.last_activity = datetime.now() |
|
|
if 'vector_store_initialized' not in st.session_state: |
|
|
st.session_state.vector_store_initialized = False |
|
|
|
|
|
if 'current_api_key' not in st.session_state: |
|
|
st.session_state.current_api_key = "" |
|
|
|
|
|
def initialize_qa_chain(self, api_key: str) -> bool: |
|
|
""" |
|
|
QAChainを初期化(APIキーを環境変数に設定してQAChainを使用) |
|
|
|
|
|
Args: |
|
|
api_key: OpenAI APIキー |
|
|
|
|
|
Returns: |
|
|
初期化成功の場合True |
|
|
""" |
|
|
try: |
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
|
|
|
|
|
vector_store_path = Path("data/vector_store") |
|
|
|
|
|
|
|
|
logger.info("QAChainを初期化中...") |
|
|
|
|
|
|
|
|
st.session_state.qa_chain = QAChain( |
|
|
persist_dir=str(vector_store_path), |
|
|
verbose=False, |
|
|
max_history_turns=10, |
|
|
max_history_chars=10000 |
|
|
|
|
|
) |
|
|
|
|
|
st.session_state.vector_store_initialized = True |
|
|
st.session_state.current_api_key = api_key |
|
|
logger.info("QAChain初期化完了") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"QAChain初期化エラー: {e}") |
|
|
st.error(f"初期化エラー: {str(e)}") |
|
|
return False |
|
|
|
|
|
def ask_question(self, question: str) -> Optional[Dict]: |
|
|
""" |
|
|
質問を処理して回答を取得 |
|
|
|
|
|
Args: |
|
|
question: ユーザーの質問 |
|
|
|
|
|
Returns: |
|
|
回答データ |
|
|
""" |
|
|
try: |
|
|
if st.session_state.qa_chain is None: |
|
|
st.error("システムが初期化されていません。APIキーを入力してください。") |
|
|
return None |
|
|
|
|
|
|
|
|
logger.info(f"質問処理開始: {question[:100]}...") |
|
|
answer, source_docs = st.session_state.qa_chain.ask_with_history(question) |
|
|
|
|
|
|
|
|
source_urls = [] |
|
|
for doc in source_docs: |
|
|
if hasattr(doc, 'metadata'): |
|
|
url = doc.metadata.get('source_url', '') |
|
|
if url and url not in source_urls: |
|
|
source_urls.append(url) |
|
|
|
|
|
result = { |
|
|
"answer": answer, |
|
|
"source_count": len(source_docs), |
|
|
"source_urls": source_urls |
|
|
} |
|
|
|
|
|
logger.info(f"回答生成成功: {len(source_docs)}件のソース参照") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"エラー発生: {e}") |
|
|
st.error(f"予期しないエラーが発生しました: {str(e)}") |
|
|
return None |
|
|
|
|
|
def clear_history(self): |
|
|
"""会話履歴をクリア""" |
|
|
try: |
|
|
if st.session_state.qa_chain: |
|
|
st.session_state.qa_chain.clear_history() |
|
|
st.session_state.messages = [] |
|
|
st.success("会話履歴をクリアしました") |
|
|
logger.info("履歴クリア成功") |
|
|
except Exception as e: |
|
|
logger.error(f"履歴クリアエラー: {e}") |
|
|
st.error("エラーが発生しました") |
|
|
|
|
|
def create_new_session(self): |
|
|
"""新しいセッションIDを生成(APIキーは維持)""" |
|
|
st.session_state.session_id = str(uuid.uuid4()) |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
if st.session_state.current_api_key: |
|
|
|
|
|
if st.session_state.qa_chain: |
|
|
st.session_state.qa_chain.clear_history() |
|
|
|
|
|
logger.info(f"新しいチャット開始(APIキー維持): {st.session_state.session_id}") |
|
|
else: |
|
|
|
|
|
st.session_state.qa_chain = None |
|
|
st.session_state.vector_store_initialized = False |
|
|
logger.info(f"新しいセッション作成: {st.session_state.session_id}") |
|
|
|
|
|
def main(): |
|
|
"""メイン関数""" |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
/* メインコンテナのスタイル */ |
|
|
.main { |
|
|
padding-top: 1rem; |
|
|
max-width: 1000px; |
|
|
margin: 0 auto; |
|
|
} |
|
|
|
|
|
.block-container { |
|
|
padding: 1rem 2rem; |
|
|
max-width: 100%; |
|
|
} |
|
|
|
|
|
/* チャット入力のスタイル */ |
|
|
.stChatInput { |
|
|
border: none !important; |
|
|
box-shadow: none !important; |
|
|
background: transparent !important; |
|
|
position: fixed; |
|
|
bottom: 0; |
|
|
padding-bottom: 1rem; |
|
|
background: white !important; |
|
|
z-index: 999; |
|
|
} |
|
|
|
|
|
/* チャット入力のテキストエリア */ |
|
|
.stChatInput textarea { |
|
|
font-size: 14px; |
|
|
border: 1px solid #E5E7EB !important; |
|
|
border-radius: 8px !important; |
|
|
padding: 0.6rem 1rem !important; |
|
|
background: #FAFAFA !important; |
|
|
transition: all 0.2s ease; |
|
|
} |
|
|
|
|
|
.stChatInput textarea:focus { |
|
|
background: white !important; |
|
|
border-color: #4F46E5 !important; |
|
|
outline: none !important; |
|
|
box-shadow: 0 0 0 2px rgba(79, 70, 229, 0.1) !important; |
|
|
} |
|
|
|
|
|
/* ボタンのスタイル */ |
|
|
.stButton > button { |
|
|
background: #4F46E5; |
|
|
color: white; |
|
|
border: none; |
|
|
border-radius: 6px; |
|
|
padding: 0.5rem 1rem; |
|
|
font-weight: 500; |
|
|
font-size: 13px; |
|
|
transition: all 0.15s ease; |
|
|
} |
|
|
|
|
|
.stButton > button:hover { |
|
|
background: #4338CA; |
|
|
} |
|
|
|
|
|
/* タイトルのスタイル */ |
|
|
h1 { |
|
|
color: #111827; |
|
|
font-weight: 600; |
|
|
text-align: center; |
|
|
font-size: 1.75rem; |
|
|
margin-bottom: 0.5rem; |
|
|
} |
|
|
|
|
|
/* サイドバーのスタイル */ |
|
|
section[data-testid="stSidebar"] { |
|
|
background: #FAFAFB; |
|
|
} |
|
|
|
|
|
/* 吹き出し内のコンテンツスタイル */ |
|
|
.bubble-content { |
|
|
font-family: inherit; |
|
|
font-size: inherit; |
|
|
white-space: pre-wrap; |
|
|
word-wrap: break-word; |
|
|
margin: 0; |
|
|
padding: 0; |
|
|
color: inherit; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
bot = EmpathemeBotUI() |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.markdown("## 設定") |
|
|
|
|
|
|
|
|
st.markdown("### OpenAI API キー") |
|
|
|
|
|
|
|
|
if st.session_state.current_api_key: |
|
|
st.success("APIキー設定済み") |
|
|
|
|
|
masked_key = st.session_state.current_api_key[:7] + "..." + st.session_state.current_api_key[-4:] |
|
|
st.caption(f"現在のキー: {masked_key}") |
|
|
else: |
|
|
st.info("APIキーを入力してください") |
|
|
|
|
|
|
|
|
with st.form("api_key_form"): |
|
|
api_key_input = st.text_input( |
|
|
"APIキーを入力", |
|
|
type="password", |
|
|
placeholder="sk-...", |
|
|
help="このセッション専用のAPIキーです。ブラウザを閉じると消去されます。", |
|
|
key="api_key_input_field" |
|
|
) |
|
|
|
|
|
submit_button = st.form_submit_button("APIキーを設定", use_container_width=True) |
|
|
|
|
|
if submit_button and api_key_input: |
|
|
if api_key_input.startswith("sk-"): |
|
|
|
|
|
if bot.initialize_qa_chain(api_key_input): |
|
|
st.success("APIキーが設定されました") |
|
|
st.rerun() |
|
|
else: |
|
|
st.error("初期化に失敗しました") |
|
|
else: |
|
|
st.error("有効なAPIキー(sk-で始まる)を入力してください") |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("### コントロール") |
|
|
if st.button("新しいチャット", use_container_width=True): |
|
|
bot.create_new_session() |
|
|
st.rerun() |
|
|
|
|
|
if st.button("履歴クリア", use_container_width=True): |
|
|
bot.clear_history() |
|
|
st.rerun() |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("### ステータス") |
|
|
if st.session_state.vector_store_initialized: |
|
|
st.success("システム準備完了") |
|
|
else: |
|
|
st.info("システム待機中") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.caption(f"セッションID: {st.session_state.session_id[:8]}...") |
|
|
if st.session_state.current_api_key: |
|
|
st.caption("APIキー設定済み(新しいチャットでも維持)") |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<div style='text-align: center; margin-bottom: 2rem;'> |
|
|
<h1 style='margin-bottom: 0.25rem;'>EmpathemeBot</h1> |
|
|
<p style='color: #6B7280; font-size: 0.9rem;'>Potionベースの質問応答システム</p> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
if not st.session_state.current_api_key: |
|
|
st.markdown( |
|
|
""" |
|
|
<div style="background: #FEF3C7; border: 2px solid #F59E0B; border-radius: 12px; padding: 1.5rem; margin: 2rem 0;"> |
|
|
<h3 style="color: #92400E; margin-top: 0;">APIキーの入力が必要です</h3> |
|
|
<p style="color: #78350F; margin-bottom: 1rem;"> |
|
|
EmpathemeBotを使用するには、OpenAI APIキーが必要です。 |
|
|
</p> |
|
|
<ol style="color: #78350F; margin-left: 1.5rem;"> |
|
|
<li>左上の「>」ボタンをクリックしてサイドバーを開く</li> |
|
|
<li>「OpenAI API キー」セクションにAPIキー(sk-...)を入力</li> |
|
|
<li>「APIキーを設定」ボタンをクリック</li> |
|
|
</ol> |
|
|
<p style="color: #78350F; font-size: 0.9rem; margin-top: 1rem;"> |
|
|
APIキーは <a href="https://platform.openai.com/api-keys" target="__blank" style="color: #F59E0B;">OpenAIのダッシュボード</a> から取得できます。 |
|
|
</p> |
|
|
<p style="color: #78350F; font-size: 0.85rem; margin-top: 0.5rem; font-style: italic;"> |
|
|
※ APIキーは各ブラウザセッション専用です。他のユーザーと共有されません。 |
|
|
</p> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if len(st.session_state.messages) == 0: |
|
|
st.markdown( |
|
|
""" |
|
|
<div style="text-align: center; padding: 3rem 0; color: #6B7280;"> |
|
|
<p style="font-size: 0.95rem;">こんにちは、KurageSan®だよ!何か英語学習に関して困っていることはありますか?</p> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
if message["role"] == "user": |
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="display: flex; justify-content: flex-end; margin: 1rem 0; padding-right: 1rem;"> |
|
|
<div style="background: linear-gradient(135deg, #4F46E5 0%, #6366F1 100%); |
|
|
color: white; |
|
|
padding: 0.75rem 1.25rem; |
|
|
border-radius: 18px 18px 4px 18px; |
|
|
max-width: 60%; |
|
|
box-shadow: 0 2px 10px rgba(79, 70, 229, 0.2); |
|
|
word-wrap: break-word;"> |
|
|
<pre class="bubble-content">{html.escape(message['content'])}</pre> |
|
|
<div style="font-size: 0.7rem; opacity: 0.8; margin-top: 0.3rem; text-align: right;"> |
|
|
{message.get('timestamp', '')} |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
else: |
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="display: flex; justify-content: flex-start; margin: 1rem 0; padding-left: 1rem;"> |
|
|
<div style="background: #F3F4F6; |
|
|
color: #111827; |
|
|
padding: 0.75rem 1.25rem; |
|
|
border-radius: 18px 18px 18px 4px; |
|
|
max-width: 60%; |
|
|
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.08); |
|
|
word-wrap: break-word;"> |
|
|
<pre class="bubble-content">{html.escape(message['content'])}</pre> |
|
|
<div style="font-size: 0.7rem; opacity: 0.6; margin-top: 0.3rem;"> |
|
|
{message.get('timestamp', '')} |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("質問を入力してください...", key="chat_input"): |
|
|
|
|
|
timestamp = datetime.now().strftime("%H:%M") |
|
|
|
|
|
|
|
|
st.session_state.messages.append({ |
|
|
"role": "user", |
|
|
"content": prompt, |
|
|
"timestamp": timestamp |
|
|
}) |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="display: flex; justify-content: flex-end; margin: 1rem 0; padding-right: 1rem;"> |
|
|
<div style="background: linear-gradient(135deg, #4F46E5 0%, #6366F1 100%); |
|
|
color: white; |
|
|
padding: 0.75rem 1.25rem; |
|
|
border-radius: 18px 18px 4px 18px; |
|
|
max-width: 60%; |
|
|
box-shadow: 0 2px 10px rgba(79, 70, 229, 0.2); |
|
|
word-wrap: break-word;"> |
|
|
<pre class="bubble-content">{html.escape(prompt)}</pre> |
|
|
<div style="font-size: 0.7rem; opacity: 0.8; margin-top: 0.3rem; text-align: right;"> |
|
|
{timestamp} |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
with st.spinner("考えています..."): |
|
|
response_timestamp = datetime.now().strftime("%H:%M") |
|
|
response_data = bot.ask_question(prompt) |
|
|
|
|
|
if response_data: |
|
|
answer = response_data["answer"] |
|
|
|
|
|
|
|
|
st.session_state.messages.append({ |
|
|
"role": "assistant", |
|
|
"content": answer, |
|
|
"timestamp": response_timestamp, |
|
|
"metadata": { |
|
|
"source_count": response_data.get("source_count", 0) |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="display: flex; justify-content: flex-start; margin: 1rem 0; padding-left: 1rem;"> |
|
|
<div style="background: #F3F4F6; |
|
|
color: #111827; |
|
|
padding: 0.75rem 1.25rem; |
|
|
border-radius: 18px 18px 18px 4px; |
|
|
max-width: 60%; |
|
|
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.08); |
|
|
word-wrap: break-word;"> |
|
|
<pre class="bubble-content">{html.escape(answer)}</pre> |
|
|
<div style="font-size: 0.7rem; opacity: 0.6; margin-top: 0.3rem;"> |
|
|
{response_timestamp} |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
else: |
|
|
|
|
|
error_message = "申し訳ございません。回答の生成に失敗しました。もう一度お試しください。" |
|
|
|
|
|
st.session_state.messages.append({ |
|
|
"role": "assistant", |
|
|
"content": error_message, |
|
|
"timestamp": response_timestamp |
|
|
}) |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="display: flex; justify-content: flex-start; margin: 1rem 0; padding-left: 1rem;"> |
|
|
<div style="background: #F3F4F6; |
|
|
color: #111827; |
|
|
padding: 0.75rem 1.25rem; |
|
|
border-radius: 18px 18px 18px 4px; |
|
|
max-width: 60%; |
|
|
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.08); |
|
|
word-wrap: break-word;"> |
|
|
<pre class="bubble-content">{html.escape(error_message)}</pre> |
|
|
<div style="font-size: 0.7rem; opacity: 0.6; margin-top: 0.3rem;"> |
|
|
{response_timestamp} |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
st.session_state.last_activity = datetime.now() |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="text-align: center; margin-top: 3rem; padding: 1rem 0; |
|
|
border-top: 1px solid #E5E7EB; color: #9CA3AF; font-size: 0.8rem;"> |
|
|
EmpathemeBot · セッション: {st.session_state.session_id[:8]} |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|