badminton001 commited on
Commit
70722a2
·
verified ·
1 Parent(s): 5e5ed91

Update retrieval/retrieve_books_50000.py

Browse files
Files changed (1) hide show
  1. retrieval/retrieve_books_50000.py +225 -228
retrieval/retrieve_books_50000.py CHANGED
@@ -1,229 +1,226 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- import json
5
- import pickle
6
- import numpy as np
7
- import faiss
8
- from pathlib import Path
9
- from sentence_transformers import SentenceTransformer
10
- from scipy.sparse import load_npz
11
- from typing import List, Dict, Any, Optional, Tuple
12
-
13
- # Import custom utility functions
14
- # Ensure utils/query_parser.py is the latest version for accurate tag extraction
15
- # Using the optimized query_parser
16
- from utils.query_parser import parse_user_query
17
- from utils.books_explanation import generate_explanation
18
-
19
- # ── Path Configurations ──────────────────────────────────────────────
20
- # Define the root directory of the project (one level up from 'retrieval' folder)
21
- ROOT = Path(__file__).parent.parent
22
- # Path to vectorized data (TF-IDF matrix, SBERT embeddings, etc.)
23
- VEC_DIR = ROOT / "data" / "book" / "vectorized"
24
- # Path to preprocessed book records (updated for 50,000 records)
25
- PREPROCESSED_DATA_PATH = ROOT / "data" / "book" / "preprocessed" / "books_preprocessed_50000.json"
26
-
27
- # ── Load Preprocessed Data ───────────────────────────────────────────
28
- book_records: List[Dict[str, Any]] = []
29
- try:
30
- with open(PREPROCESSED_DATA_PATH, encoding="utf-8") as f:
31
- book_records = json.load(f)
32
- print(f"Loaded {len(book_records)} book records.")
33
- except FileNotFoundError:
34
- print(f"Error: Preprocessed book data not found at {PREPROCESSED_DATA_PATH}")
35
- except json.JSONDecodeError:
36
- print(f"Error: Could not decode JSON from {PREPROCESSED_DATA_PATH}")
37
-
38
- # ── Load TF-IDF Index and Vectorizer ─────────────────────────────────
39
- tfidf_vectorizer = None
40
- tfidf_matrix = np.array([])
41
- try:
42
- # Updated TF-IDF asset filenames for 50,000 records
43
- tfidf_vectorizer = pickle.load(open(VEC_DIR / "books_tfidf_vectorizer_50000.pkl", "rb"))
44
- tfidf_matrix = load_npz(VEC_DIR / "books_tfidf_matrix_50000.npz").toarray().astype("float32")
45
- faiss.normalize_L2(tfidf_matrix)
46
- print("TF-IDF assets loaded.")
47
- except (FileNotFoundError, pickle.UnpicklingError, ValueError) as e:
48
- print(f"Error loading TF-IDF assets: {e}")
49
-
50
- # ── Load SBERT Index and Model ───────────────────────────────────────
51
- sbert_embeddings = np.array([])
52
- sbert_model = None
53
- try:
54
- # Updated SBERT asset filenames for 50,000 records
55
- sbert_embeddings = np.array(pickle.load(open(VEC_DIR / "books_sbert_embeddings_50000.pkl", "rb"))).astype("float32")
56
- sbert_model_name = open(VEC_DIR / "books_sbert_model_50000.txt").read().strip()
57
- sbert_model = SentenceTransformer(sbert_model_name)
58
- print("SBERT assets loaded.")
59
- except (FileNotFoundError, pickle.UnpicklingError, OSError) as e:
60
- print(f"Error loading SBERT assets: {e}")
61
-
62
-
63
- # ── Main Recommendation Function ─────────────────────────────────────
64
- def get_recommendations(
65
- query: str,
66
- top_k: int = 5,
67
- method: str = "sbert",
68
- parsed_query_tags: Optional[Dict[str, Any]] = None # Parameter for parsed tags
69
- ) -> List[Dict[str, Any]]:
70
- """
71
- Retrieves book recommendations based on user query, with enhanced filtering and re-ranking.
72
-
73
- Args:
74
- query (str): The user's input query.
75
- top_k (int): Number of top recommendations to return. Defaults to 5.
76
- method (str): The retrieval method to use ("sbert" for semantic, "tfidf" for keyword-based).
77
- parsed_query_tags (Optional[Dict[str, Any]]): Dictionary of parsed query tags (from query_parser.py).
78
-
79
- Returns:
80
- list: A list of dictionaries, where each dictionary represents a recommended book
81
- and includes its details, score, and an explanation.
82
- """
83
- if not book_records:
84
- print("Warning: Book records not loaded. Returning empty list.")
85
- return []
86
-
87
- # Parse query if tags are not already provided (e.g., direct call from an external script)
88
- if parsed_query_tags is None:
89
- parsed_query_tags = parse_user_query(query)
90
-
91
- # --- 1) Initial Candidate Selection (from full dataset) ---
92
- # Retrieve more candidates than requested top_k to allow for strict filtering
93
- CANDIDATE_MULTIPLIER = 20
94
- initial_search_k = top_k * CANDIDATE_MULTIPLIER
95
-
96
- hits: List[Tuple[int, float]] = [] # List of (original_index, similarity_score)
97
-
98
- if method == "tfidf" and tfidf_matrix.size > 0 and tfidf_vectorizer:
99
- query_vector = tfidf_vectorizer.transform([query]).toarray().astype("float32")
100
- faiss.normalize_L2(query_vector)
101
-
102
- faiss_idx_tfidf_full = faiss.IndexFlatIP(tfidf_matrix.shape[1])
103
- faiss_idx_tfidf_full.add(tfidf_matrix)
104
- distances, original_indices = faiss_idx_tfidf_full.search(query_vector, initial_search_k)
105
- hits = [(idx, float(distances[0][j])) for j, idx in enumerate(original_indices[0])]
106
-
107
- elif method == "sbert" and sbert_embeddings.size > 0 and sbert_model:
108
- query_vector = sbert_model.encode([query], convert_to_numpy=True).astype("float32")
109
-
110
- faiss_idx_sbert_full = faiss.IndexFlatL2(sbert_embeddings.shape[1])
111
- faiss_idx_sbert_full.add(sbert_embeddings)
112
- distances, original_indices = faiss_idx_sbert_full.search(query_vector, initial_search_k)
113
- # For L2 distance, smaller is better, so negate to make larger scores better for sorting
114
- hits = [(idx, -float(distances[0][j])) for j, idx in enumerate(original_indices[0])]
115
- else:
116
- print(f"Error: Invalid method '{method}' or required index/model is not available.")
117
- return []
118
-
119
- # --- 2) Filter and Re-rank based on parsed_query_tags ---
120
- filtered_and_scored_results: List[Dict[str, Any]] = []
121
-
122
- # Extract parsed query tags for easier access
123
- target_genres = set(parsed_query_tags.get("genres", []))
124
- target_moods = set(parsed_query_tags.get("mood", []))
125
- target_audience = parsed_query_tags.get("target_audience")
126
- target_era = parsed_query_tags.get("era")
127
- target_decade = parsed_query_tags.get("decade")
128
- specific_author = parsed_query_tags.get("specific_person") # Mapped to specific_person in parser
129
-
130
- # Define moods that should trigger a "hard exclusion" if the user implies negativity
131
- # This is a simple example; a more robust solution would involve sentiment analysis
132
- negative_exclusion_moods = {"sad", "dark", "grim", "bleak", "depressing", "gloomy", "somber", "disturbing", "heavy", "angry"}
133
-
134
-
135
- for original_idx, base_score in hits:
136
- book_data = book_records[original_idx].copy()
137
- item_score = base_score # Start with the base similarity score from vector search
138
- is_suitable = True # Flag to mark if the book meets all HARD filters
139
-
140
- # --- HARD FILTERS (If any of these conditions are not met, the item is excluded) ---
141
-
142
- # 1. Specific Author (Mandatory if requested)
143
- if specific_author:
144
- item_authors = [a.lower() for a in book_data.get("authors", [])]
145
- found_author = False
146
- # Check for partial, case-insensitive match in author names
147
- for author_name in item_authors:
148
- if specific_author.lower() in author_name:
149
- found_author = True
150
- break
151
- if not found_author:
152
- is_suitable = False # Exclude if specific author is requested but not found
153
- else:
154
- item_score += 0.5 # High boost for an exact or strong author match
155
-
156
- # 2. Target Audience (Mandatory if requested)
157
- if target_audience:
158
- item_audience = book_data.get("target_audience")
159
- # If item has an audience tag and it doesn't match the target, exclude
160
- if item_audience and item_audience != target_audience:
161
- is_suitable = False
162
-
163
- # 3. Era (Mandatory if requested and available in item data)
164
- if target_era:
165
- item_era = book_data.get("era")
166
- # Convert both to lower for case-insensitive comparison
167
- if item_era and item_era.lower() != target_era.lower():
168
- is_suitable = False
169
-
170
- # 4. Decade (Mandatory if requested and able to be determined from item data)
171
- if target_decade:
172
- item_publish_year = book_data.get("first_publish_year")
173
- if item_publish_year:
174
- # Calculate the decade of the book's publication year
175
- item_decade_str = f"{(item_publish_year // 10) * 10}s"
176
- if item_decade_str != target_decade:
177
- is_suitable = False
178
- else: # If no publish year, it cannot match a specific decade, so exclude
179
- is_suitable = False
180
-
181
- # 5. Mood Exclusion (New Hard Filter): If user explicitly asks for a non-negative mood
182
- # and an item has a negative mood, exclude it. This requires careful consideration
183
- # of user intent (e.g., "I don't want anything sad"). For now, a simplified check.
184
- # This is a conceptual addition and needs more sophisticated NLU to be fully robust.
185
- # For this example, we assume if ANY target mood is NOT in negative_exclusion_moods
186
- # AND the book has a negative_exclusion_mood, we exclude.
187
- if target_moods and not any(m in negative_exclusion_moods for m in target_moods): # User wants a positive/neutral mood
188
- item_moods = set(book_data.get("mood", []))
189
- if any(m in negative_exclusion_moods for m in item_moods): # Book has a negative mood
190
- is_suitable = False # Exclude if user avoids negative moods and book is negative
191
-
192
-
193
- # If any hard filter failed, this book is not suitable, skip to the next candidate
194
- if not is_suitable:
195
- continue
196
-
197
- # --- SOFT FILTERS (These conditions boost the score but do not strictly exclude) ---
198
- # Only apply soft filters if the item passed all hard filters
199
-
200
- # 1. Genres: Boost score based on the number of overlapping genres
201
- if target_genres:
202
- item_genres = set(book_data.get("genres", []))
203
- genre_matches = len(target_genres.intersection(item_genres))
204
- item_score += 0.1 * genre_matches # Small boost for each matching genre
205
-
206
- # 2. Moods: Boost score based on the number of overlapping moods (Increased weight for mood)
207
- if target_moods:
208
- item_moods = set(book_data.get("mood", []))
209
- mood_matches = len(target_moods.intersection(item_moods))
210
- item_score += 0.2 * mood_matches # Increased boost for each matching mood, reflecting importance
211
-
212
- # Add the book to results if it passed all hard filters
213
- # and include its calculated score
214
- book_data["score"] = item_score
215
- filtered_and_scored_results.append(book_data)
216
-
217
- # Sort the results by the final calculated score (higher score is better)
218
- # Using .get("score", -float('inf')) handles cases where 'score' might be missing (shouldn't happen here)
219
- filtered_and_scored_results.sort(key=lambda x: x.get("score", -float('inf')), reverse=True)
220
-
221
- # --- 3) Prepare final results ---
222
- # Take only the top_k results after filtering and re-ranking
223
- final_results = []
224
- for item in filtered_and_scored_results[:top_k]:
225
- # Generate a textual explanation for each recommendation
226
- item["explanation"] = generate_explanation(parsed_query_tags, item)
227
- final_results.append(item)
228
-
229
  return final_results
 
1
+ import json
2
+ import pickle
3
+ import numpy as np
4
+ import faiss
5
+ from pathlib import Path
6
+ from sentence_transformers import SentenceTransformer
7
+ from scipy.sparse import load_npz
8
+ from typing import List, Dict, Any, Optional, Tuple
9
+
10
+ # Import custom utility functions
11
+ # Ensure utils/query_parser.py is the latest version for accurate tag extraction
12
+ # Using the optimized query_parser
13
+ from utils.query_parser import parse_user_query
14
+ from utils.books_explanation import generate_explanation
15
+
16
+ # ── Path Configurations ──────────────────────────────────────────────
17
+ # Define the root directory of the project (one level up from 'retrieval' folder)
18
+ ROOT = Path(__file__).parent.parent
19
+ # Path to vectorized data (TF-IDF matrix, SBERT embeddings, etc.)
20
+ VEC_DIR = ROOT / "data" / "book" / "vectorized"
21
+ # Path to preprocessed book records (updated for 50,000 records)
22
+ PREPROCESSED_DATA_PATH = ROOT / "data" / "book" / "preprocessed" / "books_preprocessed_50000.json"
23
+
24
+ # ── Load Preprocessed Data ───────────────────────────────────────────
25
+ book_records: List[Dict[str, Any]] = []
26
+ try:
27
+ with open(PREPROCESSED_DATA_PATH, encoding="utf-8") as f:
28
+ book_records = json.load(f)
29
+ print(f"Loaded {len(book_records)} book records.")
30
+ except FileNotFoundError:
31
+ print(f"Error: Preprocessed book data not found at {PREPROCESSED_DATA_PATH}")
32
+ except json.JSONDecodeError:
33
+ print(f"Error: Could not decode JSON from {PREPROCESSED_DATA_PATH}")
34
+
35
+ # ── Load TF-IDF Index and Vectorizer ─────────────────────────────────
36
+ tfidf_vectorizer = None
37
+ tfidf_matrix = np.array([])
38
+ try:
39
+ # Updated TF-IDF asset filenames for 50,000 records
40
+ tfidf_vectorizer = pickle.load(open(VEC_DIR / "books_tfidf_vectorizer_50000.pkl", "rb"))
41
+ tfidf_matrix = load_npz(VEC_DIR / "books_tfidf_matrix_50000.npz").toarray().astype("float32")
42
+ faiss.normalize_L2(tfidf_matrix)
43
+ print("TF-IDF assets loaded.")
44
+ except (FileNotFoundError, pickle.UnpicklingError, ValueError) as e:
45
+ print(f"Error loading TF-IDF assets: {e}")
46
+
47
+ # ── Load SBERT Index and Model ───────────────────────────────────────
48
+ sbert_embeddings = np.array([])
49
+ sbert_model = None
50
+ try:
51
+ # Updated SBERT asset filenames for 50,000 records
52
+ sbert_embeddings = np.array(pickle.load(open(VEC_DIR / "books_sbert_embeddings_50000.pkl", "rb"))).astype("float32")
53
+ sbert_model_name = open(VEC_DIR / "books_sbert_model_50000.txt").read().strip()
54
+ sbert_model = SentenceTransformer(sbert_model_name)
55
+ print("SBERT assets loaded.")
56
+ except (FileNotFoundError, pickle.UnpicklingError, OSError) as e:
57
+ print(f"Error loading SBERT assets: {e}")
58
+
59
+
60
+ # ── Main Recommendation Function ─────────────────────────────────────
61
+ def get_recommendations(
62
+ query: str,
63
+ top_k: int = 5,
64
+ method: str = "sbert",
65
+ parsed_query_tags: Optional[Dict[str, Any]] = None # Parameter for parsed tags
66
+ ) -> List[Dict[str, Any]]:
67
+ """
68
+ Retrieves book recommendations based on user query, with enhanced filtering and re-ranking.
69
+
70
+ Args:
71
+ query (str): The user's input query.
72
+ top_k (int): Number of top recommendations to return. Defaults to 5.
73
+ method (str): The retrieval method to use ("sbert" for semantic, "tfidf" for keyword-based).
74
+ parsed_query_tags (Optional[Dict[str, Any]]): Dictionary of parsed query tags (from query_parser.py).
75
+
76
+ Returns:
77
+ list: A list of dictionaries, where each dictionary represents a recommended book
78
+ and includes its details, score, and an explanation.
79
+ """
80
+ if not book_records:
81
+ print("Warning: Book records not loaded. Returning empty list.")
82
+ return []
83
+
84
+ # Parse query if tags are not already provided (e.g., direct call from an external script)
85
+ if parsed_query_tags is None:
86
+ parsed_query_tags = parse_user_query(query)
87
+
88
+ # --- 1) Initial Candidate Selection (from full dataset) ---
89
+ # Retrieve more candidates than requested top_k to allow for strict filtering
90
+ CANDIDATE_MULTIPLIER = 20
91
+ initial_search_k = top_k * CANDIDATE_MULTIPLIER
92
+
93
+ hits: List[Tuple[int, float]] = [] # List of (original_index, similarity_score)
94
+
95
+ if method == "tfidf" and tfidf_matrix.size > 0 and tfidf_vectorizer:
96
+ query_vector = tfidf_vectorizer.transform([query]).toarray().astype("float32")
97
+ faiss.normalize_L2(query_vector)
98
+
99
+ faiss_idx_tfidf_full = faiss.IndexFlatIP(tfidf_matrix.shape[1])
100
+ faiss_idx_tfidf_full.add(tfidf_matrix)
101
+ distances, original_indices = faiss_idx_tfidf_full.search(query_vector, initial_search_k)
102
+ hits = [(idx, float(distances[0][j])) for j, idx in enumerate(original_indices[0])]
103
+
104
+ elif method == "sbert" and sbert_embeddings.size > 0 and sbert_model:
105
+ query_vector = sbert_model.encode([query], convert_to_numpy=True).astype("float32")
106
+
107
+ faiss_idx_sbert_full = faiss.IndexFlatL2(sbert_embeddings.shape[1])
108
+ faiss_idx_sbert_full.add(sbert_embeddings)
109
+ distances, original_indices = faiss_idx_sbert_full.search(query_vector, initial_search_k)
110
+ # For L2 distance, smaller is better, so negate to make larger scores better for sorting
111
+ hits = [(idx, -float(distances[0][j])) for j, idx in enumerate(original_indices[0])]
112
+ else:
113
+ print(f"Error: Invalid method '{method}' or required index/model is not available.")
114
+ return []
115
+
116
+ # --- 2) Filter and Re-rank based on parsed_query_tags ---
117
+ filtered_and_scored_results: List[Dict[str, Any]] = []
118
+
119
+ # Extract parsed query tags for easier access
120
+ target_genres = set(parsed_query_tags.get("genres", []))
121
+ target_moods = set(parsed_query_tags.get("mood", []))
122
+ target_audience = parsed_query_tags.get("target_audience")
123
+ target_era = parsed_query_tags.get("era")
124
+ target_decade = parsed_query_tags.get("decade")
125
+ specific_author = parsed_query_tags.get("specific_person") # Mapped to specific_person in parser
126
+
127
+ # Define moods that should trigger a "hard exclusion" if the user implies negativity
128
+ # This is a simple example; a more robust solution would involve sentiment analysis
129
+ negative_exclusion_moods = {"sad", "dark", "grim", "bleak", "depressing", "gloomy", "somber", "disturbing", "heavy", "angry"}
130
+
131
+
132
+ for original_idx, base_score in hits:
133
+ book_data = book_records[original_idx].copy()
134
+ item_score = base_score # Start with the base similarity score from vector search
135
+ is_suitable = True # Flag to mark if the book meets all HARD filters
136
+
137
+ # --- HARD FILTERS (If any of these conditions are not met, the item is excluded) ---
138
+
139
+ # 1. Specific Author (Mandatory if requested)
140
+ if specific_author:
141
+ item_authors = [a.lower() for a in book_data.get("authors", [])]
142
+ found_author = False
143
+ # Check for partial, case-insensitive match in author names
144
+ for author_name in item_authors:
145
+ if specific_author.lower() in author_name:
146
+ found_author = True
147
+ break
148
+ if not found_author:
149
+ is_suitable = False # Exclude if specific author is requested but not found
150
+ else:
151
+ item_score += 0.5 # High boost for an exact or strong author match
152
+
153
+ # 2. Target Audience (Mandatory if requested)
154
+ if target_audience:
155
+ item_audience = book_data.get("target_audience")
156
+ # If item has an audience tag and it doesn't match the target, exclude
157
+ if item_audience and item_audience != target_audience:
158
+ is_suitable = False
159
+
160
+ # 3. Era (Mandatory if requested and available in item data)
161
+ if target_era:
162
+ item_era = book_data.get("era")
163
+ # Convert both to lower for case-insensitive comparison
164
+ if item_era and item_era.lower() != target_era.lower():
165
+ is_suitable = False
166
+
167
+ # 4. Decade (Mandatory if requested and able to be determined from item data)
168
+ if target_decade:
169
+ item_publish_year = book_data.get("first_publish_year")
170
+ if item_publish_year:
171
+ # Calculate the decade of the book's publication year
172
+ item_decade_str = f"{(item_publish_year // 10) * 10}s"
173
+ if item_decade_str != target_decade:
174
+ is_suitable = False
175
+ else: # If no publish year, it cannot match a specific decade, so exclude
176
+ is_suitable = False
177
+
178
+ # 5. Mood Exclusion (New Hard Filter): If user explicitly asks for a non-negative mood
179
+ # and an item has a negative mood, exclude it. This requires careful consideration
180
+ # of user intent (e.g., "I don't want anything sad"). For now, a simplified check.
181
+ # This is a conceptual addition and needs more sophisticated NLU to be fully robust.
182
+ # For this example, we assume if ANY target mood is NOT in negative_exclusion_moods
183
+ # AND the book has a negative_exclusion_mood, we exclude.
184
+ if target_moods and not any(m in negative_exclusion_moods for m in target_moods): # User wants a positive/neutral mood
185
+ item_moods = set(book_data.get("mood", []))
186
+ if any(m in negative_exclusion_moods for m in item_moods): # Book has a negative mood
187
+ is_suitable = False # Exclude if user avoids negative moods and book is negative
188
+
189
+
190
+ # If any hard filter failed, this book is not suitable, skip to the next candidate
191
+ if not is_suitable:
192
+ continue
193
+
194
+ # --- SOFT FILTERS (These conditions boost the score but do not strictly exclude) ---
195
+ # Only apply soft filters if the item passed all hard filters
196
+
197
+ # 1. Genres: Boost score based on the number of overlapping genres
198
+ if target_genres:
199
+ item_genres = set(book_data.get("genres", []))
200
+ genre_matches = len(target_genres.intersection(item_genres))
201
+ item_score += 0.1 * genre_matches # Small boost for each matching genre
202
+
203
+ # 2. Moods: Boost score based on the number of overlapping moods (Increased weight for mood)
204
+ if target_moods:
205
+ item_moods = set(book_data.get("mood", []))
206
+ mood_matches = len(target_moods.intersection(item_moods))
207
+ item_score += 0.2 * mood_matches # Increased boost for each matching mood, reflecting importance
208
+
209
+ # Add the book to results if it passed all hard filters
210
+ # and include its calculated score
211
+ book_data["score"] = item_score
212
+ filtered_and_scored_results.append(book_data)
213
+
214
+ # Sort the results by the final calculated score (higher score is better)
215
+ # Using .get("score", -float('inf')) handles cases where 'score' might be missing (shouldn't happen here)
216
+ filtered_and_scored_results.sort(key=lambda x: x.get("score", -float('inf')), reverse=True)
217
+
218
+ # --- 3) Prepare final results ---
219
+ # Take only the top_k results after filtering and re-ranking
220
+ final_results = []
221
+ for item in filtered_and_scored_results[:top_k]:
222
+ # Generate a textual explanation for each recommendation
223
+ item["explanation"] = generate_explanation(parsed_query_tags, item)
224
+ final_results.append(item)
225
+
 
 
 
226
  return final_results