mozzic commited on
Commit
03cbd8d
·
verified ·
1 Parent(s): 5c7e137

Upload src\evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src//evaluation.py +152 -0
src//evaluation.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation harness for testing the Context Thread Agent
3
+ """
4
+
5
+ from typing import List, Dict
6
+ from pathlib import Path
7
+ from src.parser import NotebookParser
8
+ from src.dependencies import ContextThreadBuilder
9
+ from src.indexing import FAISSIndexer
10
+ from src.retrieval import RetrievalEngine
11
+ from src.reasoning import ContextualAnsweringSystem
12
+ from src.intent import ContextThreadEnricher
13
+
14
+
15
+ class EvaluationHarness:
16
+ """Evaluation harness for notebook Q&A."""
17
+
18
+ def __init__(self, notebooks_dir: str):
19
+ self.notebooks_dir = Path(notebooks_dir)
20
+ self.results = []
21
+
22
+ def evaluate_all(self, queries_per_notebook: int = 3) -> Dict:
23
+ """Evaluate all notebooks in the directory."""
24
+ notebook_files = list(self.notebooks_dir.glob("*.ipynb"))
25
+
26
+ for nb_file in notebook_files:
27
+ print(f"Evaluating {nb_file.name}...")
28
+ result = self.evaluate_notebook(str(nb_file), queries_per_notebook)
29
+ self.results.append(result)
30
+
31
+ return self._aggregate_results()
32
+
33
+ def evaluate_notebook(self, notebook_path: str, num_queries: int) -> Dict:
34
+ """Evaluate a single notebook."""
35
+ # Load notebook
36
+ parser = NotebookParser()
37
+ result = parser.parse_file(notebook_path)
38
+ cells = result['cells']
39
+
40
+ # Build context thread
41
+ builder = ContextThreadBuilder(
42
+ notebook_name=Path(notebook_path).stem,
43
+ thread_id=f"eval_{id(self)}"
44
+ )
45
+ builder.add_cells(cells)
46
+ thread = builder.build()
47
+
48
+ # Enrich
49
+ enricher = ContextThreadEnricher()
50
+ thread = enricher.enrich(thread)
51
+
52
+ # Index
53
+ indexer = FAISSIndexer()
54
+ indexer.add_multiple(thread.units)
55
+
56
+ # Setup systems
57
+ engine = RetrievalEngine(thread, indexer)
58
+ answering_system = ContextualAnsweringSystem(engine)
59
+
60
+ # Generate sample queries (simplified)
61
+ queries = self._generate_sample_queries(cells, num_queries)
62
+
63
+ # Evaluate
64
+ scores = []
65
+ for query in queries:
66
+ try:
67
+ response = answering_system.answer_question(query)
68
+ # Simple scoring based on citations
69
+ score = min(len(response.citations) * 0.2, 1.0)
70
+ scores.append(score)
71
+ except Exception:
72
+ scores.append(0.0)
73
+
74
+ return {
75
+ 'notebook': Path(notebook_path).name,
76
+ 'total_cells': len(cells),
77
+ 'queries_tested': len(queries),
78
+ 'avg_score': sum(scores) / len(scores) if scores else 0,
79
+ 'scores': scores
80
+ }
81
+
82
+ def _generate_sample_queries(self, cells: List, num_queries: int) -> List[str]:
83
+ """Generate sample queries for evaluation."""
84
+ queries = []
85
+
86
+ # Simple query generation
87
+ if any('Q4' in cell.source for cell in cells):
88
+ queries.append("Why did we remove Q4 data?")
89
+
90
+ if any('plot' in cell.source.lower() for cell in cells):
91
+ queries.append("What does the visualization show?")
92
+
93
+ if any('model' in cell.source.lower() for cell in cells):
94
+ queries.append("What model was used?")
95
+
96
+ # Fill with generic queries
97
+ generic_queries = [
98
+ "What is the main purpose of this analysis?",
99
+ "What data was used?",
100
+ "What were the key findings?"
101
+ ]
102
+
103
+ while len(queries) < num_queries and generic_queries:
104
+ queries.append(generic_queries.pop(0))
105
+
106
+ return queries[:num_queries]
107
+
108
+ def _aggregate_results(self) -> Dict:
109
+ """Aggregate results across all notebooks."""
110
+ if not self.results:
111
+ return {}
112
+
113
+ total_notebooks = len(self.results)
114
+ avg_score = sum(r['avg_score'] for r in self.results) / total_notebooks
115
+ total_queries = sum(r['queries_tested'] for r in self.results)
116
+
117
+ return {
118
+ 'total_notebooks': total_notebooks,
119
+ 'total_queries': total_queries,
120
+ 'average_score': avg_score,
121
+ 'results': self.results
122
+ }
123
+
124
+ def print_summary(self, summary: Dict):
125
+ """Print evaluation summary."""
126
+ print("\n" + "="*50)
127
+ print("EVALUATION SUMMARY")
128
+ print("="*50)
129
+ print(f"Notebooks evaluated: {summary.get('total_notebooks', 0)}")
130
+ print(f"Total queries: {summary.get('total_queries', 0)}")
131
+ print(".2%")
132
+
133
+ for result in summary.get('results', []):
134
+ print(f"\n{result['notebook']}:")
135
+ print(f" Cells: {result['total_cells']}")
136
+ print(f" Queries: {result['queries_tested']}")
137
+ print(".2%")
138
+
139
+ def save_results(self, output_file: str):
140
+ """Save results to CSV."""
141
+ import csv
142
+
143
+ with open(output_file, 'w', newline='') as f:
144
+ writer = csv.DictWriter(f, fieldnames=['notebook', 'total_cells', 'queries_tested', 'avg_score'])
145
+ writer.writeheader()
146
+ for result in self.results:
147
+ writer.writerow({
148
+ 'notebook': result['notebook'],
149
+ 'total_cells': result['total_cells'],
150
+ 'queries_tested': result['queries_tested'],
151
+ 'avg_score': result['avg_score']
152
+ })