|
|
|
|
|
""" |
|
|
RAG ๊ฒ์ ์์คํ
์คํ ์ถ์ ๋ฐ ๋น๊ต ๋๊ตฌ |
|
|
|
|
|
๊ธฐ๋ฅ: |
|
|
1. ์คํ ๊ฒฐ๊ณผ ์๋ ์ ์ฅ |
|
|
2. ์ด์ ์คํ๊ณผ ๋น๊ต |
|
|
3. ์ฑ๋ฅ ์ฐจํธ ์์ฑ |
|
|
4. ์ต์ ์ค์ ์ถ์ฒ |
|
|
""" |
|
|
|
|
|
import json |
|
|
import pandas as pd |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any, Optional |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
|
|
|
|
|
|
class ExperimentTracker: |
|
|
"""์คํ ์ถ์ ๋ฐ ๋น๊ต ํด๋์ค""" |
|
|
|
|
|
def __init__(self, log_dir: str = "src/evaluation/results/experiments"): |
|
|
""" |
|
|
Args: |
|
|
log_dir: ์คํ ๋ก๊ทธ ์ ์ฅ ๋๋ ํ ๋ฆฌ |
|
|
""" |
|
|
self.log_dir = Path(log_dir) |
|
|
self.log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.log_file = self.log_dir / "experiments_log.json" |
|
|
self.summary_file = self.log_dir / "experiments_summary.csv" |
|
|
|
|
|
|
|
|
if not self.log_file.exists(): |
|
|
self._save_log([]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_experiment( |
|
|
self, |
|
|
experiment_name: str, |
|
|
config: Dict[str, Any], |
|
|
metrics: Dict[str, float], |
|
|
langsmith_url: Optional[str] = None, |
|
|
notes: str = "" |
|
|
) -> None: |
|
|
""" |
|
|
์คํ ๊ฒฐ๊ณผ ์ ์ฅ |
|
|
|
|
|
Args: |
|
|
experiment_name: ์คํ ์ด๋ฆ (์: "baseline", "embedding-small") |
|
|
config: ์ค์ ์ ๋ณด (์๋ฒ ๋ฉ ๋ชจ๋ธ, Top-K ๋ฑ) |
|
|
metrics: ํ๊ฐ ์งํ (precision, recall ๋ฑ) |
|
|
langsmith_url: LangSmith ๊ฒฐ๊ณผ URL |
|
|
notes: ์ถ๊ฐ ๋ฉ๋ชจ |
|
|
""" |
|
|
|
|
|
experiment_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"experiment_name": experiment_name, |
|
|
"config": config, |
|
|
"metrics": metrics, |
|
|
"langsmith_url": langsmith_url, |
|
|
"notes": notes |
|
|
} |
|
|
|
|
|
|
|
|
logs = self._load_log() |
|
|
|
|
|
|
|
|
logs.append(experiment_data) |
|
|
|
|
|
|
|
|
self._save_log(logs) |
|
|
self._update_summary() |
|
|
|
|
|
print(f"โ
์คํ '{experiment_name}' ์ ์ฅ ์๋ฃ") |
|
|
print(f" Precision: {metrics.get('precision', 0):.4f}") |
|
|
print(f" Recall: {metrics.get('recall', 0):.4f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compare_experiments( |
|
|
self, |
|
|
experiment_names: Optional[List[str]] = None, |
|
|
top_n: int = 5 |
|
|
) -> pd.DataFrame: |
|
|
""" |
|
|
์คํ ๊ฒฐ๊ณผ ๋น๊ต |
|
|
|
|
|
Args: |
|
|
experiment_names: ๋น๊ตํ ์คํ ์ด๋ฆ ๋ฆฌ์คํธ (None์ด๋ฉด ์ต๊ทผ ์คํ) |
|
|
top_n: experiment_names๊ฐ None์ผ ๋ ์ต๊ทผ ๋ช ๊ฐ ๋น๊ตํ ์ง |
|
|
|
|
|
Returns: |
|
|
๋น๊ต ๊ฒฐ๊ณผ DataFrame |
|
|
""" |
|
|
logs = self._load_log() |
|
|
|
|
|
if not logs: |
|
|
print("โ ๏ธ ์ ์ฅ๋ ์คํ์ด ์์ต๋๋ค") |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
if experiment_names is None: |
|
|
|
|
|
selected_logs = logs[-top_n:] |
|
|
else: |
|
|
|
|
|
selected_logs = [ |
|
|
log for log in logs |
|
|
if log['experiment_name'] in experiment_names |
|
|
] |
|
|
|
|
|
if not selected_logs: |
|
|
print("โ ๏ธ ๋น๊ตํ ์คํ์ ์ฐพ์ ์ ์์ต๋๋ค") |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
comparison_data = [] |
|
|
for log in selected_logs: |
|
|
row = { |
|
|
"์คํ๋ช
": log['experiment_name'], |
|
|
"๋ ์ง": log['timestamp'][:10], |
|
|
"์๋ฒ ๋ฉ": log['config'].get('embedding_model', 'N/A'), |
|
|
"Top-K": log['config'].get('top_k', 'N/A'), |
|
|
"Precision": log['metrics'].get('precision', 0), |
|
|
"Recall": log['metrics'].get('recall', 0), |
|
|
"F1": self._calculate_f1( |
|
|
log['metrics'].get('precision', 0), |
|
|
log['metrics'].get('recall', 0) |
|
|
), |
|
|
"๊ฒ์์๊ฐ(์ด)": log['metrics'].get('avg_time', 0) |
|
|
} |
|
|
comparison_data.append(row) |
|
|
|
|
|
df = pd.DataFrame(comparison_data) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("๐ ์คํ ๋น๊ต ๊ฒฐ๊ณผ") |
|
|
print("="*80) |
|
|
print(df.to_string(index=False)) |
|
|
print("="*80) |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
def show_improvement(self, baseline_name: str, current_name: str) -> None: |
|
|
""" |
|
|
Baseline ๋๋น ๊ฐ์ ํจ๊ณผ ์ถ๋ ฅ |
|
|
|
|
|
Args: |
|
|
baseline_name: ๊ธฐ์ค ์คํ ์ด๋ฆ |
|
|
current_name: ๋น๊ตํ ์คํ ์ด๋ฆ |
|
|
""" |
|
|
logs = self._load_log() |
|
|
|
|
|
|
|
|
baseline = next((log for log in logs if log['experiment_name'] == baseline_name), None) |
|
|
current = next((log for log in logs if log['experiment_name'] == current_name), None) |
|
|
|
|
|
if not baseline or not current: |
|
|
print("โ ๏ธ ์คํ์ ์ฐพ์ ์ ์์ต๋๋ค") |
|
|
return |
|
|
|
|
|
|
|
|
baseline_precision = baseline['metrics'].get('precision', 0) |
|
|
baseline_recall = baseline['metrics'].get('recall', 0) |
|
|
|
|
|
current_precision = current['metrics'].get('precision', 0) |
|
|
current_recall = current['metrics'].get('recall', 0) |
|
|
|
|
|
precision_improvement = (current_precision - baseline_precision) / baseline_precision * 100 if baseline_precision > 0 else 0 |
|
|
recall_improvement = (current_recall - baseline_recall) / baseline_recall * 100 if baseline_recall > 0 else 0 |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print(f"๐ ๊ฐ์ ํจ๊ณผ: {baseline_name} โ {current_name}") |
|
|
print("="*80) |
|
|
print(f"\nPrecision:") |
|
|
print(f" {baseline_name}: {baseline_precision:.4f}") |
|
|
print(f" {current_name}: {current_precision:.4f}") |
|
|
print(f" ๊ฐ์ ์จ: {precision_improvement:+.2f}% {'โ
' if precision_improvement > 0 else 'โ'}") |
|
|
|
|
|
print(f"\nRecall:") |
|
|
print(f" {baseline_name}: {baseline_recall:.4f}") |
|
|
print(f" {current_name}: {current_recall:.4f}") |
|
|
print(f" ๊ฐ์ ์จ: {recall_improvement:+.2f}% {'โ
' if recall_improvement > 0 else 'โ'}") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_metrics( |
|
|
self, |
|
|
experiment_names: Optional[List[str]] = None, |
|
|
save_path: Optional[str] = None |
|
|
) -> None: |
|
|
""" |
|
|
์คํ ๊ฒฐ๊ณผ ์ฐจํธ ์์ฑ |
|
|
|
|
|
Args: |
|
|
experiment_names: ์ฐจํธ์ ํฌํจํ ์คํ (None์ด๋ฉด ์ ์ฒด) |
|
|
save_path: ์ฐจํธ ์ ์ฅ ๊ฒฝ๋ก (None์ด๋ฉด ํ๋ฉด ์ถ๋ ฅ) |
|
|
""" |
|
|
logs = self._load_log() |
|
|
|
|
|
if not logs: |
|
|
print("โ ๏ธ ์ ์ฅ๋ ์คํ์ด ์์ต๋๋ค") |
|
|
return |
|
|
|
|
|
|
|
|
if experiment_names is not None: |
|
|
logs = [log for log in logs if log['experiment_name'] in experiment_names] |
|
|
|
|
|
if not logs: |
|
|
print("โ ๏ธ ์ฐจํธ๋ฅผ ๊ทธ๋ฆด ์คํ์ด ์์ต๋๋ค") |
|
|
return |
|
|
|
|
|
|
|
|
names = [log['experiment_name'] for log in logs] |
|
|
precisions = [log['metrics'].get('precision', 0) for log in logs] |
|
|
recalls = [log['metrics'].get('recall', 0) for log in logs] |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
|
|
|
x = range(len(names)) |
|
|
width = 0.35 |
|
|
|
|
|
ax.bar([i - width/2 for i in x], precisions, width, label='Precision', alpha=0.8) |
|
|
ax.bar([i + width/2 for i in x], recalls, width, label='Recall', alpha=0.8) |
|
|
|
|
|
ax.set_xlabel('์คํ') |
|
|
ax.set_ylabel('์ ์') |
|
|
ax.set_title('์คํ๋ณ ์ฑ๋ฅ ๋น๊ต') |
|
|
ax.set_xticks(x) |
|
|
ax.set_xticklabels(names, rotation=45, ha='right') |
|
|
ax.legend() |
|
|
ax.grid(axis='y', alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
print(f"โ
์ฐจํธ ์ ์ฅ: {save_path}") |
|
|
else: |
|
|
default_path = self.log_dir / "comparison_chart.png" |
|
|
plt.savefig(default_path, dpi=300, bbox_inches='tight') |
|
|
print(f"โ
์ฐจํธ ์ ์ฅ: {default_path}") |
|
|
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def recommend_best(self, metric: str = "f1") -> Dict[str, Any]: |
|
|
""" |
|
|
์ต์ ์ค์ ์ถ์ฒ |
|
|
|
|
|
Args: |
|
|
metric: ๊ธฐ์ค ์งํ ("precision", "recall", "f1") |
|
|
|
|
|
Returns: |
|
|
์ต์ ์คํ ์ ๋ณด |
|
|
""" |
|
|
logs = self._load_log() |
|
|
|
|
|
if not logs: |
|
|
print("โ ๏ธ ์ ์ฅ๋ ์คํ์ด ์์ต๋๋ค") |
|
|
return {} |
|
|
|
|
|
|
|
|
for log in logs: |
|
|
if 'f1' not in log['metrics']: |
|
|
p = log['metrics'].get('precision', 0) |
|
|
r = log['metrics'].get('recall', 0) |
|
|
log['metrics']['f1'] = self._calculate_f1(p, r) |
|
|
|
|
|
|
|
|
best = max(logs, key=lambda x: x['metrics'].get(metric, 0)) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print(f"๐ ์ต์ ์ค์ ({metric.upper()} ๊ธฐ์ค)") |
|
|
print("="*80) |
|
|
print(f"์คํ๋ช
: {best['experiment_name']}") |
|
|
print(f"๋ ์ง: {best['timestamp'][:10]}") |
|
|
print(f"\n์ค์ :") |
|
|
for key, value in best['config'].items(): |
|
|
print(f" {key}: {value}") |
|
|
print(f"\n์ฑ๋ฅ:") |
|
|
print(f" Precision: {best['metrics'].get('precision', 0):.4f}") |
|
|
print(f" Recall: {best['metrics'].get('recall', 0):.4f}") |
|
|
print(f" F1: {best['metrics'].get('f1', 0):.4f}") |
|
|
print("="*80) |
|
|
|
|
|
return best |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_experiments(self) -> None: |
|
|
"""์ ์ฅ๋ ์คํ ๋ชฉ๋ก ์ถ๋ ฅ""" |
|
|
logs = self._load_log() |
|
|
|
|
|
if not logs: |
|
|
print("โ ๏ธ ์ ์ฅ๋ ์คํ์ด ์์ต๋๋ค") |
|
|
return |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("๐ ์ ์ฅ๋ ์คํ ๋ชฉ๋ก") |
|
|
print("="*80) |
|
|
|
|
|
for i, log in enumerate(logs, 1): |
|
|
print(f"\n{i}. {log['experiment_name']}") |
|
|
print(f" ๋ ์ง: {log['timestamp'][:10]}") |
|
|
print(f" Precision: {log['metrics'].get('precision', 0):.4f}") |
|
|
print(f" Recall: {log['metrics'].get('recall', 0):.4f}") |
|
|
|
|
|
print("="*80) |
|
|
|
|
|
|
|
|
def clear_experiments(self) -> None: |
|
|
"""๋ชจ๋ ์คํ ๋ก๊ทธ ์ญ์ (์ฃผ์!)""" |
|
|
confirm = input("โ ๏ธ ๋ชจ๋ ์คํ ๋ก๊ทธ๋ฅผ ์ญ์ ํ์๊ฒ ์ต๋๊น? (yes/no): ") |
|
|
if confirm.lower() == 'yes': |
|
|
self._save_log([]) |
|
|
self._update_summary() |
|
|
print("โ
๋ชจ๋ ์คํ ๋ก๊ทธ ์ญ์ ์๋ฃ") |
|
|
else: |
|
|
print("โ ์ทจ์๋จ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_log(self) -> List[Dict]: |
|
|
"""๋ก๊ทธ ํ์ผ ๋ก๋""" |
|
|
if not self.log_file.exists(): |
|
|
return [] |
|
|
|
|
|
with open(self.log_file, 'r', encoding='utf-8') as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def _save_log(self, logs: List[Dict]) -> None: |
|
|
"""๋ก๊ทธ ํ์ผ ์ ์ฅ""" |
|
|
with open(self.log_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(logs, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
def _update_summary(self) -> None: |
|
|
"""์์ฝ CSV ์
๋ฐ์ดํธ""" |
|
|
logs = self._load_log() |
|
|
|
|
|
if not logs: |
|
|
return |
|
|
|
|
|
summary_data = [] |
|
|
for log in logs: |
|
|
row = { |
|
|
"timestamp": log['timestamp'], |
|
|
"experiment_name": log['experiment_name'], |
|
|
"embedding_model": log['config'].get('embedding_model', 'N/A'), |
|
|
"top_k": log['config'].get('top_k', 'N/A'), |
|
|
"precision": log['metrics'].get('precision', 0), |
|
|
"recall": log['metrics'].get('recall', 0), |
|
|
"f1": self._calculate_f1( |
|
|
log['metrics'].get('precision', 0), |
|
|
log['metrics'].get('recall', 0) |
|
|
), |
|
|
"avg_time": log['metrics'].get('avg_time', 0) |
|
|
} |
|
|
summary_data.append(row) |
|
|
|
|
|
df = pd.DataFrame(summary_data) |
|
|
df.to_csv(self.summary_file, index=False, encoding='utf-8-sig') |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _calculate_f1(precision: float, recall: float) -> float: |
|
|
"""F1 ์ ์ ๊ณ์ฐ""" |
|
|
if precision + recall == 0: |
|
|
return 0 |
|
|
return 2 * (precision * recall) / (precision + recall) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
tracker = ExperimentTracker() |
|
|
|
|
|
|
|
|
tracker.log_experiment( |
|
|
experiment_name="baseline", |
|
|
config={ |
|
|
"embedding_model": "text-embedding-3-small", |
|
|
"top_k": 5, |
|
|
"chunk_size": 1000 |
|
|
}, |
|
|
metrics={ |
|
|
"precision": 0.30, |
|
|
"recall": 0.65, |
|
|
"avg_time": 0.41 |
|
|
}, |
|
|
notes="์ด๊ธฐ baseline ์คํ" |
|
|
) |
|
|
|
|
|
|
|
|
tracker.compare_experiments() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|