| | """Systematic model comparison framework.""" |
| |
|
| | import logging |
| | import json |
| | from typing import Dict, List, Optional, Any, Tuple, Callable |
| | from pathlib import Path |
| | import pandas as pd |
| | import torch |
| | from torch.utils.data import Dataset |
| |
|
| | from evaluation.metrics import ( |
| | precision, |
| | recall, |
| | f1_score, |
| | exact_match, |
| | get_predict, |
| | per_class_metrics |
| | ) |
| | from experiments.experiment_tracker import ExperimentTracker |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ModelComparison: |
| | """ |
| | Systematic comparison of multiple model architectures. |
| | |
| | Trains and evaluates multiple models on the same dataset, |
| | tracks results, and generates comparison reports. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | tracker: Optional[ExperimentTracker] = None, |
| | results_dir: str = "experiments/comparisons" |
| | ): |
| | """ |
| | Initialize model comparison framework. |
| | |
| | Args: |
| | tracker: ExperimentTracker instance (creates new if None) |
| | results_dir: Directory to store comparison results |
| | """ |
| | if tracker is None: |
| | self.tracker = ExperimentTracker() |
| | else: |
| | self.tracker = tracker |
| | |
| | self.results_dir = Path(results_dir) |
| | self.results_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | self.comparison_results = [] |
| | |
| | def compare_models( |
| | self, |
| | models_config: List[Dict[str, Any]], |
| | train_dataset: Dataset, |
| | val_dataset: Dataset, |
| | test_dataset: Optional[Dataset] = None, |
| | train_func: Optional[Callable] = None, |
| | epochs: int = 3, |
| | batch_size: int = 16 |
| | ) -> pd.DataFrame: |
| | """ |
| | Compare multiple models on the same datasets. |
| | |
| | Args: |
| | models_config: List of model configurations |
| | Each config should have: model_name, model_class, model_kwargs |
| | train_dataset: Training dataset |
| | val_dataset: Validation dataset |
| | test_dataset: Optional test dataset |
| | train_func: Optional custom training function |
| | epochs: Number of training epochs |
| | batch_size: Batch size for training |
| | |
| | Returns: |
| | DataFrame with comparison results |
| | |
| | Example: |
| | >>> comparison = ModelComparison() |
| | >>> models = [ |
| | ... { |
| | ... "model_name": "RussianBERT", |
| | ... "model_class": RussianNewsClassifier, |
| | ... "model_kwargs": {"num_labels": 100} |
| | ... }, |
| | ... { |
| | ... "model_name": "RoBERTa", |
| | ... "model_class": RoBERTaNewsClassifier, |
| | ... "model_kwargs": {"num_labels": 100} |
| | ... } |
| | ... ] |
| | >>> results = comparison.compare_models(models, train_ds, val_ds) |
| | """ |
| | logger.info("=" * 80) |
| | logger.info("Starting Model Comparison") |
| | logger.info("=" * 80) |
| | logger.info(f"Comparing {len(models_config)} models") |
| | |
| | results = [] |
| | |
| | for i, model_config in enumerate(models_config, 1): |
| | model_name = model_config.get("model_name", f"model_{i}") |
| | logger.info(f"\n{'=' * 80}") |
| | logger.info(f"Model {i}/{len(models_config)}: {model_name}") |
| | logger.info(f"{'=' * 80}") |
| | |
| | try: |
| | |
| | experiment_id = self.tracker.start_experiment( |
| | experiment_name=f"comparison_{model_name}", |
| | model_name=model_name, |
| | config={ |
| | "epochs": epochs, |
| | "batch_size": batch_size, |
| | **model_config.get("model_kwargs", {}) |
| | }, |
| | tags=["model_comparison", model_name] |
| | ) |
| | |
| | |
| | model = None |
| | if train_func: |
| | logger.info(f"Training {model_name}...") |
| | model = train_func( |
| | model_config=model_config, |
| | train_dataset=train_dataset, |
| | val_dataset=val_dataset, |
| | epochs=epochs, |
| | batch_size=batch_size |
| | ) |
| | else: |
| | logger.warning("No training function provided, skipping training") |
| | |
| | |
| | if model and val_dataset: |
| | val_metrics = self._evaluate_model( |
| | model, |
| | val_dataset, |
| | model_config.get("use_snippet", False), |
| | prefix="val_" |
| | ) |
| | self.tracker.log_metrics(experiment_id, val_metrics) |
| | logger.info(f"Validation metrics: {val_metrics}") |
| | |
| | |
| | test_metrics = {} |
| | if model and test_dataset: |
| | test_metrics = self._evaluate_model( |
| | model, |
| | test_dataset, |
| | model_config.get("use_snippet", False), |
| | prefix="test_" |
| | ) |
| | self.tracker.log_metrics(experiment_id, test_metrics) |
| | logger.info(f"Test metrics: {test_metrics}") |
| | |
| | |
| | self.tracker.finish_experiment(experiment_id, test_metrics) |
| | |
| | |
| | result = { |
| | "model_name": model_name, |
| | "experiment_id": experiment_id, |
| | **val_metrics, |
| | **test_metrics, |
| | "status": "completed" |
| | } |
| | results.append(result) |
| | self.comparison_results.append(result) |
| | |
| | except Exception as e: |
| | logger.error(f"Error comparing {model_name}: {e}") |
| | result = { |
| | "model_name": model_name, |
| | "status": "failed", |
| | "error": str(e) |
| | } |
| | results.append(result) |
| | |
| | |
| | comparison_df = pd.DataFrame(results) |
| | |
| | |
| | comparison_file = self.results_dir / f"comparison_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv" |
| | comparison_df.to_csv(comparison_file, index=False) |
| | logger.info(f"\nComparison results saved to: {comparison_file}") |
| | |
| | |
| | self._generate_report(comparison_df) |
| | |
| | return comparison_df |
| | |
| | def _evaluate_model( |
| | self, |
| | model: torch.nn.Module, |
| | dataset: Dataset, |
| | use_snippet: bool = False, |
| | prefix: str = "" |
| | ) -> Dict[str, float]: |
| | """ |
| | Evaluate model on dataset. |
| | |
| | Args: |
| | model: Trained model |
| | dataset: Dataset to evaluate on |
| | use_snippet: Whether model uses snippets |
| | prefix: Prefix for metric names (e.g., "val_", "test_") |
| | |
| | Returns: |
| | Dictionary of metrics |
| | """ |
| | |
| | pred_prob, target = get_predict(model, dataset, use_snippet) |
| | |
| | |
| | threshold = 0.5 |
| | y_pred = (pred_prob > threshold).float() |
| | |
| | |
| | metrics = { |
| | f"{prefix}precision": precision(target, y_pred), |
| | f"{prefix}recall": recall(target, y_pred), |
| | f"{prefix}f1": f1_score(target, y_pred), |
| | f"{prefix}exact_match": exact_match(target, y_pred) |
| | } |
| | |
| | return metrics |
| | |
| | def _generate_report(self, comparison_df: pd.DataFrame) -> None: |
| | """ |
| | Generate comparison report. |
| | |
| | Args: |
| | comparison_df: DataFrame with comparison results |
| | """ |
| | report_file = self.results_dir / f"report_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.txt" |
| | |
| | with open(report_file, 'w') as f: |
| | f.write("=" * 80 + "\n") |
| | f.write("MODEL COMPARISON REPORT\n") |
| | f.write("=" * 80 + "\n\n") |
| | |
| | f.write(f"Generated: {pd.Timestamp.now()}\n") |
| | f.write(f"Models compared: {len(comparison_df)}\n\n") |
| | |
| | |
| | f.write("OVERALL METRICS\n") |
| | f.write("-" * 80 + "\n") |
| | |
| | metric_cols = [col for col in comparison_df.columns if any( |
| | m in col for m in ["precision", "recall", "f1", "exact_match"] |
| | )] |
| | |
| | for metric_col in metric_cols: |
| | f.write(f"\n{metric_col.upper()}:\n") |
| | sorted_df = comparison_df.sort_values(metric_col, ascending=False, na_last=True) |
| | for _, row in sorted_df.iterrows(): |
| | model_name = row.get("model_name", "Unknown") |
| | value = row.get(metric_col, "N/A") |
| | f.write(f" {model_name}: {value}\n") |
| | |
| | |
| | f.write("\n" + "=" * 80 + "\n") |
| | f.write("BEST MODEL\n") |
| | f.write("=" * 80 + "\n") |
| | |
| | if "val_f1" in comparison_df.columns: |
| | best = comparison_df.nlargest(1, "val_f1") |
| | if not best.empty: |
| | best_model = best.iloc[0] |
| | f.write(f"Model: {best_model['model_name']}\n") |
| | f.write(f"Validation F1: {best_model.get('val_f1', 'N/A')}\n") |
| | f.write(f"Validation Precision: {best_model.get('val_precision', 'N/A')}\n") |
| | f.write(f"Validation Recall: {best_model.get('val_recall', 'N/A')}\n") |
| | |
| | logger.info(f"Comparison report saved to: {report_file}") |
| | |
| | def get_best_model( |
| | self, |
| | metric_name: str = "val_f1", |
| | comparison_df: Optional[pd.DataFrame] = None |
| | ) -> Optional[Dict[str, Any]]: |
| | """ |
| | Get best model from comparison. |
| | |
| | Args: |
| | metric_name: Metric to use for selection |
| | comparison_df: Optional comparison DataFrame (uses stored if None) |
| | |
| | Returns: |
| | Dictionary with best model information |
| | """ |
| | if comparison_df is None: |
| | if not self.comparison_results: |
| | logger.warning("No comparison results available") |
| | return None |
| | comparison_df = pd.DataFrame(self.comparison_results) |
| | |
| | if metric_name not in comparison_df.columns: |
| | logger.warning(f"Metric {metric_name} not found in comparison results") |
| | return None |
| | |
| | |
| | completed = comparison_df[comparison_df["status"] == "completed"] |
| | if completed.empty: |
| | logger.warning("No completed models found") |
| | return None |
| | |
| | |
| | best = completed.nlargest(1, metric_name) |
| | if best.empty: |
| | return None |
| | |
| | best_model = best.iloc[0].to_dict() |
| | logger.info(f"Best model: {best_model['model_name']} ({metric_name}={best_model.get(metric_name, 'N/A')})") |
| | |
| | return best_model |
| | |
| | def compare_from_checkpoints( |
| | self, |
| | checkpoint_paths: List[Dict[str, str]], |
| | test_dataset: Dataset, |
| | model_classes: Dict[str, type] |
| | ) -> pd.DataFrame: |
| | """ |
| | Compare models from saved checkpoints. |
| | |
| | Args: |
| | checkpoint_paths: List of dicts with model_name and checkpoint_path |
| | test_dataset: Test dataset for evaluation |
| | model_classes: Dictionary mapping model_name to model class |
| | |
| | Returns: |
| | DataFrame with comparison results |
| | """ |
| | logger.info("=" * 80) |
| | logger.info("Comparing Models from Checkpoints") |
| | logger.info("=" * 80) |
| | |
| | results = [] |
| | |
| | for checkpoint_info in checkpoint_paths: |
| | model_name = checkpoint_info["model_name"] |
| | checkpoint_path = checkpoint_info["checkpoint_path"] |
| | |
| | logger.info(f"\nEvaluating {model_name} from {checkpoint_path}") |
| | |
| | try: |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| | |
| | |
| | model_class = model_classes.get(model_name) |
| | if model_class is None: |
| | logger.warning(f"Model class not found for {model_name}, skipping") |
| | continue |
| | |
| | |
| | model_kwargs = { |
| | "num_labels": checkpoint.get("num_labels", 1000), |
| | "use_snippet": checkpoint.get("use_snippet", False), |
| | "dropout": checkpoint.get("dropout", 0.3) |
| | } |
| | |
| | if "model_name" in checkpoint: |
| | model_kwargs["model_name"] = checkpoint["model_name"] |
| | |
| | model = model_class(**model_kwargs) |
| | model.load_state_dict(checkpoint["state_dict"]) |
| | model.eval() |
| | |
| | |
| | use_snippet = checkpoint.get("use_snippet", False) |
| | test_metrics = self._evaluate_model( |
| | model, |
| | test_dataset, |
| | use_snippet, |
| | prefix="test_" |
| | ) |
| | |
| | result = { |
| | "model_name": model_name, |
| | "checkpoint_path": checkpoint_path, |
| | **test_metrics, |
| | "status": "completed" |
| | } |
| | results.append(result) |
| | |
| | except Exception as e: |
| | logger.error(f"Error evaluating {model_name}: {e}") |
| | results.append({ |
| | "model_name": model_name, |
| | "status": "failed", |
| | "error": str(e) |
| | }) |
| | |
| | comparison_df = pd.DataFrame(results) |
| | |
| | |
| | comparison_file = self.results_dir / f"checkpoint_comparison_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv" |
| | comparison_df.to_csv(comparison_file, index=False) |
| | |
| | return comparison_df |
| |
|
| |
|