Spaces:
Sleeping
Sleeping
| """ | |
| 3๊ฐ์ง ๋ชจ๋ธ ๋น๊ต ์คํ | |
| ๋น๊ต ๋์: | |
| 1. QLoRA + RAG (๊ธฐ์กด ์๋น์ค) | |
| 2. QLoRA ๋จ๋ (RAG ์ ๊ฑฐ) | |
| 3. Base + RAG (PEFT ์ ๊ฑฐ) | |
| ์ธก์ ์งํ: | |
| - ๊ณผ์ ํฉ ์ฌ๋ถ (In-Distribution vs Out-Distribution) | |
| - ๋ต๋ณ ์๋ (elapsed_time, retrieval_time, generation_time) | |
| - ํ ํฐ ๊ฐ์ (total_tokens, prompt_tokens, completion_tokens) | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import json | |
| import logging | |
| from typing import Dict, List, Any | |
| from datetime import datetime | |
| from pathlib import Path | |
| # ํ๋ก์ ํธ ๋ฃจํธ ๊ฒฝ๋ก ์ถ๊ฐ | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from src.utils.config import RAGConfig | |
| from src.eval_dataset import EvalDataset | |
| # ๋ก๊น ์ค์ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ModelComparison: | |
| """๋ชจ๋ธ ๋น๊ต ์คํ ํด๋์ค""" | |
| def __init__(self, config=None, output_dir: str = "./results"): | |
| """์ด๊ธฐํ""" | |
| self.config = config or RAGConfig() | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| # ํ์์คํฌํ | |
| self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # ๋ฐ์ดํฐ์ | |
| self.dataset = EvalDataset() | |
| # ๋ชจ๋ธ ํ์ดํ๋ผ์ธ | |
| self.pipelines = {} | |
| logger.info(f"โ ModelComparison ์ด๊ธฐํ ์๋ฃ") | |
| logger.info(f" ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก: {self.output_dir}") | |
| def load_models(self): | |
| """2๊ฐ์ง ๋ชจ๋ธ ๋ก๋ (Base๋ ์ถํ GGUF ๋ณํ ํ ์ถ๊ฐ ์์ )""" | |
| logger.info("\n" + "="*60) | |
| logger.info("๋ชจ๋ธ ๋ก๋ฉ ์์ (2๊ฐ ๋ชจ๋ธ)") | |
| logger.info("="*60) | |
| try: | |
| # 1. QLoRA + RAG (๊ธฐ์กด) | |
| logger.info("\n[1/2] QLoRA + RAG ๋ชจ๋ธ ๋ก๋ฉ...") | |
| from src.generator.generator_gguf import GGUFRAGPipeline | |
| self.pipelines['qlora_rag'] = GGUFRAGPipeline(config=self.config) | |
| logger.info("โ QLoRA + RAG ๋ก๋ ์๋ฃ") | |
| # 2. QLoRA ๋จ๋ (RAG ์ ๊ฑฐ) | |
| logger.info("\n[2/2] QLoRA ๋จ๋ ๋ชจ๋ธ ๋ก๋ฉ...") | |
| from src.generator.generator_gguf_no_rag import GGUFNoRAGPipeline | |
| self.pipelines['qlora_only'] = GGUFNoRAGPipeline(config=self.config) | |
| logger.info("โ QLoRA ๋จ๋ ๋ก๋ ์๋ฃ") | |
| # 3. Base + RAG (PEFT ์ ๊ฑฐ) - TODO: GGUF ๋ณํ ํ ์ถ๊ฐ | |
| logger.info("\n[3/3] Base + RAG ๋ชจ๋ธ ๋ก๋ฉ...") | |
| from src.generator.generator_gguf_base import GGUFBaseRAGPipeline | |
| self.pipelines['base_rag'] = GGUFBaseRAGPipeline(config=self.config) | |
| logger.info("โ Base + RAG ๋ก๋ ์๋ฃ") | |
| logger.warning("\nโ ๏ธ Base + RAG ์คํต: Base ๋ชจ๋ธ GGUF ๋ณํ ํ ์ถ๊ฐ ์์ ") | |
| except Exception as e: | |
| logger.error(f"โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def run_single_query( | |
| self, | |
| model_name: str, | |
| query: str, | |
| query_info: Dict[str, Any] | |
| ) -> Dict[str, Any]: | |
| """๋จ์ผ ์ง๋ฌธ์ ๋ํ ๋ชจ๋ธ ์คํ""" | |
| pipeline = self.pipelines[model_name] | |
| try: | |
| start_time = time.time() | |
| result = pipeline.generate_answer(query) | |
| total_time = time.time() - start_time | |
| # ๊ฒฐ๊ณผ ์ ๋ฆฌ | |
| return { | |
| 'model': model_name, | |
| 'query': query, | |
| 'category': query_info.get('category', 'unknown'), | |
| 'expected_type': query_info.get('expected_type', 'unknown'), | |
| 'answer': result['answer'], | |
| 'used_retrieval': result.get('used_retrieval', False), | |
| 'query_type': result.get('query_type', 'unknown'), | |
| 'search_mode': result.get('search_mode', 'none'), | |
| 'elapsed_time': total_time, | |
| 'model_elapsed_time': result.get('elapsed_time', 0), | |
| 'usage': result.get('usage', {}), | |
| 'sources_count': len(result.get('sources', [])), | |
| 'success': True, | |
| 'error': None | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ง๋ฌธ ์คํ ์คํจ [{model_name}]: {e}") | |
| return { | |
| 'model': model_name, | |
| 'query': query, | |
| 'category': query_info.get('category', 'unknown'), | |
| 'expected_type': query_info.get('expected_type', 'unknown'), | |
| 'answer': None, | |
| 'used_retrieval': False, | |
| 'query_type': 'error', | |
| 'search_mode': 'none', | |
| 'elapsed_time': 0, | |
| 'model_elapsed_time': 0, | |
| 'usage': {}, | |
| 'sources_count': 0, | |
| 'success': False, | |
| 'error': str(e) | |
| } | |
| def run_experiment( | |
| self, | |
| distribution: str = 'all', | |
| save_results: bool = True | |
| ) -> Dict[str, List[Dict[str, Any]]]: | |
| """ | |
| ์คํ ์คํ | |
| Args: | |
| distribution: 'in', 'out', 'all' | |
| save_results: ๊ฒฐ๊ณผ ์ ์ฅ ์ฌ๋ถ | |
| """ | |
| logger.info("\n" + "="*60) | |
| logger.info("์คํ ์์") | |
| logger.info("="*60) | |
| # ๋ฐ์ดํฐ์ ์ค๋น | |
| if distribution == 'in': | |
| queries_dict = {'in_distribution': self.dataset.get_in_distribution()} | |
| elif distribution == 'out': | |
| queries_dict = {'out_distribution': self.dataset.get_out_distribution()} | |
| else: # 'all' | |
| queries_dict = self.dataset.get_all_queries() | |
| # ๊ฒฐ๊ณผ ์ ์ฅ | |
| all_results = { | |
| 'metadata': { | |
| 'timestamp': self.timestamp, | |
| 'distribution': distribution, | |
| 'models': list(self.pipelines.keys()), | |
| 'total_queries': sum(len(v) for v in queries_dict.values()) | |
| }, | |
| 'results': {} | |
| } | |
| # ๊ฐ ๋ถํฌ์ ๋ํด ์คํ | |
| for dist_type, queries in queries_dict.items(): | |
| logger.info(f"\n{'='*60}") | |
| logger.info(f"{dist_type.upper()} ์คํ ({len(queries)}๊ฐ ์ง๋ฌธ)") | |
| logger.info(f"{'='*60}") | |
| dist_results = [] | |
| # ๊ฐ ์ง๋ฌธ์ ๋ํด | |
| for i, query_info in enumerate(queries, 1): | |
| query = query_info['query'] | |
| logger.info(f"\n[{i}/{len(queries)}] ์ง๋ฌธ: {query}") | |
| # ๊ฐ ๋ชจ๋ธ์ ๋ํด | |
| for model_name in self.pipelines.keys(): | |
| logger.info(f" โ {model_name} ์คํ ์ค...") | |
| result = self.run_single_query(model_name, query, query_info) | |
| dist_results.append(result) | |
| if result['success']: | |
| logger.info(f" โ ์๋ฃ ({result['elapsed_time']:.2f}์ด)") | |
| else: | |
| logger.warning(f" โ ์คํจ: {result['error']}") | |
| all_results['results'][dist_type] = dist_results | |
| # ๊ฒฐ๊ณผ ์ ์ฅ | |
| if save_results: | |
| self._save_results(all_results) | |
| logger.info("\n" + "="*60) | |
| logger.info("โ ์คํ ์๋ฃ") | |
| logger.info("="*60 + "\n") | |
| return all_results | |
| def _save_results(self, results: Dict[str, Any]): | |
| """๊ฒฐ๊ณผ ์ ์ฅ""" | |
| # JSON ํ์ผ๋ก ์ ์ฅ | |
| output_file = self.output_dir / f"results_{self.timestamp}.json" | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| logger.info(f"๐ ๊ฒฐ๊ณผ ์ ์ฅ: {output_file}") | |
| # ์์ฝ ํต๊ณ ์ ์ฅ | |
| summary_file = self.output_dir / f"summary_{self.timestamp}.txt" | |
| self._save_summary(results, summary_file) | |
| logger.info(f"๐ ์์ฝ ์ ์ฅ: {summary_file}") | |
| def _save_summary(self, results: Dict[str, Any], output_file: Path): | |
| """์์ฝ ํต๊ณ ์ ์ฅ""" | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| f.write("="*60 + "\n") | |
| f.write("์คํ ๊ฒฐ๊ณผ ์์ฝ\n") | |
| f.write("="*60 + "\n\n") | |
| # ๋ฉํ๋ฐ์ดํฐ | |
| metadata = results['metadata'] | |
| f.write(f"ํ์์คํฌํ: {metadata['timestamp']}\n") | |
| f.write(f"๋ถํฌ: {metadata['distribution']}\n") | |
| f.write(f"๋ชจ๋ธ: {', '.join(metadata['models'])}\n") | |
| f.write(f"์ด ์ง๋ฌธ ์: {metadata['total_queries']}\n\n") | |
| # ๊ฐ ๋ถํฌ๋ณ ํต๊ณ | |
| for dist_type, dist_results in results['results'].items(): | |
| f.write(f"\n{'='*60}\n") | |
| f.write(f"{dist_type.upper()} ๊ฒฐ๊ณผ\n") | |
| f.write(f"{'='*60}\n\n") | |
| # ๋ชจ๋ธ๋ณ๋ก ๊ทธ๋ฃนํ | |
| model_stats = {} | |
| for result in dist_results: | |
| model = result['model'] | |
| if model not in model_stats: | |
| model_stats[model] = [] | |
| model_stats[model].append(result) | |
| # ๊ฐ ๋ชจ๋ธ๋ณ ํต๊ณ | |
| for model, model_results in model_stats.items(): | |
| f.write(f"\n[{model}]\n") | |
| # ์ฑ๊ณต/์คํจ | |
| success_count = sum(1 for r in model_results if r['success']) | |
| f.write(f" ์ฑ๊ณต: {success_count}/{len(model_results)}\n") | |
| # ํ๊ท ์๊ฐ | |
| avg_time = sum(r['elapsed_time'] for r in model_results if r['success']) / max(success_count, 1) | |
| f.write(f" ํ๊ท ์๊ฐ: {avg_time:.3f}์ด\n") | |
| # ํ๊ท ํ ํฐ | |
| total_tokens = sum(r['usage'].get('total_tokens', 0) for r in model_results if r['success']) | |
| avg_tokens = total_tokens / max(success_count, 1) | |
| f.write(f" ํ๊ท ํ ํฐ: {avg_tokens:.1f}\n") | |
| # RAG ์ฌ์ฉ๋ฅ | |
| rag_count = sum(1 for r in model_results if r['used_retrieval']) | |
| f.write(f" RAG ์ฌ์ฉ: {rag_count}/{len(model_results)} ({rag_count/len(model_results)*100:.1f}%)\n") | |
| def main(): | |
| """๋ฉ์ธ ํจ์""" | |
| logger.info("="*60) | |
| logger.info("RFPilot ๋ชจ๋ธ ๋น๊ต ์คํ") | |
| logger.info("="*60) | |
| # Config ๋ก๋ | |
| config = RAGConfig() | |
| # ์คํ ์ด๊ธฐํ | |
| experiment = ModelComparison(config=config, output_dir="./experiments/results") | |
| # ๋ฐ์ดํฐ์ ํ์ธ | |
| experiment.dataset.print_summary() | |
| experiment.dataset.print_samples(n=3) | |
| # ๋ชจ๋ธ ๋ก๋ | |
| experiment.load_models() | |
| # ์คํ ์คํ | |
| # ์ต์ 1: ์ ์ฒด ์คํ | |
| results = experiment.run_experiment(distribution='all', save_results=True) | |
| # ์ต์ 2: In-Distribution๋ง | |
| # results = experiment.run_experiment(distribution='in', save_results=True) | |
| # ์ต์ 3: Out-Distribution๋ง | |
| # results = experiment.run_experiment(distribution='out', save_results=True) | |
| logger.info(f"\nโ ๋ชจ๋ ์คํ ์๋ฃ!") | |
| logger.info(f" ๊ฒฐ๊ณผ ์ ์ฅ ์์น: {experiment.output_dir}") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except KeyboardInterrupt: | |
| logger.info("\nโ ๏ธ ์ฌ์ฉ์์ ์ํด ์ค๋จ๋จ") | |
| except Exception as e: | |
| logger.error(f"\nโ ์คํ ์คํจ: {e}") | |
| import traceback | |
| traceback.print_exc() |