smart-summarizer / run_evaluation.py
Rajak13's picture
Add comprehensive CNN/DailyMail evaluation system - dataset loading, model evaluation, topic analysis, and comparison
cf5d247 verified
raw
history blame
4.93 kB
#!/usr/bin/env python3
"""
Simple script to run model evaluation on CNN/DailyMail dataset
"""
import os
import sys
import logging
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from evaluation.dataset_loader import CNNDailyMailLoader
from evaluation.model_evaluator import ModelEvaluator
from evaluation.results_analyzer import ResultsAnalyzer
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def main():
"""Run comprehensive evaluation"""
# Configuration
SAMPLE_SIZE = 50 # Number of samples to evaluate
OUTPUT_DIR = "evaluation_results"
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
logger.info("Starting Smart Summarizer Evaluation")
logger.info(f"Sample size: {SAMPLE_SIZE}")
try:
# Step 1: Load dataset
logger.info("Step 1: Loading CNN/DailyMail dataset...")
loader = CNNDailyMailLoader()
dataset = loader.load_dataset()
# Step 2: Create evaluation subset
logger.info("Step 2: Creating evaluation subset...")
eval_data = loader.create_evaluation_subset(size=SAMPLE_SIZE)
loader.save_evaluation_data(eval_data, f"{OUTPUT_DIR}/eval_data.json")
# Step 3: Categorize by topics
logger.info("Step 3: Categorizing by topics...")
categorized_data = loader.categorize_by_topic(eval_data)
# Save categorized data
for topic, data in categorized_data.items():
if data:
loader.save_evaluation_data(data, f"{OUTPUT_DIR}/data_{topic}.json")
logger.info(f" {topic}: {len(data)} articles")
# Step 4: Initialize models
logger.info("Step 4: Initializing models...")
evaluator = ModelEvaluator()
evaluator.initialize_models()
# Step 5: Run overall evaluation
logger.info("Step 5: Running overall evaluation...")
overall_results = evaluator.evaluate_all_models(eval_data, max_samples=SAMPLE_SIZE)
# Save overall results
evaluator.save_results(overall_results, f"{OUTPUT_DIR}/results_overall.json")
# Create overall comparison
comparison_df = evaluator.compare_models(overall_results)
comparison_df.to_csv(f"{OUTPUT_DIR}/comparison_overall.csv", index=False)
print("\n" + "="*60)
print("OVERALL EVALUATION RESULTS")
print("="*60)
print(comparison_df.to_string(index=False))
# Step 6: Run topic-based evaluation
logger.info("Step 6: Running topic-based evaluation...")
topic_results = {}
for topic, data in categorized_data.items():
if len(data) >= 5: # Only evaluate topics with sufficient data
logger.info(f" Evaluating topic: {topic}")
topic_results[topic] = evaluator.evaluate_all_models(data, max_samples=20)
# Save topic results
evaluator.save_results(topic_results[topic], f"{OUTPUT_DIR}/results_{topic}.json")
# Create topic comparison
topic_comparison = evaluator.compare_models(topic_results[topic])
topic_comparison.to_csv(f"{OUTPUT_DIR}/comparison_{topic}.csv", index=False)
print(f"\n{topic.upper()} TOPIC RESULTS:")
print("-" * 40)
print(topic_comparison.to_string(index=False))
# Step 7: Create visualizations and analysis
logger.info("Step 7: Creating analysis and visualizations...")
analyzer = ResultsAnalyzer()
# Overall performance charts
analyzer.create_performance_charts(overall_results, OUTPUT_DIR)
# Topic analysis if we have topic results
if topic_results:
analyzer.analyze_topic_performance(topic_results, OUTPUT_DIR)
# Detailed report
analyzer.create_detailed_report(overall_results, OUTPUT_DIR)
print(f"\n" + "="*60)
print("EVALUATION COMPLETE")
print("="*60)
print(f"Results saved to: {OUTPUT_DIR}/")
print("Files created:")
print(f" - results_overall.json (detailed results)")
print(f" - comparison_overall.csv (summary table)")
print(f" - performance_comparison.png (charts)")
print(f" - evaluation_report.md (detailed report)")
if topic_results:
print(f" - topic_performance_heatmap.png (topic analysis)")
print(f" - topic_summary.csv (topic breakdown)")
except Exception as e:
logger.error(f"Evaluation failed: {e}")
raise
if __name__ == "__main__":
main()