| """ |
| Compare two models: band-aided vs properly trained. |
| Evaluates both on same test set and reports metrics. |
| |
| Usage: |
| python scripts/compare_models.py \ |
| --model1 ./output/Se124M_700K_infix \ |
| --model2 ./output/Se124M_700K_infix_v2 \ |
| --num_samples 500 |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from datetime import datetime |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from evaluate import evaluate_model |
|
|
|
|
| def format_metric(value, metric_type): |
| """Format metric value for display.""" |
| if metric_type == "rate": |
| return f"{value * 100:5.1f}%" |
| elif metric_type == "float": |
| return f"{value:7.2f}" |
| elif metric_type == "int": |
| return f"{int(value):7d}" |
| else: |
| return f"{value:7}" |
|
|
|
|
| def print_comparison_table(metrics1, metrics2, model1_name, model2_name): |
| """Print formatted comparison table.""" |
| print("\n" + "=" * 80) |
| print("COMPARISON RESULTS") |
| print("=" * 80) |
|
|
| |
| print(f"{'Metric':<35} {model1_name:>20} {model2_name:>20}") |
| print("-" * 80) |
|
|
| |
| comparison_metrics = [ |
| ("valid_rate", "Valid Rate", "rate"), |
| ("parseable_rate", "Parseable Rate", "rate"), |
| ("constraints_met_rate", "Constraints Met", "rate"), |
| ("diversity_rate", "Diversity", "rate"), |
| ("avg_expression_length", "Avg Expression Length", "float"), |
| ("total_samples", "Total Samples", "int"), |
| ("total_valid", "Total Valid", "int"), |
| ] |
|
|
| improvements = [] |
|
|
| for key, label, metric_type in comparison_metrics: |
| val1 = metrics1.get(key, 0) |
| val2 = metrics2.get(key, 0) |
|
|
| formatted_val1 = format_metric(val1, metric_type) |
| formatted_val2 = format_metric(val2, metric_type) |
|
|
| print(f"{label:<35} {formatted_val1:>20} {formatted_val2:>20}") |
|
|
| |
| if metric_type == "rate" and val1 > 0: |
| improvement = ((val2 - val1) / val1) * 100 |
| improvements.append((label, improvement, val2 - val1)) |
|
|
| print("=" * 80) |
|
|
| |
| print("\nIMPROVEMENTS (Model 2 vs Model 1):") |
| print("-" * 80) |
|
|
| for label, improvement, absolute_diff in improvements: |
| sign = "+" if improvement > 0 else "" |
| abs_sign = "+" if absolute_diff > 0 else "" |
| print(f"{label:<35} {sign}{improvement:>6.1f}% ({abs_sign}{absolute_diff * 100:>5.1f} pp)") |
|
|
| print("-" * 80) |
|
|
| |
| valid_rate_improvement = metrics2.get("valid_rate", 0) - metrics1.get("valid_rate", 0) |
|
|
| print("\n" + "=" * 80) |
| if valid_rate_improvement > 0.20: |
| print(f"🎯 SIGNIFICANT IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points") |
| print(" The properly trained model significantly outperforms the band-aided version!") |
| elif valid_rate_improvement > 0.05: |
| print(f"✅ IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points") |
| print(" The properly trained model shows clear improvement.") |
| elif valid_rate_improvement > 0: |
| print(f"📈 SLIGHT IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points") |
| print(" The properly trained model shows modest improvement.") |
| elif valid_rate_improvement == 0: |
| print("⚖️ TIE: Both models perform equally") |
| print(" No significant difference between models.") |
| else: |
| print(f"⚠️ REGRESSION: Model 1 wins by {-valid_rate_improvement * 100:.1f} percentage points") |
| print(" The band-aided model performs better - retraining may need adjustment.") |
|
|
| print("=" * 80) |
|
|
|
|
| def save_comparison_report(metrics1, metrics2, model1_name, model2_name, output_dir): |
| """Save detailed comparison report to JSON.""" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| report_file = os.path.join(output_dir, f"comparison_{timestamp}.json") |
|
|
| report = { |
| "timestamp": timestamp, |
| "model1": { |
| "name": model1_name, |
| "metrics": metrics1 |
| }, |
| "model2": { |
| "name": model2_name, |
| "metrics": metrics2 |
| }, |
| "comparison": { |
| "valid_rate_diff": metrics2.get("valid_rate", 0) - metrics1.get("valid_rate", 0), |
| "parseable_rate_diff": metrics2.get("parseable_rate", 0) - metrics1.get("parseable_rate", 0), |
| "constraints_met_diff": metrics2.get("constraints_met_rate", 0) - metrics1.get("constraints_met_rate", 0), |
| "diversity_diff": metrics2.get("diversity_rate", 0) - metrics1.get("diversity_rate", 0), |
| } |
| } |
|
|
| with open(report_file, "w") as f: |
| json.dump(report, f, indent=2) |
|
|
| print(f"\n📄 Detailed comparison report saved to: {report_file}") |
| return report_file |
|
|
|
|
| def compare_models(model1_path, model2_path, model1_name, model2_name, |
| num_samples=500, dataset_repo_id="augustocsc/sintetico_natural", |
| data_dir="700K", data_column="i_prompt_n", output_dir="./evaluation_results/comparison"): |
| """Compare two models on same test set.""" |
|
|
| print("=" * 80) |
| print("MODEL COMPARISON") |
| print("=" * 80) |
| print(f"Model 1 ({model1_name}): {model1_path}") |
| print(f"Model 2 ({model2_name}): {model2_path}") |
| print(f"Samples: {num_samples}") |
| print(f"Dataset: {dataset_repo_id}/{data_dir}") |
| print("=" * 80) |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| print(f"\n[1/2] Evaluating Model 1: {model1_name}") |
| print("-" * 80) |
|
|
| args1 = argparse.Namespace( |
| model_path=model1_path, |
| base_model=None, |
| dataset_repo_id=dataset_repo_id, |
| data_dir=data_dir, |
| data_column=data_column, |
| num_samples=num_samples, |
| num_generations=1, |
| max_new_tokens=128, |
| temperature=0.7, |
| top_p=0.9, |
| output_dir=os.path.join(output_dir, "model1"), |
| seed=42, |
| device="auto" |
| ) |
|
|
| try: |
| metrics1 = evaluate_model(args1) |
| except Exception as e: |
| print(f"\n❌ Error evaluating Model 1: {e}") |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| print(f"\n[2/2] Evaluating Model 2: {model2_name}") |
| print("-" * 80) |
|
|
| args2 = argparse.Namespace( |
| model_path=model2_path, |
| base_model=None, |
| dataset_repo_id=dataset_repo_id, |
| data_dir=data_dir, |
| data_column=data_column, |
| num_samples=num_samples, |
| num_generations=1, |
| max_new_tokens=128, |
| temperature=0.7, |
| top_p=0.9, |
| output_dir=os.path.join(output_dir, "model2"), |
| seed=42, |
| device="auto" |
| ) |
|
|
| try: |
| metrics2 = evaluate_model(args2) |
| except Exception as e: |
| print(f"\n❌ Error evaluating Model 2: {e}") |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| print_comparison_table(metrics1, metrics2, model1_name, model2_name) |
|
|
| |
| save_comparison_report(metrics1, metrics2, model1_name, model2_name, output_dir) |
|
|
| return metrics1, metrics2 |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Compare two models on the same test set" |
| ) |
| parser.add_argument("--model1", type=str, required=True, |
| help="Path to first model (band-aided)") |
| parser.add_argument("--model2", type=str, required=True, |
| help="Path to second model (properly trained)") |
| parser.add_argument("--model1_name", type=str, default="Band-Aided", |
| help="Display name for model 1") |
| parser.add_argument("--model2_name", type=str, default="Proper", |
| help="Display name for model 2") |
| parser.add_argument("--num_samples", type=int, default=500, |
| help="Number of samples to evaluate") |
| parser.add_argument("--dataset_repo_id", type=str, default="augustocsc/sintetico_natural", |
| help="HuggingFace dataset repository") |
| parser.add_argument("--data_dir", type=str, default="700K", |
| help="Data directory within dataset") |
| parser.add_argument("--data_column", type=str, default="i_prompt_n", |
| help="Column name for prompts") |
| parser.add_argument("--output_dir", type=str, default="./evaluation_results/comparison", |
| help="Directory to save comparison results") |
|
|
| args = parser.parse_args() |
|
|
| |
| try: |
| compare_models( |
| model1_path=args.model1, |
| model2_path=args.model2, |
| model1_name=args.model1_name, |
| model2_name=args.model2_name, |
| num_samples=args.num_samples, |
| dataset_repo_id=args.dataset_repo_id, |
| data_dir=args.data_dir, |
| data_column=args.data_column, |
| output_dir=args.output_dir |
| ) |
|
|
| print("\n✅ Comparison complete!") |
|
|
| except Exception as e: |
| print(f"\n❌ Error during comparison: {e}") |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|