|
|
""" |
|
|
๊ณต๊ณต๊ธฐ๊ด ์ฌ์
์ ์์ RAG ์ฑ๋ด |
|
|
|
|
|
๊ธฐ๋ฅ: |
|
|
- ๋ชจ๋ธ ์ ํ (API/๋ก์ปฌ GGUF) |
|
|
- Query Router (๊ฒ์ vs ์ง์ ๋ต๋ณ) |
|
|
- RAG ๊ธฐ๋ฐ ์ง์์๋ต (Hybrid Search + Re-ranker) |
|
|
- ์กฐ๊ฑด๋ถ ์ฐธ๊ณ ๋ฌธ์ ํ์ |
|
|
- ๋ํ ํ์คํ ๋ฆฌ ๊ด๋ฆฌ |
|
|
- ๊ฒ์ ๋ชจ๋ ์ ํ |
|
|
""" |
|
|
|
|
|
import streamlit as st |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import json |
|
|
|
|
|
|
|
|
root_dir = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(root_dir)) |
|
|
|
|
|
from src.utils.config import RAGConfig |
|
|
from src.utils.conversation_manager import ConversationManager |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="๊ณต๊ณต๊ธฐ๊ด ์ฌ์
์ ์์ ์ฑ๋ด", |
|
|
page_icon="๐ค", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.main-header { |
|
|
font-size: 2.5rem; |
|
|
font-weight: bold; |
|
|
color: #1f77b4; |
|
|
margin-bottom: 0.5rem; |
|
|
} |
|
|
.sub-header { |
|
|
font-size: 1.2rem; |
|
|
color: #666; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
.chat-message { |
|
|
padding: 1.5rem; |
|
|
border-radius: 0.5rem; |
|
|
margin-bottom: 1rem; |
|
|
display: flex; |
|
|
flex-direction: column; |
|
|
} |
|
|
.user-message { |
|
|
background-color: #e3f2fd; |
|
|
border-left: 5px solid #2196f3; |
|
|
} |
|
|
.assistant-message { |
|
|
background-color: #f5f5f5; |
|
|
border-left: 5px solid #4caf50; |
|
|
} |
|
|
.message-header { |
|
|
font-weight: bold; |
|
|
margin-bottom: 0.5rem; |
|
|
display: flex; |
|
|
align-items: center; |
|
|
gap: 0.5rem; |
|
|
} |
|
|
.message-content { |
|
|
line-height: 1.6; |
|
|
} |
|
|
.source-document { |
|
|
background-color: #fff9c4; |
|
|
padding: 1rem; |
|
|
border-radius: 0.3rem; |
|
|
margin: 0.5rem 0; |
|
|
border-left: 3px solid #fbc02d; |
|
|
} |
|
|
.source-header { |
|
|
font-weight: bold; |
|
|
color: #f57f17; |
|
|
margin-bottom: 0.5rem; |
|
|
} |
|
|
.metadata { |
|
|
font-size: 0.85rem; |
|
|
color: #666; |
|
|
margin-top: 0.5rem; |
|
|
} |
|
|
.token-usage { |
|
|
background-color: #e8f5e9; |
|
|
padding: 0.5rem 1rem; |
|
|
border-radius: 0.3rem; |
|
|
font-size: 0.9rem; |
|
|
margin-top: 0.5rem; |
|
|
} |
|
|
.search-mode-info { |
|
|
background-color: #e3f2fd; |
|
|
padding: 0.5rem 1rem; |
|
|
border-radius: 0.3rem; |
|
|
font-size: 0.9rem; |
|
|
margin-top: 0.5rem; |
|
|
} |
|
|
.routing-info { |
|
|
background-color: #fff3e0; |
|
|
padding: 0.5rem 1rem; |
|
|
border-radius: 0.3rem; |
|
|
font-size: 0.9rem; |
|
|
margin-top: 0.5rem; |
|
|
border-left: 3px solid #ff9800; |
|
|
} |
|
|
.model-info { |
|
|
background-color: #f3e5f5; |
|
|
padding: 0.8rem 1rem; |
|
|
border-radius: 0.3rem; |
|
|
font-size: 0.9rem; |
|
|
margin: 0.5rem 0; |
|
|
border-left: 3px solid #9c27b0; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
if 'conv_manager' not in st.session_state: |
|
|
st.session_state.conv_manager = ConversationManager() |
|
|
|
|
|
if 'rag_pipeline' not in st.session_state: |
|
|
st.session_state.rag_pipeline = None |
|
|
|
|
|
if 'model_type' not in st.session_state: |
|
|
st.session_state.model_type = None |
|
|
|
|
|
if 'show_routing_info' not in st.session_state: |
|
|
st.session_state.show_routing_info = False |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def initialize_rag(model_type): |
|
|
""" |
|
|
RAG ํ์ดํ๋ผ์ธ ์ด๊ธฐํ |
|
|
|
|
|
Args: |
|
|
model_type: "API ๋ชจ๋ธ (GPT)" ๋๋ "๋ก์ปฌ ๋ชจ๋ธ (GGUF)" |
|
|
|
|
|
Returns: |
|
|
(rag_pipeline, error_message, model_name) |
|
|
""" |
|
|
try: |
|
|
config = RAGConfig() |
|
|
|
|
|
if model_type == "API ๋ชจ๋ธ (GPT)": |
|
|
|
|
|
from src.generator.generator import RAGPipeline |
|
|
rag = RAGPipeline(config=config) |
|
|
return rag, None, "OpenAI GPT" |
|
|
|
|
|
elif model_type == "๋ก์ปฌ ๋ชจ๋ธ (GGUF)": |
|
|
|
|
|
from src.generator.generator_gguf import GGUFRAGPipeline |
|
|
|
|
|
|
|
|
rag = GGUFRAGPipeline( |
|
|
config=config, |
|
|
n_gpu_layers=35, |
|
|
n_ctx=8192, |
|
|
n_threads=4, |
|
|
max_new_tokens=512, |
|
|
temperature=0.7, |
|
|
top_p=0.9 |
|
|
) |
|
|
return rag, None, "Llama-3-Ko-8B (GGUF)" |
|
|
|
|
|
else: |
|
|
return None, f"์ ์ ์๋ ๋ชจ๋ธ ํ์
: {model_type}", None |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_detail = traceback.format_exc() |
|
|
return None, f"{str(e)}\n\n{error_detail}", None |
|
|
|
|
|
|
|
|
|
|
|
def generate_answer(query: str, top_k: int = 10, search_mode: str = "hybrid_rerank", alpha: float = 0.5): |
|
|
"""์ง์์ ๋ํ ๋ต๋ณ ์์ฑ""" |
|
|
try: |
|
|
result = st.session_state.rag_pipeline.generate_answer( |
|
|
query=query, |
|
|
top_k=top_k, |
|
|
search_mode=search_mode, |
|
|
alpha=alpha |
|
|
) |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_detail = traceback.format_exc() |
|
|
return { |
|
|
'answer': f"โ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}\n\n{error_detail}", |
|
|
'sources': [], |
|
|
'used_retrieval': False, |
|
|
'search_mode': search_mode, |
|
|
'routing_info': None, |
|
|
'usage': {'total_tokens': 0, 'prompt_tokens': 0, 'completion_tokens': 0} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def display_message( |
|
|
role: str, |
|
|
content: str, |
|
|
sources: list = None, |
|
|
usage: dict = None, |
|
|
search_mode: str = None, |
|
|
used_retrieval: bool = None, |
|
|
routing_info: dict = None |
|
|
): |
|
|
""" |
|
|
๋ฉ์์ง๋ฅผ ํ๋ฉด์ ํ์ |
|
|
|
|
|
Args: |
|
|
role: 'user' ๋๋ 'assistant' |
|
|
content: ๋ฉ์์ง ๋ด์ฉ |
|
|
sources: ์ฐธ๊ณ ๋ฌธ์ ๋ฆฌ์คํธ (assistant๋ง) |
|
|
usage: ํ ํฐ ์ฌ์ฉ๋ (assistant๋ง) |
|
|
search_mode: ๊ฒ์ ๋ชจ๋ (assistant๋ง) |
|
|
used_retrieval: ๊ฒ์ ์ฌ์ฉ ์ฌ๋ถ (assistant๋ง) |
|
|
routing_info: ๋ผ์ฐํ
์ ๋ณด (assistant๋ง) |
|
|
""" |
|
|
if role == 'user': |
|
|
st.markdown(f""" |
|
|
<div class="chat-message user-message"> |
|
|
<div class="message-header"> |
|
|
๐ค ์ฌ์ฉ์ |
|
|
</div> |
|
|
<div class="message-content"> |
|
|
{content} |
|
|
</div> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
else: |
|
|
|
|
|
st.markdown(f""" |
|
|
<div class="chat-message assistant-message"> |
|
|
<div class="message-header"> |
|
|
๐ค ์ฑ๋ด |
|
|
</div> |
|
|
<div class="message-content"> |
|
|
{content} |
|
|
</div> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if st.session_state.show_routing_info and routing_info: |
|
|
route_icon = "๐" if routing_info.get('route') == 'rag' else "๐ฌ" |
|
|
st.markdown(f""" |
|
|
<div class="routing-info"> |
|
|
{route_icon} ๋ผ์ฐํ
: {routing_info.get('route', 'N/A').upper()} |
|
|
(์ ๋ขฐ๋: {routing_info.get('confidence', 0):.2f}) - |
|
|
{routing_info.get('reason', 'N/A')} |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if used_retrieval and search_mode: |
|
|
mode_display = { |
|
|
'hybrid_rerank': '๐ Hybrid + Re-ranker', |
|
|
'hybrid': '๐ Hybrid Search', |
|
|
'embedding_rerank': '๐ ์๋ฒ ๋ฉ + Re-ranker', |
|
|
'embedding': '๐ ์๋ฒ ๋ฉ ๊ฒ์', |
|
|
'direct': '๐ฌ Direct (๊ฒ์ ์์)' |
|
|
} |
|
|
st.markdown(f""" |
|
|
<div class="search-mode-info"> |
|
|
๊ฒ์ ๋ชจ๋: {mode_display.get(search_mode, search_mode)} |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if used_retrieval and sources and len(sources) > 0: |
|
|
st.markdown("### ๐ ์ฐธ๊ณ ๋ฌธ์") |
|
|
|
|
|
for i, source in enumerate(sources, 1): |
|
|
metadata = source.get('metadata', {}) |
|
|
|
|
|
|
|
|
score = source.get('score', 0) |
|
|
score_type = source.get('score_type', '') |
|
|
|
|
|
|
|
|
content_preview = source.get('content', '')[:200] + "..." |
|
|
|
|
|
st.markdown(f""" |
|
|
<div class="source-document"> |
|
|
<div class="source-header"> |
|
|
๐ ๋ฌธ์ {i} (์ ์: {score:.3f} / {score_type}) |
|
|
</div> |
|
|
<div> |
|
|
{content_preview} |
|
|
</div> |
|
|
<div class="metadata"> |
|
|
๐ ํ์ผ: {metadata.get('ํ์ผ๋ช
', 'N/A')}<br> |
|
|
๐ข ๋ฐ์ฃผ๊ธฐ๊ด: {metadata.get('๋ฐ์ฃผ ๊ธฐ๊ด', 'N/A')}<br> |
|
|
๐ ์ฌ์
๋ช
: {metadata.get('์ฌ์
๋ช
', 'N/A')} |
|
|
</div> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
elif not used_retrieval: |
|
|
|
|
|
st.info("๐ฌ ์ด ๋ต๋ณ์ ๋ฌธ์ ๊ฒ์ ์์ด ์์ฑ๋์์ต๋๋ค.") |
|
|
|
|
|
|
|
|
if usage: |
|
|
st.markdown(f""" |
|
|
<div class="token-usage"> |
|
|
๐ข ํ ํฐ ์ฌ์ฉ๋: {usage.get('total_tokens', 0)} |
|
|
(ํ๋กฌํํธ: {usage.get('prompt_tokens', 0)}, |
|
|
์์ฑ: {usage.get('completion_tokens', 0)}) |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
st.markdown('<div class="main-header">๐ค ๊ณต๊ณต๊ธฐ๊ด ์ฌ์
์ ์์ ์ฑ๋ด</div>', unsafe_allow_html=True) |
|
|
st.markdown('<div class="sub-header">Query Router + RAG ๊ธฐ๋ฐ ์ง์์๋ต ์์คํ
</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("โ๏ธ ์ค์ ") |
|
|
|
|
|
|
|
|
st.markdown("### ๐ค ๋ชจ๋ธ ์ค์ ") |
|
|
|
|
|
model_type = st.selectbox( |
|
|
"์์ฑ ๋ชจ๋ธ ์ ํ", |
|
|
options=[ |
|
|
"API ๋ชจ๋ธ (GPT)", |
|
|
"๋ก์ปฌ ๋ชจ๋ธ (GGUF)" |
|
|
], |
|
|
index=0, |
|
|
help="OpenAI API ๋๋ ๋ก์ปฌ GGUF ๋ชจ๋ธ ์ ํ" |
|
|
) |
|
|
|
|
|
|
|
|
if model_type == "API ๋ชจ๋ธ (GPT)": |
|
|
st.markdown(""" |
|
|
<div class="model-info"> |
|
|
๐ <b>OpenAI GPT ๋ชจ๋ธ</b><br> |
|
|
โข ๋น ๋ฅด๊ณ ์์ ์ <br> |
|
|
โข API ํค ํ์<br> |
|
|
โข ๋น์ฉ ๋ฐ์ (ํ ํฐ๋น) |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
else: |
|
|
st.markdown(""" |
|
|
<div class="model-info"> |
|
|
๐ฅ๏ธ <b>Llama-3-Ko-8B (GGUF)</b><br> |
|
|
โข T4 GPU ๊ฐ์<br> |
|
|
โข ๋ก์ปฌ ์คํ (๋ฌด๋ฃ)<br> |
|
|
โข ์ด๊ธฐ ๋ก๋ฉ ์๊ฐ ์์<br> |
|
|
โข 35๊ฐ ๋ ์ด์ด GPU ์ฌ์ฉ |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("### ๐ ๊ฒ์ ์ค์ ") |
|
|
|
|
|
search_mode = st.selectbox( |
|
|
"๊ฒ์ ๋ชจ๋", |
|
|
options=["hybrid", "embedding"], |
|
|
index=0, |
|
|
format_func=lambda x: { |
|
|
"hybrid": "๐ Hybrid Search (BM25 + ์๋ฒ ๋ฉ)", |
|
|
"embedding": "๐ ์๋ฒ ๋ฉ ๊ฒ์" |
|
|
}[x], |
|
|
help="Hybrid: ํค์๋ + ์๋ฏธ ๊ฒ์ ๋ณํ (๊ถ์ฅ)" |
|
|
) |
|
|
|
|
|
|
|
|
use_reranker = st.toggle( |
|
|
"๐ Re-ranker ์ฌ์ฉ", |
|
|
value=True, |
|
|
help="๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ CrossEncoder๋ก ์ฌ์ ๋ ฌํ์ฌ ์ ํ๋ ํฅ์ (๊ถ์ฅ)" |
|
|
) |
|
|
|
|
|
|
|
|
if use_reranker: |
|
|
if search_mode == "hybrid": |
|
|
actual_search_mode = "hybrid_rerank" |
|
|
else: |
|
|
actual_search_mode = "embedding_rerank" |
|
|
else: |
|
|
actual_search_mode = search_mode |
|
|
|
|
|
top_k = st.slider( |
|
|
"๊ฒ์ํ ๋ฌธ์ ๊ฐ์ (Top-K)", |
|
|
min_value=1, |
|
|
max_value=20, |
|
|
value=10, |
|
|
help="๊ฒ์ํ ๋ฌธ์ ๊ฐ์" |
|
|
) |
|
|
|
|
|
alpha = st.slider( |
|
|
"์๋ฒ ๋ฉ ๊ฐ์ค์น (alpha)", |
|
|
min_value=0.0, |
|
|
max_value=1.0, |
|
|
value=0.5, |
|
|
step=0.1, |
|
|
help="0: BM25๋ง, 1: ์๋ฒ ๋ฉ๋ง, 0.5: ๋์ผ ๊ฐ์ค์น (Hybrid ๋ชจ๋์์๋ง ์ฌ์ฉ)", |
|
|
disabled=(search_mode == "embedding") |
|
|
) |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("### ๐ ๏ธ ๊ฐ๋ฐ์ ์ต์
") |
|
|
|
|
|
show_routing = st.toggle( |
|
|
"๐ ๋ผ์ฐํ
์ ๋ณด ํ์", |
|
|
value=False, |
|
|
help="Router์ ํ๋จ ๊ณผ์ ์ ํ์ (๋๋ฒ๊น
์ฉ)" |
|
|
) |
|
|
st.session_state.show_routing_info = show_routing |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("### ๐ฌ ๋ํ ๊ด๋ฆฌ") |
|
|
|
|
|
if st.button("๐๏ธ ๋ํ ์ด๊ธฐํ", use_container_width=True): |
|
|
st.session_state.conv_manager.clear() |
|
|
st.rerun() |
|
|
|
|
|
if st.button("๐พ ๋ํ ๋ค์ด๋ก๋", use_container_width=True): |
|
|
if len(st.session_state.conv_manager) > 0: |
|
|
json_str = st.session_state.conv_manager.export_to_json() |
|
|
|
|
|
st.download_button( |
|
|
label="๐ฅ JSON ๋ค์ด๋ก๋", |
|
|
data=json_str, |
|
|
file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", |
|
|
mime="application/json", |
|
|
use_container_width=True |
|
|
) |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("### ๐ ํต๊ณ") |
|
|
stats = st.session_state.conv_manager.get_statistics() |
|
|
|
|
|
st.metric("์ด ๋ํ ์", stats.get('total', 0)) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("### ๐ ํ์ฌ ์ค์ ") |
|
|
st.text(f"๋ชจ๋ธ: {model_type}") |
|
|
st.text(f"๊ฒ์ ๋ชจ๋: {search_mode}") |
|
|
st.text(f"Re-ranker: {'โ
ON' if use_reranker else 'โ OFF'}") |
|
|
st.text(f"์ค์ ๋ชจ๋: {actual_search_mode}") |
|
|
st.text(f"Top-K: {top_k}") |
|
|
if search_mode == "hybrid": |
|
|
st.text(f"Alpha: {alpha}") |
|
|
st.text(f"Router Info: {'โ
ON' if show_routing else 'โ OFF'}") |
|
|
|
|
|
|
|
|
|
|
|
if (st.session_state.rag_pipeline is None or |
|
|
st.session_state.model_type != model_type): |
|
|
|
|
|
with st.spinner(f"๐ {model_type} ์ด๊ธฐํ ์ค... (GGUF ๋ชจ๋ธ์ 1~2๋ถ ์์๋ ์ ์์ต๋๋ค)"): |
|
|
rag, error, rag_type = initialize_rag(model_type) |
|
|
|
|
|
if error: |
|
|
st.error(f"โ RAG ํ์ดํ๋ผ์ธ ์ด๊ธฐํ ์คํจ") |
|
|
|
|
|
with st.expander("๐ ์๋ฌ ์์ธ ์ ๋ณด"): |
|
|
st.code(error) |
|
|
|
|
|
st.info(""" |
|
|
### ๐ก ํด๊ฒฐ ๋ฐฉ๋ฒ |
|
|
|
|
|
**GGUF ๋ชจ๋ธ ์คํจ ์:** |
|
|
1. llama-cpp-python ์ค์น ํ์ธ: |
|
|
```bash |
|
|
pip install llama-cpp-python |
|
|
``` |
|
|
|
|
|
2. GGUF ๋ชจ๋ธ ํ์ผ ํ์ธ: |
|
|
- config.yaml์ GGUF_MODEL_PATH ๋๋ |
|
|
- MODEL_HUB_REPO ์ค์ ํ์ธ |
|
|
|
|
|
3. GPU ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ ์: |
|
|
- n_gpu_layers ๊ฐ ๊ฐ์ (35 โ 20) |
|
|
|
|
|
**API ๋ชจ๋ธ ์คํจ ์:** |
|
|
1. ChromaDB๊ฐ ์์ฑ๋์๋์ง ํ์ธ: |
|
|
```bash |
|
|
python main.py --step embed |
|
|
``` |
|
|
|
|
|
2. OpenAI API ํค ํ์ธ: |
|
|
```bash |
|
|
# .env ํ์ผ |
|
|
OPENAI_API_KEY=your-key-here |
|
|
``` |
|
|
|
|
|
3. ํ์ํ ํจํค์ง ์ค์น: |
|
|
```bash |
|
|
pip install rank-bm25 sentence-transformers |
|
|
``` |
|
|
""") |
|
|
return |
|
|
|
|
|
st.session_state.rag_pipeline = rag |
|
|
st.session_state.model_type = model_type |
|
|
st.success(f"โ
{rag_type} ๋ชจ๋ธ ์ค๋น ์๋ฃ!") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
if len(st.session_state.conv_manager) == 0: |
|
|
st.info(""" |
|
|
### ๐ ํ์ํฉ๋๋ค! |
|
|
|
|
|
๊ณต๊ณต๊ธฐ๊ด ์ฌ์
์ ์์์ ๋ํด ์ง๋ฌธํด๋ณด์ธ์. |
|
|
|
|
|
**์์ ์ง๋ฌธ:** |
|
|
- "์๋
ํ์ธ์" (๊ฒ์ ์ ํจ) |
|
|
- "๋ฐ์ดํฐ ํ์คํ ์๊ตฌ์ฌํญ์ ๋ฌด์์ธ๊ฐ์?" (๊ฒ์ ์ํ) |
|
|
- "๋ณด์ ๊ด๋ จ ์๊ตฌ์ฌํญ์ ์ค๋ช
ํด์ฃผ์ธ์" (๊ฒ์ ์ํ) |
|
|
- "๊ณ ๋ง์์" (๊ฒ์ ์ ํจ) |
|
|
""") |
|
|
|
|
|
|
|
|
for msg in st.session_state.conv_manager.get_ui_history(): |
|
|
display_message( |
|
|
role=msg['role'], |
|
|
content=msg['content'], |
|
|
sources=msg.get('sources'), |
|
|
usage=msg.get('usage'), |
|
|
search_mode=msg.get('search_mode'), |
|
|
used_retrieval=msg.get('used_retrieval'), |
|
|
routing_info=msg.get('routing_info') |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
with st.form(key='question_form', clear_on_submit=True): |
|
|
user_input = st.text_area( |
|
|
"์ง๋ฌธ์ ์
๋ ฅํ์ธ์:", |
|
|
height=100, |
|
|
placeholder="์: ๋ฐ์ดํฐ ํ์คํ ์๊ตฌ์ฌํญ์ ๋ฌด์์ธ๊ฐ์?" |
|
|
) |
|
|
|
|
|
col1, col2, col3 = st.columns([1, 1, 4]) |
|
|
|
|
|
with col1: |
|
|
submit_button = st.form_submit_button("๐ค ์ ์ก", use_container_width=True) |
|
|
|
|
|
|
|
|
if submit_button and user_input: |
|
|
|
|
|
|
|
|
with st.spinner("๐ค ๋ต๋ณ ์์ฑ ์ค..."): |
|
|
result = generate_answer( |
|
|
query=user_input, |
|
|
top_k=top_k, |
|
|
search_mode=actual_search_mode, |
|
|
alpha=alpha |
|
|
) |
|
|
|
|
|
|
|
|
st.session_state.conv_manager.add_message( |
|
|
user_msg=user_input, |
|
|
ai_msg=result['answer'], |
|
|
query_type=result.get('query_type', 'unknown'), |
|
|
sources=result.get('sources', []), |
|
|
usage=result.get('usage', {}), |
|
|
search_mode=result.get('search_mode'), |
|
|
used_retrieval=result.get('used_retrieval', False), |
|
|
routing_info=result.get('routing_info') |
|
|
) |
|
|
|
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |