|
|
""" |
|
|
RAG κ²μ μμ€ν
νκ° λꡬ |
|
|
- LangSmith Experiment μ€ν |
|
|
- Context Precision/Recall νκ° |
|
|
- μ€ν μΆμ λ° λΉκ΅ |
|
|
|
|
|
μ¬μ©λ²: |
|
|
python run_experiment.py # λνν λ©λ΄ |
|
|
python run_experiment.py --run # μ€ν μ€ν |
|
|
python run_experiment.py --compare # μ€ν λΉκ΅ |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import sys |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any |
|
|
from langsmith import Client, evaluate |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
project_root = Path(__file__).resolve().parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from src.retriever.retriever import RAGRetriever |
|
|
from src.utils.config import RAGConfig |
|
|
from src.evaluation.experiment_tracker import ExperimentTracker |
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
os.environ["LANGCHAIN_PROJECT"] = "RAG-Retriever-Eval" |
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
|
|
|
|
|
|
|
|
|
|
retriever = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_text(text: str) -> str: |
|
|
"""ν
μ€νΈ μ κ·ν""" |
|
|
|
|
|
normalized = text.lower() |
|
|
|
|
|
|
|
|
normalized = re.sub(r'[\r\n\t]+', ' ', normalized) |
|
|
|
|
|
|
|
|
normalized = ' '.join(normalized.split()) |
|
|
|
|
|
return normalized.strip() |
|
|
|
|
|
|
|
|
def is_matching_context(retrieved_text: str, ground_truth_text: str, threshold: float = 0.5) -> bool: |
|
|
"""λ λ¬Έμκ° κ°μ λ¬ΈμμΈμ§ νλ¨""" |
|
|
normalized_retrieved = normalize_text(retrieved_text) |
|
|
normalized_truth = normalize_text(ground_truth_text) |
|
|
|
|
|
|
|
|
if normalized_truth in normalized_retrieved: |
|
|
return True |
|
|
|
|
|
if normalized_retrieved in normalized_truth: |
|
|
return True |
|
|
|
|
|
|
|
|
truth_words = set(normalized_truth.split()) |
|
|
retrieved_words = set(normalized_retrieved.split()) |
|
|
|
|
|
if len(truth_words) == 0: |
|
|
return False |
|
|
|
|
|
matched_words = truth_words & retrieved_words |
|
|
coverage = len(matched_words) / len(truth_words) |
|
|
|
|
|
return coverage >= threshold |
|
|
|
|
|
|
|
|
def count_matching_contexts( |
|
|
retrieved_contexts: List[str], |
|
|
ground_truth_contexts: List[str], |
|
|
threshold: float = 0.5 |
|
|
) -> int: |
|
|
"""λ§€μΉλλ λ¬Έμ κ°μ κ³μ°""" |
|
|
matched_count = 0 |
|
|
|
|
|
for retrieved in retrieved_contexts: |
|
|
for truth in ground_truth_contexts: |
|
|
if is_matching_context(retrieved, truth, threshold): |
|
|
matched_count += 1 |
|
|
break |
|
|
|
|
|
return matched_count |
|
|
|
|
|
|
|
|
def context_precision_evaluator(run: Any, example: Any) -> Dict[str, float]: |
|
|
"""Context Precision νκ°""" |
|
|
try: |
|
|
|
|
|
if isinstance(run.outputs, dict): |
|
|
retrieved_results = run.outputs.get('output', []) |
|
|
else: |
|
|
retrieved_results = run.outputs |
|
|
|
|
|
|
|
|
retrieved_contexts = [] |
|
|
for result in retrieved_results: |
|
|
if isinstance(result, dict): |
|
|
text = result.get('content', '') |
|
|
if text: |
|
|
retrieved_contexts.append(text) |
|
|
|
|
|
|
|
|
ground_truth_contexts = example.outputs.get('ground_truth_contexts', []) |
|
|
|
|
|
|
|
|
if len(retrieved_contexts) == 0: |
|
|
return {"key": "context_precision", "score": 0.0, "comment": "κ²μ κ²°κ³Ό μμ"} |
|
|
|
|
|
if len(ground_truth_contexts) == 0: |
|
|
return {"key": "context_precision", "score": 0.0, "comment": "μ λ΅ μμ"} |
|
|
|
|
|
|
|
|
matched_count = count_matching_contexts( |
|
|
retrieved_contexts, |
|
|
ground_truth_contexts, |
|
|
threshold=0.5 |
|
|
) |
|
|
|
|
|
|
|
|
precision = matched_count / len(retrieved_contexts) |
|
|
|
|
|
return { |
|
|
"key": "context_precision", |
|
|
"score": precision, |
|
|
"comment": f"λ§€μΉ: {matched_count}/{len(retrieved_contexts)}" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Context Precision κ³μ° μ€λ₯: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return {"key": "context_precision", "score": 0.0, "comment": f"μ€λ₯: {str(e)}"} |
|
|
|
|
|
|
|
|
def context_recall_evaluator(run: Any, example: Any) -> Dict[str, float]: |
|
|
"""Context Recall νκ°""" |
|
|
try: |
|
|
|
|
|
if isinstance(run.outputs, dict): |
|
|
retrieved_results = run.outputs.get('output', []) |
|
|
else: |
|
|
retrieved_results = run.outputs |
|
|
|
|
|
retrieved_contexts = [] |
|
|
for result in retrieved_results: |
|
|
if isinstance(result, dict): |
|
|
text = result.get('content', '') |
|
|
if text: |
|
|
retrieved_contexts.append(text) |
|
|
|
|
|
|
|
|
ground_truth_contexts = example.outputs.get('ground_truth_contexts', []) |
|
|
|
|
|
|
|
|
if len(ground_truth_contexts) == 0: |
|
|
return {"key": "context_recall", "score": 0.0, "comment": "μ λ΅ μμ"} |
|
|
|
|
|
if len(retrieved_contexts) == 0: |
|
|
return {"key": "context_recall", "score": 0.0, "comment": "κ²μ κ²°κ³Ό μμ"} |
|
|
|
|
|
|
|
|
matched_count = 0 |
|
|
for truth in ground_truth_contexts: |
|
|
for retrieved in retrieved_contexts: |
|
|
if is_matching_context(retrieved, truth, threshold=0.5): |
|
|
matched_count += 1 |
|
|
break |
|
|
|
|
|
|
|
|
recall = matched_count / len(ground_truth_contexts) |
|
|
|
|
|
return { |
|
|
"key": "context_recall", |
|
|
"score": recall, |
|
|
"comment": f"λ°κ²¬: {matched_count}/{len(ground_truth_contexts)}" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Context Recall κ³μ° μ€λ₯: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return {"key": "context_recall", "score": 0.0, "comment": f"μ€λ₯: {str(e)}"} |
|
|
|
|
|
|
|
|
def retrieval_time_evaluator(run: Any, example: Any) -> Dict[str, float]: |
|
|
"""κ²μ μκ° μΈ‘μ """ |
|
|
try: |
|
|
latency = run.execution_time |
|
|
return { |
|
|
"key": "retrieval_time", |
|
|
"score": latency, |
|
|
"comment": f"{latency:.3f}μ΄" |
|
|
} |
|
|
except Exception as e: |
|
|
return {"key": "retrieval_time", "score": 0.0, "comment": "μκ° μΈ‘μ μ€ν¨"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retriever_target(inputs: dict) -> dict: |
|
|
"""LangSmith Experimentμ© κ²μ ν¨μ""" |
|
|
question = inputs.get("question", "") |
|
|
|
|
|
if not question: |
|
|
return {"output": []} |
|
|
|
|
|
|
|
|
results = retriever.search_with_mode( |
|
|
query=question, |
|
|
top_k=None, |
|
|
mode="hybrid_rerank", |
|
|
alpha=0.5 |
|
|
) |
|
|
|
|
|
return {"output": results} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_experiment( |
|
|
experiment_name: str, |
|
|
config: dict, |
|
|
dataset_name: str = "RAG-Retriever-TestSet-v1", |
|
|
notes: str = "" |
|
|
) -> dict: |
|
|
""" |
|
|
μ€ν μ€ν λ° μλ μΆμ |
|
|
|
|
|
Args: |
|
|
experiment_name: μ€ν μ΄λ¦ |
|
|
config: μ€ν μ€μ |
|
|
dataset_name: Dataset μ΄λ¦ |
|
|
notes: λ©λͺ¨ |
|
|
|
|
|
Returns: |
|
|
μ€ν κ²°κ³Ό |
|
|
""" |
|
|
global retriever |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print(f"π μ€ν μμ: {experiment_name}") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
print("\nπ§ κ²μκΈ° μ΄κΈ°ν...") |
|
|
rag_config = RAGConfig() |
|
|
|
|
|
|
|
|
if 'embedding_model' in config: |
|
|
rag_config.EMBEDDING_MODEL_NAME = config['embedding_model'] |
|
|
if 'top_k' in config: |
|
|
rag_config.DEFAULT_TOP_K = config['top_k'] |
|
|
|
|
|
retriever = RAGRetriever(config=rag_config) |
|
|
|
|
|
print(f"β
μ€μ μλ£:") |
|
|
print(f" μλ² λ© λͺ¨λΈ: {rag_config.EMBEDDING_MODEL_NAME}") |
|
|
print(f" Top-K: {rag_config.DEFAULT_TOP_K}") |
|
|
|
|
|
|
|
|
evaluators_list = [ |
|
|
context_precision_evaluator, |
|
|
context_recall_evaluator, |
|
|
] |
|
|
|
|
|
|
|
|
client = Client() |
|
|
|
|
|
|
|
|
print(f"\nβ³ Experiment μ€ν μ€...") |
|
|
|
|
|
try: |
|
|
results = evaluate( |
|
|
retriever_target, |
|
|
data=dataset_name, |
|
|
evaluators=evaluators_list, |
|
|
experiment_prefix=experiment_name, |
|
|
max_concurrency=1, |
|
|
) |
|
|
|
|
|
print(f"\nβ
Experiment μλ£!") |
|
|
|
|
|
|
|
|
df = results.to_pandas() |
|
|
|
|
|
metrics = { |
|
|
"precision": df["feedback.context_precision"].mean(), |
|
|
"recall": df["feedback.context_recall"].mean(), |
|
|
"avg_time": df["execution_time"].mean(), |
|
|
} |
|
|
|
|
|
|
|
|
tracker = ExperimentTracker() |
|
|
|
|
|
langsmith_url = "https://smith.langchain.com/" |
|
|
|
|
|
tracker.log_experiment( |
|
|
experiment_name=experiment_name, |
|
|
config=config, |
|
|
metrics=metrics, |
|
|
langsmith_url=langsmith_url, |
|
|
notes=notes |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π μ€ν κ²°κ³Ό") |
|
|
print("="*80) |
|
|
print(f"Precision: {metrics['precision']:.4f}") |
|
|
print(f"Recall: {metrics['recall']:.4f}") |
|
|
|
|
|
f1 = 0 |
|
|
if (metrics['precision'] + metrics['recall']) > 0: |
|
|
f1 = 2 * metrics['precision'] * metrics['recall'] / (metrics['precision'] + metrics['recall']) |
|
|
print(f"F1: {f1:.4f}") |
|
|
print(f"νκ· κ²μ μκ°: {metrics['avg_time']:.3f}μ΄") |
|
|
print("="*80) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ μ€ν μ€ν¨: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interactive_run(): |
|
|
"""λνν μ€ν μ€ν""" |
|
|
print("\n" + "="*80) |
|
|
print("π§ͺ RAG κ²μ μμ€ν
μ±λ₯ μ€ν") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
print("\nμ€ν μ€μ μ μ
λ ₯νμΈμ:") |
|
|
|
|
|
experiment_name = input("μ€ν μ΄λ¦ (μ: baseline, hybrid-rerank): ").strip() |
|
|
if not experiment_name: |
|
|
experiment_name = "experiment" |
|
|
|
|
|
embedding_model = input("μλ² λ© λͺ¨λΈ (μν°: text-embedding-3-small): ").strip() |
|
|
if not embedding_model: |
|
|
embedding_model = "text-embedding-3-small" |
|
|
|
|
|
top_k_input = input("Top-K (μν°: 10): ").strip() |
|
|
top_k = int(top_k_input) if top_k_input else 10 |
|
|
|
|
|
notes = input("λ©λͺ¨ (μ νμ¬ν): ").strip() |
|
|
|
|
|
|
|
|
config = { |
|
|
"embedding_model": embedding_model, |
|
|
"top_k": top_k, |
|
|
} |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π μ€ν μ 보 νμΈ") |
|
|
print("="*80) |
|
|
print(f"μ€ν μ΄λ¦: {experiment_name}") |
|
|
print(f"μλ² λ© λͺ¨λΈ: {embedding_model}") |
|
|
print(f"Top-K: {top_k}") |
|
|
if notes: |
|
|
print(f"λ©λͺ¨: {notes}") |
|
|
print("="*80) |
|
|
|
|
|
confirm = input("\nμ€νμ μμνμκ² μ΅λκΉ? (y/n): ").strip().lower() |
|
|
if confirm != 'y': |
|
|
print("β μ·¨μλ¨") |
|
|
return |
|
|
|
|
|
|
|
|
run_experiment( |
|
|
experiment_name=experiment_name, |
|
|
config=config, |
|
|
notes=notes |
|
|
) |
|
|
|
|
|
|
|
|
def interactive_compare(): |
|
|
"""λνν μ€ν λΉκ΅""" |
|
|
tracker = ExperimentTracker() |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π μ€ν λΉκ΅ λꡬ") |
|
|
print("="*80) |
|
|
|
|
|
while True: |
|
|
print("\nλ©λ΄:") |
|
|
print(" 1. λͺ¨λ μ€ν λͺ©λ‘ 보기") |
|
|
print(" 2. μ΅κ·Ό μ€ν λΉκ΅ (μ΅κ·Ό 5κ°)") |
|
|
print(" 3. νΉμ μ€ν λΉκ΅") |
|
|
print(" 4. κ°μ ν¨κ³Ό νμΈ") |
|
|
print(" 5. μ°¨νΈ μμ±") |
|
|
print(" 6. μ΅μ μ€μ μΆμ²") |
|
|
print(" 0. μ’
λ£") |
|
|
|
|
|
choice = input("\nμ ν: ").strip() |
|
|
|
|
|
if choice == "1": |
|
|
tracker.list_experiments() |
|
|
|
|
|
elif choice == "2": |
|
|
tracker.compare_experiments(top_n=5) |
|
|
|
|
|
elif choice == "3": |
|
|
names = input("μ€ν μ΄λ¦λ€ (μΌνλ‘ κ΅¬λΆ): ").strip() |
|
|
if names: |
|
|
experiment_names = [n.strip() for n in names.split(',')] |
|
|
tracker.compare_experiments(experiment_names=experiment_names) |
|
|
|
|
|
elif choice == "4": |
|
|
baseline = input("Baseline μ€ν μ΄λ¦: ").strip() |
|
|
current = input("λΉκ΅ν μ€ν μ΄λ¦: ").strip() |
|
|
|
|
|
if baseline and current: |
|
|
tracker.show_improvement(baseline, current) |
|
|
|
|
|
elif choice == "5": |
|
|
names_input = input("μ€ν μ΄λ¦λ€ (μΌνλ‘ κ΅¬λΆ, μν°: μ 체): ").strip() |
|
|
|
|
|
if names_input: |
|
|
experiment_names = [n.strip() for n in names_input.split(',')] |
|
|
else: |
|
|
experiment_names = None |
|
|
|
|
|
tracker.plot_metrics(experiment_names=experiment_names) |
|
|
|
|
|
elif choice == "6": |
|
|
metric = input("κΈ°μ€ μ§ν (precision/recall/f1, μν°: f1): ").strip() |
|
|
if not metric: |
|
|
metric = "f1" |
|
|
|
|
|
tracker.recommend_best(metric=metric) |
|
|
|
|
|
elif choice == "0": |
|
|
print("π μ’
λ£ν©λλ€") |
|
|
break |
|
|
|
|
|
else: |
|
|
print("β μλͺ»λ μ νμ
λλ€") |
|
|
|
|
|
|
|
|
def main_menu(): |
|
|
"""λ©μΈ λ©λ΄""" |
|
|
print("\n" + "="*80) |
|
|
print("π¬ RAG νκ° μμ€ν
") |
|
|
print("="*80) |
|
|
|
|
|
while True: |
|
|
print("\nλ©λ΄:") |
|
|
print(" 1. μ€ν μ€ν") |
|
|
print(" 2. μ€ν λΉκ΅") |
|
|
print(" 0. μ’
λ£") |
|
|
|
|
|
choice = input("\nμ ν: ").strip() |
|
|
|
|
|
if choice == "1": |
|
|
interactive_run() |
|
|
|
|
|
elif choice == "2": |
|
|
interactive_compare() |
|
|
|
|
|
elif choice == "0": |
|
|
print("π μ’
λ£ν©λλ€") |
|
|
break |
|
|
|
|
|
else: |
|
|
print("β μλͺ»λ μ νμ
λλ€") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""λ©μΈ μ€ν""" |
|
|
parser = argparse.ArgumentParser(description='RAG νκ° μμ€ν
') |
|
|
|
|
|
parser.add_argument( |
|
|
'--run', |
|
|
action='store_true', |
|
|
help='μ€ν μ€ν λͺ¨λ' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--compare', |
|
|
action='store_true', |
|
|
help='μ€ν λΉκ΅ λͺ¨λ' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
try: |
|
|
if args.run: |
|
|
interactive_run() |
|
|
elif args.compare: |
|
|
interactive_compare() |
|
|
else: |
|
|
main_menu() |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nβ οΈ μ€λ¨λ¨") |
|
|
except Exception as e: |
|
|
print(f"\nβ μ€λ₯: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |