Dongjin1203's picture
์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ฆ๊ฐ€
15c1ef1
raw
history blame
19.3 kB
"""
๊ณต๊ณต๊ธฐ๊ด€ ์‚ฌ์—…์ œ์•ˆ์„œ RAG ์ฑ—๋ด‡
๊ธฐ๋Šฅ:
- ๋ชจ๋ธ ์„ ํƒ (API/๋กœ์ปฌ GGUF)
- Query Router (๊ฒ€์ƒ‰ vs ์ง์ ‘ ๋‹ต๋ณ€)
- RAG ๊ธฐ๋ฐ˜ ์งˆ์˜์‘๋‹ต (Hybrid Search + Re-ranker)
- ์กฐ๊ฑด๋ถ€ ์ฐธ๊ณ  ๋ฌธ์„œ ํ‘œ์‹œ
- ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ๊ด€๋ฆฌ
- ๊ฒ€์ƒ‰ ๋ชจ๋“œ ์„ ํƒ
"""
import streamlit as st
import sys
from pathlib import Path
from datetime import datetime
import json
# ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ์ถ”๊ฐ€
root_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(root_dir))
from src.utils.config import RAGConfig
from src.utils.conversation_manager import ConversationManager
# ===== ํŽ˜์ด์ง€ ์„ค์ • =====
st.set_page_config(
page_title="๊ณต๊ณต๊ธฐ๊ด€ ์‚ฌ์—…์ œ์•ˆ์„œ ์ฑ—๋ด‡",
page_icon="๐Ÿค–",
layout="wide",
initial_sidebar_state="expanded"
)
# ===== ์Šคํƒ€์ผ =====
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
font-weight: bold;
color: #1f77b4;
margin-bottom: 0.5rem;
}
.sub-header {
font-size: 1.2rem;
color: #666;
margin-bottom: 2rem;
}
.chat-message {
padding: 1.5rem;
border-radius: 0.5rem;
margin-bottom: 1rem;
display: flex;
flex-direction: column;
}
.user-message {
background-color: #e3f2fd;
border-left: 5px solid #2196f3;
}
.assistant-message {
background-color: #f5f5f5;
border-left: 5px solid #4caf50;
}
.message-header {
font-weight: bold;
margin-bottom: 0.5rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.message-content {
line-height: 1.6;
}
.source-document {
background-color: #fff9c4;
padding: 1rem;
border-radius: 0.3rem;
margin: 0.5rem 0;
border-left: 3px solid #fbc02d;
}
.source-header {
font-weight: bold;
color: #f57f17;
margin-bottom: 0.5rem;
}
.metadata {
font-size: 0.85rem;
color: #666;
margin-top: 0.5rem;
}
.token-usage {
background-color: #e8f5e9;
padding: 0.5rem 1rem;
border-radius: 0.3rem;
font-size: 0.9rem;
margin-top: 0.5rem;
}
.search-mode-info {
background-color: #e3f2fd;
padding: 0.5rem 1rem;
border-radius: 0.3rem;
font-size: 0.9rem;
margin-top: 0.5rem;
}
.routing-info {
background-color: #fff3e0;
padding: 0.5rem 1rem;
border-radius: 0.3rem;
font-size: 0.9rem;
margin-top: 0.5rem;
border-left: 3px solid #ff9800;
}
.model-info {
background-color: #f3e5f5;
padding: 0.8rem 1rem;
border-radius: 0.3rem;
font-size: 0.9rem;
margin: 0.5rem 0;
border-left: 3px solid #9c27b0;
}
</style>
""", unsafe_allow_html=True)
# ===== ์„ธ์…˜ ์ƒํƒœ ์ดˆ๊ธฐํ™” =====
if 'conv_manager' not in st.session_state:
st.session_state.conv_manager = ConversationManager()
if 'rag_pipeline' not in st.session_state:
st.session_state.rag_pipeline = None
if 'model_type' not in st.session_state:
st.session_state.model_type = None
if 'show_routing_info' not in st.session_state:
st.session_state.show_routing_info = False
# ===== RAG ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” =====
@st.cache_resource
def initialize_rag(model_type):
"""
RAG ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
Args:
model_type: "API ๋ชจ๋ธ (GPT)" ๋˜๋Š” "๋กœ์ปฌ ๋ชจ๋ธ (GGUF)"
Returns:
(rag_pipeline, error_message, model_name)
"""
try:
config = RAGConfig()
if model_type == "API ๋ชจ๋ธ (GPT)":
# API ๋ชจ๋ธ ์‚ฌ์šฉ
from src.generator.generator import RAGPipeline
rag = RAGPipeline(config=config)
return rag, None, "OpenAI GPT"
elif model_type == "๋กœ์ปฌ ๋ชจ๋ธ (GGUF)":
# GGUF ๋ชจ๋ธ ์‚ฌ์šฉ
from src.generator.generator_gguf import GGUFRAGPipeline
# T4 GPU ์ตœ์  ์„ค์ •
rag = GGUFRAGPipeline(
config=config,
n_gpu_layers=35, # T4์—์„œ ์ „์ฒด ๋ ˆ์ด์–ด GPU ์‚ฌ์šฉ
n_ctx=8192, # ์ปจํ…์ŠคํŠธ ๊ธธ์ด
n_threads=4, # CPU ์Šค๋ ˆ๋“œ (GPU ์‚ฌ์šฉ ์‹œ ๋‚ฎ๊ฒŒ)
max_new_tokens=512, # ์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ
temperature=0.7,
top_p=0.9
)
return rag, None, "Llama-3-Ko-8B (GGUF)"
else:
return None, f"์•Œ ์ˆ˜ ์—†๋Š” ๋ชจ๋ธ ํƒ€์ž…: {model_type}", None
except Exception as e:
import traceback
error_detail = traceback.format_exc()
return None, f"{str(e)}\n\n{error_detail}", None
# ===== ๋‹ต๋ณ€ ์ƒ์„ฑ =====
def generate_answer(query: str, top_k: int = 10, search_mode: str = "hybrid_rerank", alpha: float = 0.5):
"""์งˆ์˜์— ๋Œ€ํ•œ ๋‹ต๋ณ€ ์ƒ์„ฑ"""
try:
result = st.session_state.rag_pipeline.generate_answer(
query=query,
top_k=top_k,
search_mode=search_mode,
alpha=alpha
)
return result
except Exception as e:
import traceback
error_detail = traceback.format_exc()
return {
'answer': f"โŒ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}\n\n{error_detail}",
'sources': [],
'used_retrieval': False,
'search_mode': search_mode,
'routing_info': None,
'usage': {'total_tokens': 0, 'prompt_tokens': 0, 'completion_tokens': 0}
}
# ===== ๋ฉ”์‹œ์ง€ ํ‘œ์‹œ =====
def display_message(
role: str,
content: str,
sources: list = None,
usage: dict = None,
search_mode: str = None,
used_retrieval: bool = None,
routing_info: dict = None
):
"""
๋ฉ”์‹œ์ง€๋ฅผ ํ™”๋ฉด์— ํ‘œ์‹œ
Args:
role: 'user' ๋˜๋Š” 'assistant'
content: ๋ฉ”์‹œ์ง€ ๋‚ด์šฉ
sources: ์ฐธ๊ณ  ๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ (assistant๋งŒ)
usage: ํ† ํฐ ์‚ฌ์šฉ๋Ÿ‰ (assistant๋งŒ)
search_mode: ๊ฒ€์ƒ‰ ๋ชจ๋“œ (assistant๋งŒ)
used_retrieval: ๊ฒ€์ƒ‰ ์‚ฌ์šฉ ์—ฌ๋ถ€ (assistant๋งŒ)
routing_info: ๋ผ์šฐํŒ… ์ •๋ณด (assistant๋งŒ)
"""
if role == 'user':
st.markdown(f"""
<div class="chat-message user-message">
<div class="message-header">
๐Ÿ‘ค ์‚ฌ์šฉ์ž
</div>
<div class="message-content">
{content}
</div>
</div>
""", unsafe_allow_html=True)
else: # assistant
# ๋‹ต๋ณ€
st.markdown(f"""
<div class="chat-message assistant-message">
<div class="message-header">
๐Ÿค– ์ฑ—๋ด‡
</div>
<div class="message-content">
{content}
</div>
</div>
""", unsafe_allow_html=True)
# ===== ๋ผ์šฐํŒ… ์ •๋ณด (๊ฐœ๋ฐœ ๋ชจ๋“œ) =====
if st.session_state.show_routing_info and routing_info:
route_icon = "๐Ÿ”" if routing_info.get('route') == 'rag' else "๐Ÿ’ฌ"
st.markdown(f"""
<div class="routing-info">
{route_icon} ๋ผ์šฐํŒ…: {routing_info.get('route', 'N/A').upper()}
(์‹ ๋ขฐ๋„: {routing_info.get('confidence', 0):.2f}) -
{routing_info.get('reason', 'N/A')}
</div>
""", unsafe_allow_html=True)
# ===== ๊ฒ€์ƒ‰ ๋ชจ๋“œ ์ •๋ณด (๊ฒ€์ƒ‰ ์‚ฌ์šฉ ์‹œ๋งŒ) =====
if used_retrieval and search_mode:
mode_display = {
'hybrid_rerank': '๐Ÿ”„ Hybrid + Re-ranker',
'hybrid': '๐Ÿ”€ Hybrid Search',
'embedding_rerank': '๐Ÿ“Š ์ž„๋ฒ ๋”ฉ + Re-ranker',
'embedding': '๐Ÿ“Š ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰',
'direct': '๐Ÿ’ฌ Direct (๊ฒ€์ƒ‰ ์—†์Œ)'
}
st.markdown(f"""
<div class="search-mode-info">
๊ฒ€์ƒ‰ ๋ชจ๋“œ: {mode_display.get(search_mode, search_mode)}
</div>
""", unsafe_allow_html=True)
# ===== ์ฐธ๊ณ  ๋ฌธ์„œ (๊ฒ€์ƒ‰ ์‚ฌ์šฉ ์‹œ๋งŒ) =====
if used_retrieval and sources and len(sources) > 0:
st.markdown("### ๐Ÿ“š ์ฐธ๊ณ  ๋ฌธ์„œ")
for i, source in enumerate(sources, 1):
metadata = source.get('metadata', {})
# ๊ด€๋ จ๋„ ์ ์ˆ˜
score = source.get('score', 0)
score_type = source.get('score_type', '')
# ๋ฌธ์„œ ๋‚ด์šฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
content_preview = source.get('content', '')[:200] + "..."
st.markdown(f"""
<div class="source-document">
<div class="source-header">
๐Ÿ“„ ๋ฌธ์„œ {i} (์ ์ˆ˜: {score:.3f} / {score_type})
</div>
<div>
{content_preview}
</div>
<div class="metadata">
๐Ÿ“ ํŒŒ์ผ: {metadata.get('ํŒŒ์ผ๋ช…', 'N/A')}<br>
๐Ÿข ๋ฐœ์ฃผ๊ธฐ๊ด€: {metadata.get('๋ฐœ์ฃผ ๊ธฐ๊ด€', 'N/A')}<br>
๐Ÿ“‹ ์‚ฌ์—…๋ช…: {metadata.get('์‚ฌ์—…๋ช…', 'N/A')}
</div>
</div>
""", unsafe_allow_html=True)
elif not used_retrieval:
# ๊ฒ€์ƒ‰์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ ์•ˆ๋‚ด
st.info("๐Ÿ’ฌ ์ด ๋‹ต๋ณ€์€ ๋ฌธ์„œ ๊ฒ€์ƒ‰ ์—†์ด ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
# ===== ํ† ํฐ ์‚ฌ์šฉ๋Ÿ‰ =====
if usage:
st.markdown(f"""
<div class="token-usage">
๐Ÿ”ข ํ† ํฐ ์‚ฌ์šฉ๋Ÿ‰: {usage.get('total_tokens', 0)}
(ํ”„๋กฌํ”„ํŠธ: {usage.get('prompt_tokens', 0)},
์™„์„ฑ: {usage.get('completion_tokens', 0)})
</div>
""", unsafe_allow_html=True)
# ===== ๋ฉ”์ธ ์•ฑ =====
def main():
# ํ—ค๋”
st.markdown('<div class="main-header">๐Ÿค– ๊ณต๊ณต๊ธฐ๊ด€ ์‚ฌ์—…์ œ์•ˆ์„œ ์ฑ—๋ด‡</div>', unsafe_allow_html=True)
st.markdown('<div class="sub-header">Query Router + RAG ๊ธฐ๋ฐ˜ ์งˆ์˜์‘๋‹ต ์‹œ์Šคํ…œ</div>', unsafe_allow_html=True)
# ===== ์‚ฌ์ด๋“œ๋ฐ” =====
with st.sidebar:
st.header("โš™๏ธ ์„ค์ •")
# ๋ชจ๋ธ ์„ค์ •
st.markdown("### ๐Ÿค– ๋ชจ๋ธ ์„ค์ •")
model_type = st.selectbox(
"์ƒ์„ฑ ๋ชจ๋ธ ์„ ํƒ",
options=[
"API ๋ชจ๋ธ (GPT)",
"๋กœ์ปฌ ๋ชจ๋ธ (GGUF)"
],
index=0,
help="OpenAI API ๋˜๋Š” ๋กœ์ปฌ GGUF ๋ชจ๋ธ ์„ ํƒ"
)
# ๋ชจ๋ธ๋ณ„ ์ •๋ณด ํ‘œ์‹œ
if model_type == "API ๋ชจ๋ธ (GPT)":
st.markdown("""
<div class="model-info">
๐ŸŒ <b>OpenAI GPT ๋ชจ๋ธ</b><br>
โ€ข ๋น ๋ฅด๊ณ  ์•ˆ์ •์ <br>
โ€ข API ํ‚ค ํ•„์š”<br>
โ€ข ๋น„์šฉ ๋ฐœ์ƒ (ํ† ํฐ๋‹น)
</div>
""", unsafe_allow_html=True)
else:
st.markdown("""
<div class="model-info">
๐Ÿ–ฅ๏ธ <b>Llama-3-Ko-8B (GGUF)</b><br>
โ€ข T4 GPU ๊ฐ€์†<br>
โ€ข ๋กœ์ปฌ ์‹คํ–‰ (๋ฌด๋ฃŒ)<br>
โ€ข ์ดˆ๊ธฐ ๋กœ๋”ฉ ์‹œ๊ฐ„ ์†Œ์š”<br>
โ€ข 35๊ฐœ ๋ ˆ์ด์–ด GPU ์‚ฌ์šฉ
</div>
""", unsafe_allow_html=True)
st.markdown("---")
# ๊ฒ€์ƒ‰ ์„ค์ •
st.markdown("### ๐Ÿ” ๊ฒ€์ƒ‰ ์„ค์ •")
search_mode = st.selectbox(
"๊ฒ€์ƒ‰ ๋ชจ๋“œ",
options=["hybrid", "embedding"],
index=0,
format_func=lambda x: {
"hybrid": "๐Ÿ”€ Hybrid Search (BM25 + ์ž„๋ฒ ๋”ฉ)",
"embedding": "๐Ÿ“Š ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰"
}[x],
help="Hybrid: ํ‚ค์›Œ๋“œ + ์˜๋ฏธ ๊ฒ€์ƒ‰ ๋ณ‘ํ–‰ (๊ถŒ์žฅ)"
)
# Reranker ํ† ๊ธ€
use_reranker = st.toggle(
"๐Ÿ”„ Re-ranker ์‚ฌ์šฉ",
value=True,
help="๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ CrossEncoder๋กœ ์žฌ์ •๋ ฌํ•˜์—ฌ ์ •ํ™•๋„ ํ–ฅ์ƒ (๊ถŒ์žฅ)"
)
# ์‹ค์ œ ๊ฒ€์ƒ‰ ๋ชจ๋“œ ๊ฒฐ์ •
if use_reranker:
if search_mode == "hybrid":
actual_search_mode = "hybrid_rerank"
else: # embedding
actual_search_mode = "embedding_rerank"
else:
actual_search_mode = search_mode
top_k = st.slider(
"๊ฒ€์ƒ‰ํ•  ๋ฌธ์„œ ๊ฐœ์ˆ˜ (Top-K)",
min_value=1,
max_value=20,
value=10,
help="๊ฒ€์ƒ‰ํ•  ๋ฌธ์„œ ๊ฐœ์ˆ˜"
)
alpha = st.slider(
"์ž„๋ฒ ๋”ฉ ๊ฐ€์ค‘์น˜ (alpha)",
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.1,
help="0: BM25๋งŒ, 1: ์ž„๋ฒ ๋”ฉ๋งŒ, 0.5: ๋™์ผ ๊ฐ€์ค‘์น˜ (Hybrid ๋ชจ๋“œ์—์„œ๋งŒ ์‚ฌ์šฉ)",
disabled=(search_mode == "embedding")
)
st.markdown("---")
# ๊ฐœ๋ฐœ์ž ์˜ต์…˜
st.markdown("### ๐Ÿ› ๏ธ ๊ฐœ๋ฐœ์ž ์˜ต์…˜")
show_routing = st.toggle(
"๐Ÿ” ๋ผ์šฐํŒ… ์ •๋ณด ํ‘œ์‹œ",
value=False,
help="Router์˜ ํŒ๋‹จ ๊ณผ์ •์„ ํ‘œ์‹œ (๋””๋ฒ„๊น…์šฉ)"
)
st.session_state.show_routing_info = show_routing
st.markdown("---")
# ๋Œ€ํ™” ๊ด€๋ฆฌ
st.markdown("### ๐Ÿ’ฌ ๋Œ€ํ™” ๊ด€๋ฆฌ")
if st.button("๐Ÿ—‘๏ธ ๋Œ€ํ™” ์ดˆ๊ธฐํ™”", use_container_width=True):
st.session_state.conv_manager.clear()
st.rerun()
if st.button("๐Ÿ’พ ๋Œ€ํ™” ๋‹ค์šด๋กœ๋“œ", use_container_width=True):
if len(st.session_state.conv_manager) > 0:
json_str = st.session_state.conv_manager.export_to_json()
st.download_button(
label="๐Ÿ“ฅ JSON ๋‹ค์šด๋กœ๋“œ",
data=json_str,
file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
mime="application/json",
use_container_width=True
)
st.markdown("---")
# ํ†ต๊ณ„
st.markdown("### ๐Ÿ“Š ํ†ต๊ณ„")
stats = st.session_state.conv_manager.get_statistics()
st.metric("์ด ๋Œ€ํ™” ์ˆ˜", stats.get('total', 0))
# ํ˜„์žฌ ์„ค์ • ํ‘œ์‹œ
st.markdown("---")
st.markdown("### ๐Ÿ“‹ ํ˜„์žฌ ์„ค์ •")
st.text(f"๋ชจ๋ธ: {model_type}")
st.text(f"๊ฒ€์ƒ‰ ๋ชจ๋“œ: {search_mode}")
st.text(f"Re-ranker: {'โœ… ON' if use_reranker else 'โŒ OFF'}")
st.text(f"์‹ค์ œ ๋ชจ๋“œ: {actual_search_mode}")
st.text(f"Top-K: {top_k}")
if search_mode == "hybrid":
st.text(f"Alpha: {alpha}")
st.text(f"Router Info: {'โœ… ON' if show_routing else 'โŒ OFF'}")
# ===== RAG ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” =====
# ๋ชจ๋ธ ํƒ€์ž…์ด ๋ณ€๊ฒฝ๋˜์—ˆ๊ฑฐ๋‚˜ ํŒŒ์ดํ”„๋ผ์ธ์ด ์—†์œผ๋ฉด ์žฌ์ดˆ๊ธฐํ™”
if (st.session_state.rag_pipeline is None or
st.session_state.model_type != model_type):
with st.spinner(f"๐Ÿ”„ {model_type} ์ดˆ๊ธฐํ™” ์ค‘... (GGUF ๋ชจ๋ธ์€ 1~2๋ถ„ ์†Œ์š”๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค)"):
rag, error, rag_type = initialize_rag(model_type)
if error:
st.error(f"โŒ RAG ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ")
with st.expander("๐Ÿ” ์—๋Ÿฌ ์ƒ์„ธ ์ •๋ณด"):
st.code(error)
st.info("""
### ๐Ÿ’ก ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•
**GGUF ๋ชจ๋ธ ์‹คํŒจ ์‹œ:**
1. llama-cpp-python ์„ค์น˜ ํ™•์ธ:
```bash
pip install llama-cpp-python
```
2. GGUF ๋ชจ๋ธ ํŒŒ์ผ ํ™•์ธ:
- config.yaml์˜ GGUF_MODEL_PATH ๋˜๋Š”
- MODEL_HUB_REPO ์„ค์ • ํ™•์ธ
3. GPU ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ ์‹œ:
- n_gpu_layers ๊ฐ’ ๊ฐ์†Œ (35 โ†’ 20)
**API ๋ชจ๋ธ ์‹คํŒจ ์‹œ:**
1. ChromaDB๊ฐ€ ์ƒ์„ฑ๋˜์—ˆ๋Š”์ง€ ํ™•์ธ:
```bash
python main.py --step embed
```
2. OpenAI API ํ‚ค ํ™•์ธ:
```bash
# .env ํŒŒ์ผ
OPENAI_API_KEY=your-key-here
```
3. ํ•„์š”ํ•œ ํŒจํ‚ค์ง€ ์„ค์น˜:
```bash
pip install rank-bm25 sentence-transformers
```
""")
return
st.session_state.rag_pipeline = rag
st.session_state.model_type = model_type
st.success(f"โœ… {rag_type} ๋ชจ๋ธ ์ค€๋น„ ์™„๋ฃŒ!")
# ===== ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ํ‘œ์‹œ =====
st.markdown("---")
if len(st.session_state.conv_manager) == 0:
st.info("""
### ๐Ÿ‘‹ ํ™˜์˜ํ•ฉ๋‹ˆ๋‹ค!
๊ณต๊ณต๊ธฐ๊ด€ ์‚ฌ์—…์ œ์•ˆ์„œ์— ๋Œ€ํ•ด ์งˆ๋ฌธํ•ด๋ณด์„ธ์š”.
**์˜ˆ์‹œ ์งˆ๋ฌธ:**
- "์•ˆ๋…•ํ•˜์„ธ์š”" (๊ฒ€์ƒ‰ ์•ˆ ํ•จ)
- "๋ฐ์ดํ„ฐ ํ‘œ์ค€ํ™” ์š”๊ตฌ์‚ฌํ•ญ์€ ๋ฌด์—‡์ธ๊ฐ€์š”?" (๊ฒ€์ƒ‰ ์ˆ˜ํ–‰)
- "๋ณด์•ˆ ๊ด€๋ จ ์š”๊ตฌ์‚ฌํ•ญ์„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”" (๊ฒ€์ƒ‰ ์ˆ˜ํ–‰)
- "๊ณ ๋งˆ์›Œ์š”" (๊ฒ€์ƒ‰ ์•ˆ ํ•จ)
""")
# ๊ธฐ์กด ๋ฉ”์‹œ์ง€ ํ‘œ์‹œ
for msg in st.session_state.conv_manager.get_ui_history():
display_message(
role=msg['role'],
content=msg['content'],
sources=msg.get('sources'),
usage=msg.get('usage'),
search_mode=msg.get('search_mode'),
used_retrieval=msg.get('used_retrieval'),
routing_info=msg.get('routing_info')
)
# ===== ์งˆ๋ฌธ ์ž…๋ ฅ =====
st.markdown("---")
with st.form(key='question_form', clear_on_submit=True):
user_input = st.text_area(
"์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”:",
height=100,
placeholder="์˜ˆ: ๋ฐ์ดํ„ฐ ํ‘œ์ค€ํ™” ์š”๊ตฌ์‚ฌํ•ญ์€ ๋ฌด์—‡์ธ๊ฐ€์š”?"
)
col1, col2, col3 = st.columns([1, 1, 4])
with col1:
submit_button = st.form_submit_button("๐Ÿ“ค ์ „์†ก", use_container_width=True)
# ===== ์งˆ๋ฌธ ์ฒ˜๋ฆฌ =====
if submit_button and user_input:
# ๋‹ต๋ณ€ ์ƒ์„ฑ
with st.spinner("๐Ÿค” ๋‹ต๋ณ€ ์ƒ์„ฑ ์ค‘..."):
result = generate_answer(
query=user_input,
top_k=top_k,
search_mode=actual_search_mode,
alpha=alpha
)
# ์–ด์‹œ์Šคํ„ดํŠธ ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€
st.session_state.conv_manager.add_message(
user_msg=user_input,
ai_msg=result['answer'],
query_type=result.get('query_type', 'unknown'),
sources=result.get('sources', []),
usage=result.get('usage', {}),
search_mode=result.get('search_mode'),
used_retrieval=result.get('used_retrieval', False),
routing_info=result.get('routing_info')
)
# ํ™”๋ฉด ์ƒˆ๋กœ๊ณ ์นจ
st.rerun()
if __name__ == "__main__":
main()