Danielos100 commited on
Commit
228df34
·
verified ·
1 Parent(s): 640b232

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -152
app.py CHANGED
@@ -1,11 +1,13 @@
1
- # app.py
2
- # 🎁 GIfty — Smart Gift Recommender (Embeddings + FAISS + LLM + Image Gen)
3
  # Data: ckandemir/amazon-products
4
- # Retrieval: MiniLM embeddings + FAISS (cosine)
5
- # Generation: Flan-T5-small (text), SD-Turbo (image)
6
- # UI: Gradio; Quick Examples on top; Budget range: RangeSlider if present, else two sliders
 
 
7
 
8
- import os, re, json, random
9
  from typing import Dict, List, Tuple
10
 
11
  import numpy as np
@@ -16,23 +18,37 @@ from datasets import load_dataset
16
  from sentence_transformers import SentenceTransformer
17
  import faiss
18
 
19
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
20
 
21
  import torch
22
  from diffusers import AutoPipelineForText2Image
23
 
24
  # --------------------- Config ---------------------
25
  MAX_ROWS = int(os.getenv("MAX_ROWS", "8000"))
26
- TITLE = "# 🎁 GIfty — Smart Gift Recommender\n*Top-3 similar picks + 1 invented gift (with image) + personalized message*"
27
 
28
- # ===== Updated Interests (exact) =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  INTEREST_OPTIONS = [
30
  "Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion",
31
  "Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food",
32
  "Home decor","Science"
33
  ]
34
 
35
- # ===== Updated Occasions (exact) =====
36
  OCCASION_UI = [
37
  "Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming",
38
  "Retirement","Holidays","Valentine’s Day","Promotion / New job","Get well soon"
@@ -52,7 +68,6 @@ OCCASION_CANON = {
52
  "Get well soon":"get_well"
53
  }
54
 
55
- # ===== Updated Relationship & Tone =====
56
  RECIPIENT_RELATIONSHIPS = [
57
  "Family - Parent",
58
  "Family - Sibling",
@@ -68,15 +83,7 @@ RECIPIENT_RELATIONSHIPS = [
68
  ]
69
 
70
  MESSAGE_TONES = [
71
- "Formal",
72
- "Casual",
73
- "Funny",
74
- "Heartfelt",
75
- "Inspirational",
76
- "Playful",
77
- "Romantic",
78
- "Appreciative",
79
- "Encouraging",
80
  ]
81
 
82
  AGE_OPTIONS = {
@@ -211,9 +218,13 @@ def load_catalog() -> pd.DataFrame:
211
  ],
212
  "Category": ["Electronics | Audio","Grocery | Coffee","Toys & Games | Board Games"],
213
  "Selling Price": ["$59.00","$34.00","$39.00"],
214
- "Image": ["","",""],
215
  })
216
  df = map_amazon_to_schema(raw).drop_duplicates(subset=["name","short_desc"])
 
 
 
 
217
  if len(df) > MAX_ROWS:
218
  df = df.sample(n=MAX_ROWS, random_state=42).reset_index(drop=True)
219
  df["doc"] = df.apply(build_doc, axis=1)
@@ -221,38 +232,43 @@ def load_catalog() -> pd.DataFrame:
221
 
222
  CATALOG = load_catalog()
223
 
224
- # --------------------- Business filters ---------------------
225
- def _contains_ci(series: pd.Series, needle: str) -> pd.Series:
226
- if not needle: return pd.Series(True, index=series.index)
227
- return series.fillna("").str.contains(re.escape(needle), case=False, regex=True)
228
-
229
- def filter_business(df: pd.DataFrame, budget_min=None, budget_max=None,
230
- occasion_canon: str=None, age_range: str="any") -> pd.DataFrame:
231
- m = pd.Series(True, index=df.index)
232
- if budget_min is not None:
233
- m &= df["price_usd"].fillna(0) >= float(budget_min)
234
- if budget_max is not None:
235
- m &= df["price_usd"].fillna(1e9) <= float(budget_max)
236
- if occasion_canon:
237
- m &= _contains_ci(df["occasion_tags"], occasion_canon)
238
- if age_range and age_range != "any":
239
- m &= (df["age_range"].fillna("any").isin([age_range, "any"]))
240
- return df[m]
241
-
242
- # --------------------- Embeddings + FAISS ---------------------
243
  class EmbeddingIndex:
244
  def __init__(self, docs: List[str], model_id: str):
 
245
  self.model = SentenceTransformer(model_id)
246
- embs = self.model.encode(docs, convert_to_numpy=True, normalize_embeddings=True)
247
- self.index = faiss.IndexFlatIP(embs.shape[1]) # cosine via normalized vectors
248
- self.index.add(embs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  def search(self, query: str, topn: int):
251
  qv = self.model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
252
  sims, idxs = self.index.search(qv, topn)
253
  return sims[0], idxs[0]
254
 
255
- EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # fast & solid on CPU
256
  EMB_INDEX = EmbeddingIndex(CATALOG["doc"].tolist(), EMBED_MODEL_ID)
257
 
258
  # --------------------- Query building ---------------------
@@ -281,9 +297,26 @@ def profile_to_query(profile: Dict) -> str:
281
  if g != "any": parts.append("women" if g=="female" else ("men" if g=="male" else "unisex"))
282
  return " | ".join(parts)
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
285
  query = profile_to_query(profile)
286
- sims, idxs = EMBB_INDEX.search(query, topn=min(max(k*80, k), len(CATALOG))) if False else EMB_INDEX.search(query, topn=min(max(k*80, k), len(CATALOG)))
287
  df_f = filter_business(
288
  CATALOG,
289
  budget_min=profile.get("budget_min"),
@@ -292,6 +325,7 @@ def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
292
  age_range=profile.get("age_range","any"),
293
  )
294
  if df_f.empty: df_f = CATALOG
 
295
 
296
  # soft gender boost
297
  def gender_tokens(g: str) -> List[str]:
@@ -305,7 +339,7 @@ def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
305
  cand = []
306
  for i, sim in zip(idxs, sims):
307
  i = int(i)
308
- if i in df_f.index:
309
  blob = f"{CATALOG.loc[i,'tags']} {CATALOG.loc[i,'short_desc']}".lower()
310
  boost = 0.08 if any(t in blob for t in gts) else 0.0
311
  cand.append((i, float(sim) + boost))
@@ -329,95 +363,192 @@ def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
329
  res["similarity"] = [dict(picks).get(int(i), np.nan) for i in sel]
330
  return res[["name","short_desc","price_usd","occasion_tags","persona_fit","age_range","image_url","similarity"]]
331
 
332
- # --------------------- LLM (text) ---------------------
333
- LLM_ID = "google/flan-t5-small"
 
 
 
 
 
 
 
 
 
 
 
334
  try:
335
- _tok = AutoTokenizer.from_pretrained(LLM_ID)
336
- _mdl = AutoModelForSeq2SeqLM.from_pretrained(LLM_ID)
337
- LLM = pipeline("text2text-generation", model=_mdl, tokenizer=_tok)
338
  except Exception as e:
339
- LLM = None
340
- print("LLM load failed, fallback to rule-based. Error:", e)
341
 
342
- def _run_llm(prompt: str, max_new_tokens=160) -> str:
343
- if LLM is None: return ""
344
- out = LLM(prompt, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0)
345
- return out[0]["generated_text"]
346
 
347
- def _parse_json_maybe(s: str) -> dict:
348
- try:
349
- return json.loads(s)
350
- except Exception:
351
- m = re.search(r"\{.*\}", s, flags=re.S)
352
- if m:
353
- try: return json.loads(m.group(0))
354
- except Exception: return {}
355
- return {}
356
-
357
- def llm_generate_item(profile: Dict) -> Dict:
358
- prompt = f"""
359
- You are GIfty. Invent ONE gift that matches the catalog style with keys:
360
- name, short_desc, price_usd, occasion_tags, persona_fit. Use JSON only.
361
- Constraints:
362
- - Fit the recipient profile and relationship.
363
- - price_usd must be numeric within the budget range.
364
- Profile:
365
- name={profile.get('recipient_name','Friend')}
366
- relationship={profile.get('relationship','Friend')}
367
- gender={profile.get('gender','any')}
368
- age_group={profile.get('age_range','any')}
369
- interests={profile.get('interests',[])}
370
- occasion={profile.get('occ_ui','Birthday')}
371
- budget_min={profile.get('budget_min',10)}
372
- budget_max={profile.get('budget_max',100)}
373
- """
374
- txt = _run_llm(prompt, max_new_tokens=180)
375
- data = _parse_json_maybe(txt)
376
- if not data:
377
- core = (profile.get("interests",["hobby"])[0] or "hobby").lower()
378
- return {
379
- "name": f"{core.title()} starter bundle ({profile.get('occ_ui','Birthday')})",
380
- "short_desc": f"A curated set to kickstart their {core} passion.",
381
- "price_usd": float(np.clip(profile.get("budget_max", 50) or 50, 10, 300)),
382
- "occasion_tags": OCCASION_CANON.get(profile.get("occ_ui","Birthday"), "birthday"),
383
- "persona_fit": ", ".join(profile.get("interests", [])) or "general",
384
- "age_range": profile.get("age_range","any"),
385
- "image_url": ""
386
- }
387
  try:
388
- p = float(data.get("price_usd", profile.get("budget_max", 50)))
389
  except Exception:
390
- p = float(profile.get("budget_max", 50) or 50)
391
- p = float(np.clip(p, profile.get("budget_min", 10) or 10, profile.get("budget_max", 300) or 300))
392
- return {
393
- "name": data.get("name","Gift Idea"),
394
- "short_desc": data.get("short_desc","A thoughtful idea."),
395
- "price_usd": p,
396
- "occasion_tags": data.get("occasion_tags", OCCASION_CANON.get(profile.get("occ_ui","Birthday"), "birthday")),
397
- "persona_fit": data.get("persona_fit", ", ".join(profile.get("interests", [])) or "general"),
398
- "age_range": profile.get("age_range","any"),
399
- "image_url": ""
400
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
- def llm_generate_message(profile: Dict) -> str:
403
- prompt = f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  Write a short greeting (2–3 sentences) in English for a gift card.
405
- Tone: {profile.get('tone','Heartfelt')}
406
- Use the relationship to set warmth/formality.
407
- Recipient: {profile.get('recipient_name','Friend')} ({profile.get('relationship','Friend')})
408
- Occasion: {profile.get('occ_ui','Birthday')}
409
- Interests: {', '.join(profile.get('interests', []))}
410
- Age group: {profile.get('age_range','any')}; Gender: {profile.get('gender','any')}
411
  Avoid emojis.
412
  """
413
- txt = _run_llm(prompt, max_new_tokens=90)
414
- if not txt:
415
- return (f"Dear {profile.get('recipient_name','Friend')}, "
416
- f"happy {profile.get('occ_ui','Birthday').lower()}! Wishing you joy and wonderful memories.")
417
- return txt.strip()
 
 
 
 
 
418
 
419
  # --------------------- Image generation (SD-Turbo) ---------------------
 
420
  def load_image_pipeline():
 
 
421
  try:
422
  device = "cuda" if torch.cuda.is_available() else "cpu"
423
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@@ -430,12 +561,14 @@ def load_image_pipeline():
430
 
431
  IMG_PIPE = load_image_pipeline()
432
 
433
- def generate_gift_image(gift: Dict):
434
- if IMG_PIPE is None:
 
435
  return None
 
 
436
  prompt = (
437
- f"{gift.get('name','gift')}, {gift.get('short_desc','')}. "
438
- f"Style: product photo, soft studio lighting, minimal background, realistic, high detail."
439
  )
440
  try:
441
  img = IMG_PIPE(
@@ -450,6 +583,7 @@ def generate_gift_image(gift: Dict):
450
  return None
451
 
452
  # --------------------- Rendering ---------------------
 
453
  def md_escape(text: str) -> str:
454
  return str(text).replace("|","\\|").replace("*","\\*").replace("_","\\_")
455
 
@@ -482,6 +616,32 @@ def render_top3_html(df: pd.DataFrame) -> str:
482
  rows.append(card)
483
  return "\n".join(rows)
484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  # --------------------- Gradio UI ---------------------
486
  CSS = """
487
  #examples { order: 1; }
@@ -491,62 +651,61 @@ CSS = """
491
  with gr.Blocks(css=CSS) as demo:
492
  gr.Markdown(TITLE)
493
 
494
- # top section (examples placeholder)
495
  with gr.Column(elem_id="examples"):
496
  gr.Markdown("### Quick examples")
497
 
498
  with gr.Column(elem_id="form"):
499
  with gr.Row():
500
- recipient_name = gr.Textbox(label="Recipient name", value="Noa")
501
- relationship = gr.Dropdown(label="Relationship", choices=RECIPIENT_RELATIONSHIPS, value="Friend")
502
 
503
  with gr.Row():
504
  interests = gr.CheckboxGroup(
505
  label="Interests (select a few)", choices=INTEREST_OPTIONS,
506
- value=["Technology","Music"], interactive=True
507
  )
508
 
509
  with gr.Row():
510
- occasion = gr.Dropdown(label="Occasion", choices=OCCASION_UI, value="Birthday")
511
  age = gr.Dropdown(label="Age group", choices=list(AGE_OPTIONS.keys()), value="adult (18–64)")
512
- gender = gr.Dropdown(label="Recipient gender", choices=GENDER_OPTIONS, value="any")
513
 
514
- # Budget: try RangeSlider else two sliders
515
  RangeSlider = getattr(gr, "RangeSlider", None)
516
  if RangeSlider is not None:
517
- budget_range = RangeSlider(label="Budget range (USD)", minimum=5, maximum=500, step=1, value=[20, 60])
518
  budget_min, budget_max = None, None
519
  else:
520
  with gr.Row():
521
- budget_min = gr.Slider(label="Min budget (USD)", minimum=5, maximum=500, step=1, value=20)
522
  budget_max = gr.Slider(label="Max budget (USD)", minimum=5, maximum=500, step=1, value=60)
523
  budget_range = gr.State(value=None)
524
 
525
- tone = gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="Heartfelt")
526
 
527
  go = gr.Button("Get GIfty 🎯")
528
 
529
  out_top3 = gr.HTML(label="Top-3 recommendations")
530
- out_gen_text = gr.Markdown(label="Invented gift")
531
- out_gen_img = gr.Image(label="Invented gift image", type="pil")
532
- out_msg = gr.Markdown(label="Personalized message")
 
533
 
534
  # examples (render on top via CSS)
535
  if RangeSlider:
536
  example_inputs = [interests, occasion, budget_range, recipient_name, relationship, age, gender, tone]
537
  EXAMPLES = [
538
- [["Technology","Music"], "Birthday", [20,60], "Noa", "Friend", "adult (18–64)", "any", "Heartfelt"],
539
- [["Home decor","Cooking"], "Housewarming", [25,45], "Daniel", "Neighbor", "adult (18–64)", "male", "Appreciative"],
540
  [["Gaming","Photography"], "Birthday", [30,120], "Omer", "Family - Sibling", "teen (13–17)", "male", "Playful"],
541
- [["Reading","Art"], "Graduation", [15,35], "Maya", "Romantic partner", "any", "female", "Romantic"],
542
  ]
543
  else:
544
  example_inputs = [interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone]
545
  EXAMPLES = [
546
- [["Technology","Music"], "Birthday", 20, 60, "Noa", "Friend", "adult (18–64)", "any", "Heartfelt"],
547
- [["Home decor","Cooking"], "Housewarming", 25, 45, "Daniel", "Neighbor", "adult (18–64)", "male", "Appreciative"],
548
  [["Gaming","Photography"], "Birthday", 30, 120, "Omer", "Family - Sibling", "teen (13–17)", "male", "Playful"],
549
- [["Reading","Art"], "Graduation", 15, 35, "Maya", "Romantic partner", "any", "female", "Romantic"],
550
  ]
551
 
552
  with gr.Column(elem_id="examples"):
@@ -601,27 +760,29 @@ with gr.Blocks(css=CSS) as demo:
601
  top3 = recommend_topk(profile, k=3)
602
  top3_html = render_top3_html(top3)
603
 
604
- # invented gift + image
605
- gen = llm_generate_item(profile)
606
- gen_md = f"**{md_escape(gen['name'])}**\n\n{md_escape(gen['short_desc'])}\n\n~${gen['price_usd']:.0f}"
607
- gen_img = generate_gift_image(gen)
 
 
608
 
609
  # greeting
610
  msg = llm_generate_message(profile)
611
 
612
- return top3_html, gen_md, gen_img, msg
613
 
614
  if RangeSlider:
615
  go.click(
616
  ui_predict,
617
  [interests, occasion, budget_range, recipient_name, relationship, age, gender, tone],
618
- [out_top3, out_gen_text, out_gen_img, out_msg]
619
  )
620
  else:
621
  go.click(
622
  ui_predict,
623
  [interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
624
- [out_top3, out_gen_text, out_gen_img, out_msg]
625
  )
626
 
627
  if __name__ == "__main__":
 
1
+ # app.py — Gifty (revised)
2
+ # 🎁 GIfty — Smart Gift Recommender
3
  # Data: ckandemir/amazon-products
4
+ # Retrieval: MiniLM-L12-v2 embeddings + FAISS (cosine), with simple on-disk cache
5
+ # DIY Generation: small instruct LMs via HF pipeline (default: flan-t5-small) with JSON validate+repair (no padding)
6
+ # Greeting: short LLM completion
7
+ # Image: SD-Turbo (optional)
8
+ # UI: Gradio; Quick Examples; Budget RangeSlider; DIY JSON + readable card
9
 
10
+ import os, re, json, random, hashlib, pathlib
11
  from typing import Dict, List, Tuple
12
 
13
  import numpy as np
 
18
  from sentence_transformers import SentenceTransformer
19
  import faiss
20
 
21
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline
22
 
23
  import torch
24
  from diffusers import AutoPipelineForText2Image
25
 
26
  # --------------------- Config ---------------------
27
  MAX_ROWS = int(os.getenv("MAX_ROWS", "8000"))
28
+ TITLE = "# 🎁 GIfty — Smart Gift Recommender\n*Top-3 catalog picks + 1 DIY gift (JSON) + personalized message*"
29
 
30
+ # Retrieval model (embedding)
31
+ EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "sentence-transformers/all-MiniLM-L12-v2")
32
+ EMBED_CACHE_DIR = os.getenv("EMBED_CACHE_DIR", "./.gifty_cache")
33
+ pathlib.Path(EMBED_CACHE_DIR).mkdir(parents=True, exist_ok=True)
34
+
35
+ # DIY generation model (text)
36
+ GEN_MODEL_ID = os.getenv("GEN_MODEL_ID", "google/flan-t5-small")
37
+ OUTPUT_LANG = os.getenv("OUTPUT_LANG", "en") # "en" or "he"
38
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "360"))
39
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "260"))
40
+ DIY_MAX_ATTEMPTS = int(os.getenv("DIY_MAX_ATTEMPTS", "4"))
41
+
42
+ # Image gen toggle
43
+ ENABLE_IMAGE = os.getenv("ENABLE_IMAGE", "1") == "1"
44
+
45
+ # ===== UI options =====
46
  INTEREST_OPTIONS = [
47
  "Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion",
48
  "Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food",
49
  "Home decor","Science"
50
  ]
51
 
 
52
  OCCASION_UI = [
53
  "Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming",
54
  "Retirement","Holidays","Valentine’s Day","Promotion / New job","Get well soon"
 
68
  "Get well soon":"get_well"
69
  }
70
 
 
71
  RECIPIENT_RELATIONSHIPS = [
72
  "Family - Parent",
73
  "Family - Sibling",
 
83
  ]
84
 
85
  MESSAGE_TONES = [
86
+ "Formal","Casual","Funny","Heartfelt","Inspirational","Playful","Romantic","Appreciative","Encouraging",
 
 
 
 
 
 
 
 
87
  ]
88
 
89
  AGE_OPTIONS = {
 
218
  ],
219
  "Category": ["Electronics | Audio","Grocery | Coffee","Toys & Games | Board Games"],
220
  "Selling Price": ["$59.00","$34.00","$39.00"],
221
+ "Image": ["","",""]
222
  })
223
  df = map_amazon_to_schema(raw).drop_duplicates(subset=["name","short_desc"])
224
+ # EDA cleanups: drop missing price, cap to <= 500
225
+ df = df[pd.notna(df["price_usd"])].copy()
226
+ df = df[df["price_usd"] <= 500].reset_index(drop=True)
227
+ # limit rows
228
  if len(df) > MAX_ROWS:
229
  df = df.sample(n=MAX_ROWS, random_state=42).reset_index(drop=True)
230
  df["doc"] = df.apply(build_doc, axis=1)
 
232
 
233
  CATALOG = load_catalog()
234
 
235
+ # --------------------- Embeddings + FAISS (with simple cache) ---------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  class EmbeddingIndex:
237
  def __init__(self, docs: List[str], model_id: str):
238
+ self.model_id = model_id
239
  self.model = SentenceTransformer(model_id)
240
+ self.embs = self._load_or_build(docs)
241
+ self.index = faiss.IndexFlatIP(self.embs.shape[1]) # cosine via normalized vectors
242
+ self.index.add(self.embs)
243
+
244
+ def _cache_paths(self, n_docs: int) -> Tuple[str, str]:
245
+ h = hashlib.md5((self.model_id + f"|{n_docs}").encode()).hexdigest()[:10]
246
+ npy = os.path.join(EMBED_CACHE_DIR, f"emb_{h}.npy")
247
+ idx = os.path.join(EMBED_CACHE_DIR, f"faiss_{h}.index")
248
+ return npy, idx
249
+
250
+ def _load_or_build(self, docs: List[str]) -> np.ndarray:
251
+ npy_path, _ = self._cache_paths(len(docs))
252
+ if os.path.exists(npy_path):
253
+ try:
254
+ embs = np.load(npy_path)
255
+ if embs.shape[0] == len(docs):
256
+ return embs
257
+ except Exception:
258
+ pass
259
+ # build
260
+ embs = self.model.encode(docs, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True)
261
+ try:
262
+ np.save(npy_path, embs)
263
+ except Exception:
264
+ pass
265
+ return embs
266
 
267
  def search(self, query: str, topn: int):
268
  qv = self.model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
269
  sims, idxs = self.index.search(qv, topn)
270
  return sims[0], idxs[0]
271
 
 
272
  EMB_INDEX = EmbeddingIndex(CATALOG["doc"].tolist(), EMBED_MODEL_ID)
273
 
274
  # --------------------- Query building ---------------------
 
297
  if g != "any": parts.append("women" if g=="female" else ("men" if g=="male" else "unisex"))
298
  return " | ".join(parts)
299
 
300
+ def _contains_ci(series: pd.Series, needle: str) -> pd.Series:
301
+ if not needle: return pd.Series(True, index=series.index)
302
+ return series.fillna("").str.contains(re.escape(needle), case=False, regex=True)
303
+
304
+ def filter_business(df: pd.DataFrame, budget_min=None, budget_max=None,
305
+ occasion_canon: str=None, age_range: str="any") -> pd.DataFrame:
306
+ m = pd.Series(True, index=df.index)
307
+ if budget_min is not None:
308
+ m &= df["price_usd"].fillna(0) >= float(budget_min)
309
+ if budget_max is not None:
310
+ m &= df["price_usd"].fillna(1e9) <= float(budget_max)
311
+ if occasion_canon:
312
+ m &= _contains_ci(df["occasion_tags"], occasion_canon)
313
+ if age_range and age_range != "any":
314
+ m &= (df["age_range"].fillna("any").isin([age_range, "any"]))
315
+ return df[m]
316
+
317
  def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
318
  query = profile_to_query(profile)
319
+ sims, idxs = EMB_INDEX.search(query, topn=min(max(k*80, k), len(CATALOG)))
320
  df_f = filter_business(
321
  CATALOG,
322
  budget_min=profile.get("budget_min"),
 
325
  age_range=profile.get("age_range","any"),
326
  )
327
  if df_f.empty: df_f = CATALOG
328
+ df_f_idx = set(df_f.index.tolist())
329
 
330
  # soft gender boost
331
  def gender_tokens(g: str) -> List[str]:
 
339
  cand = []
340
  for i, sim in zip(idxs, sims):
341
  i = int(i)
342
+ if i in df_f_idx:
343
  blob = f"{CATALOG.loc[i,'tags']} {CATALOG.loc[i,'short_desc']}".lower()
344
  boost = 0.08 if any(t in blob for t in gts) else 0.0
345
  cand.append((i, float(sim) + boost))
 
363
  res["similarity"] = [dict(picks).get(int(i), np.nan) for i in sel]
364
  return res[["name","short_desc","price_usd","occasion_tags","persona_fit","age_range","image_url","similarity"]]
365
 
366
+ # --------------------- LLM plumbing (DIY + Greeting) ---------------------
367
+
368
+ def load_text_pipeline(model_id: str):
369
+ trust=True
370
+ if "flan" in model_id or "t5" in model_id:
371
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust)
372
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(model_id, trust_remote_code=trust)
373
+ return pipeline("text2text-generation", model=mdl, tokenizer=tok, device_map="auto", trust_remote_code=trust)
374
+ else:
375
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust)
376
+ mdl = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=trust)
377
+ return pipeline("text-generation", model=mdl, tokenizer=tok, device_map="auto", trust_remote_code=trust)
378
+
379
  try:
380
+ DIY_PIPE = load_text_pipeline(GEN_MODEL_ID)
 
 
381
  except Exception as e:
382
+ DIY_PIPE = None
383
+ print("DIY LLM load failed:", e)
384
 
385
+ # Small greeting model (can reuse DIY_PIPE)
386
+ GREETING_PIPE = DIY_PIPE
 
 
387
 
388
+ # ---- JSON helpers ----
389
+ GENERIC_NAMES = {"diy gift","gift","personalized gift","handmade gift","custom gift","מתנה","מתנה אישית","עשה זאת בעצמך"}
390
+
391
+ def _f(x, fb=0.0):
392
+ try: return float(x)
393
+ except: return float(fb)
394
+
395
+ def try_parse_json(text: str):
396
+ if not text: return None
397
+ m = re.search(r"(\{[\s\S]*\})", text.strip())
398
+ if not m: return None
399
+ blob = m.group(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  try:
401
+ return json.loads(blob)
402
  except Exception:
403
+ blob = re.sub(r",\s*}\s*$", "}", blob)
404
+ blob = re.sub(r",\s*\]", "]", blob)
405
+ try: return json.loads(blob)
406
+ except: return None
407
+
408
+ def truncate_prompt(pipe, text: str, max_tokens: int) -> str:
409
+ tok = pipe.tokenizer
410
+ ids = tok(text, truncation=True, max_length=max_tokens, return_tensors=None).get("input_ids", [])
411
+ return tok.decode(ids, skip_special_tokens=True) if ids else text
412
+
413
+ # ---- DIY prompt, validate & repair (no padding) ----
414
+
415
+ def diy_prompt(profile: Dict) -> str:
416
+ lang = "English" if OUTPUT_LANG == "en" else "Hebrew"
417
+ name = profile.get("recipient_name","Recipient")
418
+ rel = profile.get("relationship","Friend")
419
+ age = profile.get("age_range","any")
420
+ gen = profile.get("gender","any")
421
+ ints = ", ".join(profile.get("interests",[])) or "general"
422
+ occ = profile.get("occ_ui","Birthday")
423
+ lo, hi = int(profile.get("budget_min",10)), int(profile.get("budget_max",100))
424
+
425
+ return "\n".join([
426
+ f"Invent ONE original DIY gift idea from scratch for this recipient. Write all VALUES in {lang}.",
427
+ "Return JSON ONLY with exactly these keys (and nothing else):",
428
+ "gift_name, overview, materials_needed, step_by_step_instructions, estimated_cost_usd, estimated_time_minutes",
429
+ "",
430
+ "Hard requirements:",
431
+ "- Strongly reflect the recipient's interests and the occasion.",
432
+ "- overview MUST mention the recipient by NAME and include relationship, age_group, gender, and the occasion.",
433
+ "- gift_name must be SPECIFIC (not generic), 4–10 words, include at least one interest keyword.",
434
+ f"- estimated_cost_usd between ${lo}-${hi}; estimated_time_minutes 20–240.",
435
+ "- materials_needed: at least 5 concise items with quantities.",
436
+ "- step_by_step_instructions: at least 6 practical, ordered steps.",
437
+ "Forbidden gift_name terms: DIY Gift, Gift, Personalized Gift, Handmade Gift, Custom Gift.",
438
+ "",
439
+ f"Recipient: name={name}; relationship={rel}; age_group={age}; gender={gen}.",
440
+ f"Interests: {ints}. Occasion: {occ}.",
441
+ "JSON:"
442
+ ])
443
 
444
+ def diy_validate(g: dict, profile: Dict) -> Tuple[bool, List[str]]:
445
+ errs=[]
446
+ # keys
447
+ req=["gift_name","overview","materials_needed","step_by_step_instructions","estimated_cost_usd","estimated_time_minutes"]
448
+ for k in req:
449
+ if k not in g: errs.append(f"missing key: {k}")
450
+ # name
451
+ n=str(g.get("gift_name",""))
452
+ if not n.strip(): errs.append("gift_name empty")
453
+ if any(b in n.strip().lower() for b in GENERIC_NAMES): errs.append("gift_name generic")
454
+ if len(n.split())<3: errs.append("gift_name too short")
455
+ # overview mentions
456
+ ov=str(g.get("overview",""))
457
+ if profile.get("recipient_name","") and profile.get("recipient_name") not in ov: errs.append("overview missing recipient name")
458
+ for field,label in [("relationship","relationship"),("age_range","age_group"),("gender","gender"),("occ_ui","occasion")]:
459
+ val=str(profile.get(field,""))
460
+ if val and (val.split()[0] not in ov): errs.append(f"overview missing {label}")
461
+ # lists
462
+ mats=g.get("materials_needed", [])
463
+ steps=g.get("step_by_step_instructions", [])
464
+ if not isinstance(mats, list) or len(mats)<5: errs.append("materials_needed len < 5")
465
+ if not isinstance(steps, list) or len(steps)<6: errs.append("steps len < 6")
466
+ # numbers
467
+ lo, hi = _f(profile.get("budget_min",10),10), _f(profile.get("budget_max",100),100)
468
+ cost=_f(g.get("estimated_cost_usd"), -1)
469
+ if not (lo <= cost <= hi): errs.append(f"cost not in budget [{lo},{hi}]")
470
+ mins=int(_f(g.get("estimated_time_minutes"), -1))
471
+ if not (20 <= mins <= 240): errs.append("time not in 20..240")
472
+ return (len(errs)==0), errs
473
+
474
+ def diy_repair_prompt(profile: Dict, last: dict, errors: List[str]) -> str:
475
+ lang = "English" if OUTPUT_LANG == "en" else "Hebrew"
476
+ return "\n".join([
477
+ f"Fix ONLY the following problems in this JSON. Keep the same idea and style. Return JSON ONLY. Write all VALUES in {lang}.",
478
+ "Errors:",
479
+ *[f"- {e}" for e in errors],
480
+ "JSON to fix:",
481
+ json.dumps(last, ensure_ascii=False)
482
+ ])
483
+
484
+ def diy_generate(profile: Dict) -> Tuple[dict, str]:
485
+ if DIY_PIPE is None:
486
+ return {}, "DIY model not loaded"
487
+ # attempt 1: creative
488
+ prompt = diy_prompt(profile)
489
+ pr = truncate_prompt(DIY_PIPE, prompt, MAX_INPUT_TOKENS)
490
+ out = DIY_PIPE(pr, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=MAX_NEW_TOKENS, truncation=True)
491
+ if not isinstance(out, list): out=[out]
492
+ texts = [o.get("generated_text","") for o in out]
493
+ candidates = [try_parse_json(t) or {} for t in texts]
494
+
495
+ # pick first valid
496
+ for cand in candidates:
497
+ ok, errs = diy_validate(cand, profile)
498
+ if ok:
499
+ return cand, "ok"
500
+ last = cand
501
+
502
+ # repair loop (deterministic)
503
+ attempts = 1
504
+ while attempts < DIY_MAX_ATTEMPTS:
505
+ ok, errs = diy_validate(last, profile)
506
+ if ok:
507
+ return last, "ok"
508
+ fix_pr = diy_repair_prompt(profile, last, errs)
509
+ fix_pr = truncate_prompt(DIY_PIPE, fix_pr, MAX_INPUT_TOKENS)
510
+ fixed = DIY_PIPE(fix_pr, do_sample=False, max_new_tokens=MAX_NEW_TOKENS, truncation=True)
511
+ fixed = (fixed if isinstance(fixed, list) else [fixed])[0].get("generated_text","")
512
+ fixed = try_parse_json(fixed) or last
513
+ last = fixed
514
+ attempts += 1
515
+ return last, "partial"
516
+
517
+ # ---- Greeting generation ----
518
+
519
+ def greeting_prompt(profile: Dict) -> str:
520
+ tone = profile.get('tone','Heartfelt')
521
+ name = profile.get('recipient_name','Friend')
522
+ rel = profile.get('relationship','Friend')
523
+ occ = profile.get('occ_ui','Birthday')
524
+ ints = ", ".join(profile.get('interests', []))
525
+ age = profile.get('age_range','any')
526
+ gen = profile.get('gender','any')
527
+ return f"""
528
  Write a short greeting (2–3 sentences) in English for a gift card.
529
+ Tone: {tone}
530
+ Recipient: {name} ({rel})
531
+ Occasion: {occ}
532
+ Interests: {ints}
533
+ Age group: {age}; Gender: {gen}
 
534
  Avoid emojis.
535
  """
536
+
537
+ def llm_generate_message(profile: Dict) -> str:
538
+ if GREETING_PIPE is None:
539
+ return (f"Dear {profile.get('recipient_name','Friend')}, happy {profile.get('occ_ui','Birthday').lower()}! "
540
+ f"Wishing you joy and wonderful memories.")
541
+ pr = truncate_prompt(GREETING_PIPE, greeting_prompt(profile), MAX_INPUT_TOKENS)
542
+ out = GREETING_PIPE(pr, do_sample=False, max_new_tokens=90, truncation=True)
543
+ out = out if isinstance(out, list) else [out]
544
+ txt = out[0].get("generated_text","")
545
+ return txt.strip() or (f"Dear {profile.get('recipient_name','Friend')}, happy {profile.get('occ_ui','Birthday').lower()}!")
546
 
547
  # --------------------- Image generation (SD-Turbo) ---------------------
548
+
549
  def load_image_pipeline():
550
+ if not ENABLE_IMAGE:
551
+ return None
552
  try:
553
  device = "cuda" if torch.cuda.is_available() else "cpu"
554
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
561
 
562
  IMG_PIPE = load_image_pipeline()
563
 
564
+
565
+ def generate_gift_image_from_diy(diy: Dict):
566
+ if IMG_PIPE is None or not diy:
567
  return None
568
+ name = diy.get('gift_name','gift')
569
+ ov = diy.get('overview','product photo of handmade gift')
570
  prompt = (
571
+ f"{name}: {ov}. Style: product photo, soft studio lighting, minimal background, realistic, high detail."
 
572
  )
573
  try:
574
  img = IMG_PIPE(
 
583
  return None
584
 
585
  # --------------------- Rendering ---------------------
586
+
587
  def md_escape(text: str) -> str:
588
  return str(text).replace("|","\\|").replace("*","\\*").replace("_","\\_")
589
 
 
616
  rows.append(card)
617
  return "\n".join(rows)
618
 
619
+
620
+ def render_diy_md(d: Dict) -> str:
621
+ if not d:
622
+ return "<em>DIY generation failed.</em>"
623
+ name = md_escape(d.get("gift_name",""))
624
+ ov = md_escape(d.get("overview",""))
625
+ cost = d.get("estimated_cost_usd", "—")
626
+ mins = d.get("estimated_time_minutes", "—")
627
+ mats = d.get("materials_needed", [])
628
+ steps= d.get("step_by_step_instructions", [])
629
+ mats_md = "\n".join([f"- {md_escape(str(m))}" for m in mats]) if isinstance(mats, list) else "- —"
630
+ steps_md= "\n".join([f"{i+1}. {md_escape(str(s))}" for i,s in enumerate(steps)]) if isinstance(steps, list) else "1. —"
631
+ return f"""
632
+ ### DIY Gift — {name}
633
+
634
+ {ov}
635
+
636
+ **Estimated cost:** ${cost} · **Estimated time:** {mins} min
637
+
638
+ **Materials needed:**
639
+ {mats_md}
640
+
641
+ **Step-by-step:**
642
+ {steps_md}
643
+ """
644
+
645
  # --------------------- Gradio UI ---------------------
646
  CSS = """
647
  #examples { order: 1; }
 
651
  with gr.Blocks(css=CSS) as demo:
652
  gr.Markdown(TITLE)
653
 
 
654
  with gr.Column(elem_id="examples"):
655
  gr.Markdown("### Quick examples")
656
 
657
  with gr.Column(elem_id="form"):
658
  with gr.Row():
659
+ recipient_name = gr.Textbox(label="Recipient name", value="Rotem")
660
+ relationship = gr.Dropdown(label="Relationship", choices=RECIPIENT_RELATIONSHIPS, value="Romantic partner")
661
 
662
  with gr.Row():
663
  interests = gr.CheckboxGroup(
664
  label="Interests (select a few)", choices=INTEREST_OPTIONS,
665
+ value=["Reading","Fashion","Home decor"], interactive=True
666
  )
667
 
668
  with gr.Row():
669
+ occasion = gr.Dropdown(label="Occasion", choices=OCCASION_UI, value="Valentine’s Day")
670
  age = gr.Dropdown(label="Age group", choices=list(AGE_OPTIONS.keys()), value="adult (18–64)")
671
+ gender = gr.Dropdown(label="Recipient gender", choices=GENDER_OPTIONS, value="female")
672
 
 
673
  RangeSlider = getattr(gr, "RangeSlider", None)
674
  if RangeSlider is not None:
675
+ budget_range = RangeSlider(label="Budget range (USD)", minimum=5, maximum=500, step=1, value=[30, 60])
676
  budget_min, budget_max = None, None
677
  else:
678
  with gr.Row():
679
+ budget_min = gr.Slider(label="Min budget (USD)", minimum=5, maximum=500, step=1, value=30)
680
  budget_max = gr.Slider(label="Max budget (USD)", minimum=5, maximum=500, step=1, value=60)
681
  budget_range = gr.State(value=None)
682
 
683
+ tone = gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="Romantic")
684
 
685
  go = gr.Button("Get GIfty 🎯")
686
 
687
  out_top3 = gr.HTML(label="Top-3 recommendations")
688
+ out_diy_json = gr.JSON(label="DIY Gift (JSON)")
689
+ out_diy_md = gr.Markdown(label="DIY Gift (readable)")
690
+ out_gen_img = gr.Image(label="DIY Gift image", type="pil")
691
+ out_msg = gr.Markdown(label="Personalized message")
692
 
693
  # examples (render on top via CSS)
694
  if RangeSlider:
695
  example_inputs = [interests, occasion, budget_range, recipient_name, relationship, age, gender, tone]
696
  EXAMPLES = [
697
+ [["Reading","Fashion","Home decor"], "Valentine’s Day", [30,60], "Rotem", "Romantic partner", "adult (18–64)", "female", "Romantic"],
698
+ [["Technology","Movies"], "Birthday", [25,45], "Daniel", "Friend", "adult (18–64)", "male", "Funny"],
699
  [["Gaming","Photography"], "Birthday", [30,120], "Omer", "Family - Sibling", "teen (13–17)", "male", "Playful"],
700
+ [["Home decor","Cooking"], "Housewarming", [25,45], "Noa", "Neighbor", "adult (18–64)", "any", "Appreciative"],
701
  ]
702
  else:
703
  example_inputs = [interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone]
704
  EXAMPLES = [
705
+ [["Reading","Fashion","Home decor"], "Valentine’s Day", 30, 60, "Rotem", "Romantic partner", "adult (18–64)", "female", "Romantic"],
706
+ [["Technology","Movies"], "Birthday", 25, 45, "Daniel", "Friend", "adult (18–64)", "male", "Funny"],
707
  [["Gaming","Photography"], "Birthday", 30, 120, "Omer", "Family - Sibling", "teen (13–17)", "male", "Playful"],
708
+ [["Home decor","Cooking"], "Housewarming", 25, 45, "Noa", "Neighbor", "adult (18–64)", "any", "Appreciative"],
709
  ]
710
 
711
  with gr.Column(elem_id="examples"):
 
760
  top3 = recommend_topk(profile, k=3)
761
  top3_html = render_top3_html(top3)
762
 
763
+ # DIY gift (generate-from-scratch, JSON)
764
+ diy_json, diy_status = diy_generate(profile)
765
+ diy_md = render_diy_md(diy_json)
766
+
767
+ # DIY image (optional)
768
+ diy_img = generate_gift_image_from_diy(diy_json)
769
 
770
  # greeting
771
  msg = llm_generate_message(profile)
772
 
773
+ return top3_html, diy_json, diy_md, diy_img, msg
774
 
775
  if RangeSlider:
776
  go.click(
777
  ui_predict,
778
  [interests, occasion, budget_range, recipient_name, relationship, age, gender, tone],
779
+ [out_top3, out_diy_json, out_diy_md, out_gen_img, out_msg]
780
  )
781
  else:
782
  go.click(
783
  ui_predict,
784
  [interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
785
+ [out_top3, out_diy_json, out_diy_md, out_gen_img, out_msg]
786
  )
787
 
788
  if __name__ == "__main__":