|
|
from llama_cpp import Llama |
|
|
from typing import Optional, Dict, Any, List |
|
|
import logging |
|
|
import time |
|
|
import os |
|
|
|
|
|
from src.utils.config import RAGConfig |
|
|
from src.router.query_router import QueryRouter |
|
|
from src.prompts.dynamic_prompts import PromptManager |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class GGUFGenerator: |
|
|
""" |
|
|
GGUF ๊ธฐ๋ฐ Llama-3 ์์ฑ๊ธฐ |
|
|
|
|
|
llama.cpp๋ฅผ ์ฌ์ฉํ์ฌ GGUF ํฌ๋งท ๋ชจ๋ธ์ ๋ก๋ํ๊ณ |
|
|
์
์ฐฐ ๊ด๋ จ ์ง์์๋ต์ ์ํํฉ๋๋ค. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
n_gpu_layers: int = 0, |
|
|
n_ctx: int = 8192, |
|
|
n_threads: int = 8, |
|
|
config = None, |
|
|
max_new_tokens: int = 256, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9, |
|
|
system_prompt: str = "๋น์ ์ RFP(์ ์์์ฒญ์) ๋ถ์ ๋ฐ ์์ฝ ์ ๋ฌธ๊ฐ์
๋๋ค." |
|
|
): |
|
|
""" |
|
|
์์ฑ๊ธฐ ์ด๊ธฐํ |
|
|
|
|
|
Args: |
|
|
model_path: GGUF ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก |
|
|
n_gpu_layers: GPU์ ์ฌ๋ฆด ๋ ์ด์ด ์ (0 = CPU๋ง, 35 = ์ ์ฒด GPU) |
|
|
n_ctx: ์ต๋ ์ปจํ
์คํธ ๊ธธ์ด |
|
|
n_threads: CPU ์ค๋ ๋ ์ |
|
|
max_new_tokens: ์ต๋ ์์ฑ ํ ํฐ ์ |
|
|
temperature: ์์ฑ ๋ค์์ฑ (0.0~1.0) |
|
|
top_p: Nucleus sampling ํ๋ผ๋ฏธํฐ |
|
|
system_prompt: ์์คํ
ํ๋กฌํํธ |
|
|
""" |
|
|
self.config = config or RAGConfig() |
|
|
self.model_path = model_path |
|
|
self.n_gpu_layers = n_gpu_layers |
|
|
self.n_ctx = n_ctx |
|
|
self.n_threads = n_threads |
|
|
self.max_new_tokens = max_new_tokens |
|
|
self.temperature = temperature |
|
|
self.top_p = top_p |
|
|
self.system_prompt = system_prompt |
|
|
|
|
|
|
|
|
self.model = None |
|
|
|
|
|
logger.info(f"GGUFGenerator ์ด๊ธฐํ ์๋ฃ") |
|
|
|
|
|
def load_model(self) -> None: |
|
|
""" |
|
|
GGUF ๋ชจ๋ธ ๋ก๋ |
|
|
|
|
|
๋ก์ง: |
|
|
1. USE_MODEL_HUB ํ์ธ |
|
|
2-A. True โ Hugging Face Hub์์ ๋ค์ด๋ก๋ |
|
|
2-B. False โ ๋ก์ปฌ ํ์ผ ์ฌ์ฉ |
|
|
3. ๋ชจ๋ธ ๋ก๋ |
|
|
""" |
|
|
|
|
|
|
|
|
if self.model is not None: |
|
|
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋์ด ์์ต๋๋ค.") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
use_model_hub = getattr(self.config, 'USE_MODEL_HUB', True) |
|
|
|
|
|
|
|
|
if use_model_hub: |
|
|
|
|
|
model_hub_repo = getattr(self.config, 'MODEL_HUB_REPO', 'beomi/Llama-3-Open-Ko-8B-gguf') |
|
|
model_hub_filename = getattr(self.config, 'MODEL_HUB_FILENAME', 'ggml-model-Q4_K_M.gguf') |
|
|
model_cache_dir = getattr(self.config, 'MODEL_CACHE_DIR', '.cache/models') |
|
|
|
|
|
logger.info(f"๐ฅ Model Hub์์ ๋ค์ด๋ก๋: {model_hub_repo}") |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id=model_hub_repo, |
|
|
filename=model_hub_filename, |
|
|
cache_dir=model_cache_dir, |
|
|
local_dir=model_cache_dir, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
logger.info(f"โ
๋ค์ด๋ก๋ ์๋ฃ: {model_path}") |
|
|
|
|
|
else: |
|
|
|
|
|
model_path = self.model_path |
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError( |
|
|
f"โ ๋ก์ปฌ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {model_path}\n" |
|
|
f" USE_MODEL_HUB=true๋ก ์ค์ ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ์ ์ค๋นํ์ธ์." |
|
|
) |
|
|
|
|
|
logger.info(f"๐ ๋ก์ปฌ ๋ชจ๋ธ ์ฌ์ฉ: {model_path}") |
|
|
|
|
|
|
|
|
logger.info(f"๐ GGUF ๋ชจ๋ธ ๋ก๋ ์ค...") |
|
|
logger.info(f" GPU ๋ ์ด์ด: {self.n_gpu_layers}") |
|
|
logger.info(f" ์ปจํ
์คํธ: {self.n_ctx}") |
|
|
|
|
|
self.model = Llama( |
|
|
model_path=model_path, |
|
|
n_gpu_layers=self.n_gpu_layers, |
|
|
n_ctx=self.n_ctx, |
|
|
n_threads=self.n_threads, |
|
|
verbose=True, |
|
|
) |
|
|
|
|
|
|
|
|
actual_n_ctx = self.model.n_ctx() |
|
|
logger.info("โ
GGUF ๋ชจ๋ธ ๋ก๋ ์๋ฃ!") |
|
|
logger.info(f" - ์ค์ ํ n_ctx: {self.n_ctx}") |
|
|
logger.info(f" - ์ค์ n_ctx: {actual_n_ctx}") |
|
|
|
|
|
if actual_n_ctx < self.n_ctx: |
|
|
logger.warning(f"โ ๏ธ n_ctx๊ฐ ์์๋ณด๋ค ์์ต๋๋ค: {actual_n_ctx} < {self.n_ctx}") |
|
|
logger.warning(f" ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ์ผ ์ ์์ต๋๋ค. n_gpu_layers๋ฅผ ์ค์ฌ๋ณด์ธ์.") |
|
|
|
|
|
except FileNotFoundError as e: |
|
|
logger.error(f"โ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {e}") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") |
|
|
raise RuntimeError(f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") |
|
|
|
|
|
def format_prompt( |
|
|
self, |
|
|
question: str, |
|
|
context: Optional[str] = None, |
|
|
system_prompt: Optional[str] = None |
|
|
) -> str: |
|
|
""" |
|
|
GGUF ๋ชจ๋ธ์ฉ ๊ฐ๋จํ ํ๋กฌํํธ ํฌ๋งทํ
|
|
|
|
|
|
Llama-3 ํน์ ํ ํฐ ๋์ ์์ ํ
์คํธ ๊ธฐ๋ฐ ํ
ํ๋ฆฟ ์ฌ์ฉ |
|
|
""" |
|
|
|
|
|
if system_prompt is None: |
|
|
system_prompt = self.system_prompt |
|
|
|
|
|
|
|
|
if context is not None: |
|
|
user_message = f"์ฐธ๊ณ ๋ฌธ์:\n{context}\n\n์ง๋ฌธ: {question}" |
|
|
else: |
|
|
user_message = question |
|
|
|
|
|
|
|
|
formatted_prompt = f"""### ์์คํ
|
|
|
{system_prompt} |
|
|
|
|
|
### ์ฌ์ฉ์ |
|
|
{user_message} |
|
|
|
|
|
### ๋ต๋ณ |
|
|
""" |
|
|
|
|
|
return formatted_prompt |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
temperature: Optional[float] = None, |
|
|
top_p: Optional[float] = None, |
|
|
) -> str: |
|
|
""" |
|
|
ํ๋กฌํํธ๋ฅผ ์
๋ ฅ๋ฐ์ ์๋ต ์์ฑ |
|
|
|
|
|
Args: |
|
|
prompt: ํฌ๋งท๋ ํ๋กฌํํธ |
|
|
max_new_tokens: ์ต๋ ์์ฑ ํ ํฐ ์ |
|
|
temperature: ์์ฑ ๋ค์์ฑ |
|
|
top_p: Nucleus sampling |
|
|
|
|
|
Returns: |
|
|
์์ฑ๋ ์๋ต ํ
์คํธ |
|
|
|
|
|
Raises: |
|
|
RuntimeError: ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์ ๊ฒฝ์ฐ |
|
|
""" |
|
|
|
|
|
if self.model is None: |
|
|
raise RuntimeError( |
|
|
"๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. load_model()์ ๋จผ์ ํธ์ถํ์ธ์." |
|
|
) |
|
|
|
|
|
|
|
|
if max_new_tokens is None: |
|
|
max_new_tokens = self.max_new_tokens |
|
|
if temperature is None: |
|
|
temperature = self.temperature |
|
|
if top_p is None: |
|
|
top_p = self.top_p |
|
|
|
|
|
try: |
|
|
logger.info(f"๐ ์์ฑ ์์ (max_tokens={max_new_tokens}, temp={temperature})") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
output = self.model( |
|
|
prompt, |
|
|
max_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
echo=False, |
|
|
stop=[ |
|
|
|
|
|
"###", "\n\n###", |
|
|
"### ์ฌ์ฉ์", "\n์ฌ์ฉ์:", |
|
|
"</s>", |
|
|
|
|
|
"ํ๊ตญ์ด ๋ต๋ณ", "ํ๊ตญ์ด๋ก ๋ต๋ณ", "์ง์นจ:", |
|
|
"๋ฌธ์ฅ", "(๋ฌธ์ฅ", |
|
|
|
|
|
"\n\n", |
|
|
"?", |
|
|
"์?", "๊น?", "๋์?", "์ต๋๊น?" |
|
|
], |
|
|
) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
logger.info(f"โ
์์ฑ ์๋ฃ: {elapsed:.2f}์ด") |
|
|
|
|
|
|
|
|
response = output['choices'][0]['text'].strip() |
|
|
|
|
|
logger.info(f"๐ ์๋ต ๊ธธ์ด: {len(response)} ๊ธ์") |
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"โ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") |
|
|
raise RuntimeError(f"ํ
์คํธ ์์ฑ ์คํจ: {e}") |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
question: str, |
|
|
context: Optional[str] = None, |
|
|
system_prompt=None, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
์ง๋ฌธ์ ๋ํ ์๋ต ์์ฑ (ํตํฉ ๋ฉ์๋) |
|
|
|
|
|
Args: |
|
|
question: ์ฌ์ฉ์ ์ง๋ฌธ |
|
|
context: ์ ํ์ ์ปจํ
์คํธ |
|
|
system_prompt: ์ ํ์ ์์คํ
ํ๋กฌํํธ |
|
|
**kwargs: generate() ๋ฉ์๋์ ์ ๋ฌ๋ ์ถ๊ฐ ํ๋ผ๋ฏธํฐ |
|
|
|
|
|
Returns: |
|
|
์์ฑ๋ ์๋ต |
|
|
""" |
|
|
|
|
|
prompt = self.format_prompt( |
|
|
question=question, |
|
|
context=context, |
|
|
system_prompt=system_prompt |
|
|
) |
|
|
|
|
|
|
|
|
response = self.generate(prompt, **kwargs) |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
class GGUFRAGPipeline: |
|
|
""" |
|
|
GGUF ์์ฑ๊ธฐ + RAG ํตํฉ ํ์ดํ๋ผ์ธ |
|
|
|
|
|
chatbot_app.py์ ํธํ๋๋ ์ธํฐํ์ด์ค ์ ๊ณต |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config=None, |
|
|
model: str = None, |
|
|
top_k: int = None, |
|
|
|
|
|
n_gpu_layers: int = None, |
|
|
n_ctx: int = None, |
|
|
n_threads: int = None, |
|
|
max_new_tokens: int = None, |
|
|
temperature: float = None, |
|
|
top_p: float = None, |
|
|
search_mode: str = None, |
|
|
alpha: float = None |
|
|
): |
|
|
""" |
|
|
์ด๊ธฐํ |
|
|
|
|
|
Args: |
|
|
config: RAGConfig ๊ฐ์ฒด |
|
|
model: ๋ชจ๋ธ ์ด๋ฆ (์ฌ์ฉ ์ ํจ, ํธํ์ฑ์ฉ) |
|
|
top_k: ๊ธฐ๋ณธ ๊ฒ์ ๋ฌธ์ ์ |
|
|
n_gpu_layers: GPU ๋ ์ด์ด ์ (config ์ค๋ฒ๋ผ์ด๋) |
|
|
n_ctx: ์ปจํ
์คํธ ๊ธธ์ด (config ์ค๋ฒ๋ผ์ด๋) |
|
|
n_threads: CPU ์ค๋ ๋ ์ (config ์ค๋ฒ๋ผ์ด๋) |
|
|
max_new_tokens: ์ต๋ ์์ฑ ํ ํฐ (config ์ค๋ฒ๋ผ์ด๋) |
|
|
temperature: ์์ฑ ๋ค์์ฑ (config ์ค๋ฒ๋ผ์ด๋) |
|
|
top_p: Nucleus sampling (config ์ค๋ฒ๋ผ์ด๋) |
|
|
search_mode: ๊ฒ์ ๋ชจ๋ |
|
|
alpha: ์๋ฒ ๋ฉ ๊ฐ์ค์น |
|
|
""" |
|
|
self.config = config or RAGConfig() |
|
|
|
|
|
|
|
|
self.top_k = top_k or getattr(self.config, 'DEFAULT_TOP_K', 10) |
|
|
|
|
|
|
|
|
self.search_mode = search_mode or getattr(self.config, 'DEFAULT_SEARCH_MODE', 'hybrid_rerank') |
|
|
self.alpha = alpha if alpha is not None else getattr(self.config, 'DEFAULT_ALPHA', 0.5) |
|
|
|
|
|
|
|
|
logger.info("RAGRetriever ์ด๊ธฐํ ์ค...") |
|
|
from src.retriever.retriever import RAGRetriever |
|
|
self.retriever = RAGRetriever(config=self.config) |
|
|
|
|
|
|
|
|
gguf_n_gpu_layers = n_gpu_layers if n_gpu_layers is not None else getattr(self.config, 'GGUF_N_GPU_LAYERS', 35) |
|
|
gguf_n_ctx = n_ctx if n_ctx is not None else getattr(self.config, 'GGUF_N_CTX', 2048) |
|
|
gguf_n_threads = n_threads if n_threads is not None else getattr(self.config, 'GGUF_N_THREADS', 4) |
|
|
gguf_max_new_tokens = max_new_tokens if max_new_tokens is not None else getattr(self.config, 'GGUF_MAX_NEW_TOKENS', 512) |
|
|
gguf_temperature = temperature if temperature is not None else getattr(self.config, 'GGUF_TEMPERATURE', 0.7) |
|
|
gguf_top_p = top_p if top_p is not None else getattr(self.config, 'GGUF_TOP_P', 0.9) |
|
|
|
|
|
|
|
|
gguf_model_path = getattr(self.config, 'GGUF_MODEL_PATH', '.cache/models/llama-3-ko-8b.gguf') |
|
|
|
|
|
|
|
|
system_prompt = getattr(self.config, 'SYSTEM_PROMPT', '๋น์ ์ ํ๊ตญ ๊ณต๊ณต๊ธฐ๊ด ์ฌ์
์ ์์ ๋ถ์ ์ ๋ฌธ๊ฐ์
๋๋ค.') |
|
|
|
|
|
|
|
|
logger.info("GGUFGenerator ์ด๊ธฐํ ์ค...") |
|
|
logger.info(f" GPU ๋ ์ด์ด: {gguf_n_gpu_layers}") |
|
|
logger.info(f" ์ปจํ
์คํธ: {gguf_n_ctx}") |
|
|
logger.info(f" ์ค๋ ๋: {gguf_n_threads}") |
|
|
logger.info(f" ๋ชจ๋ธ ๊ฒฝ๋ก: {gguf_model_path}") |
|
|
|
|
|
self.generator = GGUFGenerator( |
|
|
model_path=gguf_model_path, |
|
|
n_gpu_layers=gguf_n_gpu_layers, |
|
|
n_ctx=gguf_n_ctx, |
|
|
n_threads=gguf_n_threads, |
|
|
config=self.config, |
|
|
max_new_tokens=gguf_max_new_tokens, |
|
|
temperature=gguf_temperature, |
|
|
top_p=gguf_top_p, |
|
|
system_prompt=system_prompt |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("GGUF ๋ชจ๋ธ ๋ก๋ ์ค...") |
|
|
self.generator.load_model() |
|
|
|
|
|
|
|
|
self.chat_history: List[Dict] = [] |
|
|
|
|
|
|
|
|
self._last_retrieved_docs = [] |
|
|
|
|
|
logger.info("โ
GGUFRAGPipeline ์ด๊ธฐํ ์๋ฃ") |
|
|
logger.info(f" - ๊ฒ์ ๋ชจ๋: {self.search_mode}") |
|
|
logger.info(f" - ๊ธฐ๋ณธ top_k: {self.top_k}") |
|
|
|
|
|
def _retrieve_and_format(self, query: str) -> str: |
|
|
"""๊ฒ์ ์ํ ๋ฐ ์ปจํ
์คํธ ํฌ๋งทํ
""" |
|
|
|
|
|
if self.search_mode == "embedding": |
|
|
docs = self.retriever.search(query, top_k=self.top_k) |
|
|
elif self.search_mode == "embedding_rerank": |
|
|
docs = self.retriever.search_with_rerank(query, top_k=self.top_k) |
|
|
elif self.search_mode == "hybrid": |
|
|
docs = self.retriever.hybrid_search( |
|
|
query, top_k=self.top_k, alpha=self.alpha |
|
|
) |
|
|
elif self.search_mode == "hybrid_rerank": |
|
|
docs = self.retriever.hybrid_search_with_rerank( |
|
|
query, top_k=self.top_k, alpha=self.alpha |
|
|
) |
|
|
else: |
|
|
docs = self.retriever.search(query, top_k=self.top_k) |
|
|
|
|
|
|
|
|
self._last_retrieved_docs = docs |
|
|
|
|
|
|
|
|
return self._format_context(docs) |
|
|
|
|
|
def _format_context(self, retrieved_docs: list) -> str: |
|
|
""" |
|
|
๊ฒ์๋ ๋ฌธ์๋ฅผ ์ปจํ
์คํธ๋ก ๋ณํ |
|
|
|
|
|
์ปจํ
์คํธ๊ฐ ๋๋ฌด ๊ธธ๋ฉด ์๋์ผ๋ก ์ค์ (ํ ํฐ ์ ํ ๋์) |
|
|
""" |
|
|
if not retrieved_docs: |
|
|
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." |
|
|
|
|
|
context_parts = [] |
|
|
max_context_chars = 8000 |
|
|
|
|
|
current_length = 0 |
|
|
for i, doc in enumerate(retrieved_docs, 1): |
|
|
doc_text = f"[๋ฌธ์ {i}]\n{doc['content']}\n" |
|
|
doc_length = len(doc_text) |
|
|
|
|
|
|
|
|
if current_length + doc_length > max_context_chars: |
|
|
logger.warning(f"โ ๏ธ ์ปจํ
์คํธ ๊ธธ์ด ์ ํ: {i-1}๊ฐ ๋ฌธ์๋ง ์ฌ์ฉ (์ต๋ {max_context_chars}์)") |
|
|
break |
|
|
|
|
|
context_parts.append(doc_text) |
|
|
current_length += doc_length |
|
|
|
|
|
return "\n".join(context_parts) |
|
|
|
|
|
def _format_sources(self, retrieved_docs: list) -> list: |
|
|
"""๊ฒ์๋ ๋ฌธ์๋ฅผ sources ํ์์ผ๋ก ๋ณํ""" |
|
|
sources = [] |
|
|
for doc in retrieved_docs: |
|
|
source_info = { |
|
|
'content': doc['content'], |
|
|
'metadata': doc['metadata'], |
|
|
'filename': doc.get('filename', 'N/A'), |
|
|
'organization': doc.get('organization', 'N/A') |
|
|
} |
|
|
|
|
|
|
|
|
if 'rerank_score' in doc: |
|
|
source_info['score'] = doc['rerank_score'] |
|
|
source_info['score_type'] = 'rerank' |
|
|
elif 'hybrid_score' in doc: |
|
|
source_info['score'] = doc['hybrid_score'] |
|
|
source_info['score_type'] = 'hybrid' |
|
|
elif 'relevance_score' in doc: |
|
|
source_info['score'] = doc['relevance_score'] |
|
|
source_info['score_type'] = 'embedding' |
|
|
else: |
|
|
source_info['score'] = 0 |
|
|
source_info['score_type'] = 'unknown' |
|
|
|
|
|
sources.append(source_info) |
|
|
|
|
|
return sources |
|
|
|
|
|
def _estimate_usage(self, query: str, answer: str) -> dict: |
|
|
"""ํ ํฐ ์ฌ์ฉ๋ ์ถ์ """ |
|
|
|
|
|
prompt_tokens = len(query.split()) * 2 |
|
|
completion_tokens = len(answer.split()) * 2 |
|
|
|
|
|
return { |
|
|
'total_tokens': prompt_tokens + completion_tokens, |
|
|
'prompt_tokens': prompt_tokens, |
|
|
'completion_tokens': completion_tokens |
|
|
} |
|
|
|
|
|
def generate_answer( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = None, |
|
|
search_mode: str = None, |
|
|
alpha: float = None |
|
|
) -> dict: |
|
|
""" |
|
|
๋ต๋ณ ์์ฑ (chatbot_app.py ํธํ ๋ฉ์ธ ๋ฉ์๋) |
|
|
|
|
|
Args: |
|
|
query: ์ง๋ฌธ |
|
|
top_k: ๊ฒ์ํ ๋ฌธ์ ์ |
|
|
search_mode: ๊ฒ์ ๋ชจ๋ |
|
|
alpha: ์๋ฒ ๋ฉ ๊ฐ์ค์น |
|
|
|
|
|
Returns: |
|
|
dict: answer, sources, search_mode, usage, elapsed_time, used_retrieval |
|
|
""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
self.top_k = top_k |
|
|
if search_mode is not None: |
|
|
self.search_mode = search_mode |
|
|
if alpha is not None: |
|
|
self.alpha = alpha |
|
|
|
|
|
|
|
|
router = QueryRouter() |
|
|
classification = router.classify(query) |
|
|
query_type = classification['type'] |
|
|
|
|
|
logger.info(f"๐ ๋ถ๋ฅ: {query_type} " |
|
|
f"(์ ๋ขฐ๋: {classification['confidence']:.2f})") |
|
|
|
|
|
|
|
|
if query_type in ['greeting', 'thanks', 'out_of_scope']: |
|
|
|
|
|
context = None |
|
|
used_retrieval = False |
|
|
self._last_retrieved_docs = [] |
|
|
|
|
|
|
|
|
system_prompt = PromptManager.get_prompt(query_type, model_type="gguf") |
|
|
logger.info(f"โญ๏ธ RAG ์คํต: {query_type}") |
|
|
|
|
|
elif query_type == 'document': |
|
|
|
|
|
context = self._retrieve_and_format(query) |
|
|
used_retrieval = True |
|
|
|
|
|
|
|
|
system_prompt = PromptManager.get_prompt('document', model_type="gguf") |
|
|
logger.info(f"๐ RAG ์ํ: {len(self._last_retrieved_docs)}๊ฐ ๋ฌธ์") |
|
|
|
|
|
|
|
|
answer = self.generator.chat( |
|
|
question=query, |
|
|
context=context, |
|
|
system_prompt=system_prompt |
|
|
) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
|
|
|
self.chat_history.append({"role": "user", "content": query}) |
|
|
self.chat_history.append({"role": "assistant", "content": answer}) |
|
|
|
|
|
|
|
|
return { |
|
|
'answer': answer, |
|
|
'sources': self._format_sources(self._last_retrieved_docs), |
|
|
'used_retrieval': used_retrieval, |
|
|
'query_type': query_type, |
|
|
'search_mode': self.search_mode if used_retrieval else 'direct', |
|
|
'routing_info': classification, |
|
|
'elapsed_time': elapsed_time, |
|
|
'usage': self._estimate_usage(query, answer) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"โ ๋ต๋ณ ์์ฑ ์คํจ: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise RuntimeError(f"๋ต๋ณ ์์ฑ ์คํจ: {str(e)}") from e |
|
|
|
|
|
def chat(self, query: str) -> str: |
|
|
"""๊ฐ๋จํ ๋ํ ์ธํฐํ์ด์ค""" |
|
|
result = self.generate_answer(query) |
|
|
return result['answer'] |
|
|
|
|
|
def clear_history(self): |
|
|
"""๋ํ ํ์คํ ๋ฆฌ ์ด๊ธฐํ""" |
|
|
self.chat_history = [] |
|
|
logger.info("๐๏ธ ๋ํ ํ์คํ ๋ฆฌ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค.") |
|
|
|
|
|
def get_history(self) -> List[Dict]: |
|
|
"""๋ํ ํ์คํ ๋ฆฌ ๋ฐํ""" |
|
|
return self.chat_history.copy() |
|
|
|
|
|
def set_search_config( |
|
|
self, |
|
|
search_mode: str = None, |
|
|
top_k: int = None, |
|
|
alpha: float = None |
|
|
): |
|
|
"""๊ฒ์ ์ค์ ๋ณ๊ฒฝ""" |
|
|
if search_mode is not None: |
|
|
self.search_mode = search_mode |
|
|
if top_k is not None: |
|
|
self.top_k = top_k |
|
|
if alpha is not None: |
|
|
self.alpha = alpha |
|
|
|
|
|
logger.info( |
|
|
f"๐ง ๊ฒ์ ์ค์ ๋ณ๊ฒฝ: mode={self.search_mode}, " |
|
|
f"top_k={self.top_k}, alpha={self.alpha}" |
|
|
) |