RFP_summary_chatbot / src /evaluation /experiment_tracker.py
Dongjin1203's picture
Initial commit for HF Spaces deployment
4739096
# ===== experiment_tracker.py =====
"""
RAG ๊ฒ€์ƒ‰ ์‹œ์Šคํ…œ ์‹คํ—˜ ์ถ”์  ๋ฐ ๋น„๊ต ๋„๊ตฌ
๊ธฐ๋Šฅ:
1. ์‹คํ—˜ ๊ฒฐ๊ณผ ์ž๋™ ์ €์žฅ
2. ์ด์ „ ์‹คํ—˜๊ณผ ๋น„๊ต
3. ์„ฑ๋Šฅ ์ฐจํŠธ ์ƒ์„ฑ
4. ์ตœ์  ์„ค์ • ์ถ”์ฒœ
"""
import json
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # ์„œ๋ฒ„ ํ™˜๊ฒฝ ๋Œ€์‘
class ExperimentTracker:
"""์‹คํ—˜ ์ถ”์  ๋ฐ ๋น„๊ต ํด๋ž˜์Šค"""
def __init__(self, log_dir: str = "src/evaluation/results/experiments"):
"""
Args:
log_dir: ์‹คํ—˜ ๋กœ๊ทธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ
"""
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_file = self.log_dir / "experiments_log.json"
self.summary_file = self.log_dir / "experiments_summary.csv"
# ๋กœ๊ทธ ํŒŒ์ผ ์ดˆ๊ธฐํ™”
if not self.log_file.exists():
self._save_log([])
# === 1. ์‹คํ—˜ ๊ฒฐ๊ณผ ์ €์žฅ ===
def log_experiment(
self,
experiment_name: str,
config: Dict[str, Any],
metrics: Dict[str, float],
langsmith_url: Optional[str] = None,
notes: str = ""
) -> None:
"""
์‹คํ—˜ ๊ฒฐ๊ณผ ์ €์žฅ
Args:
experiment_name: ์‹คํ—˜ ์ด๋ฆ„ (์˜ˆ: "baseline", "embedding-small")
config: ์„ค์ • ์ •๋ณด (์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ, Top-K ๋“ฑ)
metrics: ํ‰๊ฐ€ ์ง€ํ‘œ (precision, recall ๋“ฑ)
langsmith_url: LangSmith ๊ฒฐ๊ณผ URL
notes: ์ถ”๊ฐ€ ๋ฉ”๋ชจ
"""
# ์‹คํ—˜ ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ
experiment_data = {
"timestamp": datetime.now().isoformat(),
"experiment_name": experiment_name,
"config": config,
"metrics": metrics,
"langsmith_url": langsmith_url,
"notes": notes
}
# ๊ธฐ์กด ๋กœ๊ทธ ๋กœ๋“œ
logs = self._load_log()
# ์ƒˆ ์‹คํ—˜ ์ถ”๊ฐ€
logs.append(experiment_data)
# ์ €์žฅ
self._save_log(logs)
self._update_summary()
print(f"โœ… ์‹คํ—˜ '{experiment_name}' ์ €์žฅ ์™„๋ฃŒ")
print(f" Precision: {metrics.get('precision', 0):.4f}")
print(f" Recall: {metrics.get('recall', 0):.4f}")
# === 2. ์‹คํ—˜ ๋น„๊ต ===
def compare_experiments(
self,
experiment_names: Optional[List[str]] = None,
top_n: int = 5
) -> pd.DataFrame:
"""
์‹คํ—˜ ๊ฒฐ๊ณผ ๋น„๊ต
Args:
experiment_names: ๋น„๊ตํ•  ์‹คํ—˜ ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ (None์ด๋ฉด ์ตœ๊ทผ ์‹คํ—˜)
top_n: experiment_names๊ฐ€ None์ผ ๋•Œ ์ตœ๊ทผ ๋ช‡ ๊ฐœ ๋น„๊ตํ• ์ง€
Returns:
๋น„๊ต ๊ฒฐ๊ณผ DataFrame
"""
logs = self._load_log()
if not logs:
print("โš ๏ธ ์ €์žฅ๋œ ์‹คํ—˜์ด ์—†์Šต๋‹ˆ๋‹ค")
return pd.DataFrame()
# ๋น„๊ตํ•  ์‹คํ—˜ ์„ ํƒ
if experiment_names is None:
# ์ตœ๊ทผ N๊ฐœ
selected_logs = logs[-top_n:]
else:
# ์ง€์ •๋œ ์‹คํ—˜๋“ค
selected_logs = [
log for log in logs
if log['experiment_name'] in experiment_names
]
if not selected_logs:
print("โš ๏ธ ๋น„๊ตํ•  ์‹คํ—˜์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค")
return pd.DataFrame()
# DataFrame ์ƒ์„ฑ
comparison_data = []
for log in selected_logs:
row = {
"์‹คํ—˜๋ช…": log['experiment_name'],
"๋‚ ์งœ": log['timestamp'][:10],
"์ž„๋ฒ ๋”ฉ": log['config'].get('embedding_model', 'N/A'),
"Top-K": log['config'].get('top_k', 'N/A'),
"Precision": log['metrics'].get('precision', 0),
"Recall": log['metrics'].get('recall', 0),
"F1": self._calculate_f1(
log['metrics'].get('precision', 0),
log['metrics'].get('recall', 0)
),
"๊ฒ€์ƒ‰์‹œ๊ฐ„(์ดˆ)": log['metrics'].get('avg_time', 0)
}
comparison_data.append(row)
df = pd.DataFrame(comparison_data)
# ์ถœ๋ ฅ
print("\n" + "="*80)
print("๐Ÿ“Š ์‹คํ—˜ ๋น„๊ต ๊ฒฐ๊ณผ")
print("="*80)
print(df.to_string(index=False))
print("="*80)
return df
def show_improvement(self, baseline_name: str, current_name: str) -> None:
"""
Baseline ๋Œ€๋น„ ๊ฐœ์„  ํšจ๊ณผ ์ถœ๋ ฅ
Args:
baseline_name: ๊ธฐ์ค€ ์‹คํ—˜ ์ด๋ฆ„
current_name: ๋น„๊ตํ•  ์‹คํ—˜ ์ด๋ฆ„
"""
logs = self._load_log()
# ์‹คํ—˜ ์ฐพ๊ธฐ
baseline = next((log for log in logs if log['experiment_name'] == baseline_name), None)
current = next((log for log in logs if log['experiment_name'] == current_name), None)
if not baseline or not current:
print("โš ๏ธ ์‹คํ—˜์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค")
return
# ๊ฐœ์„ ์œจ ๊ณ„์‚ฐ
baseline_precision = baseline['metrics'].get('precision', 0)
baseline_recall = baseline['metrics'].get('recall', 0)
current_precision = current['metrics'].get('precision', 0)
current_recall = current['metrics'].get('recall', 0)
precision_improvement = (current_precision - baseline_precision) / baseline_precision * 100 if baseline_precision > 0 else 0
recall_improvement = (current_recall - baseline_recall) / baseline_recall * 100 if baseline_recall > 0 else 0
# ์ถœ๋ ฅ
print("\n" + "="*80)
print(f"๐Ÿ“ˆ ๊ฐœ์„  ํšจ๊ณผ: {baseline_name} โ†’ {current_name}")
print("="*80)
print(f"\nPrecision:")
print(f" {baseline_name}: {baseline_precision:.4f}")
print(f" {current_name}: {current_precision:.4f}")
print(f" ๊ฐœ์„ ์œจ: {precision_improvement:+.2f}% {'โœ…' if precision_improvement > 0 else 'โŒ'}")
print(f"\nRecall:")
print(f" {baseline_name}: {baseline_recall:.4f}")
print(f" {current_name}: {current_recall:.4f}")
print(f" ๊ฐœ์„ ์œจ: {recall_improvement:+.2f}% {'โœ…' if recall_improvement > 0 else 'โŒ'}")
print("\n" + "="*80)
# === 3. ์‹œ๊ฐํ™” ===
def plot_metrics(
self,
experiment_names: Optional[List[str]] = None,
save_path: Optional[str] = None
) -> None:
"""
์‹คํ—˜ ๊ฒฐ๊ณผ ์ฐจํŠธ ์ƒ์„ฑ
Args:
experiment_names: ์ฐจํŠธ์— ํฌํ•จํ•  ์‹คํ—˜ (None์ด๋ฉด ์ „์ฒด)
save_path: ์ฐจํŠธ ์ €์žฅ ๊ฒฝ๋กœ (None์ด๋ฉด ํ™”๋ฉด ์ถœ๋ ฅ)
"""
logs = self._load_log()
if not logs:
print("โš ๏ธ ์ €์žฅ๋œ ์‹คํ—˜์ด ์—†์Šต๋‹ˆ๋‹ค")
return
# ์‹คํ—˜ ์„ ํƒ
if experiment_names is not None:
logs = [log for log in logs if log['experiment_name'] in experiment_names]
if not logs:
print("โš ๏ธ ์ฐจํŠธ๋ฅผ ๊ทธ๋ฆด ์‹คํ—˜์ด ์—†์Šต๋‹ˆ๋‹ค")
return
# ๋ฐ์ดํ„ฐ ์ค€๋น„
names = [log['experiment_name'] for log in logs]
precisions = [log['metrics'].get('precision', 0) for log in logs]
recalls = [log['metrics'].get('recall', 0) for log in logs]
# ์ฐจํŠธ ์ƒ์„ฑ
fig, ax = plt.subplots(figsize=(12, 6))
x = range(len(names))
width = 0.35
ax.bar([i - width/2 for i in x], precisions, width, label='Precision', alpha=0.8)
ax.bar([i + width/2 for i in x], recalls, width, label='Recall', alpha=0.8)
ax.set_xlabel('์‹คํ—˜')
ax.set_ylabel('์ ์ˆ˜')
ax.set_title('์‹คํ—˜๋ณ„ ์„ฑ๋Šฅ ๋น„๊ต')
ax.set_xticks(x)
ax.set_xticklabels(names, rotation=45, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
# ์ €์žฅ ๋˜๋Š” ์ถœ๋ ฅ
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"โœ… ์ฐจํŠธ ์ €์žฅ: {save_path}")
else:
default_path = self.log_dir / "comparison_chart.png"
plt.savefig(default_path, dpi=300, bbox_inches='tight')
print(f"โœ… ์ฐจํŠธ ์ €์žฅ: {default_path}")
plt.close()
# === 4. ์ตœ์  ์„ค์ • ์ถ”์ฒœ ===
def recommend_best(self, metric: str = "f1") -> Dict[str, Any]:
"""
์ตœ์  ์„ค์ • ์ถ”์ฒœ
Args:
metric: ๊ธฐ์ค€ ์ง€ํ‘œ ("precision", "recall", "f1")
Returns:
์ตœ์  ์‹คํ—˜ ์ •๋ณด
"""
logs = self._load_log()
if not logs:
print("โš ๏ธ ์ €์žฅ๋œ ์‹คํ—˜์ด ์—†์Šต๋‹ˆ๋‹ค")
return {}
# F1 ์ ์ˆ˜ ๊ณ„์‚ฐ
for log in logs:
if 'f1' not in log['metrics']:
p = log['metrics'].get('precision', 0)
r = log['metrics'].get('recall', 0)
log['metrics']['f1'] = self._calculate_f1(p, r)
# ์ตœ์  ์‹คํ—˜ ์ฐพ๊ธฐ
best = max(logs, key=lambda x: x['metrics'].get(metric, 0))
print("\n" + "="*80)
print(f"๐Ÿ† ์ตœ์  ์„ค์ • ({metric.upper()} ๊ธฐ์ค€)")
print("="*80)
print(f"์‹คํ—˜๋ช…: {best['experiment_name']}")
print(f"๋‚ ์งœ: {best['timestamp'][:10]}")
print(f"\n์„ค์ •:")
for key, value in best['config'].items():
print(f" {key}: {value}")
print(f"\n์„ฑ๋Šฅ:")
print(f" Precision: {best['metrics'].get('precision', 0):.4f}")
print(f" Recall: {best['metrics'].get('recall', 0):.4f}")
print(f" F1: {best['metrics'].get('f1', 0):.4f}")
print("="*80)
return best
# === 5. ์œ ํ‹ธ๋ฆฌํ‹ฐ ===
def list_experiments(self) -> None:
"""์ €์žฅ๋œ ์‹คํ—˜ ๋ชฉ๋ก ์ถœ๋ ฅ"""
logs = self._load_log()
if not logs:
print("โš ๏ธ ์ €์žฅ๋œ ์‹คํ—˜์ด ์—†์Šต๋‹ˆ๋‹ค")
return
print("\n" + "="*80)
print("๐Ÿ“‹ ์ €์žฅ๋œ ์‹คํ—˜ ๋ชฉ๋ก")
print("="*80)
for i, log in enumerate(logs, 1):
print(f"\n{i}. {log['experiment_name']}")
print(f" ๋‚ ์งœ: {log['timestamp'][:10]}")
print(f" Precision: {log['metrics'].get('precision', 0):.4f}")
print(f" Recall: {log['metrics'].get('recall', 0):.4f}")
print("="*80)
def clear_experiments(self) -> None:
"""๋ชจ๋“  ์‹คํ—˜ ๋กœ๊ทธ ์‚ญ์ œ (์ฃผ์˜!)"""
confirm = input("โš ๏ธ ๋ชจ๋“  ์‹คํ—˜ ๋กœ๊ทธ๋ฅผ ์‚ญ์ œํ•˜์‹œ๊ฒ ์Šต๋‹ˆ๊นŒ? (yes/no): ")
if confirm.lower() == 'yes':
self._save_log([])
self._update_summary()
print("โœ… ๋ชจ๋“  ์‹คํ—˜ ๋กœ๊ทธ ์‚ญ์ œ ์™„๋ฃŒ")
else:
print("โŒ ์ทจ์†Œ๋จ")
# === ๋‚ด๋ถ€ ํ•จ์ˆ˜ ===
def _load_log(self) -> List[Dict]:
"""๋กœ๊ทธ ํŒŒ์ผ ๋กœ๋“œ"""
if not self.log_file.exists():
return []
with open(self.log_file, 'r', encoding='utf-8') as f:
return json.load(f)
def _save_log(self, logs: List[Dict]) -> None:
"""๋กœ๊ทธ ํŒŒ์ผ ์ €์žฅ"""
with open(self.log_file, 'w', encoding='utf-8') as f:
json.dump(logs, f, indent=2, ensure_ascii=False)
def _update_summary(self) -> None:
"""์š”์•ฝ CSV ์—…๋ฐ์ดํŠธ"""
logs = self._load_log()
if not logs:
return
summary_data = []
for log in logs:
row = {
"timestamp": log['timestamp'],
"experiment_name": log['experiment_name'],
"embedding_model": log['config'].get('embedding_model', 'N/A'),
"top_k": log['config'].get('top_k', 'N/A'),
"precision": log['metrics'].get('precision', 0),
"recall": log['metrics'].get('recall', 0),
"f1": self._calculate_f1(
log['metrics'].get('precision', 0),
log['metrics'].get('recall', 0)
),
"avg_time": log['metrics'].get('avg_time', 0)
}
summary_data.append(row)
df = pd.DataFrame(summary_data)
df.to_csv(self.summary_file, index=False, encoding='utf-8-sig')
@staticmethod
def _calculate_f1(precision: float, recall: float) -> float:
"""F1 ์ ์ˆ˜ ๊ณ„์‚ฐ"""
if precision + recall == 0:
return 0
return 2 * (precision * recall) / (precision + recall)
# ===== ์‚ฌ์šฉ ์˜ˆ์‹œ =====
if __name__ == "__main__":
# Tracker ์ดˆ๊ธฐํ™”
tracker = ExperimentTracker()
# ์˜ˆ์‹œ 1: ์‹คํ—˜ ๊ฒฐ๊ณผ ์ €์žฅ
tracker.log_experiment(
experiment_name="baseline",
config={
"embedding_model": "text-embedding-3-small",
"top_k": 5,
"chunk_size": 1000
},
metrics={
"precision": 0.30,
"recall": 0.65,
"avg_time": 0.41
},
notes="์ดˆ๊ธฐ baseline ์‹คํ—˜"
)
# ์˜ˆ์‹œ 2: ์‹คํ—˜ ๋น„๊ต
tracker.compare_experiments()
# ์˜ˆ์‹œ 3: ๊ฐœ์„  ํšจ๊ณผ ํ™•์ธ
# tracker.show_improvement("baseline", "embedding-small")
# ์˜ˆ์‹œ 4: ์ฐจํŠธ ์ƒ์„ฑ
# tracker.plot_metrics()
# ์˜ˆ์‹œ 5: ์ตœ์  ์„ค์ • ์ถ”์ฒœ
# tracker.recommend_best(metric="f1")