Solareva Taisia
chore(release): initial public snapshot
198ccb0
"""Systematic model comparison framework."""
import logging
import json
from typing import Dict, List, Optional, Any, Tuple, Callable
from pathlib import Path
import pandas as pd
import torch
from torch.utils.data import Dataset
from evaluation.metrics import (
precision,
recall,
f1_score,
exact_match,
get_predict,
per_class_metrics
)
from experiments.experiment_tracker import ExperimentTracker
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelComparison:
"""
Systematic comparison of multiple model architectures.
Trains and evaluates multiple models on the same dataset,
tracks results, and generates comparison reports.
"""
def __init__(
self,
tracker: Optional[ExperimentTracker] = None,
results_dir: str = "experiments/comparisons"
):
"""
Initialize model comparison framework.
Args:
tracker: ExperimentTracker instance (creates new if None)
results_dir: Directory to store comparison results
"""
if tracker is None:
self.tracker = ExperimentTracker()
else:
self.tracker = tracker
self.results_dir = Path(results_dir)
self.results_dir.mkdir(parents=True, exist_ok=True)
self.comparison_results = []
def compare_models(
self,
models_config: List[Dict[str, Any]],
train_dataset: Dataset,
val_dataset: Dataset,
test_dataset: Optional[Dataset] = None,
train_func: Optional[Callable] = None,
epochs: int = 3,
batch_size: int = 16
) -> pd.DataFrame:
"""
Compare multiple models on the same datasets.
Args:
models_config: List of model configurations
Each config should have: model_name, model_class, model_kwargs
train_dataset: Training dataset
val_dataset: Validation dataset
test_dataset: Optional test dataset
train_func: Optional custom training function
epochs: Number of training epochs
batch_size: Batch size for training
Returns:
DataFrame with comparison results
Example:
>>> comparison = ModelComparison()
>>> models = [
... {
... "model_name": "RussianBERT",
... "model_class": RussianNewsClassifier,
... "model_kwargs": {"num_labels": 100}
... },
... {
... "model_name": "RoBERTa",
... "model_class": RoBERTaNewsClassifier,
... "model_kwargs": {"num_labels": 100}
... }
... ]
>>> results = comparison.compare_models(models, train_ds, val_ds)
"""
logger.info("=" * 80)
logger.info("Starting Model Comparison")
logger.info("=" * 80)
logger.info(f"Comparing {len(models_config)} models")
results = []
for i, model_config in enumerate(models_config, 1):
model_name = model_config.get("model_name", f"model_{i}")
logger.info(f"\n{'=' * 80}")
logger.info(f"Model {i}/{len(models_config)}: {model_name}")
logger.info(f"{'=' * 80}")
try:
# Start experiment
experiment_id = self.tracker.start_experiment(
experiment_name=f"comparison_{model_name}",
model_name=model_name,
config={
"epochs": epochs,
"batch_size": batch_size,
**model_config.get("model_kwargs", {})
},
tags=["model_comparison", model_name]
)
# Train model (if train_func provided)
model = None
if train_func:
logger.info(f"Training {model_name}...")
model = train_func(
model_config=model_config,
train_dataset=train_dataset,
val_dataset=val_dataset,
epochs=epochs,
batch_size=batch_size
)
else:
logger.warning("No training function provided, skipping training")
# Evaluate on validation set
if model and val_dataset:
val_metrics = self._evaluate_model(
model,
val_dataset,
model_config.get("use_snippet", False),
prefix="val_"
)
self.tracker.log_metrics(experiment_id, val_metrics)
logger.info(f"Validation metrics: {val_metrics}")
# Evaluate on test set
test_metrics = {}
if model and test_dataset:
test_metrics = self._evaluate_model(
model,
test_dataset,
model_config.get("use_snippet", False),
prefix="test_"
)
self.tracker.log_metrics(experiment_id, test_metrics)
logger.info(f"Test metrics: {test_metrics}")
# Finish experiment
self.tracker.finish_experiment(experiment_id, test_metrics)
# Store results
result = {
"model_name": model_name,
"experiment_id": experiment_id,
**val_metrics,
**test_metrics,
"status": "completed"
}
results.append(result)
self.comparison_results.append(result)
except Exception as e:
logger.error(f"Error comparing {model_name}: {e}")
result = {
"model_name": model_name,
"status": "failed",
"error": str(e)
}
results.append(result)
# Create comparison DataFrame
comparison_df = pd.DataFrame(results)
# Save comparison results
comparison_file = self.results_dir / f"comparison_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv"
comparison_df.to_csv(comparison_file, index=False)
logger.info(f"\nComparison results saved to: {comparison_file}")
# Generate comparison report
self._generate_report(comparison_df)
return comparison_df
def _evaluate_model(
self,
model: torch.nn.Module,
dataset: Dataset,
use_snippet: bool = False,
prefix: str = ""
) -> Dict[str, float]:
"""
Evaluate model on dataset.
Args:
model: Trained model
dataset: Dataset to evaluate on
use_snippet: Whether model uses snippets
prefix: Prefix for metric names (e.g., "val_", "test_")
Returns:
Dictionary of metrics
"""
# Get predictions
pred_prob, target = get_predict(model, dataset, use_snippet)
# Use default threshold for now (can be optimized)
threshold = 0.5
y_pred = (pred_prob > threshold).float()
# Calculate metrics
metrics = {
f"{prefix}precision": precision(target, y_pred),
f"{prefix}recall": recall(target, y_pred),
f"{prefix}f1": f1_score(target, y_pred),
f"{prefix}exact_match": exact_match(target, y_pred)
}
return metrics
def _generate_report(self, comparison_df: pd.DataFrame) -> None:
"""
Generate comparison report.
Args:
comparison_df: DataFrame with comparison results
"""
report_file = self.results_dir / f"report_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.txt"
with open(report_file, 'w') as f:
f.write("=" * 80 + "\n")
f.write("MODEL COMPARISON REPORT\n")
f.write("=" * 80 + "\n\n")
f.write(f"Generated: {pd.Timestamp.now()}\n")
f.write(f"Models compared: {len(comparison_df)}\n\n")
# Overall metrics
f.write("OVERALL METRICS\n")
f.write("-" * 80 + "\n")
metric_cols = [col for col in comparison_df.columns if any(
m in col for m in ["precision", "recall", "f1", "exact_match"]
)]
for metric_col in metric_cols:
f.write(f"\n{metric_col.upper()}:\n")
sorted_df = comparison_df.sort_values(metric_col, ascending=False, na_last=True)
for _, row in sorted_df.iterrows():
model_name = row.get("model_name", "Unknown")
value = row.get(metric_col, "N/A")
f.write(f" {model_name}: {value}\n")
# Best model
f.write("\n" + "=" * 80 + "\n")
f.write("BEST MODEL\n")
f.write("=" * 80 + "\n")
if "val_f1" in comparison_df.columns:
best = comparison_df.nlargest(1, "val_f1")
if not best.empty:
best_model = best.iloc[0]
f.write(f"Model: {best_model['model_name']}\n")
f.write(f"Validation F1: {best_model.get('val_f1', 'N/A')}\n")
f.write(f"Validation Precision: {best_model.get('val_precision', 'N/A')}\n")
f.write(f"Validation Recall: {best_model.get('val_recall', 'N/A')}\n")
logger.info(f"Comparison report saved to: {report_file}")
def get_best_model(
self,
metric_name: str = "val_f1",
comparison_df: Optional[pd.DataFrame] = None
) -> Optional[Dict[str, Any]]:
"""
Get best model from comparison.
Args:
metric_name: Metric to use for selection
comparison_df: Optional comparison DataFrame (uses stored if None)
Returns:
Dictionary with best model information
"""
if comparison_df is None:
if not self.comparison_results:
logger.warning("No comparison results available")
return None
comparison_df = pd.DataFrame(self.comparison_results)
if metric_name not in comparison_df.columns:
logger.warning(f"Metric {metric_name} not found in comparison results")
return None
# Filter completed models
completed = comparison_df[comparison_df["status"] == "completed"]
if completed.empty:
logger.warning("No completed models found")
return None
# Get best model
best = completed.nlargest(1, metric_name)
if best.empty:
return None
best_model = best.iloc[0].to_dict()
logger.info(f"Best model: {best_model['model_name']} ({metric_name}={best_model.get(metric_name, 'N/A')})")
return best_model
def compare_from_checkpoints(
self,
checkpoint_paths: List[Dict[str, str]],
test_dataset: Dataset,
model_classes: Dict[str, type]
) -> pd.DataFrame:
"""
Compare models from saved checkpoints.
Args:
checkpoint_paths: List of dicts with model_name and checkpoint_path
test_dataset: Test dataset for evaluation
model_classes: Dictionary mapping model_name to model class
Returns:
DataFrame with comparison results
"""
logger.info("=" * 80)
logger.info("Comparing Models from Checkpoints")
logger.info("=" * 80)
results = []
for checkpoint_info in checkpoint_paths:
model_name = checkpoint_info["model_name"]
checkpoint_path = checkpoint_info["checkpoint_path"]
logger.info(f"\nEvaluating {model_name} from {checkpoint_path}")
try:
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Get model class
model_class = model_classes.get(model_name)
if model_class is None:
logger.warning(f"Model class not found for {model_name}, skipping")
continue
# Reconstruct model
model_kwargs = {
"num_labels": checkpoint.get("num_labels", 1000),
"use_snippet": checkpoint.get("use_snippet", False),
"dropout": checkpoint.get("dropout", 0.3)
}
if "model_name" in checkpoint:
model_kwargs["model_name"] = checkpoint["model_name"]
model = model_class(**model_kwargs)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
# Evaluate
use_snippet = checkpoint.get("use_snippet", False)
test_metrics = self._evaluate_model(
model,
test_dataset,
use_snippet,
prefix="test_"
)
result = {
"model_name": model_name,
"checkpoint_path": checkpoint_path,
**test_metrics,
"status": "completed"
}
results.append(result)
except Exception as e:
logger.error(f"Error evaluating {model_name}: {e}")
results.append({
"model_name": model_name,
"status": "failed",
"error": str(e)
})
comparison_df = pd.DataFrame(results)
# Save results
comparison_file = self.results_dir / f"checkpoint_comparison_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv"
comparison_df.to_csv(comparison_file, index=False)
return comparison_df