clovax-tax-chatbot / run_rag_tests.py
bissal's picture
large modification
3a15ede
# run_rag_tests.py - RAG ν…ŒμŠ€νŠΈ μ‹€ν–‰ 슀크립트 (ν—ˆκΉ…νŽ˜μ΄μŠ€ 슀페이슀 μ΅œμ ν™”)
"""
RAG μ‹œμŠ€ν…œ ν…ŒμŠ€νŠΈ 및 νŠœλ‹μ„ μœ„ν•œ 톡합 μ‹€ν–‰ 슀크립트
μ‚¬μš©λ²•:
python run_rag_tests.py # κΈ°λ³Έ ν…ŒμŠ€νŠΈ
python run_rag_tests.py --rebuild # RAG μ‹œμŠ€ν…œ μž¬κ΅¬μΆ• ν›„ ν…ŒμŠ€νŠΈ
python run_rag_tests.py --tune # μ„€μ • νŠœλ‹
python run_rag_tests.py --benchmark # μ„±λŠ₯ 벀치마크만
python run_rag_tests.py --quick # λΉ λ₯Έ ν…ŒμŠ€νŠΈ
"""
import os
import sys
import argparse
import subprocess
from datetime import datetime
def print_banner():
"""λ°°λ„ˆ 좜λ ₯"""
print("="*80)
print("RAG μ‹œμŠ€ν…œ ν…ŒμŠ€νŠΈ 및 νŠœλ‹ 도ꡬ")
print("μ‹€ν–‰ μ‹œκ°„:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print("μž‘μ—… 디렉토리:", os.getcwd())
print("="*80)
def check_dependencies():
"""쒅속성 확인"""
print("쒅속성 확인 쀑...")
required_files = [
'rag_system.py',
'config.py',
'law_fetcher.py',
'test_rag.py',
'tune_rag_config.py'
]
missing_files = []
for file in required_files:
if not os.path.exists(file):
missing_files.append(file)
if missing_files:
print(f"λˆ„λ½λœ 파일: {', '.join(missing_files)}")
return False
print("λͺ¨λ“  ν•„μˆ˜ 파일 확인 μ™„λ£Œ")
return True
def run_basic_tests(rebuild=False):
"""κΈ°λ³Έ ν…ŒμŠ€νŠΈ μ‹€ν–‰"""
print("\nπŸ§ͺ κΈ°λ³Έ RAG ν…ŒμŠ€νŠΈ μ‹€ν–‰")
print("-" * 40)
cmd = [sys.executable, 'test_rag.py']
if rebuild:
cmd.append('--rebuild')
try:
result = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8', errors='replace')
print("πŸ“€ ν…ŒμŠ€νŠΈ 좜λ ₯:")
if result.stdout:
print(result.stdout)
if result.stderr:
print("⚠️ κ²½κ³ /였λ₯˜:")
print(result.stderr)
if result.returncode == 0:
print("βœ… κΈ°λ³Έ ν…ŒμŠ€νŠΈ 성곡")
return True
else:
print(f"❌ κΈ°λ³Έ ν…ŒμŠ€νŠΈ μ‹€νŒ¨ (exit code: {result.returncode})")
return False
except Exception as e:
print(f"πŸ’₯ ν…ŒμŠ€νŠΈ μ‹€ν–‰ 였λ₯˜: {e}")
return False
def run_config_tuning():
"""μ„€μ • νŠœλ‹ μ‹€ν–‰"""
print("\nπŸŽ›οΈ RAG μ„€μ • νŠœλ‹ μ‹€ν–‰")
print("-" * 40)
# 벑터 DB 쑴재 확인
required_files = ['vector_db.faiss', 'documents.pkl', 'metadata.json']
missing_files = [f for f in required_files if not os.path.exists(f)]
if missing_files:
print(f"❌ 벑터 DB 파일 λˆ„λ½: {', '.join(missing_files)}")
print("πŸ“‹ λ¨Όμ € RAG μ‹œμŠ€ν…œμ„ κ΅¬μΆ•ν•΄μ£Όμ„Έμš”:")
print(" python run_rag_tests.py --rebuild")
return False
try:
result = subprocess.run([sys.executable, 'tune_rag_config.py'],
capture_output=True, text=True, encoding='utf-8', errors='replace')
print("πŸ“€ νŠœλ‹ 좜λ ₯:")
if result.stdout:
print(result.stdout)
if result.stderr:
print("⚠️ κ²½κ³ /였λ₯˜:")
print(result.stderr)
if result.returncode == 0:
print("βœ… μ„€μ • νŠœλ‹ 성곡")
# μ΅œμ ν™”λœ μ„€μ • 파일 확인
if os.path.exists('optimized_config.py'):
print("πŸ“„ μ΅œμ ν™”λœ μ„€μ • 파일 생성: optimized_config.py")
return True
else:
print(f"❌ μ„€μ • νŠœλ‹ μ‹€νŒ¨ (exit code: {result.returncode})")
return False
except Exception as e:
print(f"πŸ’₯ νŠœλ‹ μ‹€ν–‰ 였λ₯˜: {e}")
return False
def run_benchmark_only():
"""벀치마크만 μ‹€ν–‰"""
print("\nπŸƒ μ„±λŠ₯ 벀치마크 μ‹€ν–‰")
print("-" * 40)
try:
# test_rag.pyμ—μ„œ 벀치마크 λΆ€λΆ„λ§Œ μ‹€ν–‰ν•˜λ„λ‘ 별도 슀크립트 μž‘μ„±
benchmark_code = '''
import sys
sys.path.append('.')
from test_rag import RAGTester
tester = RAGTester()
if tester.initialize_rag():
tester.performance_benchmark(iterations=10)
else:
print("❌ RAG μ‹œμŠ€ν…œ μ΄ˆκΈ°ν™” μ‹€νŒ¨")
'''
result = subprocess.run([sys.executable, '-c', benchmark_code],
capture_output=True, text=True, encoding='utf-8', errors='replace')
print("πŸ“€ 벀치마크 κ²°κ³Ό:")
if result.stdout:
print(result.stdout)
if result.stderr:
print("⚠️ κ²½κ³ /였λ₯˜:")
print(result.stderr)
return result.returncode == 0
except Exception as e:
print(f"πŸ’₯ 벀치마크 μ‹€ν–‰ 였λ₯˜: {e}")
return False
def run_quick_test():
"""λΉ λ₯Έ ν…ŒμŠ€νŠΈ μ‹€ν–‰"""
print("\n⚑ λΉ λ₯Έ RAG ν…ŒμŠ€νŠΈ μ‹€ν–‰")
print("-" * 40)
try:
quick_test_code = '''
import sys
sys.path.append('.')
from test_rag import RAGTester
tester = RAGTester()
if tester.initialize_rag():
# 3개 쿼리만 ν…ŒμŠ€νŠΈ
tester.test_queries = tester.test_queries[:3]
success = tester.run_basic_tests()
print(f"\\nλΉ λ₯Έ ν…ŒμŠ€νŠΈ κ²°κ³Ό: {'성곡' if success else 'μ‹€νŒ¨'}")
else:
print("❌ RAG μ‹œμŠ€ν…œ μ΄ˆκΈ°ν™” μ‹€νŒ¨")
'''
result = subprocess.run([sys.executable, '-c', quick_test_code],
capture_output=True, text=True, encoding='utf-8')
print("πŸ“€ λΉ λ₯Έ ν…ŒμŠ€νŠΈ κ²°κ³Ό:")
print(result.stdout)
if result.stderr:
print("⚠️ κ²½κ³ /였λ₯˜:")
print(result.stderr)
return result.returncode == 0
except Exception as e:
print(f"πŸ’₯ λΉ λ₯Έ ν…ŒμŠ€νŠΈ μ‹€ν–‰ 였λ₯˜: {e}")
return False
def show_system_info():
"""μ‹œμŠ€ν…œ 정보 ν‘œμ‹œ"""
print("\nπŸ’» μ‹œμŠ€ν…œ 정보")
print("-" * 40)
print(f"🐍 Python 버전: {sys.version}")
print(f"πŸ“ ν˜„μž¬ 디렉토리: {os.getcwd()}")
# 벑터 DB μƒνƒœ 확인
vector_files = ['vector_db.faiss', 'documents.pkl', 'metadata.json']
vector_status = []
for file in vector_files:
if os.path.exists(file):
size = os.path.getsize(file) / 1024 # KB
vector_status.append(f"βœ… {file} ({size:.1f}KB)")
else:
vector_status.append(f"❌ {file} (μ—†μŒ)")
print("πŸ“Š 벑터 DB μƒνƒœ:")
for status in vector_status:
print(f" {status}")
# μΊμ‹œ μƒνƒœ 확인
cache_dir = './law_cache'
if os.path.exists(cache_dir):
cache_files = len([f for f in os.listdir(cache_dir) if f.endswith('.json')])
print(f"πŸ’Ύ 법령 μΊμ‹œ: {cache_files}개 파일")
else:
print("πŸ’Ύ 법령 μΊμ‹œ: μ—†μŒ")
def main():
"""메인 ν•¨μˆ˜"""
parser = argparse.ArgumentParser(description='RAG μ‹œμŠ€ν…œ ν…ŒμŠ€νŠΈ 및 νŠœλ‹ 도ꡬ')
parser.add_argument('--rebuild', action='store_true', help='RAG μ‹œμŠ€ν…œ μž¬κ΅¬μΆ•')
parser.add_argument('--tune', action='store_true', help='μ„€μ • νŠœλ‹ μ‹€ν–‰')
parser.add_argument('--benchmark', action='store_true', help='μ„±λŠ₯ 벀치마크만 μ‹€ν–‰')
parser.add_argument('--quick', action='store_true', help='λΉ λ₯Έ ν…ŒμŠ€νŠΈ μ‹€ν–‰')
parser.add_argument('--info', action='store_true', help='μ‹œμŠ€ν…œ μ •λ³΄λ§Œ ν‘œμ‹œ')
args = parser.parse_args()
print_banner()
# μ‹œμŠ€ν…œ μ •λ³΄λ§Œ ν‘œμ‹œν•˜κ³  μ’…λ£Œ
if args.info:
show_system_info()
return
# 쒅속성 확인
if not check_dependencies():
print("❌ 쒅속성 확인 μ‹€νŒ¨")
sys.exit(1)
show_system_info()
# μ‹€ν–‰ν•  μž‘μ—… κ²°μ •
tasks = []
if args.quick:
tasks.append(('λΉ λ₯Έ ν…ŒμŠ€νŠΈ', run_quick_test))
elif args.benchmark:
tasks.append(('벀치마크', run_benchmark_only))
elif args.tune:
tasks.append(('μ„€μ • νŠœλ‹', run_config_tuning))
else:
# κΈ°λ³Έ ν…ŒμŠ€νŠΈ
tasks.append(('κΈ°λ³Έ ν…ŒμŠ€νŠΈ', lambda: run_basic_tests(args.rebuild)))
# νŠœλ‹λ„ ν•¨κ»˜ μ‹€ν–‰ (벑터 DBκ°€ μžˆλŠ” 경우)
if all(os.path.exists(f) for f in ['vector_db.faiss', 'documents.pkl', 'metadata.json']):
print("\nπŸ“‹ 벑터 DBκ°€ μ‘΄μž¬ν•˜λ―€λ‘œ μ„€μ • νŠœλ‹λ„ μ‹€ν–‰ν•©λ‹ˆλ‹€.")
tasks.append(('μ„€μ • νŠœλ‹', run_config_tuning))
# μž‘μ—… μ‹€ν–‰
success_count = 0
for task_name, task_func in tasks:
print(f"\nπŸš€ {task_name} μ‹œμž‘...")
try:
if task_func():
success_count += 1
print(f"βœ… {task_name} μ™„λ£Œ")
else:
print(f"❌ {task_name} μ‹€νŒ¨")
except KeyboardInterrupt:
print(f"\nπŸ›‘ μ‚¬μš©μžμ— μ˜ν•΄ {task_name}이 μ€‘λ‹¨λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
break
except Exception as e:
print(f"πŸ’₯ {task_name} 쀑 였λ₯˜: {e}")
# κ²°κ³Ό μš”μ•½
print("\n" + "="*80)
print("πŸ“Š μ‹€ν–‰ κ²°κ³Ό μš”μ•½")
print("="*80)
print(f"βœ… μ„±κ³΅ν•œ μž‘μ—…: {success_count}/{len(tasks)}")
if success_count == len(tasks):
print("πŸŽ‰ λͺ¨λ“  μž‘μ—…μ΄ μ„±κ³΅μ μœΌλ‘œ μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€!")
# μƒμ„±λœ νŒŒμΌλ“€ μ•ˆλ‚΄
result_files = []
for pattern in ['test_results_*.json', 'rag_tuning_results_*.json', 'optimized_config.py']:
import glob
result_files.extend(glob.glob(pattern))
if result_files:
print(f"πŸ“„ μƒμ„±λœ κ²°κ³Ό 파일: {', '.join(result_files)}")
sys.exit(0)
else:
print("⚠️ 일뢀 μž‘μ—…μ—μ„œ λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€.")
sys.exit(1)
if __name__ == "__main__":
main()