RFP_summary_chatbot / src /generator /generator_gguf.py
Dongjin1203's picture
์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ฆ๊ฐ€
15c1ef1
from llama_cpp import Llama
from typing import Optional, Dict, Any, List
import logging
import time
import os
from src.utils.config import RAGConfig
from src.router.query_router import QueryRouter
from src.prompts.dynamic_prompts import PromptManager
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GGUFGenerator:
"""
GGUF ๊ธฐ๋ฐ˜ Llama-3 ์ƒ์„ฑ๊ธฐ
llama.cpp๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ GGUF ํฌ๋งท ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ 
์ž…์ฐฐ ๊ด€๋ จ ์งˆ์˜์‘๋‹ต์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
"""
def __init__(
self,
model_path: str,
n_gpu_layers: int = 0,
n_ctx: int = 8192,
n_threads: int = 8,
config = None,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
system_prompt: str = "๋‹น์‹ ์€ RFP(์ œ์•ˆ์š”์ฒญ์„œ) ๋ถ„์„ ๋ฐ ์š”์•ฝ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค."
):
"""
์ƒ์„ฑ๊ธฐ ์ดˆ๊ธฐํ™”
Args:
model_path: GGUF ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ
n_gpu_layers: GPU์— ์˜ฌ๋ฆด ๋ ˆ์ด์–ด ์ˆ˜ (0 = CPU๋งŒ, 35 = ์ „์ฒด GPU)
n_ctx: ์ตœ๋Œ€ ์ปจํ…์ŠคํŠธ ๊ธธ์ด
n_threads: CPU ์Šค๋ ˆ๋“œ ์ˆ˜
max_new_tokens: ์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ ์ˆ˜
temperature: ์ƒ์„ฑ ๋‹ค์–‘์„ฑ (0.0~1.0)
top_p: Nucleus sampling ํŒŒ๋ผ๋ฏธํ„ฐ
system_prompt: ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
"""
self.config = config or RAGConfig()
self.model_path = model_path
self.n_gpu_layers = n_gpu_layers
self.n_ctx = n_ctx
self.n_threads = n_threads
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
self.system_prompt = system_prompt
# ๋ชจ๋ธ (๋‚˜์ค‘์— ๋กœ๋“œ)
self.model = None
logger.info(f"GGUFGenerator ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
def load_model(self) -> None:
"""
GGUF ๋ชจ๋ธ ๋กœ๋“œ
๋กœ์ง:
1. USE_MODEL_HUB ํ™•์ธ
2-A. True โ†’ Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œ
2-B. False โ†’ ๋กœ์ปฌ ํŒŒ์ผ ์‚ฌ์šฉ
3. ๋ชจ๋ธ ๋กœ๋“œ
"""
# ์ค‘๋ณต ๋กœ๋“œ ๋ฐฉ์ง€
if self.model is not None:
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋กœ๋“œ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.")
return
try:
# Config์—์„œ USE_MODEL_HUB ํ™•์ธ (์—†์œผ๋ฉด True ๊ธฐ๋ณธ๊ฐ’)
use_model_hub = getattr(self.config, 'USE_MODEL_HUB', True)
# Model Hub ์‚ฌ์šฉ ์—ฌ๋ถ€์— ๋”ฐ๋ผ ๊ฒฝ๋กœ ๊ฒฐ์ •
if use_model_hub:
# === Model Hub์—์„œ ๋‹ค์šด๋กœ๋“œ ===
model_hub_repo = getattr(self.config, 'MODEL_HUB_REPO', 'beomi/Llama-3-Open-Ko-8B-gguf')
model_hub_filename = getattr(self.config, 'MODEL_HUB_FILENAME', 'ggml-model-Q4_K_M.gguf')
model_cache_dir = getattr(self.config, 'MODEL_CACHE_DIR', '.cache/models')
logger.info(f"๐Ÿ“ฅ Model Hub์—์„œ ๋‹ค์šด๋กœ๋“œ: {model_hub_repo}")
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(
repo_id=model_hub_repo,
filename=model_hub_filename,
cache_dir=model_cache_dir,
local_dir=model_cache_dir,
local_dir_use_symlinks=False # ์‹ฌ๋ณผ๋ฆญ ๋งํฌ ๋Œ€์‹  ์‹ค์ œ ๋ณต์‚ฌ
)
logger.info(f"โœ… ๋‹ค์šด๋กœ๋“œ ์™„๋ฃŒ: {model_path}")
else:
# === ๋กœ์ปฌ ํŒŒ์ผ ์‚ฌ์šฉ ===
model_path = self.model_path # ์ƒ์„ฑ์ž์—์„œ ๋ฐ›์€ ๊ฒฝ๋กœ ์‚ฌ์šฉ
if not os.path.exists(model_path):
raise FileNotFoundError(
f"โŒ ๋กœ์ปฌ ๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {model_path}\n"
f" USE_MODEL_HUB=true๋กœ ์„ค์ •ํ•˜๊ฑฐ๋‚˜ ๋ชจ๋ธ ํŒŒ์ผ์„ ์ค€๋น„ํ•˜์„ธ์š”."
)
logger.info(f"๐Ÿ“‚ ๋กœ์ปฌ ๋ชจ๋ธ ์‚ฌ์šฉ: {model_path}")
# === ๊ณตํ†ต: ๋ชจ๋ธ ๋กœ๋“œ ===
logger.info(f"๐Ÿš€ GGUF ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
logger.info(f" GPU ๋ ˆ์ด์–ด: {self.n_gpu_layers}")
logger.info(f" ์ปจํ…์ŠคํŠธ: {self.n_ctx}")
self.model = Llama(
model_path=model_path,
n_gpu_layers=self.n_gpu_layers,
n_ctx=self.n_ctx,
n_threads=self.n_threads,
verbose=True, # โœ… ๋””๋ฒ„๊ทธ ๋กœ๊ทธ ํ™œ์„ฑํ™”
)
# โœ… ์‹ค์ œ ์ ์šฉ๋œ n_ctx ํ™•์ธ
actual_n_ctx = self.model.n_ctx()
logger.info("โœ… GGUF ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ!")
logger.info(f" - ์„ค์ •ํ•œ n_ctx: {self.n_ctx}")
logger.info(f" - ์‹ค์ œ n_ctx: {actual_n_ctx}")
if actual_n_ctx < self.n_ctx:
logger.warning(f"โš ๏ธ n_ctx๊ฐ€ ์˜ˆ์ƒ๋ณด๋‹ค ์ž‘์Šต๋‹ˆ๋‹ค: {actual_n_ctx} < {self.n_ctx}")
logger.warning(f" ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. n_gpu_layers๋ฅผ ์ค„์—ฌ๋ณด์„ธ์š”.")
except FileNotFoundError as e:
logger.error(f"โŒ ๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {e}")
raise
except Exception as e:
logger.error(f"โŒ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
raise RuntimeError(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
def format_prompt(
self,
question: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None
) -> str:
"""
GGUF ๋ชจ๋ธ์šฉ ๊ฐ„๋‹จํ•œ ํ”„๋กฌํ”„ํŠธ ํฌ๋งทํŒ…
Llama-3 ํŠน์ˆ˜ ํ† ํฐ ๋Œ€์‹  ์ˆœ์ˆ˜ ํ…์ŠคํŠธ ๊ธฐ๋ฐ˜ ํ…œํ”Œ๋ฆฟ ์‚ฌ์šฉ
"""
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
if system_prompt is None:
system_prompt = self.system_prompt
# ์ปจํ…์ŠคํŠธ ํฌํ•จ ์—ฌ๋ถ€
if context is not None:
user_message = f"์ฐธ๊ณ  ๋ฌธ์„œ:\n{context}\n\n์งˆ๋ฌธ: {question}"
else:
user_message = question
# ๊ฐ„๋‹จํ•œ ํ•œ๊ตญ์–ด ํ…œํ”Œ๋ฆฟ (ํŠน์ˆ˜ ํ† ํฐ ์—†์Œ)
formatted_prompt = f"""### ์‹œ์Šคํ…œ
{system_prompt}
### ์‚ฌ์šฉ์ž
{user_message}
### ๋‹ต๋ณ€
"""
return formatted_prompt
def generate(
self,
prompt: str,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> str:
"""
ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ์‘๋‹ต ์ƒ์„ฑ
Args:
prompt: ํฌ๋งท๋œ ํ”„๋กฌํ”„ํŠธ
max_new_tokens: ์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ ์ˆ˜
temperature: ์ƒ์„ฑ ๋‹ค์–‘์„ฑ
top_p: Nucleus sampling
Returns:
์ƒ์„ฑ๋œ ์‘๋‹ต ํ…์ŠคํŠธ
Raises:
RuntimeError: ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ
"""
# ๋ชจ๋ธ ๋กœ๋“œ ํ™•์ธ
if self.model is None:
raise RuntimeError(
"๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. load_model()์„ ๋จผ์ € ํ˜ธ์ถœํ•˜์„ธ์š”."
)
# ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
try:
logger.info(f"๐Ÿ”„ ์ƒ์„ฑ ์‹œ์ž‘ (max_tokens={max_new_tokens}, temp={temperature})")
start_time = time.time()
# ์ƒ์„ฑ
output = self.model(
prompt,
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
echo=False, # ํ”„๋กฌํ”„ํŠธ ๋ฐ˜๋ณต ์•ˆ ํ•จ
stop=[
# ๊ตฌ๋ถ„์ž
"###", "\n\n###",
"### ์‚ฌ์šฉ์ž", "\n์‚ฌ์šฉ์ž:",
"</s>",
# ๋ฉ”ํƒ€ ํ…์ŠคํŠธ ์ฐจ๋‹จ
"ํ•œ๊ตญ์–ด ๋‹ต๋ณ€", "ํ•œ๊ตญ์–ด๋กœ ๋‹ต๋ณ€", "์ง€์นจ:",
"๋ฌธ์žฅ", "(๋ฌธ์žฅ",
# โœ… ์งˆ๋ฌธ ํŒจํ„ด ์ฐจ๋‹จ (๋‹ต๋ณ€ ํ›„ ์งˆ๋ฌธ ์ƒ์„ฑ ๋ฐฉ์ง€)
"\n\n", # ๋‹จ๋ฝ ๊ตฌ๋ถ„
"?", # ์งˆ๋ฌธ ๊ธฐํ˜ธ
"์š”?", "๊นŒ?", "๋‚˜์š”?", "์Šต๋‹ˆ๊นŒ?" # ์งˆ๋ฌธ ์–ด๋ฏธ
],
)
elapsed = time.time() - start_time
logger.info(f"โœ… ์ƒ์„ฑ ์™„๋ฃŒ: {elapsed:.2f}์ดˆ")
# ์‘๋‹ต ์ถ”์ถœ
response = output['choices'][0]['text'].strip()
logger.info(f"๐Ÿ“ ์‘๋‹ต ๊ธธ์ด: {len(response)} ๊ธ€์ž")
return response
except Exception as e:
logger.error(f"โŒ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
raise RuntimeError(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹คํŒจ: {e}")
def chat(
self,
question: str,
context: Optional[str] = None,
system_prompt=None,
**kwargs
) -> str:
"""
์งˆ๋ฌธ์— ๋Œ€ํ•œ ์‘๋‹ต ์ƒ์„ฑ (ํ†ตํ•ฉ ๋ฉ”์„œ๋“œ)
Args:
question: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
context: ์„ ํƒ์  ์ปจํ…์ŠคํŠธ
system_prompt: ์„ ํƒ์  ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
**kwargs: generate() ๋ฉ”์„œ๋“œ์— ์ „๋‹ฌ๋  ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ
Returns:
์ƒ์„ฑ๋œ ์‘๋‹ต
"""
# ํ”„๋กฌํ”„ํŠธ ํฌ๋งทํŒ…
prompt = self.format_prompt(
question=question,
context=context,
system_prompt=system_prompt
)
# ์‘๋‹ต ์ƒ์„ฑ
response = self.generate(prompt, **kwargs)
return response
class GGUFRAGPipeline:
"""
GGUF ์ƒ์„ฑ๊ธฐ + RAG ํ†ตํ•ฉ ํŒŒ์ดํ”„๋ผ์ธ
chatbot_app.py์™€ ํ˜ธํ™˜๋˜๋Š” ์ธํ„ฐํŽ˜์ด์Šค ์ œ๊ณต
"""
def __init__(
self,
config=None,
model: str = None, # ํ˜ธํ™˜์„ฑ์šฉ (์‚ฌ์šฉ ์•ˆ ํ•จ)
top_k: int = None,
# GPU ์„ค์ • (์„ ํƒ์ , config ์˜ค๋ฒ„๋ผ์ด๋“œ)
n_gpu_layers: int = None,
n_ctx: int = None,
n_threads: int = None,
max_new_tokens: int = None,
temperature: float = None,
top_p: float = None,
search_mode: str = None,
alpha: float = None
):
"""
์ดˆ๊ธฐํ™”
Args:
config: RAGConfig ๊ฐ์ฒด
model: ๋ชจ๋ธ ์ด๋ฆ„ (์‚ฌ์šฉ ์•ˆ ํ•จ, ํ˜ธํ™˜์„ฑ์šฉ)
top_k: ๊ธฐ๋ณธ ๊ฒ€์ƒ‰ ๋ฌธ์„œ ์ˆ˜
n_gpu_layers: GPU ๋ ˆ์ด์–ด ์ˆ˜ (config ์˜ค๋ฒ„๋ผ์ด๋“œ)
n_ctx: ์ปจํ…์ŠคํŠธ ๊ธธ์ด (config ์˜ค๋ฒ„๋ผ์ด๋“œ)
n_threads: CPU ์Šค๋ ˆ๋“œ ์ˆ˜ (config ์˜ค๋ฒ„๋ผ์ด๋“œ)
max_new_tokens: ์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ (config ์˜ค๋ฒ„๋ผ์ด๋“œ)
temperature: ์ƒ์„ฑ ๋‹ค์–‘์„ฑ (config ์˜ค๋ฒ„๋ผ์ด๋“œ)
top_p: Nucleus sampling (config ์˜ค๋ฒ„๋ผ์ด๋“œ)
search_mode: ๊ฒ€์ƒ‰ ๋ชจ๋“œ
alpha: ์ž„๋ฒ ๋”ฉ ๊ฐ€์ค‘์น˜
"""
self.config = config or RAGConfig()
# Config์—์„œ ๊ธฐ๋ณธ๊ฐ’ ๊ฐ€์ ธ์˜ค๊ธฐ (์—†์œผ๋ฉด fallback)
self.top_k = top_k or getattr(self.config, 'DEFAULT_TOP_K', 10)
# ๊ฒ€์ƒ‰ ์„ค์ •
self.search_mode = search_mode or getattr(self.config, 'DEFAULT_SEARCH_MODE', 'hybrid_rerank')
self.alpha = alpha if alpha is not None else getattr(self.config, 'DEFAULT_ALPHA', 0.5)
# Retriever ์ดˆ๊ธฐํ™” (RAGRetriever ์‚ฌ์šฉ)
logger.info("RAGRetriever ์ดˆ๊ธฐํ™” ์ค‘...")
from src.retriever.retriever import RAGRetriever
self.retriever = RAGRetriever(config=self.config)
# GGUF ์„ค์ • (ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์ฃผ์–ด์ง€๋ฉด config ์˜ค๋ฒ„๋ผ์ด๋“œ, ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’)
gguf_n_gpu_layers = n_gpu_layers if n_gpu_layers is not None else getattr(self.config, 'GGUF_N_GPU_LAYERS', 35)
gguf_n_ctx = n_ctx if n_ctx is not None else getattr(self.config, 'GGUF_N_CTX', 2048)
gguf_n_threads = n_threads if n_threads is not None else getattr(self.config, 'GGUF_N_THREADS', 4)
gguf_max_new_tokens = max_new_tokens if max_new_tokens is not None else getattr(self.config, 'GGUF_MAX_NEW_TOKENS', 512)
gguf_temperature = temperature if temperature is not None else getattr(self.config, 'GGUF_TEMPERATURE', 0.7)
gguf_top_p = top_p if top_p is not None else getattr(self.config, 'GGUF_TOP_P', 0.9)
# ๋ชจ๋ธ ๊ฒฝ๋กœ (fallback)
gguf_model_path = getattr(self.config, 'GGUF_MODEL_PATH', '.cache/models/llama-3-ko-8b.gguf')
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ (fallback)
system_prompt = getattr(self.config, 'SYSTEM_PROMPT', '๋‹น์‹ ์€ ํ•œ๊ตญ ๊ณต๊ณต๊ธฐ๊ด€ ์‚ฌ์—…์ œ์•ˆ์„œ ๋ถ„์„ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค.')
# GGUFGenerator ์ดˆ๊ธฐํ™”
logger.info("GGUFGenerator ์ดˆ๊ธฐํ™” ์ค‘...")
logger.info(f" GPU ๋ ˆ์ด์–ด: {gguf_n_gpu_layers}")
logger.info(f" ์ปจํ…์ŠคํŠธ: {gguf_n_ctx}")
logger.info(f" ์Šค๋ ˆ๋“œ: {gguf_n_threads}")
logger.info(f" ๋ชจ๋ธ ๊ฒฝ๋กœ: {gguf_model_path}")
self.generator = GGUFGenerator(
model_path=gguf_model_path,
n_gpu_layers=gguf_n_gpu_layers,
n_ctx=gguf_n_ctx,
n_threads=gguf_n_threads,
config=self.config,
max_new_tokens=gguf_max_new_tokens,
temperature=gguf_temperature,
top_p=gguf_top_p,
system_prompt=system_prompt
)
# ๋ชจ๋ธ ๋กœ๋“œ (์‹œ๊ฐ„ ์†Œ์š”)
logger.info("GGUF ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
self.generator.load_model()
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ
self.chat_history: List[Dict] = []
# ๋งˆ์ง€๋ง‰ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ €์žฅ (sources ๋ฐ˜ํ™˜์šฉ)
self._last_retrieved_docs = []
logger.info("โœ… GGUFRAGPipeline ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
logger.info(f" - ๊ฒ€์ƒ‰ ๋ชจ๋“œ: {self.search_mode}")
logger.info(f" - ๊ธฐ๋ณธ top_k: {self.top_k}")
def _retrieve_and_format(self, query: str) -> str:
"""๊ฒ€์ƒ‰ ์ˆ˜ํ–‰ ๋ฐ ์ปจํ…์ŠคํŠธ ํฌ๋งทํŒ…"""
# ๊ฒ€์ƒ‰ ๋ชจ๋“œ์— ๋”ฐ๋ผ ๋ฌธ์„œ ๊ฒ€์ƒ‰ (RAGRetriever ๋ฉ”์„œ๋“œ ์‚ฌ์šฉ)
if self.search_mode == "embedding":
docs = self.retriever.search(query, top_k=self.top_k)
elif self.search_mode == "embedding_rerank":
docs = self.retriever.search_with_rerank(query, top_k=self.top_k)
elif self.search_mode == "hybrid":
docs = self.retriever.hybrid_search(
query, top_k=self.top_k, alpha=self.alpha
)
elif self.search_mode == "hybrid_rerank":
docs = self.retriever.hybrid_search_with_rerank(
query, top_k=self.top_k, alpha=self.alpha
)
else:
docs = self.retriever.search(query, top_k=self.top_k)
# ๋งˆ์ง€๋ง‰ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ €์žฅ
self._last_retrieved_docs = docs
# ์ปจํ…์ŠคํŠธ ํฌ๋งทํŒ…
return self._format_context(docs)
def _format_context(self, retrieved_docs: list) -> str:
"""
๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๋ฅผ ์ปจํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜
์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž๋™์œผ๋กœ ์ค„์ž„ (ํ† ํฐ ์ œํ•œ ๋Œ€์‘)
"""
if not retrieved_docs:
return "๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
context_parts = []
max_context_chars = 8000 # ๋Œ€๋žต 2000 ํ† ํฐ ์ •๋„ (์—ฌ์œ  ์žˆ๊ฒŒ)
current_length = 0
for i, doc in enumerate(retrieved_docs, 1):
doc_text = f"[๋ฌธ์„œ {i}]\n{doc['content']}\n"
doc_length = len(doc_text)
# ์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ฒดํฌ
if current_length + doc_length > max_context_chars:
logger.warning(f"โš ๏ธ ์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ œํ•œ: {i-1}๊ฐœ ๋ฌธ์„œ๋งŒ ์‚ฌ์šฉ (์ตœ๋Œ€ {max_context_chars}์ž)")
break
context_parts.append(doc_text)
current_length += doc_length
return "\n".join(context_parts)
def _format_sources(self, retrieved_docs: list) -> list:
"""๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๋ฅผ sources ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜"""
sources = []
for doc in retrieved_docs:
source_info = {
'content': doc['content'],
'metadata': doc['metadata'],
'filename': doc.get('filename', 'N/A'),
'organization': doc.get('organization', 'N/A')
}
# ๊ฒ€์ƒ‰ ๋ชจ๋“œ์— ๋”ฐ๋ผ ์ ์ˆ˜ ํ•„๋“œ๊ฐ€ ๋‹ค๋ฆ„
if 'rerank_score' in doc:
source_info['score'] = doc['rerank_score']
source_info['score_type'] = 'rerank'
elif 'hybrid_score' in doc:
source_info['score'] = doc['hybrid_score']
source_info['score_type'] = 'hybrid'
elif 'relevance_score' in doc:
source_info['score'] = doc['relevance_score']
source_info['score_type'] = 'embedding'
else:
source_info['score'] = 0
source_info['score_type'] = 'unknown'
sources.append(source_info)
return sources
def _estimate_usage(self, query: str, answer: str) -> dict:
"""ํ† ํฐ ์‚ฌ์šฉ๋Ÿ‰ ์ถ”์ •"""
# ๊ฐ„๋‹จํ•œ ๋‹จ์–ด ์ˆ˜ ๊ธฐ๋ฐ˜ ์ถ”์ •
prompt_tokens = len(query.split()) * 2
completion_tokens = len(answer.split()) * 2
return {
'total_tokens': prompt_tokens + completion_tokens,
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens
}
def generate_answer(
self,
query: str,
top_k: int = None,
search_mode: str = None,
alpha: float = None
) -> dict:
"""
๋‹ต๋ณ€ ์ƒ์„ฑ (chatbot_app.py ํ˜ธํ™˜ ๋ฉ”์ธ ๋ฉ”์„œ๋“œ)
Args:
query: ์งˆ๋ฌธ
top_k: ๊ฒ€์ƒ‰ํ•  ๋ฌธ์„œ ์ˆ˜
search_mode: ๊ฒ€์ƒ‰ ๋ชจ๋“œ
alpha: ์ž„๋ฒ ๋”ฉ ๊ฐ€์ค‘์น˜
Returns:
dict: answer, sources, search_mode, usage, elapsed_time, used_retrieval
"""
try:
start_time = time.time()
# ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ • (๊ฒ€์ƒ‰ ์ „์— ๋จผ์ € ์„ค์ •)
if top_k is not None:
self.top_k = top_k
if search_mode is not None:
self.search_mode = search_mode
if alpha is not None:
self.alpha = alpha
# ===== Router๋กœ ๊ฒ€์ƒ‰ ์—ฌ๋ถ€ ๊ฒฐ์ • =====
router = QueryRouter()
classification = router.classify(query)
query_type = classification['type'] # 'greeting'/'thanks'/'document'/'out_of_scope'
logger.info(f"๐Ÿ“ ๋ถ„๋ฅ˜: {query_type} "
f"(์‹ ๋ขฐ๋„: {classification['confidence']:.2f})")
# 2. ํƒ€์ž…๋ณ„ ์ฒ˜๋ฆฌ
if query_type in ['greeting', 'thanks', 'out_of_scope']:
# ๊ฒ€์ƒ‰ ์Šคํ‚ต
context = None
used_retrieval = False
self._last_retrieved_docs = []
# ๋™์  ํ”„๋กฌํ”„ํŠธ ์„ ํƒ (GGUF์šฉ)
system_prompt = PromptManager.get_prompt(query_type, model_type="gguf")
logger.info(f"โญ๏ธ RAG ์Šคํ‚ต: {query_type}")
elif query_type == 'document':
# RAG ์ˆ˜ํ–‰
context = self._retrieve_and_format(query)
used_retrieval = True
# ๋™์  ํ”„๋กฌํ”„ํŠธ (GGUF์šฉ, context ํฌํ•จ)
system_prompt = PromptManager.get_prompt('document', model_type="gguf")
logger.info(f"๐Ÿ” RAG ์ˆ˜ํ–‰: {len(self._last_retrieved_docs)}๊ฐœ ๋ฌธ์„œ")
# 3. ๋‹ต๋ณ€ ์ƒ์„ฑ (system_prompt ์ „๋‹ฌ)
answer = self.generator.chat(
question=query,
context=context,
system_prompt=system_prompt
)
elapsed_time = time.time() - start_time
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€
self.chat_history.append({"role": "user", "content": query})
self.chat_history.append({"role": "assistant", "content": answer})
# ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜ (RAGPipeline๊ณผ ๋™์ผ ํ˜•์‹)
return {
'answer': answer,
'sources': self._format_sources(self._last_retrieved_docs),
'used_retrieval': used_retrieval,
'query_type': query_type,
'search_mode': self.search_mode if used_retrieval else 'direct',
'routing_info': classification,
'elapsed_time': elapsed_time,
'usage': self._estimate_usage(query, answer)
}
except Exception as e:
logger.error(f"โŒ ๋‹ต๋ณ€ ์ƒ์„ฑ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
raise RuntimeError(f"๋‹ต๋ณ€ ์ƒ์„ฑ ์‹คํŒจ: {str(e)}") from e
def chat(self, query: str) -> str:
"""๊ฐ„๋‹จํ•œ ๋Œ€ํ™” ์ธํ„ฐํŽ˜์ด์Šค"""
result = self.generate_answer(query)
return result['answer']
def clear_history(self):
"""๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ์ดˆ๊ธฐํ™”"""
self.chat_history = []
logger.info("๐Ÿ—‘๏ธ ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
def get_history(self) -> List[Dict]:
"""๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ๋ฐ˜ํ™˜"""
return self.chat_history.copy()
def set_search_config(
self,
search_mode: str = None,
top_k: int = None,
alpha: float = None
):
"""๊ฒ€์ƒ‰ ์„ค์ • ๋ณ€๊ฒฝ"""
if search_mode is not None:
self.search_mode = search_mode
if top_k is not None:
self.top_k = top_k
if alpha is not None:
self.alpha = alpha
logger.info(
f"๐Ÿ”ง ๊ฒ€์ƒ‰ ์„ค์ • ๋ณ€๊ฒฝ: mode={self.search_mode}, "
f"top_k={self.top_k}, alpha={self.alpha}"
)