RFP_summary_chatbot / src /evaluation /run_experiment.py
Dongjin1203's picture
Initial commit for HF Spaces deployment
4739096
"""
RAG 검색 μ‹œμŠ€ν…œ 평가 도ꡬ
- LangSmith Experiment μ‹€ν–‰
- Context Precision/Recall 평가
- μ‹€ν—˜ 좔적 및 비ꡐ
μ‚¬μš©λ²•:
python run_experiment.py # λŒ€ν™”ν˜• 메뉴
python run_experiment.py --run # μ‹€ν—˜ μ‹€ν–‰
python run_experiment.py --compare # μ‹€ν—˜ 비ꡐ
"""
import os
import re
import sys
import argparse
from pathlib import Path
from typing import Dict, List, Any
from langsmith import Client, evaluate
from dotenv import load_dotenv
# ν”„λ‘œμ νŠΈ 경둜 μΆ”κ°€
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(project_root))
from src.retriever.retriever import RAGRetriever
from src.utils.config import RAGConfig
from src.evaluation.experiment_tracker import ExperimentTracker
# === ν™˜κ²½ μ„€μ • ===
load_dotenv()
os.environ["LANGCHAIN_PROJECT"] = "RAG-Retriever-Eval"
os.environ["LANGCHAIN_TRACING_V2"] = "true"
# === μ „μ—­ λ³€μˆ˜ ===
retriever = None
# ============================================================
# Evaluator ν•¨μˆ˜λ“€
# ============================================================
def normalize_text(text: str) -> str:
"""ν…μŠ€νŠΈ μ •κ·œν™”"""
# μ†Œλ¬Έμž λ³€ν™˜
normalized = text.lower()
# 특수문자 제거
normalized = re.sub(r'[\r\n\t]+', ' ', normalized)
# 연속 곡백 ν•˜λ‚˜λ‘œ
normalized = ' '.join(normalized.split())
return normalized.strip()
def is_matching_context(retrieved_text: str, ground_truth_text: str, threshold: float = 0.5) -> bool:
"""두 λ¬Έμ„œκ°€ 같은 λ¬Έμ„œμΈμ§€ νŒλ‹¨"""
normalized_retrieved = normalize_text(retrieved_text)
normalized_truth = normalize_text(ground_truth_text)
# μ™„μ „ 포함 체크
if normalized_truth in normalized_retrieved:
return True
if normalized_retrieved in normalized_truth:
return True
# 단어 컀버리지 체크
truth_words = set(normalized_truth.split())
retrieved_words = set(normalized_retrieved.split())
if len(truth_words) == 0:
return False
matched_words = truth_words & retrieved_words
coverage = len(matched_words) / len(truth_words)
return coverage >= threshold
def count_matching_contexts(
retrieved_contexts: List[str],
ground_truth_contexts: List[str],
threshold: float = 0.5
) -> int:
"""λ§€μΉ­λ˜λŠ” λ¬Έμ„œ 개수 계산"""
matched_count = 0
for retrieved in retrieved_contexts:
for truth in ground_truth_contexts:
if is_matching_context(retrieved, truth, threshold):
matched_count += 1
break
return matched_count
def context_precision_evaluator(run: Any, example: Any) -> Dict[str, float]:
"""Context Precision 평가"""
try:
# 검색 κ²°κ³Ό μΆ”μΆœ
if isinstance(run.outputs, dict):
retrieved_results = run.outputs.get('output', [])
else:
retrieved_results = run.outputs
# ν…μŠ€νŠΈλ§Œ μΆ”μΆœ
retrieved_contexts = []
for result in retrieved_results:
if isinstance(result, dict):
text = result.get('content', '')
if text:
retrieved_contexts.append(text)
# μ •λ‹΅ μΆ”μΆœ
ground_truth_contexts = example.outputs.get('ground_truth_contexts', [])
# 검증
if len(retrieved_contexts) == 0:
return {"key": "context_precision", "score": 0.0, "comment": "검색 κ²°κ³Ό μ—†μŒ"}
if len(ground_truth_contexts) == 0:
return {"key": "context_precision", "score": 0.0, "comment": "μ •λ‹΅ μ—†μŒ"}
# λ§€μΉ­ 개수 계산
matched_count = count_matching_contexts(
retrieved_contexts,
ground_truth_contexts,
threshold=0.5
)
# Precision 계산
precision = matched_count / len(retrieved_contexts)
return {
"key": "context_precision",
"score": precision,
"comment": f"λ§€μΉ­: {matched_count}/{len(retrieved_contexts)}"
}
except Exception as e:
print(f"Context Precision 계산 였λ₯˜: {e}")
import traceback
traceback.print_exc()
return {"key": "context_precision", "score": 0.0, "comment": f"였λ₯˜: {str(e)}"}
def context_recall_evaluator(run: Any, example: Any) -> Dict[str, float]:
"""Context Recall 평가"""
try:
# 검색 κ²°κ³Ό μΆ”μΆœ
if isinstance(run.outputs, dict):
retrieved_results = run.outputs.get('output', [])
else:
retrieved_results = run.outputs
retrieved_contexts = []
for result in retrieved_results:
if isinstance(result, dict):
text = result.get('content', '')
if text:
retrieved_contexts.append(text)
# μ •λ‹΅ μΆ”μΆœ
ground_truth_contexts = example.outputs.get('ground_truth_contexts', [])
# 검증
if len(ground_truth_contexts) == 0:
return {"key": "context_recall", "score": 0.0, "comment": "μ •λ‹΅ μ—†μŒ"}
if len(retrieved_contexts) == 0:
return {"key": "context_recall", "score": 0.0, "comment": "검색 κ²°κ³Ό μ—†μŒ"}
# λ§€μΉ­ 개수 계산
matched_count = 0
for truth in ground_truth_contexts:
for retrieved in retrieved_contexts:
if is_matching_context(retrieved, truth, threshold=0.5):
matched_count += 1
break
# Recall 계산
recall = matched_count / len(ground_truth_contexts)
return {
"key": "context_recall",
"score": recall,
"comment": f"발견: {matched_count}/{len(ground_truth_contexts)}"
}
except Exception as e:
print(f"Context Recall 계산 였λ₯˜: {e}")
import traceback
traceback.print_exc()
return {"key": "context_recall", "score": 0.0, "comment": f"였λ₯˜: {str(e)}"}
def retrieval_time_evaluator(run: Any, example: Any) -> Dict[str, float]:
"""검색 μ‹œκ°„ μΈ‘μ •"""
try:
latency = run.execution_time
return {
"key": "retrieval_time",
"score": latency,
"comment": f"{latency:.3f}초"
}
except Exception as e:
return {"key": "retrieval_time", "score": 0.0, "comment": "μ‹œκ°„ μΈ‘μ • μ‹€νŒ¨"}
# ============================================================
# Target ν•¨μˆ˜
# ============================================================
def retriever_target(inputs: dict) -> dict:
"""LangSmith Experiment용 검색 ν•¨μˆ˜"""
question = inputs.get("question", "")
if not question:
return {"output": []}
# ν•˜μ΄λΈŒλ¦¬λ“œ 검색 + Re-ranker μ‹€ν–‰
results = retriever.search_with_mode(
query=question,
top_k=None,
mode="hybrid_rerank",
alpha=0.5
)
return {"output": results}
# ============================================================
# μ‹€ν—˜ μ‹€ν–‰
# ============================================================
def run_experiment(
experiment_name: str,
config: dict,
dataset_name: str = "RAG-Retriever-TestSet-v1",
notes: str = ""
) -> dict:
"""
μ‹€ν—˜ μ‹€ν–‰ 및 μžλ™ 좔적
Args:
experiment_name: μ‹€ν—˜ 이름
config: μ‹€ν—˜ μ„€μ •
dataset_name: Dataset 이름
notes: λ©”λͺ¨
Returns:
μ‹€ν—˜ κ²°κ³Ό
"""
global retriever
print("\n" + "="*80)
print(f"πŸš€ μ‹€ν—˜ μ‹œμž‘: {experiment_name}")
print("="*80)
# 1. 검색기 μ΄ˆκΈ°ν™”
print("\nπŸ”§ 검색기 μ΄ˆκΈ°ν™”...")
rag_config = RAGConfig()
# Config 적용
if 'embedding_model' in config:
rag_config.EMBEDDING_MODEL_NAME = config['embedding_model']
if 'top_k' in config:
rag_config.DEFAULT_TOP_K = config['top_k']
retriever = RAGRetriever(config=rag_config)
print(f"βœ… μ„€μ • μ™„λ£Œ:")
print(f" μž„λ² λ”© λͺ¨λΈ: {rag_config.EMBEDDING_MODEL_NAME}")
print(f" Top-K: {rag_config.DEFAULT_TOP_K}")
# 2. Evaluators μ„€μ •
evaluators_list = [
context_precision_evaluator,
context_recall_evaluator,
]
# 3. LangSmith Client μ΄ˆκΈ°ν™”
client = Client()
# 4. Experiment μ‹€ν–‰
print(f"\n⏳ Experiment μ‹€ν–‰ 쀑...")
try:
results = evaluate(
retriever_target,
data=dataset_name,
evaluators=evaluators_list,
experiment_prefix=experiment_name,
max_concurrency=1,
)
print(f"\nβœ… Experiment μ™„λ£Œ!")
# 5. κ²°κ³Ό μΆ”μΆœ
df = results.to_pandas()
metrics = {
"precision": df["feedback.context_precision"].mean(),
"recall": df["feedback.context_recall"].mean(),
"avg_time": df["execution_time"].mean(),
}
# 6. μžλ™ 좔적 μ €μž₯
tracker = ExperimentTracker()
langsmith_url = "https://smith.langchain.com/"
tracker.log_experiment(
experiment_name=experiment_name,
config=config,
metrics=metrics,
langsmith_url=langsmith_url,
notes=notes
)
# 7. κ²°κ³Ό 좜λ ₯
print("\n" + "="*80)
print("πŸ“Š μ‹€ν—˜ κ²°κ³Ό")
print("="*80)
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
f1 = 0
if (metrics['precision'] + metrics['recall']) > 0:
f1 = 2 * metrics['precision'] * metrics['recall'] / (metrics['precision'] + metrics['recall'])
print(f"F1: {f1:.4f}")
print(f"평균 검색 μ‹œκ°„: {metrics['avg_time']:.3f}초")
print("="*80)
return results
except Exception as e:
print(f"\n❌ μ‹€ν—˜ μ‹€νŒ¨: {e}")
import traceback
traceback.print_exc()
raise
# ============================================================
# λŒ€ν™”ν˜• 메뉴
# ============================================================
def interactive_run():
"""λŒ€ν™”ν˜• μ‹€ν—˜ μ‹€ν–‰"""
print("\n" + "="*80)
print("πŸ§ͺ RAG 검색 μ‹œμŠ€ν…œ μ„±λŠ₯ μ‹€ν—˜")
print("="*80)
# μ‹€ν—˜ μ„€μ • μž…λ ₯
print("\nμ‹€ν—˜ 섀정을 μž…λ ₯ν•˜μ„Έμš”:")
experiment_name = input("μ‹€ν—˜ 이름 (예: baseline, hybrid-rerank): ").strip()
if not experiment_name:
experiment_name = "experiment"
embedding_model = input("μž„λ² λ”© λͺ¨λΈ (μ—”ν„°: text-embedding-3-small): ").strip()
if not embedding_model:
embedding_model = "text-embedding-3-small"
top_k_input = input("Top-K (μ—”ν„°: 10): ").strip()
top_k = int(top_k_input) if top_k_input else 10
notes = input("λ©”λͺ¨ (선택사항): ").strip()
# μ„€μ • ꡬ성
config = {
"embedding_model": embedding_model,
"top_k": top_k,
}
# 확인
print("\n" + "="*80)
print("πŸ“‹ μ‹€ν—˜ 정보 확인")
print("="*80)
print(f"μ‹€ν—˜ 이름: {experiment_name}")
print(f"μž„λ² λ”© λͺ¨λΈ: {embedding_model}")
print(f"Top-K: {top_k}")
if notes:
print(f"λ©”λͺ¨: {notes}")
print("="*80)
confirm = input("\nμ‹€ν—˜μ„ μ‹œμž‘ν•˜μ‹œκ² μŠ΅λ‹ˆκΉŒ? (y/n): ").strip().lower()
if confirm != 'y':
print("❌ μ·¨μ†Œλ¨")
return
# μ‹€ν—˜ μ‹€ν–‰
run_experiment(
experiment_name=experiment_name,
config=config,
notes=notes
)
def interactive_compare():
"""λŒ€ν™”ν˜• μ‹€ν—˜ 비ꡐ"""
tracker = ExperimentTracker()
print("\n" + "="*80)
print("πŸ” μ‹€ν—˜ 비ꡐ 도ꡬ")
print("="*80)
while True:
print("\n메뉴:")
print(" 1. λͺ¨λ“  μ‹€ν—˜ λͺ©λ‘ 보기")
print(" 2. 졜근 μ‹€ν—˜ 비ꡐ (졜근 5개)")
print(" 3. νŠΉμ • μ‹€ν—˜ 비ꡐ")
print(" 4. κ°œμ„  효과 확인")
print(" 5. 차트 생성")
print(" 6. 졜적 μ„€μ • μΆ”μ²œ")
print(" 0. μ’…λ£Œ")
choice = input("\n선택: ").strip()
if choice == "1":
tracker.list_experiments()
elif choice == "2":
tracker.compare_experiments(top_n=5)
elif choice == "3":
names = input("μ‹€ν—˜ 이름듀 (μ‰Όν‘œλ‘œ ꡬ뢄): ").strip()
if names:
experiment_names = [n.strip() for n in names.split(',')]
tracker.compare_experiments(experiment_names=experiment_names)
elif choice == "4":
baseline = input("Baseline μ‹€ν—˜ 이름: ").strip()
current = input("비ꡐ할 μ‹€ν—˜ 이름: ").strip()
if baseline and current:
tracker.show_improvement(baseline, current)
elif choice == "5":
names_input = input("μ‹€ν—˜ 이름듀 (μ‰Όν‘œλ‘œ ꡬ뢄, μ—”ν„°: 전체): ").strip()
if names_input:
experiment_names = [n.strip() for n in names_input.split(',')]
else:
experiment_names = None
tracker.plot_metrics(experiment_names=experiment_names)
elif choice == "6":
metric = input("κΈ°μ€€ μ§€ν‘œ (precision/recall/f1, μ—”ν„°: f1): ").strip()
if not metric:
metric = "f1"
tracker.recommend_best(metric=metric)
elif choice == "0":
print("πŸ‘‹ μ’…λ£Œν•©λ‹ˆλ‹€")
break
else:
print("❌ 잘λͺ»λœ μ„ νƒμž…λ‹ˆλ‹€")
def main_menu():
"""메인 메뉴"""
print("\n" + "="*80)
print("πŸ”¬ RAG 평가 μ‹œμŠ€ν…œ")
print("="*80)
while True:
print("\n메뉴:")
print(" 1. μ‹€ν—˜ μ‹€ν–‰")
print(" 2. μ‹€ν—˜ 비ꡐ")
print(" 0. μ’…λ£Œ")
choice = input("\n선택: ").strip()
if choice == "1":
interactive_run()
elif choice == "2":
interactive_compare()
elif choice == "0":
print("πŸ‘‹ μ’…λ£Œν•©λ‹ˆλ‹€")
break
else:
print("❌ 잘λͺ»λœ μ„ νƒμž…λ‹ˆλ‹€")
# ============================================================
# 메인 μ‹€ν–‰
# ============================================================
def main():
"""메인 μ‹€ν–‰"""
parser = argparse.ArgumentParser(description='RAG 평가 μ‹œμŠ€ν…œ')
parser.add_argument(
'--run',
action='store_true',
help='μ‹€ν—˜ μ‹€ν–‰ λͺ¨λ“œ'
)
parser.add_argument(
'--compare',
action='store_true',
help='μ‹€ν—˜ 비ꡐ λͺ¨λ“œ'
)
args = parser.parse_args()
try:
if args.run:
interactive_run()
elif args.compare:
interactive_compare()
else:
main_menu()
except KeyboardInterrupt:
print("\n\n⚠️ 쀑단됨")
except Exception as e:
print(f"\n❌ 였λ₯˜: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()