QLoRA_RAG_test / src /compare_models.py
Dongjin1203's picture
Add Base model GGUF support
b3d3fe8
"""
3๊ฐ€์ง€ ๋ชจ๋ธ ๋น„๊ต ์‹คํ—˜
๋น„๊ต ๋Œ€์ƒ:
1. QLoRA + RAG (๊ธฐ์กด ์„œ๋น„์Šค)
2. QLoRA ๋‹จ๋… (RAG ์ œ๊ฑฐ)
3. Base + RAG (PEFT ์ œ๊ฑฐ)
์ธก์ • ์ง€ํ‘œ:
- ๊ณผ์ ํ•ฉ ์—ฌ๋ถ€ (In-Distribution vs Out-Distribution)
- ๋‹ต๋ณ€ ์†๋„ (elapsed_time, retrieval_time, generation_time)
- ํ† ํฐ ๊ฐœ์ˆ˜ (total_tokens, prompt_tokens, completion_tokens)
"""
import os
import sys
import time
import json
import logging
from typing import Dict, List, Any
from datetime import datetime
from pathlib import Path
# ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ๊ฒฝ๋กœ ์ถ”๊ฐ€
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.utils.config import RAGConfig
from src.eval_dataset import EvalDataset
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelComparison:
"""๋ชจ๋ธ ๋น„๊ต ์‹คํ—˜ ํด๋ž˜์Šค"""
def __init__(self, config=None, output_dir: str = "./results"):
"""์ดˆ๊ธฐํ™”"""
self.config = config or RAGConfig()
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# ํƒ€์ž„์Šคํƒฌํ”„
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# ๋ฐ์ดํ„ฐ์…‹
self.dataset = EvalDataset()
# ๋ชจ๋ธ ํŒŒ์ดํ”„๋ผ์ธ
self.pipelines = {}
logger.info(f"โœ… ModelComparison ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
logger.info(f" ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ: {self.output_dir}")
def load_models(self):
"""2๊ฐ€์ง€ ๋ชจ๋ธ ๋กœ๋“œ (Base๋Š” ์ถ”ํ›„ GGUF ๋ณ€ํ™˜ ํ›„ ์ถ”๊ฐ€ ์˜ˆ์ •)"""
logger.info("\n" + "="*60)
logger.info("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘ (2๊ฐœ ๋ชจ๋ธ)")
logger.info("="*60)
try:
# 1. QLoRA + RAG (๊ธฐ์กด)
logger.info("\n[1/2] QLoRA + RAG ๋ชจ๋ธ ๋กœ๋”ฉ...")
from src.generator.generator_gguf import GGUFRAGPipeline
self.pipelines['qlora_rag'] = GGUFRAGPipeline(config=self.config)
logger.info("โœ… QLoRA + RAG ๋กœ๋“œ ์™„๋ฃŒ")
# 2. QLoRA ๋‹จ๋… (RAG ์ œ๊ฑฐ)
logger.info("\n[2/2] QLoRA ๋‹จ๋… ๋ชจ๋ธ ๋กœ๋”ฉ...")
from src.generator.generator_gguf_no_rag import GGUFNoRAGPipeline
self.pipelines['qlora_only'] = GGUFNoRAGPipeline(config=self.config)
logger.info("โœ… QLoRA ๋‹จ๋… ๋กœ๋“œ ์™„๋ฃŒ")
# 3. Base + RAG (PEFT ์ œ๊ฑฐ) - TODO: GGUF ๋ณ€ํ™˜ ํ›„ ์ถ”๊ฐ€
logger.info("\n[3/3] Base + RAG ๋ชจ๋ธ ๋กœ๋”ฉ...")
from src.generator.generator_gguf_base import GGUFBaseRAGPipeline
self.pipelines['base_rag'] = GGUFBaseRAGPipeline(config=self.config)
logger.info("โœ… Base + RAG ๋กœ๋“œ ์™„๋ฃŒ")
logger.warning("\nโš ๏ธ Base + RAG ์Šคํ‚ต: Base ๋ชจ๋ธ GGUF ๋ณ€ํ™˜ ํ›„ ์ถ”๊ฐ€ ์˜ˆ์ •")
except Exception as e:
logger.error(f"โŒ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
raise
def run_single_query(
self,
model_name: str,
query: str,
query_info: Dict[str, Any]
) -> Dict[str, Any]:
"""๋‹จ์ผ ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋ชจ๋ธ ์‹คํ–‰"""
pipeline = self.pipelines[model_name]
try:
start_time = time.time()
result = pipeline.generate_answer(query)
total_time = time.time() - start_time
# ๊ฒฐ๊ณผ ์ •๋ฆฌ
return {
'model': model_name,
'query': query,
'category': query_info.get('category', 'unknown'),
'expected_type': query_info.get('expected_type', 'unknown'),
'answer': result['answer'],
'used_retrieval': result.get('used_retrieval', False),
'query_type': result.get('query_type', 'unknown'),
'search_mode': result.get('search_mode', 'none'),
'elapsed_time': total_time,
'model_elapsed_time': result.get('elapsed_time', 0),
'usage': result.get('usage', {}),
'sources_count': len(result.get('sources', [])),
'success': True,
'error': None
}
except Exception as e:
logger.error(f"โŒ ์งˆ๋ฌธ ์‹คํ–‰ ์‹คํŒจ [{model_name}]: {e}")
return {
'model': model_name,
'query': query,
'category': query_info.get('category', 'unknown'),
'expected_type': query_info.get('expected_type', 'unknown'),
'answer': None,
'used_retrieval': False,
'query_type': 'error',
'search_mode': 'none',
'elapsed_time': 0,
'model_elapsed_time': 0,
'usage': {},
'sources_count': 0,
'success': False,
'error': str(e)
}
def run_experiment(
self,
distribution: str = 'all',
save_results: bool = True
) -> Dict[str, List[Dict[str, Any]]]:
"""
์‹คํ—˜ ์‹คํ–‰
Args:
distribution: 'in', 'out', 'all'
save_results: ๊ฒฐ๊ณผ ์ €์žฅ ์—ฌ๋ถ€
"""
logger.info("\n" + "="*60)
logger.info("์‹คํ—˜ ์‹œ์ž‘")
logger.info("="*60)
# ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„
if distribution == 'in':
queries_dict = {'in_distribution': self.dataset.get_in_distribution()}
elif distribution == 'out':
queries_dict = {'out_distribution': self.dataset.get_out_distribution()}
else: # 'all'
queries_dict = self.dataset.get_all_queries()
# ๊ฒฐ๊ณผ ์ €์žฅ
all_results = {
'metadata': {
'timestamp': self.timestamp,
'distribution': distribution,
'models': list(self.pipelines.keys()),
'total_queries': sum(len(v) for v in queries_dict.values())
},
'results': {}
}
# ๊ฐ ๋ถ„ํฌ์— ๋Œ€ํ•ด ์‹คํ—˜
for dist_type, queries in queries_dict.items():
logger.info(f"\n{'='*60}")
logger.info(f"{dist_type.upper()} ์‹คํ—˜ ({len(queries)}๊ฐœ ์งˆ๋ฌธ)")
logger.info(f"{'='*60}")
dist_results = []
# ๊ฐ ์งˆ๋ฌธ์— ๋Œ€ํ•ด
for i, query_info in enumerate(queries, 1):
query = query_info['query']
logger.info(f"\n[{i}/{len(queries)}] ์งˆ๋ฌธ: {query}")
# ๊ฐ ๋ชจ๋ธ์— ๋Œ€ํ•ด
for model_name in self.pipelines.keys():
logger.info(f" โ†’ {model_name} ์‹คํ–‰ ์ค‘...")
result = self.run_single_query(model_name, query, query_info)
dist_results.append(result)
if result['success']:
logger.info(f" โœ… ์™„๋ฃŒ ({result['elapsed_time']:.2f}์ดˆ)")
else:
logger.warning(f" โŒ ์‹คํŒจ: {result['error']}")
all_results['results'][dist_type] = dist_results
# ๊ฒฐ๊ณผ ์ €์žฅ
if save_results:
self._save_results(all_results)
logger.info("\n" + "="*60)
logger.info("โœ… ์‹คํ—˜ ์™„๋ฃŒ")
logger.info("="*60 + "\n")
return all_results
def _save_results(self, results: Dict[str, Any]):
"""๊ฒฐ๊ณผ ์ €์žฅ"""
# JSON ํŒŒ์ผ๋กœ ์ €์žฅ
output_file = self.output_dir / f"results_{self.timestamp}.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
logger.info(f"๐Ÿ“ ๊ฒฐ๊ณผ ์ €์žฅ: {output_file}")
# ์š”์•ฝ ํ†ต๊ณ„ ์ €์žฅ
summary_file = self.output_dir / f"summary_{self.timestamp}.txt"
self._save_summary(results, summary_file)
logger.info(f"๐Ÿ“Š ์š”์•ฝ ์ €์žฅ: {summary_file}")
def _save_summary(self, results: Dict[str, Any], output_file: Path):
"""์š”์•ฝ ํ†ต๊ณ„ ์ €์žฅ"""
with open(output_file, 'w', encoding='utf-8') as f:
f.write("="*60 + "\n")
f.write("์‹คํ—˜ ๊ฒฐ๊ณผ ์š”์•ฝ\n")
f.write("="*60 + "\n\n")
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
metadata = results['metadata']
f.write(f"ํƒ€์ž„์Šคํƒฌํ”„: {metadata['timestamp']}\n")
f.write(f"๋ถ„ํฌ: {metadata['distribution']}\n")
f.write(f"๋ชจ๋ธ: {', '.join(metadata['models'])}\n")
f.write(f"์ด ์งˆ๋ฌธ ์ˆ˜: {metadata['total_queries']}\n\n")
# ๊ฐ ๋ถ„ํฌ๋ณ„ ํ†ต๊ณ„
for dist_type, dist_results in results['results'].items():
f.write(f"\n{'='*60}\n")
f.write(f"{dist_type.upper()} ๊ฒฐ๊ณผ\n")
f.write(f"{'='*60}\n\n")
# ๋ชจ๋ธ๋ณ„๋กœ ๊ทธ๋ฃนํ™”
model_stats = {}
for result in dist_results:
model = result['model']
if model not in model_stats:
model_stats[model] = []
model_stats[model].append(result)
# ๊ฐ ๋ชจ๋ธ๋ณ„ ํ†ต๊ณ„
for model, model_results in model_stats.items():
f.write(f"\n[{model}]\n")
# ์„ฑ๊ณต/์‹คํŒจ
success_count = sum(1 for r in model_results if r['success'])
f.write(f" ์„ฑ๊ณต: {success_count}/{len(model_results)}\n")
# ํ‰๊ท  ์‹œ๊ฐ„
avg_time = sum(r['elapsed_time'] for r in model_results if r['success']) / max(success_count, 1)
f.write(f" ํ‰๊ท  ์‹œ๊ฐ„: {avg_time:.3f}์ดˆ\n")
# ํ‰๊ท  ํ† ํฐ
total_tokens = sum(r['usage'].get('total_tokens', 0) for r in model_results if r['success'])
avg_tokens = total_tokens / max(success_count, 1)
f.write(f" ํ‰๊ท  ํ† ํฐ: {avg_tokens:.1f}\n")
# RAG ์‚ฌ์šฉ๋ฅ 
rag_count = sum(1 for r in model_results if r['used_retrieval'])
f.write(f" RAG ์‚ฌ์šฉ: {rag_count}/{len(model_results)} ({rag_count/len(model_results)*100:.1f}%)\n")
def main():
"""๋ฉ”์ธ ํ•จ์ˆ˜"""
logger.info("="*60)
logger.info("RFPilot ๋ชจ๋ธ ๋น„๊ต ์‹คํ—˜")
logger.info("="*60)
# Config ๋กœ๋“œ
config = RAGConfig()
# ์‹คํ—˜ ์ดˆ๊ธฐํ™”
experiment = ModelComparison(config=config, output_dir="./experiments/results")
# ๋ฐ์ดํ„ฐ์…‹ ํ™•์ธ
experiment.dataset.print_summary()
experiment.dataset.print_samples(n=3)
# ๋ชจ๋ธ ๋กœ๋“œ
experiment.load_models()
# ์‹คํ—˜ ์‹คํ–‰
# ์˜ต์…˜ 1: ์ „์ฒด ์‹คํ—˜
results = experiment.run_experiment(distribution='all', save_results=True)
# ์˜ต์…˜ 2: In-Distribution๋งŒ
# results = experiment.run_experiment(distribution='in', save_results=True)
# ์˜ต์…˜ 3: Out-Distribution๋งŒ
# results = experiment.run_experiment(distribution='out', save_results=True)
logger.info(f"\nโœ… ๋ชจ๋“  ์‹คํ—˜ ์™„๋ฃŒ!")
logger.info(f" ๊ฒฐ๊ณผ ์ €์žฅ ์œ„์น˜: {experiment.output_dir}")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
logger.info("\nโš ๏ธ ์‚ฌ์šฉ์ž์— ์˜ํ•ด ์ค‘๋‹จ๋จ")
except Exception as e:
logger.error(f"\nโŒ ์‹คํ—˜ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()