File size: 5,580 Bytes
03cbd8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""

Evaluation harness for testing the Context Thread Agent

"""

from typing import List, Dict
from pathlib import Path
from src.parser import NotebookParser
from src.dependencies import ContextThreadBuilder
from src.indexing import FAISSIndexer
from src.retrieval import RetrievalEngine
from src.reasoning import ContextualAnsweringSystem
from src.intent import ContextThreadEnricher


class EvaluationHarness:
    """Evaluation harness for notebook Q&A."""
    
    def __init__(self, notebooks_dir: str):
        self.notebooks_dir = Path(notebooks_dir)
        self.results = []
    
    def evaluate_all(self, queries_per_notebook: int = 3) -> Dict:
        """Evaluate all notebooks in the directory."""
        notebook_files = list(self.notebooks_dir.glob("*.ipynb"))
        
        for nb_file in notebook_files:
            print(f"Evaluating {nb_file.name}...")
            result = self.evaluate_notebook(str(nb_file), queries_per_notebook)
            self.results.append(result)
        
        return self._aggregate_results()
    
    def evaluate_notebook(self, notebook_path: str, num_queries: int) -> Dict:
        """Evaluate a single notebook."""
        # Load notebook
        parser = NotebookParser()
        result = parser.parse_file(notebook_path)
        cells = result['cells']
        
        # Build context thread
        builder = ContextThreadBuilder(
            notebook_name=Path(notebook_path).stem,
            thread_id=f"eval_{id(self)}"
        )
        builder.add_cells(cells)
        thread = builder.build()
        
        # Enrich
        enricher = ContextThreadEnricher()
        thread = enricher.enrich(thread)
        
        # Index
        indexer = FAISSIndexer()
        indexer.add_multiple(thread.units)
        
        # Setup systems
        engine = RetrievalEngine(thread, indexer)
        answering_system = ContextualAnsweringSystem(engine)
        
        # Generate sample queries (simplified)
        queries = self._generate_sample_queries(cells, num_queries)
        
        # Evaluate
        scores = []
        for query in queries:
            try:
                response = answering_system.answer_question(query)
                # Simple scoring based on citations
                score = min(len(response.citations) * 0.2, 1.0)
                scores.append(score)
            except Exception:
                scores.append(0.0)
        
        return {
            'notebook': Path(notebook_path).name,
            'total_cells': len(cells),
            'queries_tested': len(queries),
            'avg_score': sum(scores) / len(scores) if scores else 0,
            'scores': scores
        }
    
    def _generate_sample_queries(self, cells: List, num_queries: int) -> List[str]:
        """Generate sample queries for evaluation."""
        queries = []
        
        # Simple query generation
        if any('Q4' in cell.source for cell in cells):
            queries.append("Why did we remove Q4 data?")
        
        if any('plot' in cell.source.lower() for cell in cells):
            queries.append("What does the visualization show?")
        
        if any('model' in cell.source.lower() for cell in cells):
            queries.append("What model was used?")
        
        # Fill with generic queries
        generic_queries = [
            "What is the main purpose of this analysis?",
            "What data was used?",
            "What were the key findings?"
        ]
        
        while len(queries) < num_queries and generic_queries:
            queries.append(generic_queries.pop(0))
        
        return queries[:num_queries]
    
    def _aggregate_results(self) -> Dict:
        """Aggregate results across all notebooks."""
        if not self.results:
            return {}
        
        total_notebooks = len(self.results)
        avg_score = sum(r['avg_score'] for r in self.results) / total_notebooks
        total_queries = sum(r['queries_tested'] for r in self.results)
        
        return {
            'total_notebooks': total_notebooks,
            'total_queries': total_queries,
            'average_score': avg_score,
            'results': self.results
        }
    
    def print_summary(self, summary: Dict):
        """Print evaluation summary."""
        print("\n" + "="*50)
        print("EVALUATION SUMMARY")
        print("="*50)
        print(f"Notebooks evaluated: {summary.get('total_notebooks', 0)}")
        print(f"Total queries: {summary.get('total_queries', 0)}")
        print(".2%")
        
        for result in summary.get('results', []):
            print(f"\n{result['notebook']}:")
            print(f"  Cells: {result['total_cells']}")
            print(f"  Queries: {result['queries_tested']}")
            print(".2%")
    
    def save_results(self, output_file: str):
        """Save results to CSV."""
        import csv
        
        with open(output_file, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=['notebook', 'total_cells', 'queries_tested', 'avg_score'])
            writer.writeheader()
            for result in self.results:
                writer.writerow({
                    'notebook': result['notebook'],
                    'total_cells': result['total_cells'],
                    'queries_tested': result['queries_tested'],
                    'avg_score': result['avg_score']
                })