Iris314 commited on
Commit
c81cc13
·
verified ·
1 Parent(s): 69dd7d5

Update recipe_recommendation/src/coldstart.py

Browse files
Files changed (1) hide show
  1. recipe_recommendation/src/coldstart.py +386 -386
recipe_recommendation/src/coldstart.py CHANGED
@@ -1,387 +1,387 @@
1
- import os
2
- import ast
3
- import json
4
- import random
5
- import pandas as pd
6
- import numpy as np
7
- from tqdm import tqdm
8
- import warnings
9
-
10
- from .candidate import coarse_rank_candidates, hard_filter, rule_generate_candidates
11
- from .feature import build_features
12
- from .io import load_recipes_csv, load_ingredient_map
13
-
14
- RECIPES_PATH = load_recipes_csv()
15
- INGREDIENT_MAP = load_ingredient_map()
16
- PARENTS = INGREDIENT_MAP["parents"]
17
- CHILDREN = INGREDIENT_MAP["children"]
18
-
19
- def parse_list(x):
20
- """Convert a stringified list into a Python list safely."""
21
- if pd.isna(x) or x == "":
22
- return []
23
- if isinstance(x, list):
24
- return x
25
- try:
26
- return ast.literal_eval(x)
27
- except Exception:
28
- return []
29
-
30
- def parse_set(x):
31
- """Convert a stringified collection into a Python set safely."""
32
- if pd.isna(x) or x == "":
33
- return set()
34
- if isinstance(x, set):
35
- return x
36
- if isinstance(x, (list, tuple)):
37
- return set(x)
38
- if isinstance(x, str):
39
- try:
40
- v = ast.literal_eval(x)
41
- if isinstance(v, (list, tuple, set)):
42
- return set(v)
43
- return {v}
44
- except Exception:
45
- return {x.strip()}
46
- return {x}
47
-
48
- def _parents_pool_from_df(df: pd.DataFrame):
49
- cols = ["main_parent", "staple_parent", "other_parent", "seasoning_parent"]
50
- pool = set()
51
- for c in cols:
52
- if c in df.columns:
53
- for s in df[c]:
54
- pool |= set(s) if isinstance(s, (set, list, tuple)) else set()
55
- return sorted(pool)
56
-
57
-
58
- def sample_user_parents(parents_pool,
59
- user_profile=None,
60
- prev_inventory=None,
61
- min_items=3, max_items=10,
62
- keep_ratio=0.6, reset_interval=20, round_idx=0):
63
- liked = set((user_profile or {}).get("other_preferences", {}).get("preferred_main", []))
64
- disliked = set((user_profile or {}).get("other_preferences", {}).get("disliked_main", []))
65
- forbidden = set((user_profile or {}).get("forbidden_parents", [])) | disliked
66
-
67
- pool, weights = [], []
68
- for p in parents_pool:
69
- if p in forbidden:
70
- continue
71
- w = 3.0 if p in liked else 1.0
72
- pool.append(p); weights.append(w)
73
- if not pool:
74
- pool, weights = parents_pool[:], [1.0] * len(parents_pool)
75
-
76
- inventory = set()
77
- force_reset = (round_idx % reset_interval == 0)
78
- if prev_inventory and not force_reset:
79
- prev_list = list(prev_inventory); random.shuffle(prev_list)
80
- keep_k = max(0, int(len(prev_list) * keep_ratio))
81
- inventory |= set(prev_list[:keep_k])
82
-
83
- k = random.randint(min_items, max_items)
84
- remain = max(0, k - len(inventory))
85
- for _ in range(min(remain, len(pool))):
86
- idx = random.choices(range(len(pool)), weights=weights, k=1)[0]
87
- inventory.add(pool[idx])
88
- return list(inventory)
89
-
90
-
91
- def _weighted_pick3(indexes, scores, temperature=1.0):
92
- idxs = list(indexes)
93
- scs = np.array(scores, dtype=float)
94
- if np.any(scs < 0):
95
- scs = scs - scs.min()
96
- if scs.sum() == 0:
97
- scs = np.ones_like(scs)
98
- picks = []
99
- for _ in range(min(3, len(idxs))):
100
- probs = np.exp(scs / max(temperature, 1e-6))
101
- probs = probs / probs.sum()
102
- choice = np.random.choice(len(idxs), p=probs)
103
- picks.append(idxs[choice])
104
- idxs.pop(choice)
105
- scs = np.delete(scs, choice)
106
- if len(idxs) == 0:
107
- break
108
- return picks
109
-
110
-
111
- # ---------- Main cold-start ----------
112
- # ---------- Main cold-start ----------
113
- def cold_start_ranker(user_id: str,
114
- n_rounds: int = 2000,
115
- topn_coarse: int = 5000,
116
- topk_rule: int = 3,
117
- batch_size: int = 5000,
118
- switch_interval: int = 100):
119
- """
120
- Cold-start data generation for learning-to-rank.
121
- Top-5 selection prioritizes user pantry coverage deterministically:
122
- 1. Fully covered recipes first (missing_count == 0)
123
- 2. Then few missing (esp. staple/other)
124
- 3. Heavy penalty for missing main ingredients.
125
- """
126
-
127
- base_dir = os.path.join("recipe_recommendation", "user_data", user_id)
128
- if not os.path.exists(base_dir):
129
- base_dir = os.path.join("recipe_recommendation", "input_user_data", user_id)
130
-
131
- if not os.path.exists(base_dir):
132
- raise FileNotFoundError(
133
- f"❌ User profile not found for '{user_id}' in either 'recipe_recommendation/user_data' or 'recipe_recommendation/input_user_data'."
134
- )
135
-
136
- print(f"[cold_start_ranker] Using base_dir = {base_dir}")
137
-
138
- profile_path = os.path.join(base_dir, "user_profile.json")
139
- features_path = os.path.join(base_dir, "user_features_rank.csv")
140
-
141
- if os.path.exists(features_path):
142
- print(f"[cold_start] Features already exist at {features_path}")
143
- return features_path
144
-
145
- with open(profile_path, "r", encoding="utf-8") as f:
146
- user_profile = json.load(f)
147
-
148
- # Load and parse recipes
149
- df_all = pd.read_csv(RECIPES_PATH)
150
- to_set = ["main_parent", "staple_parent", "other_parent", "seasoning_parent", "cuisine_attr"]
151
- to_list = ["ingredients"]
152
- for c in to_set:
153
- if c in df_all.columns:
154
- df_all[c] = df_all[c].apply(parse_set)
155
- for c in to_list:
156
- if c in df_all.columns:
157
- df_all[c] = df_all[c].apply(parse_list)
158
-
159
- # Step 1 hard filter
160
- if hard_filter is not None:
161
- try:
162
- before = len(df_all)
163
- mask = df_all.apply(lambda r: hard_filter(r.to_dict(), user_profile), axis=1)
164
- df_all = df_all[mask]
165
- after = len(df_all)
166
- print(f"[cold_start] Step1 hard filter applied: {before} -> {after}")
167
- except Exception as e:
168
- warnings.warn(f"[cold_start] hard_filter failed, skip. err={e}")
169
-
170
- n_chunks = (len(df_all) // batch_size) + 1
171
- chunks = np.array_split(df_all, n_chunks)
172
- parents_pool = _parents_pool_from_df(df_all)
173
- rows = []
174
- prev_inventory = None
175
-
176
- for i in tqdm(range(n_rounds), desc="Cold-start rounds"):
177
- chunk_id = (i // switch_interval) % n_chunks
178
- df_chunk = chunks[chunk_id].copy()
179
-
180
- # pantry sampling
181
- user_parents = sample_user_parents(
182
- parents_pool,
183
- user_profile=user_profile,
184
- prev_inventory=prev_inventory,
185
- round_idx=i
186
- )
187
- prev_inventory = user_parents
188
-
189
- # Step 2: coarse recall
190
- coarse_list = coarse_rank_candidates(
191
- recipes=df_chunk.to_dict(orient="records"),
192
- user_parents=user_parents,
193
- user_profile=user_profile,
194
- top_n=min(topn_coarse, len(df_chunk))
195
- )
196
- if not coarse_list:
197
- continue
198
-
199
- coarse_df = pd.DataFrame(coarse_list)
200
-
201
- # Step 3: rule rerank → Top-5 candidates (just for selecting the 5)
202
- rule_df = rule_generate_candidates(
203
- coarse_df,
204
- user_parents=user_parents,
205
- user_profile=user_profile
206
- )
207
- if rule_df.empty or len(rule_df) < topk_rule:
208
- continue
209
-
210
- top5 = rule_df.head(topk_rule).copy()
211
-
212
- # ===== Deterministic scoring with feasibility + region + soft constraints =====
213
- user_set = set(user_parents)
214
- scored_candidates = []
215
-
216
- # Nutrition goals (from profile)
217
- ng = user_profile.get("nutritional_goals", {})
218
- cal_min = ng.get("calories", {}).get("min", 0)
219
- cal_max = ng.get("calories", {}).get("max", 1e9)
220
- pro_min = ng.get("protein", {}).get("min", 0)
221
- pro_max = ng.get("protein", {}).get("max", 1e9)
222
-
223
- # Preferences
224
- liked = set(user_profile.get("other_preferences", {}).get("preferred_main", []))
225
- disliked = set(user_profile.get("other_preferences", {}).get("disliked_main", []))
226
- max_cooking_time = user_profile.get("other_preferences", {}).get("cooking_time_max", None)
227
-
228
- for idx, row in top5.iterrows():
229
- main_set = set(row.get("main_parent", set()))
230
- staple_set = set(row.get("staple_parent", set()))
231
- other_set = set(row.get("other_parent", set()))
232
-
233
- main_total = len(main_set)
234
- staple_total = len(staple_set)
235
- main_match = len(main_set & user_set)
236
- staple_match = len(staple_set & user_set)
237
-
238
- # === 1) Feasibility check ===
239
- total_needed = max(1, main_total + staple_total)
240
- total_have = main_match + staple_match
241
- coverage_ratio = total_have / total_needed
242
-
243
- if coverage_ratio < 0.5:
244
- continue
245
-
246
- # === 2) Region preference ===
247
- region_score = 1.0 if row.get("region_match", 0) else 0.0
248
-
249
- # === 3) Cooking time soft constraint ===
250
- time_val = row.get("minutes", None)
251
- time_score = 0.0
252
- if max_cooking_time and time_val is not None:
253
- try:
254
- t_val = float(time_val)
255
- t_max = float(max_cooking_time)
256
- lower_bound = 0.8 * t_max
257
- upper_bound = 1.2 * t_max
258
- if lower_bound <= t_val <= upper_bound:
259
- time_score = 1.0
260
- else:
261
- deviation = abs(t_val - t_max) / t_max
262
- time_score = max(0.0, 1.0 - deviation)
263
- except (TypeError, ValueError):
264
- time_score = 0.0
265
- else:
266
- time_score = 1.0
267
-
268
- # === 4) Calories soft constraint ===
269
- cal_val = row.get("calories", None)
270
- cal_score = 1.0
271
- if cal_val is not None and cal_min < cal_max:
272
- try:
273
- c_val = float(cal_val)
274
- cal_center = 0.5 * (cal_min + cal_max)
275
- tol = 0.3 * cal_center
276
- lower_bound = cal_center - tol
277
- upper_bound = cal_center + tol
278
- if lower_bound <= c_val <= upper_bound:
279
- cal_score = 1.0
280
- else:
281
- deviation = abs(c_val - cal_center) / cal_center
282
- cal_score = max(0.0, 1.0 - deviation)
283
- except (TypeError, ValueError):
284
- cal_score = 0.0
285
-
286
- # === 4b) Protein soft constraint ===
287
- protein_val = row.get("protein", None)
288
- protein_score = 1.0
289
- if protein_val is not None and pro_min < pro_max:
290
- try:
291
- p_val = float(protein_val)
292
- pro_center = 0.5 * (pro_min + pro_max)
293
- tol = 0.2 * pro_center
294
- lower_bound = pro_center - tol
295
- upper_bound = pro_center + tol
296
- if lower_bound <= p_val <= upper_bound:
297
- protein_score = 1.0
298
- else:
299
- deviation = abs(p_val - pro_center) / pro_center
300
- protein_score = max(0.0, 1.0 - deviation)
301
- except (TypeError, ValueError):
302
- protein_score = 0.0
303
-
304
- # === 5) Liked / Disliked main ===
305
- like_bonus = 1.0 if main_set & liked else 0.0
306
- dislike_penalty = 1.0 if main_set & disliked else 0.0
307
-
308
- # === 6) Final scoring ===
309
- score = (
310
- 0.5 * coverage_ratio +
311
- 0.15 * region_score +
312
- 0.1 * time_score +
313
- 0.1 * cal_score +
314
- 0.05 * protein_score +
315
- 0.05 * like_bonus -
316
- 0.05 * dislike_penalty
317
- )
318
-
319
- scored_candidates.append((idx, score))
320
-
321
- # Sort and pick top3 for relevance
322
- scored_candidates.sort(key=lambda x: x[1], reverse=True)
323
- picked_idxs = [idx for idx, _ in scored_candidates[:3]]
324
-
325
- # relevance labels 3 / 2 / 1
326
- labels = {idx: 0 for idx in top5.index}
327
- if len(picked_idxs) > 0:
328
- labels[picked_idxs[0]] = 3
329
- if len(picked_idxs) > 1:
330
- labels[picked_idxs[1]] = 2
331
- if len(picked_idxs) > 2:
332
- labels[picked_idxs[2]] = 1
333
-
334
- # build features for all 5 candidates
335
- for idx, row in top5.iterrows():
336
- up = set(user_parents)
337
- main_set = set(row.get("main_parent", set()))
338
- staple_set = set(row.get("staple_parent", set()))
339
- other_set = set(row.get("other_parent", set()))
340
-
341
- recipe_dict = {
342
- "main": main_set,
343
- "staple": staple_set,
344
- "other": other_set,
345
- "seasoning": set(row.get("seasoning_parent", set())),
346
- "matched_main": len(main_set & up),
347
- "matched_staple": len(staple_set & up),
348
- "matched_other": len(other_set & up),
349
- "calories": row.get("calories", 0),
350
- "protein": row.get("protein", 0),
351
- "fat": row.get("fat", 0),
352
- "region": row.get("region", ""),
353
- "cuisine_attr": row.get("cuisine_attr", []),
354
- "ingredients": row.get("ingredients", []),
355
- "minutes": row.get("minutes", None),
356
- }
357
-
358
- feats = build_features(recipe_dict, user_profile)
359
- feats["relevance"] = float(labels[idx])
360
- feats["qid"] = int(i)
361
- rows.append(feats)
362
-
363
- out = pd.DataFrame(rows)
364
- if "qid" not in out.columns or out.empty:
365
- print(f"[cold_start] No valid training data generated for {user_id}, skipping save.")
366
- return None
367
-
368
- valid_qids = out.groupby("qid").size()
369
- keep_qids = valid_qids[valid_qids > 1].index
370
- out = out[out["qid"].isin(keep_qids)].reset_index(drop=True)
371
-
372
- os.makedirs(base_dir, exist_ok=True)
373
- out_path = os.path.join(base_dir, "user_features_rank.csv")
374
- out.to_csv(out_path, index=False)
375
- print(f"[cold_start] Saved {len(out)} rows to {out_path}")
376
- return out_path
377
-
378
-
379
- if __name__ == "__main__":
380
- cold_start_ranker(
381
- user_id="user_1",
382
- n_rounds=10000,
383
- topn_coarse=20000,
384
- topk_rule=5,
385
- coverage_penalty=0.15,
386
- temperature=0.5
387
  )
 
1
+ import os
2
+ import ast
3
+ import json
4
+ import random
5
+ import pandas as pd
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import warnings
9
+
10
+ from .candidate import coarse_rank_candidates, hard_filter, rule_generate_candidates
11
+ from .feature import build_features
12
+ from .io import load_recipes_csv, load_ingredient_map
13
+
14
+ RECIPES_PATH = load_recipes_csv()
15
+ INGREDIENT_MAP = load_ingredient_map()
16
+ PARENTS = INGREDIENT_MAP["parents"]
17
+ CHILDREN = INGREDIENT_MAP["children"]
18
+
19
+ def parse_list(x):
20
+ """Convert a stringified list into a Python list safely."""
21
+ if pd.isna(x) or x == "":
22
+ return []
23
+ if isinstance(x, list):
24
+ return x
25
+ try:
26
+ return ast.literal_eval(x)
27
+ except Exception:
28
+ return []
29
+
30
+ def parse_set(x):
31
+ """Convert a stringified collection into a Python set safely."""
32
+ if pd.isna(x) or x == "":
33
+ return set()
34
+ if isinstance(x, set):
35
+ return x
36
+ if isinstance(x, (list, tuple)):
37
+ return set(x)
38
+ if isinstance(x, str):
39
+ try:
40
+ v = ast.literal_eval(x)
41
+ if isinstance(v, (list, tuple, set)):
42
+ return set(v)
43
+ return {v}
44
+ except Exception:
45
+ return {x.strip()}
46
+ return {x}
47
+
48
+ def _parents_pool_from_df(df: pd.DataFrame):
49
+ cols = ["main_parent", "staple_parent", "other_parent", "seasoning_parent"]
50
+ pool = set()
51
+ for c in cols:
52
+ if c in df.columns:
53
+ for s in df[c]:
54
+ pool |= set(s) if isinstance(s, (set, list, tuple)) else set()
55
+ return sorted(pool)
56
+
57
+
58
+ def sample_user_parents(parents_pool,
59
+ user_profile=None,
60
+ prev_inventory=None,
61
+ min_items=3, max_items=10,
62
+ keep_ratio=0.6, reset_interval=20, round_idx=0):
63
+ liked = set((user_profile or {}).get("other_preferences", {}).get("preferred_main", []))
64
+ disliked = set((user_profile or {}).get("other_preferences", {}).get("disliked_main", []))
65
+ forbidden = set((user_profile or {}).get("forbidden_parents", [])) | disliked
66
+
67
+ pool, weights = [], []
68
+ for p in parents_pool:
69
+ if p in forbidden:
70
+ continue
71
+ w = 3.0 if p in liked else 1.0
72
+ pool.append(p); weights.append(w)
73
+ if not pool:
74
+ pool, weights = parents_pool[:], [1.0] * len(parents_pool)
75
+
76
+ inventory = set()
77
+ force_reset = (round_idx % reset_interval == 0)
78
+ if prev_inventory and not force_reset:
79
+ prev_list = list(prev_inventory); random.shuffle(prev_list)
80
+ keep_k = max(0, int(len(prev_list) * keep_ratio))
81
+ inventory |= set(prev_list[:keep_k])
82
+
83
+ k = random.randint(min_items, max_items)
84
+ remain = max(0, k - len(inventory))
85
+ for _ in range(min(remain, len(pool))):
86
+ idx = random.choices(range(len(pool)), weights=weights, k=1)[0]
87
+ inventory.add(pool[idx])
88
+ return list(inventory)
89
+
90
+
91
+ def _weighted_pick3(indexes, scores, temperature=1.0):
92
+ idxs = list(indexes)
93
+ scs = np.array(scores, dtype=float)
94
+ if np.any(scs < 0):
95
+ scs = scs - scs.min()
96
+ if scs.sum() == 0:
97
+ scs = np.ones_like(scs)
98
+ picks = []
99
+ for _ in range(min(3, len(idxs))):
100
+ probs = np.exp(scs / max(temperature, 1e-6))
101
+ probs = probs / probs.sum()
102
+ choice = np.random.choice(len(idxs), p=probs)
103
+ picks.append(idxs[choice])
104
+ idxs.pop(choice)
105
+ scs = np.delete(scs, choice)
106
+ if len(idxs) == 0:
107
+ break
108
+ return picks
109
+
110
+
111
+ # ---------- Main cold-start ----------
112
+ # ---------- Main cold-start ----------
113
+ def cold_start_ranker(user_id: str,
114
+ n_rounds: int = 1000,
115
+ topn_coarse: int = 5000,
116
+ topk_rule: int = 3,
117
+ batch_size: int = 5000,
118
+ switch_interval: int = 100):
119
+ """
120
+ Cold-start data generation for learning-to-rank.
121
+ Top-5 selection prioritizes user pantry coverage deterministically:
122
+ 1. Fully covered recipes first (missing_count == 0)
123
+ 2. Then few missing (esp. staple/other)
124
+ 3. Heavy penalty for missing main ingredients.
125
+ """
126
+
127
+ base_dir = os.path.join("recipe_recommendation", "user_data", user_id)
128
+ if not os.path.exists(base_dir):
129
+ base_dir = os.path.join("recipe_recommendation", "input_user_data", user_id)
130
+
131
+ if not os.path.exists(base_dir):
132
+ raise FileNotFoundError(
133
+ f"❌ User profile not found for '{user_id}' in either 'recipe_recommendation/user_data' or 'recipe_recommendation/input_user_data'."
134
+ )
135
+
136
+ print(f"[cold_start_ranker] Using base_dir = {base_dir}")
137
+
138
+ profile_path = os.path.join(base_dir, "user_profile.json")
139
+ features_path = os.path.join(base_dir, "user_features_rank.csv")
140
+
141
+ if os.path.exists(features_path):
142
+ print(f"[cold_start] Features already exist at {features_path}")
143
+ return features_path
144
+
145
+ with open(profile_path, "r", encoding="utf-8") as f:
146
+ user_profile = json.load(f)
147
+
148
+ # Load and parse recipes
149
+ df_all = pd.read_csv(RECIPES_PATH)
150
+ to_set = ["main_parent", "staple_parent", "other_parent", "seasoning_parent", "cuisine_attr"]
151
+ to_list = ["ingredients"]
152
+ for c in to_set:
153
+ if c in df_all.columns:
154
+ df_all[c] = df_all[c].apply(parse_set)
155
+ for c in to_list:
156
+ if c in df_all.columns:
157
+ df_all[c] = df_all[c].apply(parse_list)
158
+
159
+ # Step 1 hard filter
160
+ if hard_filter is not None:
161
+ try:
162
+ before = len(df_all)
163
+ mask = df_all.apply(lambda r: hard_filter(r.to_dict(), user_profile), axis=1)
164
+ df_all = df_all[mask]
165
+ after = len(df_all)
166
+ print(f"[cold_start] Step1 hard filter applied: {before} -> {after}")
167
+ except Exception as e:
168
+ warnings.warn(f"[cold_start] hard_filter failed, skip. err={e}")
169
+
170
+ n_chunks = (len(df_all) // batch_size) + 1
171
+ chunks = np.array_split(df_all, n_chunks)
172
+ parents_pool = _parents_pool_from_df(df_all)
173
+ rows = []
174
+ prev_inventory = None
175
+
176
+ for i in tqdm(range(n_rounds), desc="Cold-start rounds"):
177
+ chunk_id = (i // switch_interval) % n_chunks
178
+ df_chunk = chunks[chunk_id].copy()
179
+
180
+ # pantry sampling
181
+ user_parents = sample_user_parents(
182
+ parents_pool,
183
+ user_profile=user_profile,
184
+ prev_inventory=prev_inventory,
185
+ round_idx=i
186
+ )
187
+ prev_inventory = user_parents
188
+
189
+ # Step 2: coarse recall
190
+ coarse_list = coarse_rank_candidates(
191
+ recipes=df_chunk.to_dict(orient="records"),
192
+ user_parents=user_parents,
193
+ user_profile=user_profile,
194
+ top_n=min(topn_coarse, len(df_chunk))
195
+ )
196
+ if not coarse_list:
197
+ continue
198
+
199
+ coarse_df = pd.DataFrame(coarse_list)
200
+
201
+ # Step 3: rule rerank → Top-5 candidates (just for selecting the 5)
202
+ rule_df = rule_generate_candidates(
203
+ coarse_df,
204
+ user_parents=user_parents,
205
+ user_profile=user_profile
206
+ )
207
+ if rule_df.empty or len(rule_df) < topk_rule:
208
+ continue
209
+
210
+ top5 = rule_df.head(topk_rule).copy()
211
+
212
+ # ===== Deterministic scoring with feasibility + region + soft constraints =====
213
+ user_set = set(user_parents)
214
+ scored_candidates = []
215
+
216
+ # Nutrition goals (from profile)
217
+ ng = user_profile.get("nutritional_goals", {})
218
+ cal_min = ng.get("calories", {}).get("min", 0)
219
+ cal_max = ng.get("calories", {}).get("max", 1e9)
220
+ pro_min = ng.get("protein", {}).get("min", 0)
221
+ pro_max = ng.get("protein", {}).get("max", 1e9)
222
+
223
+ # Preferences
224
+ liked = set(user_profile.get("other_preferences", {}).get("preferred_main", []))
225
+ disliked = set(user_profile.get("other_preferences", {}).get("disliked_main", []))
226
+ max_cooking_time = user_profile.get("other_preferences", {}).get("cooking_time_max", None)
227
+
228
+ for idx, row in top5.iterrows():
229
+ main_set = set(row.get("main_parent", set()))
230
+ staple_set = set(row.get("staple_parent", set()))
231
+ other_set = set(row.get("other_parent", set()))
232
+
233
+ main_total = len(main_set)
234
+ staple_total = len(staple_set)
235
+ main_match = len(main_set & user_set)
236
+ staple_match = len(staple_set & user_set)
237
+
238
+ # === 1) Feasibility check ===
239
+ total_needed = max(1, main_total + staple_total)
240
+ total_have = main_match + staple_match
241
+ coverage_ratio = total_have / total_needed
242
+
243
+ if coverage_ratio < 0.5:
244
+ continue
245
+
246
+ # === 2) Region preference ===
247
+ region_score = 1.0 if row.get("region_match", 0) else 0.0
248
+
249
+ # === 3) Cooking time soft constraint ===
250
+ time_val = row.get("minutes", None)
251
+ time_score = 0.0
252
+ if max_cooking_time and time_val is not None:
253
+ try:
254
+ t_val = float(time_val)
255
+ t_max = float(max_cooking_time)
256
+ lower_bound = 0.8 * t_max
257
+ upper_bound = 1.2 * t_max
258
+ if lower_bound <= t_val <= upper_bound:
259
+ time_score = 1.0
260
+ else:
261
+ deviation = abs(t_val - t_max) / t_max
262
+ time_score = max(0.0, 1.0 - deviation)
263
+ except (TypeError, ValueError):
264
+ time_score = 0.0
265
+ else:
266
+ time_score = 1.0
267
+
268
+ # === 4) Calories soft constraint ===
269
+ cal_val = row.get("calories", None)
270
+ cal_score = 1.0
271
+ if cal_val is not None and cal_min < cal_max:
272
+ try:
273
+ c_val = float(cal_val)
274
+ cal_center = 0.5 * (cal_min + cal_max)
275
+ tol = 0.3 * cal_center
276
+ lower_bound = cal_center - tol
277
+ upper_bound = cal_center + tol
278
+ if lower_bound <= c_val <= upper_bound:
279
+ cal_score = 1.0
280
+ else:
281
+ deviation = abs(c_val - cal_center) / cal_center
282
+ cal_score = max(0.0, 1.0 - deviation)
283
+ except (TypeError, ValueError):
284
+ cal_score = 0.0
285
+
286
+ # === 4b) Protein soft constraint ===
287
+ protein_val = row.get("protein", None)
288
+ protein_score = 1.0
289
+ if protein_val is not None and pro_min < pro_max:
290
+ try:
291
+ p_val = float(protein_val)
292
+ pro_center = 0.5 * (pro_min + pro_max)
293
+ tol = 0.2 * pro_center
294
+ lower_bound = pro_center - tol
295
+ upper_bound = pro_center + tol
296
+ if lower_bound <= p_val <= upper_bound:
297
+ protein_score = 1.0
298
+ else:
299
+ deviation = abs(p_val - pro_center) / pro_center
300
+ protein_score = max(0.0, 1.0 - deviation)
301
+ except (TypeError, ValueError):
302
+ protein_score = 0.0
303
+
304
+ # === 5) Liked / Disliked main ===
305
+ like_bonus = 1.0 if main_set & liked else 0.0
306
+ dislike_penalty = 1.0 if main_set & disliked else 0.0
307
+
308
+ # === 6) Final scoring ===
309
+ score = (
310
+ 0.5 * coverage_ratio +
311
+ 0.15 * region_score +
312
+ 0.1 * time_score +
313
+ 0.1 * cal_score +
314
+ 0.05 * protein_score +
315
+ 0.05 * like_bonus -
316
+ 0.05 * dislike_penalty
317
+ )
318
+
319
+ scored_candidates.append((idx, score))
320
+
321
+ # Sort and pick top3 for relevance
322
+ scored_candidates.sort(key=lambda x: x[1], reverse=True)
323
+ picked_idxs = [idx for idx, _ in scored_candidates[:3]]
324
+
325
+ # relevance labels 3 / 2 / 1
326
+ labels = {idx: 0 for idx in top5.index}
327
+ if len(picked_idxs) > 0:
328
+ labels[picked_idxs[0]] = 3
329
+ if len(picked_idxs) > 1:
330
+ labels[picked_idxs[1]] = 2
331
+ if len(picked_idxs) > 2:
332
+ labels[picked_idxs[2]] = 1
333
+
334
+ # build features for all 5 candidates
335
+ for idx, row in top5.iterrows():
336
+ up = set(user_parents)
337
+ main_set = set(row.get("main_parent", set()))
338
+ staple_set = set(row.get("staple_parent", set()))
339
+ other_set = set(row.get("other_parent", set()))
340
+
341
+ recipe_dict = {
342
+ "main": main_set,
343
+ "staple": staple_set,
344
+ "other": other_set,
345
+ "seasoning": set(row.get("seasoning_parent", set())),
346
+ "matched_main": len(main_set & up),
347
+ "matched_staple": len(staple_set & up),
348
+ "matched_other": len(other_set & up),
349
+ "calories": row.get("calories", 0),
350
+ "protein": row.get("protein", 0),
351
+ "fat": row.get("fat", 0),
352
+ "region": row.get("region", ""),
353
+ "cuisine_attr": row.get("cuisine_attr", []),
354
+ "ingredients": row.get("ingredients", []),
355
+ "minutes": row.get("minutes", None),
356
+ }
357
+
358
+ feats = build_features(recipe_dict, user_profile)
359
+ feats["relevance"] = float(labels[idx])
360
+ feats["qid"] = int(i)
361
+ rows.append(feats)
362
+
363
+ out = pd.DataFrame(rows)
364
+ if "qid" not in out.columns or out.empty:
365
+ print(f"[cold_start] No valid training data generated for {user_id}, skipping save.")
366
+ return None
367
+
368
+ valid_qids = out.groupby("qid").size()
369
+ keep_qids = valid_qids[valid_qids > 1].index
370
+ out = out[out["qid"].isin(keep_qids)].reset_index(drop=True)
371
+
372
+ os.makedirs(base_dir, exist_ok=True)
373
+ out_path = os.path.join(base_dir, "user_features_rank.csv")
374
+ out.to_csv(out_path, index=False)
375
+ print(f"[cold_start] Saved {len(out)} rows to {out_path}")
376
+ return out_path
377
+
378
+
379
+ if __name__ == "__main__":
380
+ cold_start_ranker(
381
+ user_id="user_1",
382
+ n_rounds=10000,
383
+ topn_coarse=20000,
384
+ topk_rule=5,
385
+ coverage_penalty=0.15,
386
+ temperature=0.5
387
  )