Spaces:
Sleeping
Sleeping
| from llama_cpp import Llama | |
| from typing import Optional, Dict, Any, List | |
| import logging | |
| import time | |
| import os | |
| from src.utils.config import RAGConfig | |
| from src.router.query_router import QueryRouter | |
| from src.prompts.dynamic_prompts import PromptManager | |
| # ๋ก๊น ์ค์ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class GGUFGenerator: | |
| """ | |
| GGUF ๊ธฐ๋ฐ Llama-3 ์์ฑ๊ธฐ | |
| llama.cpp๋ฅผ ์ฌ์ฉํ์ฌ GGUF ํฌ๋งท ๋ชจ๋ธ์ ๋ก๋ํ๊ณ | |
| ์ ์ฐฐ ๊ด๋ จ ์ง์์๋ต์ ์ํํฉ๋๋ค. | |
| """ | |
| def __init__( | |
| self, | |
| model_path: str, | |
| n_gpu_layers: int = 0, | |
| n_ctx: int = 8192, | |
| n_threads: int = 8, | |
| config = None, | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| system_prompt: str = "๋น์ ์ RFP(์ ์์์ฒญ์) ๋ถ์ ๋ฐ ์์ฝ ์ ๋ฌธ๊ฐ์ ๋๋ค." | |
| ): | |
| """์์ฑ๊ธฐ ์ด๊ธฐํ""" | |
| self.config = config or RAGConfig() | |
| self.model_path = model_path | |
| self.n_gpu_layers = n_gpu_layers | |
| self.n_ctx = n_ctx | |
| self.n_threads = n_threads | |
| self.max_new_tokens = max_new_tokens | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.system_prompt = system_prompt | |
| # ๋ชจ๋ธ (๋์ค์ ๋ก๋) | |
| self.model = None | |
| logger.info(f"GGUFGenerator ์ด๊ธฐํ ์๋ฃ (Base ๋ชจ๋ธ)") | |
| def load_model(self) -> None: | |
| """ | |
| GGUF ๋ชจ๋ธ ๋ก๋ | |
| โ Base ๋ชจ๋ธ ์ฌ์ฉ: Config์์ BASE_MODEL_HUB_REPO ๊ฐ์ ธ์ค๊ธฐ | |
| """ | |
| # ์ค๋ณต ๋ก๋ ๋ฐฉ์ง | |
| if self.model is not None: | |
| logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋์ด ์์ต๋๋ค.") | |
| return | |
| try: | |
| # Config์์ USE_MODEL_HUB ํ์ธ | |
| use_model_hub = getattr(self.config, 'USE_MODEL_HUB', True) | |
| # Model Hub ์ฌ์ฉ ์ฌ๋ถ์ ๋ฐ๋ผ ๊ฒฝ๋ก ๊ฒฐ์ | |
| if use_model_hub: | |
| # === Model Hub์์ ๋ค์ด๋ก๋ === | |
| # โ Config์์ Base ๋ชจ๋ธ ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ | |
| base_model_repo = getattr( | |
| self.config, | |
| 'BASE_MODEL_HUB_REPO', | |
| 'Dongjin1203/Llama-3-Open-Ko-8B-GGUF' | |
| ) | |
| base_model_filename = getattr( | |
| self.config, | |
| 'BASE_MODEL_HUB_FILENAME', | |
| 'Llama-3-Open-Ko-8B-Q4_K_M.gguf' | |
| ) | |
| model_cache_dir = getattr(self.config, 'MODEL_CACHE_DIR', '.cache/models') | |
| logger.info(f"๐ฅ Base ๋ชจ๋ธ ๋ค์ด๋ก๋: {base_model_repo}") | |
| logger.info(f" ํ์ผ๋ช : {base_model_filename}") | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download( | |
| repo_id=base_model_repo, | |
| filename=base_model_filename, | |
| cache_dir=model_cache_dir, | |
| local_dir=model_cache_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"โ Base ๋ชจ๋ธ ๋ค์ด๋ก๋ ์๋ฃ: {model_path}") | |
| else: | |
| # === ๋ก์ปฌ ํ์ผ ์ฌ์ฉ === | |
| model_path = self.model_path | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError( | |
| f"โ ๋ก์ปฌ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {model_path}\n" | |
| f" USE_MODEL_HUB=true๋ก ์ค์ ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ์ ์ค๋นํ์ธ์." | |
| ) | |
| logger.info(f"๐ ๋ก์ปฌ Base ๋ชจ๋ธ ์ฌ์ฉ: {model_path}") | |
| # === ๊ณตํต: ๋ชจ๋ธ ๋ก๋ === | |
| logger.info(f"๐ Base GGUF ๋ชจ๋ธ ๋ก๋ ์ค...") | |
| logger.info(f" GPU ๋ ์ด์ด: {self.n_gpu_layers}") | |
| logger.info(f" ์ปจํ ์คํธ: {self.n_ctx}") | |
| self.model = Llama( | |
| model_path=model_path, | |
| n_gpu_layers=self.n_gpu_layers, | |
| n_ctx=self.n_ctx, | |
| n_threads=self.n_threads, | |
| verbose=True, | |
| ) | |
| # ์ค์ ์ ์ฉ๋ n_ctx ํ์ธ | |
| actual_n_ctx = self.model.n_ctx() | |
| logger.info("โ Base GGUF ๋ชจ๋ธ ๋ก๋ ์๋ฃ!") | |
| logger.info(f" - ๋ชจ๋ธ: {base_model_repo if use_model_hub else 'local'}") | |
| logger.info(f" - ์ค์ ํ n_ctx: {self.n_ctx}") | |
| logger.info(f" - ์ค์ n_ctx: {actual_n_ctx}") | |
| if actual_n_ctx < self.n_ctx: | |
| logger.warning(f"โ ๏ธ n_ctx๊ฐ ์์๋ณด๋ค ์์ต๋๋ค: {actual_n_ctx} < {self.n_ctx}") | |
| except FileNotFoundError as e: | |
| logger.error(f"โ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {e}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| raise RuntimeError(f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| def format_prompt( | |
| self, | |
| question: str, | |
| context: Optional[str] = None, | |
| system_prompt: Optional[str] = None | |
| ) -> str: | |
| """GGUF ๋ชจ๋ธ์ฉ ๊ฐ๋จํ ํ๋กฌํํธ ํฌ๋งทํ """ | |
| if system_prompt is None: | |
| system_prompt = self.system_prompt | |
| if context is not None: | |
| user_message = f"์ฐธ๊ณ ๋ฌธ์:\n{context}\n\n์ง๋ฌธ: {question}" | |
| else: | |
| user_message = question | |
| formatted_prompt = f"""### ์์คํ | |
| {system_prompt} | |
| ### ์ฌ์ฉ์ | |
| {user_message} | |
| ### ๋ต๋ณ | |
| """ | |
| return formatted_prompt | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_new_tokens: Optional[int] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| ) -> str: | |
| """ํ๋กฌํํธ๋ฅผ ์ ๋ ฅ๋ฐ์ ์๋ต ์์ฑ""" | |
| if self.model is None: | |
| raise RuntimeError( | |
| "๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. load_model()์ ๋จผ์ ํธ์ถํ์ธ์." | |
| ) | |
| if max_new_tokens is None: | |
| max_new_tokens = self.max_new_tokens | |
| if temperature is None: | |
| temperature = self.temperature | |
| if top_p is None: | |
| top_p = self.top_p | |
| try: | |
| logger.info(f"๐ ์์ฑ ์์ (max_tokens={max_new_tokens}, temp={temperature})") | |
| start_time = time.time() | |
| output = self.model( | |
| prompt, | |
| max_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| echo=False, | |
| stop=[ | |
| "###", "\n\n###", | |
| "### ์ฌ์ฉ์", "\n์ฌ์ฉ์:", | |
| "</s>", | |
| "ํ๊ตญ์ด ๋ต๋ณ", "ํ๊ตญ์ด๋ก ๋ต๋ณ", "์ง์นจ:", | |
| "๋ฌธ์ฅ", "(๋ฌธ์ฅ", | |
| "\n\n", | |
| "?", | |
| "์?", "๊น?", "๋์?", "์ต๋๊น?" | |
| ], | |
| ) | |
| elapsed = time.time() - start_time | |
| logger.info(f"โ ์์ฑ ์๋ฃ: {elapsed:.2f}์ด") | |
| response = output['choices'][0]['text'].strip() | |
| logger.info(f"๐ ์๋ต ๊ธธ์ด: {len(response)} ๊ธ์") | |
| return response | |
| except Exception as e: | |
| logger.error(f"โ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| raise RuntimeError(f"ํ ์คํธ ์์ฑ ์คํจ: {e}") | |
| def chat( | |
| self, | |
| question: str, | |
| context: Optional[str] = None, | |
| system_prompt=None, | |
| **kwargs | |
| ) -> str: | |
| """์ง๋ฌธ์ ๋ํ ์๋ต ์์ฑ""" | |
| prompt = self.format_prompt( | |
| question=question, | |
| context=context, | |
| system_prompt=system_prompt | |
| ) | |
| response = self.generate(prompt, **kwargs) | |
| return response | |
| class GGUFBaseRAGPipeline: | |
| """ | |
| Base ๋ชจ๋ธ + RAG ํ์ดํ๋ผ์ธ | |
| โ Base ๋ชจ๋ธ ์ฌ์ฉ (beomi/Llama-3-Open-Ko-8B) | |
| โ RAG ์ ์ง | |
| โ ๊ธฐ์กด generator_gguf.py์ ๋์ผํ ๊ธฐ๋ฅ | |
| """ | |
| def __init__( | |
| self, | |
| config=None, | |
| model: str = None, | |
| top_k: int = None, | |
| n_gpu_layers: int = None, | |
| n_ctx: int = None, | |
| n_threads: int = None, | |
| max_new_tokens: int = None, | |
| temperature: float = None, | |
| top_p: float = None, | |
| search_mode: str = None, | |
| alpha: float = None | |
| ): | |
| """์ด๊ธฐํ""" | |
| self.config = config or RAGConfig() | |
| # ๊ฒ์ ์ค์ | |
| self.top_k = top_k or getattr(self.config, 'DEFAULT_TOP_K', 10) | |
| self.search_mode = search_mode or getattr(self.config, 'DEFAULT_SEARCH_MODE', 'hybrid_rerank') | |
| self.alpha = alpha if alpha is not None else getattr(self.config, 'DEFAULT_ALPHA', 0.5) | |
| # Retriever ์ด๊ธฐํ | |
| logger.info("RAGRetriever ์ด๊ธฐํ ์ค...") | |
| from src.retriever.retriever import RAGRetriever | |
| self.retriever = RAGRetriever(config=self.config) | |
| # GGUF ์ค์ | |
| gguf_n_gpu_layers = n_gpu_layers if n_gpu_layers is not None else getattr(self.config, 'GGUF_N_GPU_LAYERS', 35) | |
| gguf_n_ctx = n_ctx if n_ctx is not None else getattr(self.config, 'GGUF_N_CTX', 2048) | |
| gguf_n_threads = n_threads if n_threads is not None else getattr(self.config, 'GGUF_N_THREADS', 4) | |
| gguf_max_new_tokens = max_new_tokens if max_new_tokens is not None else getattr(self.config, 'GGUF_MAX_NEW_TOKENS', 512) | |
| gguf_temperature = temperature if temperature is not None else getattr(self.config, 'GGUF_TEMPERATURE', 0.7) | |
| gguf_top_p = top_p if top_p is not None else getattr(self.config, 'GGUF_TOP_P', 0.9) | |
| # ๋ชจ๋ธ ๊ฒฝ๋ก (์ฌ์ฉ ์ ํจ, Hub์์ ๋ค์ด๋ก๋) | |
| gguf_model_path = getattr(self.config, 'GGUF_MODEL_PATH', '.cache/models/llama-3-ko-8b.gguf') | |
| # ์์คํ ํ๋กฌํํธ | |
| system_prompt = getattr(self.config, 'SYSTEM_PROMPT', '๋น์ ์ ํ๊ตญ ๊ณต๊ณต๊ธฐ๊ด ์ฌ์ ์ ์์ ๋ถ์ ์ ๋ฌธ๊ฐ์ ๋๋ค.') | |
| # GGUFGenerator ์ด๊ธฐํ | |
| logger.info("GGUFGenerator ์ด๊ธฐํ ์ค... (Base ๋ชจ๋ธ)") | |
| logger.info(f" GPU ๋ ์ด์ด: {gguf_n_gpu_layers}") | |
| logger.info(f" ์ปจํ ์คํธ: {gguf_n_ctx}") | |
| self.generator = GGUFGenerator( | |
| model_path=gguf_model_path, | |
| n_gpu_layers=gguf_n_gpu_layers, | |
| n_ctx=gguf_n_ctx, | |
| n_threads=gguf_n_threads, | |
| config=self.config, | |
| max_new_tokens=gguf_max_new_tokens, | |
| temperature=gguf_temperature, | |
| top_p=gguf_top_p, | |
| system_prompt=system_prompt | |
| ) | |
| # ๋ชจ๋ธ ๋ก๋ | |
| logger.info("Base GGUF ๋ชจ๋ธ ๋ก๋ ์ค...") | |
| self.generator.load_model() | |
| # Router | |
| self.router = QueryRouter() | |
| # ๋ํ ํ์คํ ๋ฆฌ | |
| self.chat_history: List[Dict] = [] | |
| # ๋ง์ง๋ง ๊ฒ์ ๊ฒฐ๊ณผ | |
| self._last_retrieved_docs = [] | |
| logger.info("โ GGUFBaseRAGPipeline ์ด๊ธฐํ ์๋ฃ") | |
| logger.info(f" - ๊ฒ์ ๋ชจ๋: {self.search_mode}") | |
| logger.info(f" - ๊ธฐ๋ณธ top_k: {self.top_k}") | |
| def _retrieve_and_format(self, query: str) -> str: | |
| """๊ฒ์ ์ํ ๋ฐ ์ปจํ ์คํธ ํฌ๋งทํ """ | |
| # ๊ฒ์ ๋ชจ๋์ ๋ฐ๋ผ ๋ฌธ์ ๊ฒ์ | |
| if self.search_mode == "embedding": | |
| docs = self.retriever.search(query, top_k=self.top_k) | |
| elif self.search_mode == "embedding_rerank": | |
| docs = self.retriever.search_with_rerank(query, top_k=self.top_k) | |
| elif self.search_mode == "hybrid": | |
| docs = self.retriever.hybrid_search( | |
| query, top_k=self.top_k, alpha=self.alpha | |
| ) | |
| elif self.search_mode == "hybrid_rerank": | |
| docs = self.retriever.hybrid_search_with_rerank( | |
| query, top_k=self.top_k, alpha=self.alpha | |
| ) | |
| else: | |
| docs = self.retriever.search(query, top_k=self.top_k) | |
| # ๋ง์ง๋ง ๊ฒ์ ๊ฒฐ๊ณผ ์ ์ฅ | |
| self._last_retrieved_docs = docs | |
| # ์ปจํ ์คํธ ํฌ๋งทํ | |
| return self._format_context(docs) | |
| def _format_context(self, retrieved_docs: list) -> str: | |
| """๊ฒ์๋ ๋ฌธ์๋ฅผ ์ปจํ ์คํธ๋ก ๋ณํ""" | |
| if not retrieved_docs: | |
| return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
| context_parts = [] | |
| max_context_chars = 8000 | |
| current_length = 0 | |
| for i, doc in enumerate(retrieved_docs, 1): | |
| doc_text = f"[๋ฌธ์ {i}]\n{doc['content']}\n" | |
| doc_length = len(doc_text) | |
| if current_length + doc_length > max_context_chars: | |
| logger.warning(f"โ ๏ธ ์ปจํ ์คํธ ๊ธธ์ด ์ ํ: {i-1}๊ฐ ๋ฌธ์๋ง ์ฌ์ฉ") | |
| break | |
| context_parts.append(doc_text) | |
| current_length += doc_length | |
| return "\n".join(context_parts) | |
| def _format_sources(self, retrieved_docs: list) -> list: | |
| """๊ฒ์๋ ๋ฌธ์๋ฅผ sources ํ์์ผ๋ก ๋ณํ""" | |
| sources = [] | |
| for doc in retrieved_docs: | |
| source_info = { | |
| 'content': doc['content'], | |
| 'metadata': doc['metadata'], | |
| 'filename': doc.get('filename', 'N/A'), | |
| 'organization': doc.get('organization', 'N/A') | |
| } | |
| if 'rerank_score' in doc: | |
| source_info['score'] = doc['rerank_score'] | |
| source_info['score_type'] = 'rerank' | |
| elif 'hybrid_score' in doc: | |
| source_info['score'] = doc['hybrid_score'] | |
| source_info['score_type'] = 'hybrid' | |
| elif 'relevance_score' in doc: | |
| source_info['score'] = doc['relevance_score'] | |
| source_info['score_type'] = 'embedding' | |
| else: | |
| source_info['score'] = 0 | |
| source_info['score_type'] = 'unknown' | |
| sources.append(source_info) | |
| return sources | |
| def _estimate_usage(self, query: str, answer: str) -> dict: | |
| """ํ ํฐ ์ฌ์ฉ๋ ์ถ์ """ | |
| prompt_tokens = len(query.split()) * 2 | |
| completion_tokens = len(answer.split()) * 2 | |
| return { | |
| 'total_tokens': prompt_tokens + completion_tokens, | |
| 'prompt_tokens': prompt_tokens, | |
| 'completion_tokens': completion_tokens | |
| } | |
| def generate_answer( | |
| self, | |
| query: str, | |
| top_k: int = None, | |
| search_mode: str = None, | |
| alpha: float = None | |
| ) -> dict: | |
| """๋ต๋ณ ์์ฑ (Base ๋ชจ๋ธ + RAG)""" | |
| try: | |
| start_time = time.time() | |
| # ํ๋ผ๋ฏธํฐ ์ค์ | |
| if top_k is not None: | |
| self.top_k = top_k | |
| if search_mode is not None: | |
| self.search_mode = search_mode | |
| if alpha is not None: | |
| self.alpha = alpha | |
| # Router๋ก ๊ฒ์ ์ฌ๋ถ ๊ฒฐ์ | |
| classification = self.router.classify(query) | |
| query_type = classification['type'] | |
| logger.info(f"๐ ๋ถ๋ฅ: {query_type} (์ ๋ขฐ๋: {classification['confidence']:.2f})") | |
| # ํ์ ๋ณ ์ฒ๋ฆฌ | |
| if query_type in ['greeting', 'thanks', 'out_of_scope']: | |
| # ๊ฒ์ ์คํต | |
| context = None | |
| used_retrieval = False | |
| self._last_retrieved_docs = [] | |
| # ๋์ ํ๋กฌํํธ | |
| system_prompt = PromptManager.get_prompt(query_type, model_type="gguf") | |
| logger.info(f"โญ๏ธ RAG ์คํต: {query_type}") | |
| elif query_type == 'document': | |
| # RAG ์ํ | |
| context = self._retrieve_and_format(query) | |
| used_retrieval = True | |
| # ๋์ ํ๋กฌํํธ | |
| system_prompt = PromptManager.get_prompt('document', model_type="gguf") | |
| logger.info(f"๐ RAG ์ํ: {len(self._last_retrieved_docs)}๊ฐ ๋ฌธ์") | |
| # ๋ต๋ณ ์์ฑ | |
| answer = self.generator.chat( | |
| question=query, | |
| context=context, | |
| system_prompt=system_prompt | |
| ) | |
| elapsed_time = time.time() - start_time | |
| # ๋ํ ํ์คํ ๋ฆฌ ์ถ๊ฐ | |
| self.chat_history.append({"role": "user", "content": query}) | |
| self.chat_history.append({"role": "assistant", "content": answer}) | |
| # ๊ฒฐ๊ณผ ๋ฐํ | |
| return { | |
| 'answer': answer, | |
| 'sources': self._format_sources(self._last_retrieved_docs), | |
| 'used_retrieval': used_retrieval, | |
| 'query_type': query_type, | |
| 'search_mode': self.search_mode if used_retrieval else 'direct', | |
| 'routing_info': classification, | |
| 'elapsed_time': elapsed_time, | |
| 'usage': self._estimate_usage(query, answer) | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ๋ต๋ณ ์์ฑ ์คํจ: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise RuntimeError(f"๋ต๋ณ ์์ฑ ์คํจ: {str(e)}") from e | |
| def chat(self, query: str) -> str: | |
| """๊ฐ๋จํ ๋ํ ์ธํฐํ์ด์ค""" | |
| result = self.generate_answer(query) | |
| return result['answer'] | |
| def clear_history(self): | |
| """๋ํ ํ์คํ ๋ฆฌ ์ด๊ธฐํ""" | |
| self.chat_history = [] | |
| logger.info("๐๏ธ ๋ํ ํ์คํ ๋ฆฌ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค.") | |
| def get_history(self) -> List[Dict]: | |
| """๋ํ ํ์คํ ๋ฆฌ ๋ฐํ""" | |
| return self.chat_history.copy() | |
| def set_search_config( | |
| self, | |
| search_mode: str = None, | |
| top_k: int = None, | |
| alpha: float = None | |
| ): | |
| """๊ฒ์ ์ค์ ๋ณ๊ฒฝ""" | |
| if search_mode is not None: | |
| self.search_mode = search_mode | |
| if top_k is not None: | |
| self.top_k = top_k | |
| if alpha is not None: | |
| self.alpha = alpha | |
| logger.info( | |
| f"๐ง ๊ฒ์ ์ค์ ๋ณ๊ฒฝ: mode={self.search_mode}, " | |
| f"top_k={self.top_k}, alpha={self.alpha}" | |
| ) |