QLoRA_RAG_test / src /generator /generator_gguf_base.py
Dongjin1203's picture
๋ฒ ์ด์Šค ๋ชจ๋ธ repo, ํŒŒ์ผ๋ช… ์ˆ˜์ •
9c1f0f0
from llama_cpp import Llama
from typing import Optional, Dict, Any, List
import logging
import time
import os
from src.utils.config import RAGConfig
from src.router.query_router import QueryRouter
from src.prompts.dynamic_prompts import PromptManager
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GGUFGenerator:
"""
GGUF ๊ธฐ๋ฐ˜ Llama-3 ์ƒ์„ฑ๊ธฐ
llama.cpp๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ GGUF ํฌ๋งท ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ 
์ž…์ฐฐ ๊ด€๋ จ ์งˆ์˜์‘๋‹ต์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
"""
def __init__(
self,
model_path: str,
n_gpu_layers: int = 0,
n_ctx: int = 8192,
n_threads: int = 8,
config = None,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
system_prompt: str = "๋‹น์‹ ์€ RFP(์ œ์•ˆ์š”์ฒญ์„œ) ๋ถ„์„ ๋ฐ ์š”์•ฝ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค."
):
"""์ƒ์„ฑ๊ธฐ ์ดˆ๊ธฐํ™”"""
self.config = config or RAGConfig()
self.model_path = model_path
self.n_gpu_layers = n_gpu_layers
self.n_ctx = n_ctx
self.n_threads = n_threads
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
self.system_prompt = system_prompt
# ๋ชจ๋ธ (๋‚˜์ค‘์— ๋กœ๋“œ)
self.model = None
logger.info(f"GGUFGenerator ์ดˆ๊ธฐํ™” ์™„๋ฃŒ (Base ๋ชจ๋ธ)")
def load_model(self) -> None:
"""
GGUF ๋ชจ๋ธ ๋กœ๋“œ
โœ… Base ๋ชจ๋ธ ์‚ฌ์šฉ: Config์—์„œ BASE_MODEL_HUB_REPO ๊ฐ€์ ธ์˜ค๊ธฐ
"""
# ์ค‘๋ณต ๋กœ๋“œ ๋ฐฉ์ง€
if self.model is not None:
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋กœ๋“œ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.")
return
try:
# Config์—์„œ USE_MODEL_HUB ํ™•์ธ
use_model_hub = getattr(self.config, 'USE_MODEL_HUB', True)
# Model Hub ์‚ฌ์šฉ ์—ฌ๋ถ€์— ๋”ฐ๋ผ ๊ฒฝ๋กœ ๊ฒฐ์ •
if use_model_hub:
# === Model Hub์—์„œ ๋‹ค์šด๋กœ๋“œ ===
# โœ… Config์—์„œ Base ๋ชจ๋ธ ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
base_model_repo = getattr(
self.config,
'BASE_MODEL_HUB_REPO',
'Dongjin1203/Llama-3-Open-Ko-8B-GGUF'
)
base_model_filename = getattr(
self.config,
'BASE_MODEL_HUB_FILENAME',
'Llama-3-Open-Ko-8B-Q4_K_M.gguf'
)
model_cache_dir = getattr(self.config, 'MODEL_CACHE_DIR', '.cache/models')
logger.info(f"๐Ÿ“ฅ Base ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ: {base_model_repo}")
logger.info(f" ํŒŒ์ผ๋ช…: {base_model_filename}")
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(
repo_id=base_model_repo,
filename=base_model_filename,
cache_dir=model_cache_dir,
local_dir=model_cache_dir,
local_dir_use_symlinks=False
)
logger.info(f"โœ… Base ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์™„๋ฃŒ: {model_path}")
else:
# === ๋กœ์ปฌ ํŒŒ์ผ ์‚ฌ์šฉ ===
model_path = self.model_path
if not os.path.exists(model_path):
raise FileNotFoundError(
f"โŒ ๋กœ์ปฌ ๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {model_path}\n"
f" USE_MODEL_HUB=true๋กœ ์„ค์ •ํ•˜๊ฑฐ๋‚˜ ๋ชจ๋ธ ํŒŒ์ผ์„ ์ค€๋น„ํ•˜์„ธ์š”."
)
logger.info(f"๐Ÿ“‚ ๋กœ์ปฌ Base ๋ชจ๋ธ ์‚ฌ์šฉ: {model_path}")
# === ๊ณตํ†ต: ๋ชจ๋ธ ๋กœ๋“œ ===
logger.info(f"๐Ÿš€ Base GGUF ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
logger.info(f" GPU ๋ ˆ์ด์–ด: {self.n_gpu_layers}")
logger.info(f" ์ปจํ…์ŠคํŠธ: {self.n_ctx}")
self.model = Llama(
model_path=model_path,
n_gpu_layers=self.n_gpu_layers,
n_ctx=self.n_ctx,
n_threads=self.n_threads,
verbose=True,
)
# ์‹ค์ œ ์ ์šฉ๋œ n_ctx ํ™•์ธ
actual_n_ctx = self.model.n_ctx()
logger.info("โœ… Base GGUF ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ!")
logger.info(f" - ๋ชจ๋ธ: {base_model_repo if use_model_hub else 'local'}")
logger.info(f" - ์„ค์ •ํ•œ n_ctx: {self.n_ctx}")
logger.info(f" - ์‹ค์ œ n_ctx: {actual_n_ctx}")
if actual_n_ctx < self.n_ctx:
logger.warning(f"โš ๏ธ n_ctx๊ฐ€ ์˜ˆ์ƒ๋ณด๋‹ค ์ž‘์Šต๋‹ˆ๋‹ค: {actual_n_ctx} < {self.n_ctx}")
except FileNotFoundError as e:
logger.error(f"โŒ ๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {e}")
raise
except Exception as e:
logger.error(f"โŒ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
raise RuntimeError(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
def format_prompt(
self,
question: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None
) -> str:
"""GGUF ๋ชจ๋ธ์šฉ ๊ฐ„๋‹จํ•œ ํ”„๋กฌํ”„ํŠธ ํฌ๋งทํŒ…"""
if system_prompt is None:
system_prompt = self.system_prompt
if context is not None:
user_message = f"์ฐธ๊ณ  ๋ฌธ์„œ:\n{context}\n\n์งˆ๋ฌธ: {question}"
else:
user_message = question
formatted_prompt = f"""### ์‹œ์Šคํ…œ
{system_prompt}
### ์‚ฌ์šฉ์ž
{user_message}
### ๋‹ต๋ณ€
"""
return formatted_prompt
def generate(
self,
prompt: str,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> str:
"""ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ์‘๋‹ต ์ƒ์„ฑ"""
if self.model is None:
raise RuntimeError(
"๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. load_model()์„ ๋จผ์ € ํ˜ธ์ถœํ•˜์„ธ์š”."
)
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
try:
logger.info(f"๐Ÿ”„ ์ƒ์„ฑ ์‹œ์ž‘ (max_tokens={max_new_tokens}, temp={temperature})")
start_time = time.time()
output = self.model(
prompt,
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
echo=False,
stop=[
"###", "\n\n###",
"### ์‚ฌ์šฉ์ž", "\n์‚ฌ์šฉ์ž:",
"</s>",
"ํ•œ๊ตญ์–ด ๋‹ต๋ณ€", "ํ•œ๊ตญ์–ด๋กœ ๋‹ต๋ณ€", "์ง€์นจ:",
"๋ฌธ์žฅ", "(๋ฌธ์žฅ",
"\n\n",
"?",
"์š”?", "๊นŒ?", "๋‚˜์š”?", "์Šต๋‹ˆ๊นŒ?"
],
)
elapsed = time.time() - start_time
logger.info(f"โœ… ์ƒ์„ฑ ์™„๋ฃŒ: {elapsed:.2f}์ดˆ")
response = output['choices'][0]['text'].strip()
logger.info(f"๐Ÿ“ ์‘๋‹ต ๊ธธ์ด: {len(response)} ๊ธ€์ž")
return response
except Exception as e:
logger.error(f"โŒ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
raise RuntimeError(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹คํŒจ: {e}")
def chat(
self,
question: str,
context: Optional[str] = None,
system_prompt=None,
**kwargs
) -> str:
"""์งˆ๋ฌธ์— ๋Œ€ํ•œ ์‘๋‹ต ์ƒ์„ฑ"""
prompt = self.format_prompt(
question=question,
context=context,
system_prompt=system_prompt
)
response = self.generate(prompt, **kwargs)
return response
class GGUFBaseRAGPipeline:
"""
Base ๋ชจ๋ธ + RAG ํŒŒ์ดํ”„๋ผ์ธ
โœ… Base ๋ชจ๋ธ ์‚ฌ์šฉ (beomi/Llama-3-Open-Ko-8B)
โœ… RAG ์œ ์ง€
โœ… ๊ธฐ์กด generator_gguf.py์™€ ๋™์ผํ•œ ๊ธฐ๋Šฅ
"""
def __init__(
self,
config=None,
model: str = None,
top_k: int = None,
n_gpu_layers: int = None,
n_ctx: int = None,
n_threads: int = None,
max_new_tokens: int = None,
temperature: float = None,
top_p: float = None,
search_mode: str = None,
alpha: float = None
):
"""์ดˆ๊ธฐํ™”"""
self.config = config or RAGConfig()
# ๊ฒ€์ƒ‰ ์„ค์ •
self.top_k = top_k or getattr(self.config, 'DEFAULT_TOP_K', 10)
self.search_mode = search_mode or getattr(self.config, 'DEFAULT_SEARCH_MODE', 'hybrid_rerank')
self.alpha = alpha if alpha is not None else getattr(self.config, 'DEFAULT_ALPHA', 0.5)
# Retriever ์ดˆ๊ธฐํ™”
logger.info("RAGRetriever ์ดˆ๊ธฐํ™” ์ค‘...")
from src.retriever.retriever import RAGRetriever
self.retriever = RAGRetriever(config=self.config)
# GGUF ์„ค์ •
gguf_n_gpu_layers = n_gpu_layers if n_gpu_layers is not None else getattr(self.config, 'GGUF_N_GPU_LAYERS', 35)
gguf_n_ctx = n_ctx if n_ctx is not None else getattr(self.config, 'GGUF_N_CTX', 2048)
gguf_n_threads = n_threads if n_threads is not None else getattr(self.config, 'GGUF_N_THREADS', 4)
gguf_max_new_tokens = max_new_tokens if max_new_tokens is not None else getattr(self.config, 'GGUF_MAX_NEW_TOKENS', 512)
gguf_temperature = temperature if temperature is not None else getattr(self.config, 'GGUF_TEMPERATURE', 0.7)
gguf_top_p = top_p if top_p is not None else getattr(self.config, 'GGUF_TOP_P', 0.9)
# ๋ชจ๋ธ ๊ฒฝ๋กœ (์‚ฌ์šฉ ์•ˆ ํ•จ, Hub์—์„œ ๋‹ค์šด๋กœ๋“œ)
gguf_model_path = getattr(self.config, 'GGUF_MODEL_PATH', '.cache/models/llama-3-ko-8b.gguf')
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
system_prompt = getattr(self.config, 'SYSTEM_PROMPT', '๋‹น์‹ ์€ ํ•œ๊ตญ ๊ณต๊ณต๊ธฐ๊ด€ ์‚ฌ์—…์ œ์•ˆ์„œ ๋ถ„์„ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค.')
# GGUFGenerator ์ดˆ๊ธฐํ™”
logger.info("GGUFGenerator ์ดˆ๊ธฐํ™” ์ค‘... (Base ๋ชจ๋ธ)")
logger.info(f" GPU ๋ ˆ์ด์–ด: {gguf_n_gpu_layers}")
logger.info(f" ์ปจํ…์ŠคํŠธ: {gguf_n_ctx}")
self.generator = GGUFGenerator(
model_path=gguf_model_path,
n_gpu_layers=gguf_n_gpu_layers,
n_ctx=gguf_n_ctx,
n_threads=gguf_n_threads,
config=self.config,
max_new_tokens=gguf_max_new_tokens,
temperature=gguf_temperature,
top_p=gguf_top_p,
system_prompt=system_prompt
)
# ๋ชจ๋ธ ๋กœ๋“œ
logger.info("Base GGUF ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
self.generator.load_model()
# Router
self.router = QueryRouter()
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ
self.chat_history: List[Dict] = []
# ๋งˆ์ง€๋ง‰ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ
self._last_retrieved_docs = []
logger.info("โœ… GGUFBaseRAGPipeline ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
logger.info(f" - ๊ฒ€์ƒ‰ ๋ชจ๋“œ: {self.search_mode}")
logger.info(f" - ๊ธฐ๋ณธ top_k: {self.top_k}")
def _retrieve_and_format(self, query: str) -> str:
"""๊ฒ€์ƒ‰ ์ˆ˜ํ–‰ ๋ฐ ์ปจํ…์ŠคํŠธ ํฌ๋งทํŒ…"""
# ๊ฒ€์ƒ‰ ๋ชจ๋“œ์— ๋”ฐ๋ผ ๋ฌธ์„œ ๊ฒ€์ƒ‰
if self.search_mode == "embedding":
docs = self.retriever.search(query, top_k=self.top_k)
elif self.search_mode == "embedding_rerank":
docs = self.retriever.search_with_rerank(query, top_k=self.top_k)
elif self.search_mode == "hybrid":
docs = self.retriever.hybrid_search(
query, top_k=self.top_k, alpha=self.alpha
)
elif self.search_mode == "hybrid_rerank":
docs = self.retriever.hybrid_search_with_rerank(
query, top_k=self.top_k, alpha=self.alpha
)
else:
docs = self.retriever.search(query, top_k=self.top_k)
# ๋งˆ์ง€๋ง‰ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ €์žฅ
self._last_retrieved_docs = docs
# ์ปจํ…์ŠคํŠธ ํฌ๋งทํŒ…
return self._format_context(docs)
def _format_context(self, retrieved_docs: list) -> str:
"""๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๋ฅผ ์ปจํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜"""
if not retrieved_docs:
return "๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
context_parts = []
max_context_chars = 8000
current_length = 0
for i, doc in enumerate(retrieved_docs, 1):
doc_text = f"[๋ฌธ์„œ {i}]\n{doc['content']}\n"
doc_length = len(doc_text)
if current_length + doc_length > max_context_chars:
logger.warning(f"โš ๏ธ ์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ œํ•œ: {i-1}๊ฐœ ๋ฌธ์„œ๋งŒ ์‚ฌ์šฉ")
break
context_parts.append(doc_text)
current_length += doc_length
return "\n".join(context_parts)
def _format_sources(self, retrieved_docs: list) -> list:
"""๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๋ฅผ sources ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜"""
sources = []
for doc in retrieved_docs:
source_info = {
'content': doc['content'],
'metadata': doc['metadata'],
'filename': doc.get('filename', 'N/A'),
'organization': doc.get('organization', 'N/A')
}
if 'rerank_score' in doc:
source_info['score'] = doc['rerank_score']
source_info['score_type'] = 'rerank'
elif 'hybrid_score' in doc:
source_info['score'] = doc['hybrid_score']
source_info['score_type'] = 'hybrid'
elif 'relevance_score' in doc:
source_info['score'] = doc['relevance_score']
source_info['score_type'] = 'embedding'
else:
source_info['score'] = 0
source_info['score_type'] = 'unknown'
sources.append(source_info)
return sources
def _estimate_usage(self, query: str, answer: str) -> dict:
"""ํ† ํฐ ์‚ฌ์šฉ๋Ÿ‰ ์ถ”์ •"""
prompt_tokens = len(query.split()) * 2
completion_tokens = len(answer.split()) * 2
return {
'total_tokens': prompt_tokens + completion_tokens,
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens
}
def generate_answer(
self,
query: str,
top_k: int = None,
search_mode: str = None,
alpha: float = None
) -> dict:
"""๋‹ต๋ณ€ ์ƒ์„ฑ (Base ๋ชจ๋ธ + RAG)"""
try:
start_time = time.time()
# ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •
if top_k is not None:
self.top_k = top_k
if search_mode is not None:
self.search_mode = search_mode
if alpha is not None:
self.alpha = alpha
# Router๋กœ ๊ฒ€์ƒ‰ ์—ฌ๋ถ€ ๊ฒฐ์ •
classification = self.router.classify(query)
query_type = classification['type']
logger.info(f"๐Ÿ“ ๋ถ„๋ฅ˜: {query_type} (์‹ ๋ขฐ๋„: {classification['confidence']:.2f})")
# ํƒ€์ž…๋ณ„ ์ฒ˜๋ฆฌ
if query_type in ['greeting', 'thanks', 'out_of_scope']:
# ๊ฒ€์ƒ‰ ์Šคํ‚ต
context = None
used_retrieval = False
self._last_retrieved_docs = []
# ๋™์  ํ”„๋กฌํ”„ํŠธ
system_prompt = PromptManager.get_prompt(query_type, model_type="gguf")
logger.info(f"โญ๏ธ RAG ์Šคํ‚ต: {query_type}")
elif query_type == 'document':
# RAG ์ˆ˜ํ–‰
context = self._retrieve_and_format(query)
used_retrieval = True
# ๋™์  ํ”„๋กฌํ”„ํŠธ
system_prompt = PromptManager.get_prompt('document', model_type="gguf")
logger.info(f"๐Ÿ” RAG ์ˆ˜ํ–‰: {len(self._last_retrieved_docs)}๊ฐœ ๋ฌธ์„œ")
# ๋‹ต๋ณ€ ์ƒ์„ฑ
answer = self.generator.chat(
question=query,
context=context,
system_prompt=system_prompt
)
elapsed_time = time.time() - start_time
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ์ถ”๊ฐ€
self.chat_history.append({"role": "user", "content": query})
self.chat_history.append({"role": "assistant", "content": answer})
# ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return {
'answer': answer,
'sources': self._format_sources(self._last_retrieved_docs),
'used_retrieval': used_retrieval,
'query_type': query_type,
'search_mode': self.search_mode if used_retrieval else 'direct',
'routing_info': classification,
'elapsed_time': elapsed_time,
'usage': self._estimate_usage(query, answer)
}
except Exception as e:
logger.error(f"โŒ ๋‹ต๋ณ€ ์ƒ์„ฑ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
raise RuntimeError(f"๋‹ต๋ณ€ ์ƒ์„ฑ ์‹คํŒจ: {str(e)}") from e
def chat(self, query: str) -> str:
"""๊ฐ„๋‹จํ•œ ๋Œ€ํ™” ์ธํ„ฐํŽ˜์ด์Šค"""
result = self.generate_answer(query)
return result['answer']
def clear_history(self):
"""๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ์ดˆ๊ธฐํ™”"""
self.chat_history = []
logger.info("๐Ÿ—‘๏ธ ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
def get_history(self) -> List[Dict]:
"""๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ๋ฐ˜ํ™˜"""
return self.chat_history.copy()
def set_search_config(
self,
search_mode: str = None,
top_k: int = None,
alpha: float = None
):
"""๊ฒ€์ƒ‰ ์„ค์ • ๋ณ€๊ฒฝ"""
if search_mode is not None:
self.search_mode = search_mode
if top_k is not None:
self.top_k = top_k
if alpha is not None:
self.alpha = alpha
logger.info(
f"๐Ÿ”ง ๊ฒ€์ƒ‰ ์„ค์ • ๋ณ€๊ฒฝ: mode={self.search_mode}, "
f"top_k={self.top_k}, alpha={self.alpha}"
)