badminton001 commited on
Commit
f15ad29
·
verified ·
1 Parent(s): 29812d2

Update evaluation/evaluate_books.py

Browse files
Files changed (1) hide show
  1. evaluation/evaluate_books.py +349 -352
evaluation/evaluate_books.py CHANGED
@@ -1,352 +1,349 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- import os
5
- import sys
6
- import json
7
- import time
8
- import matplotlib.pyplot as plt
9
- import numpy as np # Import numpy for better array handling
10
- from typing import List, Tuple, Dict, Any, Callable
11
- from sklearn.metrics import precision_score, recall_score, f1_score
12
-
13
- # Set project root
14
- this_dir = os.path.dirname(__file__)
15
- project_root = os.path.abspath(os.path.join(this_dir, ".."))
16
- sys.path.append(project_root)
17
-
18
- # Assuming these imports are correct and available in your project structure
19
- from retrieval.retrieve_books_50000 import get_recommendations as book_recs, book_records
20
- # Import the updated query parser (renamed to user_query_parser)
21
- from utils.query_parser import parse_user_query
22
-
23
-
24
- # ---------- 1. Load Evaluation Data ----------
25
- def load_eval_data(test_file: str, gt_file: str) -> Tuple[List[str], List[List[int]]]:
26
- """
27
- Loads evaluation queries and ground truth data for books.
28
-
29
- Args:
30
- test_file (str): Filename for the test queries JSON.
31
- gt_file (str): Filename for the ground truth JSON.
32
-
33
- Returns:
34
- Tuple[List[str], List[List[int]]]: A tuple containing a list of queries
35
- and a list of corresponding ground truth indices.
36
- """
37
- base = os.path.join(project_root, "evaluation")
38
-
39
- # Read and parse test queries after removing comments
40
- test_path = os.path.join(base, test_file)
41
- try:
42
- with open(test_path, 'r', encoding='utf-8') as f:
43
- lines = f.readlines()
44
- # Remove lines that start with '//' as comments
45
- content = ''.join([l for l in lines if not l.strip().startswith('//')])
46
- queries_raw = json.loads(content)
47
- except FileNotFoundError:
48
- print(f"Error: Test queries file not found at {test_path}")
49
- return [], []
50
- except json.JSONDecodeError:
51
- print(f"Error: Could not decode JSON from {test_path}")
52
- return [], []
53
-
54
- # Read and parse Ground-truth after removing comments
55
- gt_path = os.path.join(base, gt_file)
56
- try:
57
- with open(gt_path, 'r', encoding='utf-8') as f:
58
- gt_lines = f.readlines()
59
- # Remove lines that start with '//' as comments
60
- gt_content = ''.join([l for l in gt_lines if not l.strip().startswith('//')])
61
- gt_map = json.loads(gt_content)
62
- except FileNotFoundError:
63
- print(f"Error: Ground truth file not found at {gt_path}")
64
- return [], []
65
- except json.JSONDecodeError:
66
- print(f"Error: Could not decode JSON from {gt_path}")
67
- return [], []
68
-
69
- # Build query_id -> query text map
70
- id_to_query = {item['query_id']: item['query'] for item in queries_raw}
71
- # Build query_id -> [ground truths] map
72
- id_to_gt = {int(qid): vals for qid, vals in gt_map.items()}
73
-
74
- # Align and return queries and truths
75
- queries, truths = [], []
76
- for qid, qtext in id_to_query.items():
77
- if qid in id_to_gt:
78
- queries.append(qtext)
79
- truths.append(id_to_gt[qid])
80
- print(f"✅ Loaded {len(queries)} queries with ground truths.")
81
- return queries, truths
82
-
83
-
84
- # ---------- 2. Retrieval Function Factory ----------
85
- def retrieval_func_factory(params: Dict[str, Any]) -> Callable[[str], List[Tuple[int, float]]]:
86
- """
87
- Creates a retrieval function based on specified parameters for books.
88
-
89
- Args:
90
- params (Dict[str, Any]): A dictionary containing 'top_k' and 'method' for retrieval.
91
-
92
- Returns:
93
- Callable[[str], List[Tuple[int, float]]]: A function that takes a query string
94
- and returns a list of (index, score) tuples.
95
- """
96
-
97
- def fn(query: str) -> List[Tuple[int, float]]:
98
- # Parse the query to extract tags using the user_query_parser
99
- parsed_tags = parse_user_query(query)
100
-
101
- # Call the book recommendation function with parsed_query_tags
102
- results = book_recs(query, top_k=params['top_k'], method=params['method'], parsed_query_tags=parsed_tags)
103
-
104
- # Create an index map from book title to its original index in book_records
105
- # This assumes book titles are unique enough for mapping. If not, a unique identifier
106
- # (like source_key or a dedicated ID) should be used.
107
- # Ensure book_records is correctly populated and contains 'title' and 'source_key' or similar unique ID
108
- index_map = {item.get('title'): idx for idx, item in enumerate(book_records) if item.get('title')}
109
- # As a fallback or if titles are not unique, consider using 'source_key'
110
- # index_map = {item.get('source_key'): idx for idx, item in enumerate(book_records) if item.get('source_key')}
111
-
112
-
113
- retrieved_items_with_indices = []
114
- for r in results:
115
- # Try to get the original index using the title (or other unique ID)
116
- # Use .get() with a default to avoid KeyError if title not in map
117
- original_idx = index_map.get(r.get('title'))
118
- if original_idx is not None and 'score' in r:
119
- retrieved_items_with_indices.append((original_idx, r['score']))
120
- # If original_idx is None, it means the recommended item's title wasn't found in the index map.
121
- # This could indicate an issue with how index_map is created or how results are structured.
122
-
123
- return retrieved_items_with_indices
124
-
125
- return fn
126
-
127
-
128
- # ---------- 3. Accuracy Evaluation ----------
129
- def evaluate_accuracy(retrieval_func: Callable[[str], List[Tuple[int, float]]], queries: List[str],
130
- truths: List[List[int]]) -> float:
131
- """
132
- Evaluates the accuracy (Top-1 Hit Rate) of the retrieval function for books.
133
-
134
- Args:
135
- retrieval_func (Callable): The retrieval function to evaluate.
136
- queries (List[str]): List of test queries.
137
- truths (List[List[int]]): List of ground truth indices for each query.
138
-
139
- Returns:
140
- float: The Top-1 accuracy score.
141
- """
142
- correct = 0
143
- for q, gt in zip(queries, truths):
144
- results = retrieval_func(q)
145
- # Check if the top-1 result is in the ground truth
146
- if results and results[0][0] in gt:
147
- correct += 1
148
- return correct / len(queries) if queries else 0.0
149
-
150
-
151
- # ---------- 4. Timing ----------
152
- def measure_response_time(retrieval_func: Callable[[str], Any], queries: List[str]) -> float:
153
- """
154
- Measures the average response time per query for the retrieval function.
155
-
156
- Args:
157
- retrieval_func (Callable): The retrieval function to measure.
158
- queries (List[str]): List of test queries.
159
-
160
- Returns:
161
- float: Average response time in seconds per query.
162
- """
163
- start = time.time()
164
- for q in queries:
165
- retrieval_func(q)
166
- end = time.time()
167
- return (end - start) / len(queries) if queries else 0.0
168
-
169
-
170
- # ---------- 5. Visualization ----------
171
- def plot_optimization_report(metrics_data: Dict[str, Dict[str, List[float]]],
172
- param_grid: Dict[str, List[Any]],
173
- save_path_prefix: str = 'optimization_report_books'):
174
- """
175
- Plots the optimization report for retrieval metrics for books, separating plots by metric
176
- and grouping lines by retrieval method.
177
-
178
- Args:
179
- metrics_data (Dict[str, Dict[str, List[float]]]): Structured dictionary
180
- {metric_name: {method_name: [list_of_values_for_top_ks]}}
181
- param_grid (Dict[str, List[Any]]): The original parameter grid used for evaluation.
182
- save_path_prefix (str): Prefix for saving the plot images (e.g., 'optimization_report_books_accuracy.png').
183
- """
184
- top_k_values = sorted(param_grid['top_k'])
185
- methods = param_grid['method']
186
-
187
- for metric_name, method_metrics in metrics_data.items():
188
- plt.figure(figsize=(10, 6))
189
- # Use a consistent color cycle for different methods
190
- colors = plt.cm.get_cmap('viridis', len(methods))
191
-
192
- for i, method in enumerate(methods):
193
- values = method_metrics.get(method, [])
194
- if values:
195
- plt.plot(top_k_values, values, label=f'{method} Method',
196
- marker='o', linestyle='-', linewidth=2, color=colors(i))
197
- # Add text labels for values on the plot
198
- for x, y in zip(top_k_values, values):
199
- plt.text(x, y, f'{y:.3f}', ha='center', va='bottom', fontsize=8)
200
-
201
-
202
- plt.xlabel('Top-k Value')
203
- plt.ylabel(metric_name.replace('_', ' ').title())
204
- plt.title(f'Book Retrieval Optimization - {metric_name.replace("_", " ").title()} by Top-k and Method')
205
- plt.xticks(top_k_values)
206
- plt.legend(title='Retrieval Method')
207
- plt.grid(True, linestyle='--', alpha=0.7)
208
- plt.tight_layout()
209
- save_file = f"{save_path_prefix}_{metric_name}.png"
210
- plt.savefig(save_file)
211
- print(f"✅ Plot saved to {save_file}")
212
- plt.close()
213
-
214
-
215
- # ---------- 6. Top-k and Binary Metrics ----------
216
- def compute_topk_metrics(retrieval_func: Callable, queries: List[str], truths: List[List[int]],
217
- k_values: List[int] = [1, 3, 5]):
218
- """
219
- Computes Top-k Hit Rates, and average Precision, Recall, and F1-score for books.
220
-
221
- Args:
222
- retrieval_func (Callable): The retrieval function to evaluate.
223
- queries (List[str]): List of test queries.
224
- truths (List[List[int]]): List of ground truth indices for each query.
225
- k_values (List[int]): List of k values for Top-k Hit Rate calculation.
226
- """
227
- hit_rates = {k: 0 for k in k_values}
228
- precisions, recalls, f1s = [], [], []
229
- total_queries = len(queries)
230
-
231
- # Total size of the book corpus, needed for binary metrics
232
- corpus_size = len(book_records)
233
-
234
- if total_queries == 0:
235
- print("No queries to evaluate.")
236
- return
237
-
238
- for q_idx, (q, gt) in enumerate(zip(queries, truths)):
239
- retrieved = retrieval_func(q)
240
- # Extract only the indices of retrieved items
241
- retrieved_ids = [idx for idx, _ in retrieved]
242
-
243
- # Calculate Top-k Hit Rate
244
- for k in k_values:
245
- if any(pred_id in gt for pred_id in retrieved_ids[:k]):
246
- hit_rates[k] += 1
247
-
248
- # Prepare for Precision, Recall, F1-score (binary classification for each item in corpus)
249
- # y_true: A binary list where 1 means the item is a ground truth, 0 otherwise
250
- y_true = [1 if i in gt else 0 for i in range(corpus_size)]
251
-
252
- # y_pred: A binary list where 1 means the item was retrieved, 0 otherwise
253
- # Only consider items that were actually retrieved by the system
254
- y_pred = [0] * corpus_size
255
- for idx in retrieved_ids:
256
- if idx < corpus_size: # Ensure index is within bounds
257
- y_pred[idx] = 1
258
-
259
- # Compute metrics for the current query and append
260
- precisions.append(precision_score(y_true, y_pred, zero_division=0))
261
- recalls.append(recall_score(y_true, y_pred, zero_division=0))
262
- f1s.append(f1_score(y_true, y_pred, zero_division=0))
263
-
264
- print("\n--- Book Retrieval Metrics ---")
265
- for k in k_values:
266
- print(f"Top@{k} Hit Rate: {hit_rates[k] / total_queries:.4f}")
267
-
268
- print(f"Avg Precision: {sum(precisions) / total_queries:.4f}")
269
- print(f"Avg Recall: {sum(recalls) / total_queries:.4f}")
270
- print(f"Avg F1: {sum(f1s) / total_queries:.4f}")
271
-
272
-
273
- # ---------- 7. Main Execution ----------
274
- if __name__ == '__main__':
275
- # Load book specific evaluation data
276
- # Make sure these files exist in your 'evaluation' directory
277
- queries_books, truths_books = load_eval_data('test/test_queries_books_100.json', 'test/ground_truth_books_100.json')
278
-
279
- if not queries_books:
280
- print("Exiting evaluation due to no queries loaded.")
281
- sys.exit(1)
282
-
283
- print(f"✅ {len(queries_books)} book queries loaded.\n")
284
-
285
- # Define parameter grid for optimization
286
- param_grid = {
287
- 'top_k': [1, 3, 5, 10], # Added 10 for more granular evaluation
288
- 'method': ['tfidf', 'sbert']
289
- }
290
-
291
- best_score, best_params = -1.0, {}
292
-
293
- # Store metrics in a more structured way for easier plotting
294
- metrics_for_plotting = {
295
- 'accuracy': {method: [] for method in param_grid['method']},
296
- 'response_time': {method: [] for method in param_grid['method']}
297
- }
298
-
299
- from itertools import product
300
-
301
- print("--- Starting Book Retrieval Evaluation ---")
302
- # Sort top_k values to ensure consistent plotting order
303
- sorted_top_k = sorted(param_grid['top_k'])
304
-
305
- # Temporary storage to build sorted lists for plotting
306
- temp_metrics_by_method_topk = {
307
- method: {k: {'accuracy': 0, 'response_time': 0} for k in sorted_top_k}
308
- for method in param_grid['method']
309
- }
310
-
311
- for combo in product(sorted_top_k, param_grid['method']):
312
- params = {
313
- 'top_k': combo[0],
314
- 'method': combo[1]
315
- }
316
-
317
- # Create retrieval function for current parameters
318
- func = retrieval_func_factory(params)
319
-
320
- # Evaluate accuracy and response time
321
- score = evaluate_accuracy(func, queries_books, truths_books)
322
- avg_time = measure_response_time(func, queries_books)
323
-
324
- # Store data for plotting
325
- temp_metrics_by_method_topk[params['method']][params['top_k']]['accuracy'] = score
326
- temp_metrics_by_method_topk[params['method']][params['top_k']]['response_time'] = avg_time
327
-
328
- print(f"Params {params} -> Acc: {score:.4f}, Time: {avg_time:.4f}s")
329
-
330
- # Track the best parameters based on accuracy (still global best acc)
331
- if score > best_score:
332
- best_score, best_params = score, params
333
-
334
- # Populate the metrics_for_plotting dictionary after all evaluations are done
335
- # This ensures the lists are in the correct order based on sorted_top_k
336
- for method in param_grid['method']:
337
- for k in sorted_top_k:
338
- metrics_for_plotting['accuracy'][method].append(temp_metrics_by_method_topk[method][k]['accuracy'])
339
- metrics_for_plotting['response_time'][method].append(temp_metrics_by_method_topk[method][k]['response_time'])
340
-
341
-
342
- print(f"\n✨ Best Params for Books: {best_params}, Accuracy: {best_score:.4f}")
343
-
344
- # Plot the optimization report using the improved function
345
- plot_optimization_report(metrics_for_plotting, param_grid,
346
- save_path_prefix='optimization_report_books')
347
-
348
- # Compute and print Top-k and binary metrics for the best performing model
349
- print("\n--- Detailed Metrics for Best Book Retrieval Model ---")
350
- compute_topk_metrics(retrieval_func_factory(best_params), queries_books, truths_books)
351
-
352
- print("\nBook evaluation complete.")
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import time
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np # Import numpy for better array handling
7
+ from typing import List, Tuple, Dict, Any, Callable
8
+ from sklearn.metrics import precision_score, recall_score, f1_score
9
+
10
+ # Set project root
11
+ this_dir = os.path.dirname(__file__)
12
+ project_root = os.path.abspath(os.path.join(this_dir, ".."))
13
+ sys.path.append(project_root)
14
+
15
+ # Assuming these imports are correct and available in your project structure
16
+ from retrieval.retrieve_books_50000 import get_recommendations as book_recs, book_records
17
+ # Import the updated query parser (renamed to user_query_parser)
18
+ from utils.query_parser import parse_user_query
19
+
20
+
21
+ # ---------- 1. Load Evaluation Data ----------
22
+ def load_eval_data(test_file: str, gt_file: str) -> Tuple[List[str], List[List[int]]]:
23
+ """
24
+ Loads evaluation queries and ground truth data for books.
25
+
26
+ Args:
27
+ test_file (str): Filename for the test queries JSON.
28
+ gt_file (str): Filename for the ground truth JSON.
29
+
30
+ Returns:
31
+ Tuple[List[str], List[List[int]]]: A tuple containing a list of queries
32
+ and a list of corresponding ground truth indices.
33
+ """
34
+ base = os.path.join(project_root, "evaluation")
35
+
36
+ # Read and parse test queries after removing comments
37
+ test_path = os.path.join(base, test_file)
38
+ try:
39
+ with open(test_path, 'r', encoding='utf-8') as f:
40
+ lines = f.readlines()
41
+ # Remove lines that start with '//' as comments
42
+ content = ''.join([l for l in lines if not l.strip().startswith('//')])
43
+ queries_raw = json.loads(content)
44
+ except FileNotFoundError:
45
+ print(f"Error: Test queries file not found at {test_path}")
46
+ return [], []
47
+ except json.JSONDecodeError:
48
+ print(f"Error: Could not decode JSON from {test_path}")
49
+ return [], []
50
+
51
+ # Read and parse Ground-truth after removing comments
52
+ gt_path = os.path.join(base, gt_file)
53
+ try:
54
+ with open(gt_path, 'r', encoding='utf-8') as f:
55
+ gt_lines = f.readlines()
56
+ # Remove lines that start with '//' as comments
57
+ gt_content = ''.join([l for l in gt_lines if not l.strip().startswith('//')])
58
+ gt_map = json.loads(gt_content)
59
+ except FileNotFoundError:
60
+ print(f"Error: Ground truth file not found at {gt_path}")
61
+ return [], []
62
+ except json.JSONDecodeError:
63
+ print(f"Error: Could not decode JSON from {gt_path}")
64
+ return [], []
65
+
66
+ # Build query_id -> query text map
67
+ id_to_query = {item['query_id']: item['query'] for item in queries_raw}
68
+ # Build query_id -> [ground truths] map
69
+ id_to_gt = {int(qid): vals for qid, vals in gt_map.items()}
70
+
71
+ # Align and return queries and truths
72
+ queries, truths = [], []
73
+ for qid, qtext in id_to_query.items():
74
+ if qid in id_to_gt:
75
+ queries.append(qtext)
76
+ truths.append(id_to_gt[qid])
77
+ print(f"✅ Loaded {len(queries)} queries with ground truths.")
78
+ return queries, truths
79
+
80
+
81
+ # ---------- 2. Retrieval Function Factory ----------
82
+ def retrieval_func_factory(params: Dict[str, Any]) -> Callable[[str], List[Tuple[int, float]]]:
83
+ """
84
+ Creates a retrieval function based on specified parameters for books.
85
+
86
+ Args:
87
+ params (Dict[str, Any]): A dictionary containing 'top_k' and 'method' for retrieval.
88
+
89
+ Returns:
90
+ Callable[[str], List[Tuple[int, float]]]: A function that takes a query string
91
+ and returns a list of (index, score) tuples.
92
+ """
93
+
94
+ def fn(query: str) -> List[Tuple[int, float]]:
95
+ # Parse the query to extract tags using the user_query_parser
96
+ parsed_tags = parse_user_query(query)
97
+
98
+ # Call the book recommendation function with parsed_query_tags
99
+ results = book_recs(query, top_k=params['top_k'], method=params['method'], parsed_query_tags=parsed_tags)
100
+
101
+ # Create an index map from book title to its original index in book_records
102
+ # This assumes book titles are unique enough for mapping. If not, a unique identifier
103
+ # (like source_key or a dedicated ID) should be used.
104
+ # Ensure book_records is correctly populated and contains 'title' and 'source_key' or similar unique ID
105
+ index_map = {item.get('title'): idx for idx, item in enumerate(book_records) if item.get('title')}
106
+ # As a fallback or if titles are not unique, consider using 'source_key'
107
+ # index_map = {item.get('source_key'): idx for idx, item in enumerate(book_records) if item.get('source_key')}
108
+
109
+
110
+ retrieved_items_with_indices = []
111
+ for r in results:
112
+ # Try to get the original index using the title (or other unique ID)
113
+ # Use .get() with a default to avoid KeyError if title not in map
114
+ original_idx = index_map.get(r.get('title'))
115
+ if original_idx is not None and 'score' in r:
116
+ retrieved_items_with_indices.append((original_idx, r['score']))
117
+ # If original_idx is None, it means the recommended item's title wasn't found in the index map.
118
+ # This could indicate an issue with how index_map is created or how results are structured.
119
+
120
+ return retrieved_items_with_indices
121
+
122
+ return fn
123
+
124
+
125
+ # ---------- 3. Accuracy Evaluation ----------
126
+ def evaluate_accuracy(retrieval_func: Callable[[str], List[Tuple[int, float]]], queries: List[str],
127
+ truths: List[List[int]]) -> float:
128
+ """
129
+ Evaluates the accuracy (Top-1 Hit Rate) of the retrieval function for books.
130
+
131
+ Args:
132
+ retrieval_func (Callable): The retrieval function to evaluate.
133
+ queries (List[str]): List of test queries.
134
+ truths (List[List[int]]): List of ground truth indices for each query.
135
+
136
+ Returns:
137
+ float: The Top-1 accuracy score.
138
+ """
139
+ correct = 0
140
+ for q, gt in zip(queries, truths):
141
+ results = retrieval_func(q)
142
+ # Check if the top-1 result is in the ground truth
143
+ if results and results[0][0] in gt:
144
+ correct += 1
145
+ return correct / len(queries) if queries else 0.0
146
+
147
+
148
+ # ---------- 4. Timing ----------
149
+ def measure_response_time(retrieval_func: Callable[[str], Any], queries: List[str]) -> float:
150
+ """
151
+ Measures the average response time per query for the retrieval function.
152
+
153
+ Args:
154
+ retrieval_func (Callable): The retrieval function to measure.
155
+ queries (List[str]): List of test queries.
156
+
157
+ Returns:
158
+ float: Average response time in seconds per query.
159
+ """
160
+ start = time.time()
161
+ for q in queries:
162
+ retrieval_func(q)
163
+ end = time.time()
164
+ return (end - start) / len(queries) if queries else 0.0
165
+
166
+
167
+ # ---------- 5. Visualization ----------
168
+ def plot_optimization_report(metrics_data: Dict[str, Dict[str, List[float]]],
169
+ param_grid: Dict[str, List[Any]],
170
+ save_path_prefix: str = 'optimization_report_books'):
171
+ """
172
+ Plots the optimization report for retrieval metrics for books, separating plots by metric
173
+ and grouping lines by retrieval method.
174
+
175
+ Args:
176
+ metrics_data (Dict[str, Dict[str, List[float]]]): Structured dictionary
177
+ {metric_name: {method_name: [list_of_values_for_top_ks]}}
178
+ param_grid (Dict[str, List[Any]]): The original parameter grid used for evaluation.
179
+ save_path_prefix (str): Prefix for saving the plot images (e.g., 'optimization_report_books_accuracy.png').
180
+ """
181
+ top_k_values = sorted(param_grid['top_k'])
182
+ methods = param_grid['method']
183
+
184
+ for metric_name, method_metrics in metrics_data.items():
185
+ plt.figure(figsize=(10, 6))
186
+ # Use a consistent color cycle for different methods
187
+ colors = plt.cm.get_cmap('viridis', len(methods))
188
+
189
+ for i, method in enumerate(methods):
190
+ values = method_metrics.get(method, [])
191
+ if values:
192
+ plt.plot(top_k_values, values, label=f'{method} Method',
193
+ marker='o', linestyle='-', linewidth=2, color=colors(i))
194
+ # Add text labels for values on the plot
195
+ for x, y in zip(top_k_values, values):
196
+ plt.text(x, y, f'{y:.3f}', ha='center', va='bottom', fontsize=8)
197
+
198
+
199
+ plt.xlabel('Top-k Value')
200
+ plt.ylabel(metric_name.replace('_', ' ').title())
201
+ plt.title(f'Book Retrieval Optimization - {metric_name.replace("_", " ").title()} by Top-k and Method')
202
+ plt.xticks(top_k_values)
203
+ plt.legend(title='Retrieval Method')
204
+ plt.grid(True, linestyle='--', alpha=0.7)
205
+ plt.tight_layout()
206
+ save_file = f"{save_path_prefix}_{metric_name}.png"
207
+ plt.savefig(save_file)
208
+ print(f"✅ Plot saved to {save_file}")
209
+ plt.close()
210
+
211
+
212
+ # ---------- 6. Top-k and Binary Metrics ----------
213
+ def compute_topk_metrics(retrieval_func: Callable, queries: List[str], truths: List[List[int]],
214
+ k_values: List[int] = [1, 3, 5]):
215
+ """
216
+ Computes Top-k Hit Rates, and average Precision, Recall, and F1-score for books.
217
+
218
+ Args:
219
+ retrieval_func (Callable): The retrieval function to evaluate.
220
+ queries (List[str]): List of test queries.
221
+ truths (List[List[int]]): List of ground truth indices for each query.
222
+ k_values (List[int]): List of k values for Top-k Hit Rate calculation.
223
+ """
224
+ hit_rates = {k: 0 for k in k_values}
225
+ precisions, recalls, f1s = [], [], []
226
+ total_queries = len(queries)
227
+
228
+ # Total size of the book corpus, needed for binary metrics
229
+ corpus_size = len(book_records)
230
+
231
+ if total_queries == 0:
232
+ print("No queries to evaluate.")
233
+ return
234
+
235
+ for q_idx, (q, gt) in enumerate(zip(queries, truths)):
236
+ retrieved = retrieval_func(q)
237
+ # Extract only the indices of retrieved items
238
+ retrieved_ids = [idx for idx, _ in retrieved]
239
+
240
+ # Calculate Top-k Hit Rate
241
+ for k in k_values:
242
+ if any(pred_id in gt for pred_id in retrieved_ids[:k]):
243
+ hit_rates[k] += 1
244
+
245
+ # Prepare for Precision, Recall, F1-score (binary classification for each item in corpus)
246
+ # y_true: A binary list where 1 means the item is a ground truth, 0 otherwise
247
+ y_true = [1 if i in gt else 0 for i in range(corpus_size)]
248
+
249
+ # y_pred: A binary list where 1 means the item was retrieved, 0 otherwise
250
+ # Only consider items that were actually retrieved by the system
251
+ y_pred = [0] * corpus_size
252
+ for idx in retrieved_ids:
253
+ if idx < corpus_size: # Ensure index is within bounds
254
+ y_pred[idx] = 1
255
+
256
+ # Compute metrics for the current query and append
257
+ precisions.append(precision_score(y_true, y_pred, zero_division=0))
258
+ recalls.append(recall_score(y_true, y_pred, zero_division=0))
259
+ f1s.append(f1_score(y_true, y_pred, zero_division=0))
260
+
261
+ print("\n--- Book Retrieval Metrics ---")
262
+ for k in k_values:
263
+ print(f"Top@{k} Hit Rate: {hit_rates[k] / total_queries:.4f}")
264
+
265
+ print(f"Avg Precision: {sum(precisions) / total_queries:.4f}")
266
+ print(f"Avg Recall: {sum(recalls) / total_queries:.4f}")
267
+ print(f"Avg F1: {sum(f1s) / total_queries:.4f}")
268
+
269
+
270
+ # ---------- 7. Main Execution ----------
271
+ if __name__ == '__main__':
272
+ # Load book specific evaluation data
273
+ # Make sure these files exist in your 'evaluation' directory
274
+ queries_books, truths_books = load_eval_data('test/test_queries_books_100.json', 'test/ground_truth_books_100.json')
275
+
276
+ if not queries_books:
277
+ print("Exiting evaluation due to no queries loaded.")
278
+ sys.exit(1)
279
+
280
+ print(f" {len(queries_books)} book queries loaded.\n")
281
+
282
+ # Define parameter grid for optimization
283
+ param_grid = {
284
+ 'top_k': [1, 3, 5, 10], # Added 10 for more granular evaluation
285
+ 'method': ['tfidf', 'sbert']
286
+ }
287
+
288
+ best_score, best_params = -1.0, {}
289
+
290
+ # Store metrics in a more structured way for easier plotting
291
+ metrics_for_plotting = {
292
+ 'accuracy': {method: [] for method in param_grid['method']},
293
+ 'response_time': {method: [] for method in param_grid['method']}
294
+ }
295
+
296
+ from itertools import product
297
+
298
+ print("--- Starting Book Retrieval Evaluation ---")
299
+ # Sort top_k values to ensure consistent plotting order
300
+ sorted_top_k = sorted(param_grid['top_k'])
301
+
302
+ # Temporary storage to build sorted lists for plotting
303
+ temp_metrics_by_method_topk = {
304
+ method: {k: {'accuracy': 0, 'response_time': 0} for k in sorted_top_k}
305
+ for method in param_grid['method']
306
+ }
307
+
308
+ for combo in product(sorted_top_k, param_grid['method']):
309
+ params = {
310
+ 'top_k': combo[0],
311
+ 'method': combo[1]
312
+ }
313
+
314
+ # Create retrieval function for current parameters
315
+ func = retrieval_func_factory(params)
316
+
317
+ # Evaluate accuracy and response time
318
+ score = evaluate_accuracy(func, queries_books, truths_books)
319
+ avg_time = measure_response_time(func, queries_books)
320
+
321
+ # Store data for plotting
322
+ temp_metrics_by_method_topk[params['method']][params['top_k']]['accuracy'] = score
323
+ temp_metrics_by_method_topk[params['method']][params['top_k']]['response_time'] = avg_time
324
+
325
+ print(f"Params {params} -> Acc: {score:.4f}, Time: {avg_time:.4f}s")
326
+
327
+ # Track the best parameters based on accuracy (still global best acc)
328
+ if score > best_score:
329
+ best_score, best_params = score, params
330
+
331
+ # Populate the metrics_for_plotting dictionary after all evaluations are done
332
+ # This ensures the lists are in the correct order based on sorted_top_k
333
+ for method in param_grid['method']:
334
+ for k in sorted_top_k:
335
+ metrics_for_plotting['accuracy'][method].append(temp_metrics_by_method_topk[method][k]['accuracy'])
336
+ metrics_for_plotting['response_time'][method].append(temp_metrics_by_method_topk[method][k]['response_time'])
337
+
338
+
339
+ print(f"\n✨ Best Params for Books: {best_params}, Accuracy: {best_score:.4f}")
340
+
341
+ # Plot the optimization report using the improved function
342
+ plot_optimization_report(metrics_for_plotting, param_grid,
343
+ save_path_prefix='optimization_report_books')
344
+
345
+ # Compute and print Top-k and binary metrics for the best performing model
346
+ print("\n--- Detailed Metrics for Best Book Retrieval Model ---")
347
+ compute_topk_metrics(retrieval_func_factory(best_params), queries_books, truths_books)
348
+
349
+ print("\nBook evaluation complete.")