badminton001 commited on
Commit
3cfe3ca
·
verified ·
1 Parent(s): 70722a2

Update retrieval/retrieve_movies_50000.py

Browse files
Files changed (1) hide show
  1. retrieval/retrieve_movies_50000.py +219 -222
retrieval/retrieve_movies_50000.py CHANGED
@@ -1,222 +1,219 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- import json
5
- import pickle
6
- import numpy as np
7
- import faiss
8
- from pathlib import Path
9
- from sentence_transformers import SentenceTransformer
10
- from scipy.sparse import load_npz
11
- from typing import List, Dict, Any, Optional, Tuple
12
-
13
- # Import custom utility functions
14
- # Ensure utils/query_parser.py is the latest version for accurate tag extraction
15
- from utils.query_parser import parse_user_query
16
- from utils.movies_explanation import generate_explanation
17
-
18
- # ── Path Configurations ──────────────────────────────────────────────
19
- # Define the root directory of the project (one level up from 'retrieval' folder)
20
- ROOT = Path(__file__).parent.parent
21
- # Path to vectorized data (TF-IDF matrix, SBERT embeddings, etc.)
22
- VEC_DIR = ROOT / "data" / "movie" / "vectorized"
23
- # Path to preprocessed movie records (updated for 50,000 records)
24
- PREPROCESSED_DATA_PATH = ROOT / "data" / "movie" / "preprocessed" / "movies_preprocessed_50000.json"
25
-
26
- # ── Load Preprocessed Data ───────────────────────────────────────────
27
- movie_records: List[Dict[str, Any]] = []
28
- try:
29
- with open(PREPROCESSED_DATA_PATH, encoding="utf-8") as f:
30
- movie_records = json.load(f)
31
- print(f"Loaded {len(movie_records)} movie records.")
32
- except FileNotFoundError:
33
- print(f"Error: Preprocessed movie data not found at {PREPROCESSED_DATA_PATH}")
34
- except json.JSONDecodeError:
35
- print(f"Error: Could not decode JSON from {PREPROCESSED_DATA_PATH}")
36
-
37
- # ── Load TF-IDF Index and Vectorizer ─────────────────────────────────
38
- tfidf_vectorizer = None
39
- tfidf_matrix = np.array([])
40
- try:
41
- # Updated TF-IDF asset filenames for 50,000 records
42
- tfidf_vectorizer = pickle.load(open(VEC_DIR / "movies_tfidf_vectorizer_50000.pkl", "rb"))
43
- tfidf_matrix = load_npz(VEC_DIR / "movies_tfidf_matrix_50000.npz").toarray().astype("float32")
44
- faiss.normalize_L2(tfidf_matrix)
45
- print("TF-IDF assets loaded.")
46
- except (FileNotFoundError, pickle.UnpicklingError, ValueError) as e:
47
- print(f"Error loading TF-IDF assets: {e}")
48
-
49
- # ── Load SBERT Index and Model ───────────────────────────────────────
50
- sbert_embeddings = np.array([])
51
- sbert_model = None
52
- try:
53
- # Updated SBERT asset filenames for 50,000 records
54
- sbert_embeddings = np.array(pickle.load(open(VEC_DIR / "movies_sbert_embeddings_50000.pkl", "rb"))).astype("float32")
55
- sbert_model_name = open(VEC_DIR / "movies_sbert_model_50000.txt").read().strip()
56
- sbert_model = SentenceTransformer(sbert_model_name)
57
- print("SBERT assets loaded.")
58
- except (FileNotFoundError, pickle.UnpicklingError, OSError) as e:
59
- print(f"Error loading SBERT assets: {e}")
60
-
61
-
62
- # ── Main Recommendation Function ─────────────────────────────────────
63
- def get_recommendations(
64
- query: str,
65
- top_k: int = 5,
66
- method: str = "sbert",
67
- parsed_query_tags: Optional[Dict[str, Any]] = None # Parameter for parsed tags
68
- ) -> List[Dict[str, Any]]:
69
- """
70
- Retrieves movie recommendations based on user query, with enhanced filtering and re-ranking.
71
-
72
- Args:
73
- query (str): The user's input query.
74
- top_k (int): Number of top recommendations to return. Defaults to 5.
75
- method (str): The retrieval method to use ("sbert" for semantic, "tfidf" for keyword-based).
76
- parsed_query_tags (Optional[Dict[str, Any]]): Dictionary of parsed query tags (from query_parser.py).
77
-
78
- Returns:
79
- list: A list of dictionaries, where each dictionary represents a recommended movie
80
- and includes its details, score, and an explanation.
81
- """
82
- if not movie_records:
83
- print("Warning: Movie records not loaded. Returning empty list.")
84
- return []
85
-
86
- # Parse query if tags are not already provided (e.g., direct call from an external script)
87
- if parsed_query_tags is None:
88
- parsed_query_tags = parse_user_query(query)
89
-
90
- # --- 1) Initial Candidate Selection (from full dataset) ---
91
- # Retrieve more candidates than requested top_k to allow for strict filtering
92
- CANDIDATE_MULTIPLIER = 20
93
- initial_search_k = top_k * CANDIDATE_MULTIPLIER
94
-
95
- hits: List[Tuple[int, float]] = [] # List of (original_index, similarity_score)
96
-
97
- if method == "tfidf" and tfidf_matrix.size > 0 and tfidf_vectorizer:
98
- query_vector = tfidf_vectorizer.transform([query]).toarray().astype("float32")
99
- faiss.normalize_L2(query_vector)
100
-
101
- faiss_idx_tfidf_full = faiss.IndexFlatIP(tfidf_matrix.shape[1])
102
- faiss_idx_tfidf_full.add(tfidf_matrix)
103
- distances, original_indices = faiss_idx_tfidf_full.search(query_vector, initial_search_k)
104
- hits = [(idx, float(distances[0][j])) for j, idx in enumerate(original_indices[0])]
105
-
106
- elif method == "sbert" and sbert_embeddings.size > 0 and sbert_model:
107
- query_vector = sbert_model.encode([query], convert_to_numpy=True).astype("float32")
108
-
109
- faiss_idx_sbert_full = faiss.IndexFlatL2(sbert_embeddings.shape[1])
110
- faiss_idx_sbert_full.add(sbert_embeddings)
111
- distances, original_indices = faiss_idx_sbert_full.search(query_vector, initial_search_k)
112
- # For L2 distance, smaller is better, so negate to make larger scores better for sorting
113
- hits = [(idx, -float(distances[0][j])) for j, idx in enumerate(original_indices[0])]
114
- else:
115
- print(f"Error: Invalid method '{method}' or required index/model is not available.")
116
- return []
117
-
118
- # --- 2) Filter and Re-rank based on parsed_query_tags ---
119
- filtered_and_scored_results: List[Dict[str, Any]] = []
120
-
121
- # Extract parsed query tags for easier access
122
- target_genres = set(parsed_query_tags.get("genres", []))
123
- target_moods = set(parsed_query_tags.get("mood", []))
124
- target_audience = parsed_query_tags.get("target_audience")
125
- target_era = parsed_query_tags.get("era")
126
- target_decade = parsed_query_tags.get("decade")
127
- specific_director = parsed_query_tags.get("specific_person") # Mapped to specific_person in parser
128
-
129
- # Define moods that should trigger a "hard exclusion" if the user implies negativity
130
- # This is a simple example; a more robust solution would involve sentiment analysis
131
- negative_exclusion_moods = {"sad", "dark", "grim", "bleak", "depressing", "gloomy", "somber", "disturbing", "heavy", "angry", "chilling"}
132
-
133
-
134
- for original_idx, base_score in hits:
135
- movie_data = movie_records[original_idx].copy()
136
- item_score = base_score # Start with the base similarity score from vector search
137
- is_suitable = True # Flag to mark if the movie meets all HARD filters
138
-
139
- # --- HARD FILTERS (If any of these conditions are not met, the item is excluded) ---
140
-
141
- # 1. Specific Director (Mandatory if requested)
142
- if specific_director:
143
- item_director = movie_data.get("director")
144
- # Check for existence and then case-insensitive partial match
145
- if not item_director or specific_director.lower() not in item_director.lower():
146
- is_suitable = False # Exclude if specific director is requested but not found
147
- else:
148
- item_score += 0.5 # High boost for an exact or strong director match
149
-
150
- # 2. Target Audience (Mandatory if requested)
151
- if target_audience:
152
- item_audience = movie_data.get("target_audience")
153
- # If item has an audience tag and it doesn't match the target, exclude
154
- if item_audience and item_audience != target_audience:
155
- is_suitable = False
156
-
157
- # 3. Era (Mandatory if requested and available in item data)
158
- if target_era:
159
- item_era = movie_data.get("era")
160
- # Convert both to lower for case-insensitive comparison
161
- if item_era and item_era.lower() != target_era.lower():
162
- is_suitable = False
163
-
164
- # 4. Decade (Mandatory if requested and able to be determined from item data)
165
- if target_decade:
166
- item_release_date = movie_data.get("release_date", "")
167
- if item_release_date and len(item_release_date) >= 4:
168
- item_year = int(item_release_date[:4]) # Extract year from release_date
169
- # Calculate the decade of the movie's release year
170
- item_decade_str = f"{(item_year // 10) * 10}s"
171
- if item_decade_str != target_decade:
172
- is_suitable = False
173
- else: # If no release date, it cannot match a specific decade, so exclude
174
- is_suitable = False
175
-
176
- # 5. Mood Exclusion (New Hard Filter): If user explicitly asks for a non-negative mood
177
- # and an item has a negative mood, exclude it.
178
- # For this example, we assume if ANY target mood is NOT in negative_exclusion_moods
179
- # AND the movie has a negative_exclusion_mood, we exclude.
180
- if target_moods and not any(m in negative_exclusion_moods for m in target_moods): # User wants a positive/neutral mood
181
- item_moods = set(movie_data.get("mood", []))
182
- if any(m in negative_exclusion_moods for m in item_moods): # Movie has a negative mood
183
- is_suitable = False # Exclude if user avoids negative moods and movie is negative
184
-
185
-
186
- # If any hard filter failed, this movie is not suitable, skip to the next candidate
187
- if not is_suitable:
188
- continue
189
-
190
- # --- SOFT FILTERS (These conditions boost the score but do not strictly exclude) ---
191
- # Only apply soft filters if the item passed all hard filters
192
-
193
- # 1. Genres: Boost score based on the number of overlapping genres
194
- if target_genres:
195
- item_genres = set(movie_data.get("genres", []))
196
- genre_matches = len(target_genres.intersection(item_genres))
197
- item_score += 0.1 * genre_matches # Small boost for each matching genre
198
-
199
- # 2. Moods: Boost score based on the number of overlapping moods (Increased weight for mood)
200
- if target_moods:
201
- item_moods = set(movie_data.get("mood", []))
202
- mood_matches = len(target_moods.intersection(item_moods))
203
- item_score += 0.2 * mood_matches # Increased boost for each matching mood, reflecting importance
204
-
205
- # Add the movie to results if it passed all hard filters
206
- # and include its calculated score
207
- movie_data["score"] = item_score
208
- filtered_and_scored_results.append(movie_data)
209
-
210
- # Sort the results by the final calculated score (higher score is better)
211
- # Using .get("score", -float('inf')) handles cases where 'score' might be missing (shouldn't happen here)
212
- filtered_and_scored_results.sort(key=lambda x: x.get("score", -float('inf')), reverse=True)
213
-
214
- # --- 3) Prepare final results ---
215
- # Take only the top_k results after filtering and re-ranking
216
- final_results = []
217
- for item in filtered_and_scored_results[:top_k]:
218
- # Generate a textual explanation for each recommendation
219
- item["explanation"] = generate_explanation(parsed_query_tags, item)
220
- final_results.append(item)
221
-
222
- return final_results
 
1
+ import json
2
+ import pickle
3
+ import numpy as np
4
+ import faiss
5
+ from pathlib import Path
6
+ from sentence_transformers import SentenceTransformer
7
+ from scipy.sparse import load_npz
8
+ from typing import List, Dict, Any, Optional, Tuple
9
+
10
+ # Import custom utility functions
11
+ # Ensure utils/query_parser.py is the latest version for accurate tag extraction
12
+ from utils.query_parser import parse_user_query
13
+ from utils.movies_explanation import generate_explanation
14
+
15
+ # ── Path Configurations ──────────────────────────────────────────────
16
+ # Define the root directory of the project (one level up from 'retrieval' folder)
17
+ ROOT = Path(__file__).parent.parent
18
+ # Path to vectorized data (TF-IDF matrix, SBERT embeddings, etc.)
19
+ VEC_DIR = ROOT / "data" / "movie" / "vectorized"
20
+ # Path to preprocessed movie records (updated for 50,000 records)
21
+ PREPROCESSED_DATA_PATH = ROOT / "data" / "movie" / "preprocessed" / "movies_preprocessed_50000.json"
22
+
23
+ # ── Load Preprocessed Data ───────────────────────────────────────────
24
+ movie_records: List[Dict[str, Any]] = []
25
+ try:
26
+ with open(PREPROCESSED_DATA_PATH, encoding="utf-8") as f:
27
+ movie_records = json.load(f)
28
+ print(f"Loaded {len(movie_records)} movie records.")
29
+ except FileNotFoundError:
30
+ print(f"Error: Preprocessed movie data not found at {PREPROCESSED_DATA_PATH}")
31
+ except json.JSONDecodeError:
32
+ print(f"Error: Could not decode JSON from {PREPROCESSED_DATA_PATH}")
33
+
34
+ # ── Load TF-IDF Index and Vectorizer ─────────────────────────────────
35
+ tfidf_vectorizer = None
36
+ tfidf_matrix = np.array([])
37
+ try:
38
+ # Updated TF-IDF asset filenames for 50,000 records
39
+ tfidf_vectorizer = pickle.load(open(VEC_DIR / "movies_tfidf_vectorizer_50000.pkl", "rb"))
40
+ tfidf_matrix = load_npz(VEC_DIR / "movies_tfidf_matrix_50000.npz").toarray().astype("float32")
41
+ faiss.normalize_L2(tfidf_matrix)
42
+ print("TF-IDF assets loaded.")
43
+ except (FileNotFoundError, pickle.UnpicklingError, ValueError) as e:
44
+ print(f"Error loading TF-IDF assets: {e}")
45
+
46
+ # ── Load SBERT Index and Model ───────────────────────────────────────
47
+ sbert_embeddings = np.array([])
48
+ sbert_model = None
49
+ try:
50
+ # Updated SBERT asset filenames for 50,000 records
51
+ sbert_embeddings = np.array(pickle.load(open(VEC_DIR / "movies_sbert_embeddings_50000.pkl", "rb"))).astype("float32")
52
+ sbert_model_name = open(VEC_DIR / "movies_sbert_model_50000.txt").read().strip()
53
+ sbert_model = SentenceTransformer(sbert_model_name)
54
+ print("SBERT assets loaded.")
55
+ except (FileNotFoundError, pickle.UnpicklingError, OSError) as e:
56
+ print(f"Error loading SBERT assets: {e}")
57
+
58
+
59
+ # ── Main Recommendation Function ─────────────────────────────────────
60
+ def get_recommendations(
61
+ query: str,
62
+ top_k: int = 5,
63
+ method: str = "sbert",
64
+ parsed_query_tags: Optional[Dict[str, Any]] = None # Parameter for parsed tags
65
+ ) -> List[Dict[str, Any]]:
66
+ """
67
+ Retrieves movie recommendations based on user query, with enhanced filtering and re-ranking.
68
+
69
+ Args:
70
+ query (str): The user's input query.
71
+ top_k (int): Number of top recommendations to return. Defaults to 5.
72
+ method (str): The retrieval method to use ("sbert" for semantic, "tfidf" for keyword-based).
73
+ parsed_query_tags (Optional[Dict[str, Any]]): Dictionary of parsed query tags (from query_parser.py).
74
+
75
+ Returns:
76
+ list: A list of dictionaries, where each dictionary represents a recommended movie
77
+ and includes its details, score, and an explanation.
78
+ """
79
+ if not movie_records:
80
+ print("Warning: Movie records not loaded. Returning empty list.")
81
+ return []
82
+
83
+ # Parse query if tags are not already provided (e.g., direct call from an external script)
84
+ if parsed_query_tags is None:
85
+ parsed_query_tags = parse_user_query(query)
86
+
87
+ # --- 1) Initial Candidate Selection (from full dataset) ---
88
+ # Retrieve more candidates than requested top_k to allow for strict filtering
89
+ CANDIDATE_MULTIPLIER = 20
90
+ initial_search_k = top_k * CANDIDATE_MULTIPLIER
91
+
92
+ hits: List[Tuple[int, float]] = [] # List of (original_index, similarity_score)
93
+
94
+ if method == "tfidf" and tfidf_matrix.size > 0 and tfidf_vectorizer:
95
+ query_vector = tfidf_vectorizer.transform([query]).toarray().astype("float32")
96
+ faiss.normalize_L2(query_vector)
97
+
98
+ faiss_idx_tfidf_full = faiss.IndexFlatIP(tfidf_matrix.shape[1])
99
+ faiss_idx_tfidf_full.add(tfidf_matrix)
100
+ distances, original_indices = faiss_idx_tfidf_full.search(query_vector, initial_search_k)
101
+ hits = [(idx, float(distances[0][j])) for j, idx in enumerate(original_indices[0])]
102
+
103
+ elif method == "sbert" and sbert_embeddings.size > 0 and sbert_model:
104
+ query_vector = sbert_model.encode([query], convert_to_numpy=True).astype("float32")
105
+
106
+ faiss_idx_sbert_full = faiss.IndexFlatL2(sbert_embeddings.shape[1])
107
+ faiss_idx_sbert_full.add(sbert_embeddings)
108
+ distances, original_indices = faiss_idx_sbert_full.search(query_vector, initial_search_k)
109
+ # For L2 distance, smaller is better, so negate to make larger scores better for sorting
110
+ hits = [(idx, -float(distances[0][j])) for j, idx in enumerate(original_indices[0])]
111
+ else:
112
+ print(f"Error: Invalid method '{method}' or required index/model is not available.")
113
+ return []
114
+
115
+ # --- 2) Filter and Re-rank based on parsed_query_tags ---
116
+ filtered_and_scored_results: List[Dict[str, Any]] = []
117
+
118
+ # Extract parsed query tags for easier access
119
+ target_genres = set(parsed_query_tags.get("genres", []))
120
+ target_moods = set(parsed_query_tags.get("mood", []))
121
+ target_audience = parsed_query_tags.get("target_audience")
122
+ target_era = parsed_query_tags.get("era")
123
+ target_decade = parsed_query_tags.get("decade")
124
+ specific_director = parsed_query_tags.get("specific_person") # Mapped to specific_person in parser
125
+
126
+ # Define moods that should trigger a "hard exclusion" if the user implies negativity
127
+ # This is a simple example; a more robust solution would involve sentiment analysis
128
+ negative_exclusion_moods = {"sad", "dark", "grim", "bleak", "depressing", "gloomy", "somber", "disturbing", "heavy", "angry", "chilling"}
129
+
130
+
131
+ for original_idx, base_score in hits:
132
+ movie_data = movie_records[original_idx].copy()
133
+ item_score = base_score # Start with the base similarity score from vector search
134
+ is_suitable = True # Flag to mark if the movie meets all HARD filters
135
+
136
+ # --- HARD FILTERS (If any of these conditions are not met, the item is excluded) ---
137
+
138
+ # 1. Specific Director (Mandatory if requested)
139
+ if specific_director:
140
+ item_director = movie_data.get("director")
141
+ # Check for existence and then case-insensitive partial match
142
+ if not item_director or specific_director.lower() not in item_director.lower():
143
+ is_suitable = False # Exclude if specific director is requested but not found
144
+ else:
145
+ item_score += 0.5 # High boost for an exact or strong director match
146
+
147
+ # 2. Target Audience (Mandatory if requested)
148
+ if target_audience:
149
+ item_audience = movie_data.get("target_audience")
150
+ # If item has an audience tag and it doesn't match the target, exclude
151
+ if item_audience and item_audience != target_audience:
152
+ is_suitable = False
153
+
154
+ # 3. Era (Mandatory if requested and available in item data)
155
+ if target_era:
156
+ item_era = movie_data.get("era")
157
+ # Convert both to lower for case-insensitive comparison
158
+ if item_era and item_era.lower() != target_era.lower():
159
+ is_suitable = False
160
+
161
+ # 4. Decade (Mandatory if requested and able to be determined from item data)
162
+ if target_decade:
163
+ item_release_date = movie_data.get("release_date", "")
164
+ if item_release_date and len(item_release_date) >= 4:
165
+ item_year = int(item_release_date[:4]) # Extract year from release_date
166
+ # Calculate the decade of the movie's release year
167
+ item_decade_str = f"{(item_year // 10) * 10}s"
168
+ if item_decade_str != target_decade:
169
+ is_suitable = False
170
+ else: # If no release date, it cannot match a specific decade, so exclude
171
+ is_suitable = False
172
+
173
+ # 5. Mood Exclusion (New Hard Filter): If user explicitly asks for a non-negative mood
174
+ # and an item has a negative mood, exclude it.
175
+ # For this example, we assume if ANY target mood is NOT in negative_exclusion_moods
176
+ # AND the movie has a negative_exclusion_mood, we exclude.
177
+ if target_moods and not any(m in negative_exclusion_moods for m in target_moods): # User wants a positive/neutral mood
178
+ item_moods = set(movie_data.get("mood", []))
179
+ if any(m in negative_exclusion_moods for m in item_moods): # Movie has a negative mood
180
+ is_suitable = False # Exclude if user avoids negative moods and movie is negative
181
+
182
+
183
+ # If any hard filter failed, this movie is not suitable, skip to the next candidate
184
+ if not is_suitable:
185
+ continue
186
+
187
+ # --- SOFT FILTERS (These conditions boost the score but do not strictly exclude) ---
188
+ # Only apply soft filters if the item passed all hard filters
189
+
190
+ # 1. Genres: Boost score based on the number of overlapping genres
191
+ if target_genres:
192
+ item_genres = set(movie_data.get("genres", []))
193
+ genre_matches = len(target_genres.intersection(item_genres))
194
+ item_score += 0.1 * genre_matches # Small boost for each matching genre
195
+
196
+ # 2. Moods: Boost score based on the number of overlapping moods (Increased weight for mood)
197
+ if target_moods:
198
+ item_moods = set(movie_data.get("mood", []))
199
+ mood_matches = len(target_moods.intersection(item_moods))
200
+ item_score += 0.2 * mood_matches # Increased boost for each matching mood, reflecting importance
201
+
202
+ # Add the movie to results if it passed all hard filters
203
+ # and include its calculated score
204
+ movie_data["score"] = item_score
205
+ filtered_and_scored_results.append(movie_data)
206
+
207
+ # Sort the results by the final calculated score (higher score is better)
208
+ # Using .get("score", -float('inf')) handles cases where 'score' might be missing (shouldn't happen here)
209
+ filtered_and_scored_results.sort(key=lambda x: x.get("score", -float('inf')), reverse=True)
210
+
211
+ # --- 3) Prepare final results ---
212
+ # Take only the top_k results after filtering and re-ranking
213
+ final_results = []
214
+ for item in filtered_and_scored_results[:top_k]:
215
+ # Generate a textual explanation for each recommendation
216
+ item["explanation"] = generate_explanation(parsed_query_tags, item)
217
+ final_results.append(item)
218
+
219
+ return final_results