Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
-
# app.py
|
| 2 |
-
# 🎁 GIfty — Smart Gift Recommender
|
| 3 |
# Data: ckandemir/amazon-products
|
| 4 |
-
# Retrieval: MiniLM embeddings + FAISS (cosine)
|
| 5 |
-
# Generation:
|
| 6 |
-
#
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
import os, re, json, random
|
| 9 |
from typing import Dict, List, Tuple
|
| 10 |
|
| 11 |
import numpy as np
|
|
@@ -16,23 +18,37 @@ from datasets import load_dataset
|
|
| 16 |
from sentence_transformers import SentenceTransformer
|
| 17 |
import faiss
|
| 18 |
|
| 19 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 20 |
|
| 21 |
import torch
|
| 22 |
from diffusers import AutoPipelineForText2Image
|
| 23 |
|
| 24 |
# --------------------- Config ---------------------
|
| 25 |
MAX_ROWS = int(os.getenv("MAX_ROWS", "8000"))
|
| 26 |
-
TITLE = "# 🎁 GIfty — Smart Gift Recommender\n*Top-3
|
| 27 |
|
| 28 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
INTEREST_OPTIONS = [
|
| 30 |
"Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion",
|
| 31 |
"Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food",
|
| 32 |
"Home decor","Science"
|
| 33 |
]
|
| 34 |
|
| 35 |
-
# ===== Updated Occasions (exact) =====
|
| 36 |
OCCASION_UI = [
|
| 37 |
"Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming",
|
| 38 |
"Retirement","Holidays","Valentine’s Day","Promotion / New job","Get well soon"
|
|
@@ -52,7 +68,6 @@ OCCASION_CANON = {
|
|
| 52 |
"Get well soon":"get_well"
|
| 53 |
}
|
| 54 |
|
| 55 |
-
# ===== Updated Relationship & Tone =====
|
| 56 |
RECIPIENT_RELATIONSHIPS = [
|
| 57 |
"Family - Parent",
|
| 58 |
"Family - Sibling",
|
|
@@ -68,15 +83,7 @@ RECIPIENT_RELATIONSHIPS = [
|
|
| 68 |
]
|
| 69 |
|
| 70 |
MESSAGE_TONES = [
|
| 71 |
-
"Formal",
|
| 72 |
-
"Casual",
|
| 73 |
-
"Funny",
|
| 74 |
-
"Heartfelt",
|
| 75 |
-
"Inspirational",
|
| 76 |
-
"Playful",
|
| 77 |
-
"Romantic",
|
| 78 |
-
"Appreciative",
|
| 79 |
-
"Encouraging",
|
| 80 |
]
|
| 81 |
|
| 82 |
AGE_OPTIONS = {
|
|
@@ -211,9 +218,13 @@ def load_catalog() -> pd.DataFrame:
|
|
| 211 |
],
|
| 212 |
"Category": ["Electronics | Audio","Grocery | Coffee","Toys & Games | Board Games"],
|
| 213 |
"Selling Price": ["$59.00","$34.00","$39.00"],
|
| 214 |
-
"Image": ["","",""]
|
| 215 |
})
|
| 216 |
df = map_amazon_to_schema(raw).drop_duplicates(subset=["name","short_desc"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
if len(df) > MAX_ROWS:
|
| 218 |
df = df.sample(n=MAX_ROWS, random_state=42).reset_index(drop=True)
|
| 219 |
df["doc"] = df.apply(build_doc, axis=1)
|
|
@@ -221,38 +232,43 @@ def load_catalog() -> pd.DataFrame:
|
|
| 221 |
|
| 222 |
CATALOG = load_catalog()
|
| 223 |
|
| 224 |
-
# ---------------------
|
| 225 |
-
def _contains_ci(series: pd.Series, needle: str) -> pd.Series:
|
| 226 |
-
if not needle: return pd.Series(True, index=series.index)
|
| 227 |
-
return series.fillna("").str.contains(re.escape(needle), case=False, regex=True)
|
| 228 |
-
|
| 229 |
-
def filter_business(df: pd.DataFrame, budget_min=None, budget_max=None,
|
| 230 |
-
occasion_canon: str=None, age_range: str="any") -> pd.DataFrame:
|
| 231 |
-
m = pd.Series(True, index=df.index)
|
| 232 |
-
if budget_min is not None:
|
| 233 |
-
m &= df["price_usd"].fillna(0) >= float(budget_min)
|
| 234 |
-
if budget_max is not None:
|
| 235 |
-
m &= df["price_usd"].fillna(1e9) <= float(budget_max)
|
| 236 |
-
if occasion_canon:
|
| 237 |
-
m &= _contains_ci(df["occasion_tags"], occasion_canon)
|
| 238 |
-
if age_range and age_range != "any":
|
| 239 |
-
m &= (df["age_range"].fillna("any").isin([age_range, "any"]))
|
| 240 |
-
return df[m]
|
| 241 |
-
|
| 242 |
-
# --------------------- Embeddings + FAISS ---------------------
|
| 243 |
class EmbeddingIndex:
|
| 244 |
def __init__(self, docs: List[str], model_id: str):
|
|
|
|
| 245 |
self.model = SentenceTransformer(model_id)
|
| 246 |
-
embs = self.
|
| 247 |
-
self.index = faiss.IndexFlatIP(embs.shape[1]) # cosine via normalized vectors
|
| 248 |
-
self.index.add(embs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
def search(self, query: str, topn: int):
|
| 251 |
qv = self.model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
|
| 252 |
sims, idxs = self.index.search(qv, topn)
|
| 253 |
return sims[0], idxs[0]
|
| 254 |
|
| 255 |
-
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # fast & solid on CPU
|
| 256 |
EMB_INDEX = EmbeddingIndex(CATALOG["doc"].tolist(), EMBED_MODEL_ID)
|
| 257 |
|
| 258 |
# --------------------- Query building ---------------------
|
|
@@ -281,9 +297,26 @@ def profile_to_query(profile: Dict) -> str:
|
|
| 281 |
if g != "any": parts.append("women" if g=="female" else ("men" if g=="male" else "unisex"))
|
| 282 |
return " | ".join(parts)
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
|
| 285 |
query = profile_to_query(profile)
|
| 286 |
-
sims, idxs =
|
| 287 |
df_f = filter_business(
|
| 288 |
CATALOG,
|
| 289 |
budget_min=profile.get("budget_min"),
|
|
@@ -292,6 +325,7 @@ def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
|
|
| 292 |
age_range=profile.get("age_range","any"),
|
| 293 |
)
|
| 294 |
if df_f.empty: df_f = CATALOG
|
|
|
|
| 295 |
|
| 296 |
# soft gender boost
|
| 297 |
def gender_tokens(g: str) -> List[str]:
|
|
@@ -305,7 +339,7 @@ def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
|
|
| 305 |
cand = []
|
| 306 |
for i, sim in zip(idxs, sims):
|
| 307 |
i = int(i)
|
| 308 |
-
if i in
|
| 309 |
blob = f"{CATALOG.loc[i,'tags']} {CATALOG.loc[i,'short_desc']}".lower()
|
| 310 |
boost = 0.08 if any(t in blob for t in gts) else 0.0
|
| 311 |
cand.append((i, float(sim) + boost))
|
|
@@ -329,95 +363,192 @@ def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
|
|
| 329 |
res["similarity"] = [dict(picks).get(int(i), np.nan) for i in sel]
|
| 330 |
return res[["name","short_desc","price_usd","occasion_tags","persona_fit","age_range","image_url","similarity"]]
|
| 331 |
|
| 332 |
-
# --------------------- LLM (
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
try:
|
| 335 |
-
|
| 336 |
-
_mdl = AutoModelForSeq2SeqLM.from_pretrained(LLM_ID)
|
| 337 |
-
LLM = pipeline("text2text-generation", model=_mdl, tokenizer=_tok)
|
| 338 |
except Exception as e:
|
| 339 |
-
|
| 340 |
-
print("LLM load failed
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
out = LLM(prompt, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0)
|
| 345 |
-
return out[0]["generated_text"]
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
You are GIfty. Invent ONE gift that matches the catalog style with keys:
|
| 360 |
-
name, short_desc, price_usd, occasion_tags, persona_fit. Use JSON only.
|
| 361 |
-
Constraints:
|
| 362 |
-
- Fit the recipient profile and relationship.
|
| 363 |
-
- price_usd must be numeric within the budget range.
|
| 364 |
-
Profile:
|
| 365 |
-
name={profile.get('recipient_name','Friend')}
|
| 366 |
-
relationship={profile.get('relationship','Friend')}
|
| 367 |
-
gender={profile.get('gender','any')}
|
| 368 |
-
age_group={profile.get('age_range','any')}
|
| 369 |
-
interests={profile.get('interests',[])}
|
| 370 |
-
occasion={profile.get('occ_ui','Birthday')}
|
| 371 |
-
budget_min={profile.get('budget_min',10)}
|
| 372 |
-
budget_max={profile.get('budget_max',100)}
|
| 373 |
-
"""
|
| 374 |
-
txt = _run_llm(prompt, max_new_tokens=180)
|
| 375 |
-
data = _parse_json_maybe(txt)
|
| 376 |
-
if not data:
|
| 377 |
-
core = (profile.get("interests",["hobby"])[0] or "hobby").lower()
|
| 378 |
-
return {
|
| 379 |
-
"name": f"{core.title()} starter bundle ({profile.get('occ_ui','Birthday')})",
|
| 380 |
-
"short_desc": f"A curated set to kickstart their {core} passion.",
|
| 381 |
-
"price_usd": float(np.clip(profile.get("budget_max", 50) or 50, 10, 300)),
|
| 382 |
-
"occasion_tags": OCCASION_CANON.get(profile.get("occ_ui","Birthday"), "birthday"),
|
| 383 |
-
"persona_fit": ", ".join(profile.get("interests", [])) or "general",
|
| 384 |
-
"age_range": profile.get("age_range","any"),
|
| 385 |
-
"image_url": ""
|
| 386 |
-
}
|
| 387 |
try:
|
| 388 |
-
|
| 389 |
except Exception:
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
-
def
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
Write a short greeting (2–3 sentences) in English for a gift card.
|
| 405 |
-
Tone: {
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
Age group: {profile.get('age_range','any')}; Gender: {profile.get('gender','any')}
|
| 411 |
Avoid emojis.
|
| 412 |
"""
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
# --------------------- Image generation (SD-Turbo) ---------------------
|
|
|
|
| 420 |
def load_image_pipeline():
|
|
|
|
|
|
|
| 421 |
try:
|
| 422 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 423 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
@@ -430,12 +561,14 @@ def load_image_pipeline():
|
|
| 430 |
|
| 431 |
IMG_PIPE = load_image_pipeline()
|
| 432 |
|
| 433 |
-
|
| 434 |
-
|
|
|
|
| 435 |
return None
|
|
|
|
|
|
|
| 436 |
prompt = (
|
| 437 |
-
f"{
|
| 438 |
-
f"Style: product photo, soft studio lighting, minimal background, realistic, high detail."
|
| 439 |
)
|
| 440 |
try:
|
| 441 |
img = IMG_PIPE(
|
|
@@ -450,6 +583,7 @@ def generate_gift_image(gift: Dict):
|
|
| 450 |
return None
|
| 451 |
|
| 452 |
# --------------------- Rendering ---------------------
|
|
|
|
| 453 |
def md_escape(text: str) -> str:
|
| 454 |
return str(text).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
| 455 |
|
|
@@ -482,6 +616,32 @@ def render_top3_html(df: pd.DataFrame) -> str:
|
|
| 482 |
rows.append(card)
|
| 483 |
return "\n".join(rows)
|
| 484 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
# --------------------- Gradio UI ---------------------
|
| 486 |
CSS = """
|
| 487 |
#examples { order: 1; }
|
|
@@ -491,62 +651,61 @@ CSS = """
|
|
| 491 |
with gr.Blocks(css=CSS) as demo:
|
| 492 |
gr.Markdown(TITLE)
|
| 493 |
|
| 494 |
-
# top section (examples placeholder)
|
| 495 |
with gr.Column(elem_id="examples"):
|
| 496 |
gr.Markdown("### Quick examples")
|
| 497 |
|
| 498 |
with gr.Column(elem_id="form"):
|
| 499 |
with gr.Row():
|
| 500 |
-
recipient_name = gr.Textbox(label="Recipient name", value="
|
| 501 |
-
relationship = gr.Dropdown(label="Relationship", choices=RECIPIENT_RELATIONSHIPS, value="
|
| 502 |
|
| 503 |
with gr.Row():
|
| 504 |
interests = gr.CheckboxGroup(
|
| 505 |
label="Interests (select a few)", choices=INTEREST_OPTIONS,
|
| 506 |
-
value=["
|
| 507 |
)
|
| 508 |
|
| 509 |
with gr.Row():
|
| 510 |
-
occasion = gr.Dropdown(label="Occasion", choices=OCCASION_UI, value="
|
| 511 |
age = gr.Dropdown(label="Age group", choices=list(AGE_OPTIONS.keys()), value="adult (18–64)")
|
| 512 |
-
gender = gr.Dropdown(label="Recipient gender", choices=GENDER_OPTIONS, value="
|
| 513 |
|
| 514 |
-
# Budget: try RangeSlider else two sliders
|
| 515 |
RangeSlider = getattr(gr, "RangeSlider", None)
|
| 516 |
if RangeSlider is not None:
|
| 517 |
-
budget_range = RangeSlider(label="Budget range (USD)", minimum=5, maximum=500, step=1, value=[
|
| 518 |
budget_min, budget_max = None, None
|
| 519 |
else:
|
| 520 |
with gr.Row():
|
| 521 |
-
budget_min = gr.Slider(label="Min budget (USD)", minimum=5, maximum=500, step=1, value=
|
| 522 |
budget_max = gr.Slider(label="Max budget (USD)", minimum=5, maximum=500, step=1, value=60)
|
| 523 |
budget_range = gr.State(value=None)
|
| 524 |
|
| 525 |
-
tone = gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="
|
| 526 |
|
| 527 |
go = gr.Button("Get GIfty 🎯")
|
| 528 |
|
| 529 |
out_top3 = gr.HTML(label="Top-3 recommendations")
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
|
|
|
| 533 |
|
| 534 |
# examples (render on top via CSS)
|
| 535 |
if RangeSlider:
|
| 536 |
example_inputs = [interests, occasion, budget_range, recipient_name, relationship, age, gender, tone]
|
| 537 |
EXAMPLES = [
|
| 538 |
-
[["
|
| 539 |
-
[["
|
| 540 |
[["Gaming","Photography"], "Birthday", [30,120], "Omer", "Family - Sibling", "teen (13–17)", "male", "Playful"],
|
| 541 |
-
[["
|
| 542 |
]
|
| 543 |
else:
|
| 544 |
example_inputs = [interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone]
|
| 545 |
EXAMPLES = [
|
| 546 |
-
[["
|
| 547 |
-
[["
|
| 548 |
[["Gaming","Photography"], "Birthday", 30, 120, "Omer", "Family - Sibling", "teen (13–17)", "male", "Playful"],
|
| 549 |
-
[["
|
| 550 |
]
|
| 551 |
|
| 552 |
with gr.Column(elem_id="examples"):
|
|
@@ -601,27 +760,29 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 601 |
top3 = recommend_topk(profile, k=3)
|
| 602 |
top3_html = render_top3_html(top3)
|
| 603 |
|
| 604 |
-
#
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
|
|
|
|
|
|
| 608 |
|
| 609 |
# greeting
|
| 610 |
msg = llm_generate_message(profile)
|
| 611 |
|
| 612 |
-
return top3_html,
|
| 613 |
|
| 614 |
if RangeSlider:
|
| 615 |
go.click(
|
| 616 |
ui_predict,
|
| 617 |
[interests, occasion, budget_range, recipient_name, relationship, age, gender, tone],
|
| 618 |
-
[out_top3,
|
| 619 |
)
|
| 620 |
else:
|
| 621 |
go.click(
|
| 622 |
ui_predict,
|
| 623 |
[interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
| 624 |
-
[out_top3,
|
| 625 |
)
|
| 626 |
|
| 627 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
# app.py — Gifty (revised)
|
| 2 |
+
# 🎁 GIfty — Smart Gift Recommender
|
| 3 |
# Data: ckandemir/amazon-products
|
| 4 |
+
# Retrieval: MiniLM-L12-v2 embeddings + FAISS (cosine), with simple on-disk cache
|
| 5 |
+
# DIY Generation: small instruct LMs via HF pipeline (default: flan-t5-small) with JSON validate+repair (no padding)
|
| 6 |
+
# Greeting: short LLM completion
|
| 7 |
+
# Image: SD-Turbo (optional)
|
| 8 |
+
# UI: Gradio; Quick Examples; Budget RangeSlider; DIY JSON + readable card
|
| 9 |
|
| 10 |
+
import os, re, json, random, hashlib, pathlib
|
| 11 |
from typing import Dict, List, Tuple
|
| 12 |
|
| 13 |
import numpy as np
|
|
|
|
| 18 |
from sentence_transformers import SentenceTransformer
|
| 19 |
import faiss
|
| 20 |
|
| 21 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline
|
| 22 |
|
| 23 |
import torch
|
| 24 |
from diffusers import AutoPipelineForText2Image
|
| 25 |
|
| 26 |
# --------------------- Config ---------------------
|
| 27 |
MAX_ROWS = int(os.getenv("MAX_ROWS", "8000"))
|
| 28 |
+
TITLE = "# 🎁 GIfty — Smart Gift Recommender\n*Top-3 catalog picks + 1 DIY gift (JSON) + personalized message*"
|
| 29 |
|
| 30 |
+
# Retrieval model (embedding)
|
| 31 |
+
EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "sentence-transformers/all-MiniLM-L12-v2")
|
| 32 |
+
EMBED_CACHE_DIR = os.getenv("EMBED_CACHE_DIR", "./.gifty_cache")
|
| 33 |
+
pathlib.Path(EMBED_CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
# DIY generation model (text)
|
| 36 |
+
GEN_MODEL_ID = os.getenv("GEN_MODEL_ID", "google/flan-t5-small")
|
| 37 |
+
OUTPUT_LANG = os.getenv("OUTPUT_LANG", "en") # "en" or "he"
|
| 38 |
+
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "360"))
|
| 39 |
+
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "260"))
|
| 40 |
+
DIY_MAX_ATTEMPTS = int(os.getenv("DIY_MAX_ATTEMPTS", "4"))
|
| 41 |
+
|
| 42 |
+
# Image gen toggle
|
| 43 |
+
ENABLE_IMAGE = os.getenv("ENABLE_IMAGE", "1") == "1"
|
| 44 |
+
|
| 45 |
+
# ===== UI options =====
|
| 46 |
INTEREST_OPTIONS = [
|
| 47 |
"Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion",
|
| 48 |
"Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food",
|
| 49 |
"Home decor","Science"
|
| 50 |
]
|
| 51 |
|
|
|
|
| 52 |
OCCASION_UI = [
|
| 53 |
"Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming",
|
| 54 |
"Retirement","Holidays","Valentine’s Day","Promotion / New job","Get well soon"
|
|
|
|
| 68 |
"Get well soon":"get_well"
|
| 69 |
}
|
| 70 |
|
|
|
|
| 71 |
RECIPIENT_RELATIONSHIPS = [
|
| 72 |
"Family - Parent",
|
| 73 |
"Family - Sibling",
|
|
|
|
| 83 |
]
|
| 84 |
|
| 85 |
MESSAGE_TONES = [
|
| 86 |
+
"Formal","Casual","Funny","Heartfelt","Inspirational","Playful","Romantic","Appreciative","Encouraging",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
]
|
| 88 |
|
| 89 |
AGE_OPTIONS = {
|
|
|
|
| 218 |
],
|
| 219 |
"Category": ["Electronics | Audio","Grocery | Coffee","Toys & Games | Board Games"],
|
| 220 |
"Selling Price": ["$59.00","$34.00","$39.00"],
|
| 221 |
+
"Image": ["","",""]
|
| 222 |
})
|
| 223 |
df = map_amazon_to_schema(raw).drop_duplicates(subset=["name","short_desc"])
|
| 224 |
+
# EDA cleanups: drop missing price, cap to <= 500
|
| 225 |
+
df = df[pd.notna(df["price_usd"])].copy()
|
| 226 |
+
df = df[df["price_usd"] <= 500].reset_index(drop=True)
|
| 227 |
+
# limit rows
|
| 228 |
if len(df) > MAX_ROWS:
|
| 229 |
df = df.sample(n=MAX_ROWS, random_state=42).reset_index(drop=True)
|
| 230 |
df["doc"] = df.apply(build_doc, axis=1)
|
|
|
|
| 232 |
|
| 233 |
CATALOG = load_catalog()
|
| 234 |
|
| 235 |
+
# --------------------- Embeddings + FAISS (with simple cache) ---------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
class EmbeddingIndex:
|
| 237 |
def __init__(self, docs: List[str], model_id: str):
|
| 238 |
+
self.model_id = model_id
|
| 239 |
self.model = SentenceTransformer(model_id)
|
| 240 |
+
self.embs = self._load_or_build(docs)
|
| 241 |
+
self.index = faiss.IndexFlatIP(self.embs.shape[1]) # cosine via normalized vectors
|
| 242 |
+
self.index.add(self.embs)
|
| 243 |
+
|
| 244 |
+
def _cache_paths(self, n_docs: int) -> Tuple[str, str]:
|
| 245 |
+
h = hashlib.md5((self.model_id + f"|{n_docs}").encode()).hexdigest()[:10]
|
| 246 |
+
npy = os.path.join(EMBED_CACHE_DIR, f"emb_{h}.npy")
|
| 247 |
+
idx = os.path.join(EMBED_CACHE_DIR, f"faiss_{h}.index")
|
| 248 |
+
return npy, idx
|
| 249 |
+
|
| 250 |
+
def _load_or_build(self, docs: List[str]) -> np.ndarray:
|
| 251 |
+
npy_path, _ = self._cache_paths(len(docs))
|
| 252 |
+
if os.path.exists(npy_path):
|
| 253 |
+
try:
|
| 254 |
+
embs = np.load(npy_path)
|
| 255 |
+
if embs.shape[0] == len(docs):
|
| 256 |
+
return embs
|
| 257 |
+
except Exception:
|
| 258 |
+
pass
|
| 259 |
+
# build
|
| 260 |
+
embs = self.model.encode(docs, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True)
|
| 261 |
+
try:
|
| 262 |
+
np.save(npy_path, embs)
|
| 263 |
+
except Exception:
|
| 264 |
+
pass
|
| 265 |
+
return embs
|
| 266 |
|
| 267 |
def search(self, query: str, topn: int):
|
| 268 |
qv = self.model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
|
| 269 |
sims, idxs = self.index.search(qv, topn)
|
| 270 |
return sims[0], idxs[0]
|
| 271 |
|
|
|
|
| 272 |
EMB_INDEX = EmbeddingIndex(CATALOG["doc"].tolist(), EMBED_MODEL_ID)
|
| 273 |
|
| 274 |
# --------------------- Query building ---------------------
|
|
|
|
| 297 |
if g != "any": parts.append("women" if g=="female" else ("men" if g=="male" else "unisex"))
|
| 298 |
return " | ".join(parts)
|
| 299 |
|
| 300 |
+
def _contains_ci(series: pd.Series, needle: str) -> pd.Series:
|
| 301 |
+
if not needle: return pd.Series(True, index=series.index)
|
| 302 |
+
return series.fillna("").str.contains(re.escape(needle), case=False, regex=True)
|
| 303 |
+
|
| 304 |
+
def filter_business(df: pd.DataFrame, budget_min=None, budget_max=None,
|
| 305 |
+
occasion_canon: str=None, age_range: str="any") -> pd.DataFrame:
|
| 306 |
+
m = pd.Series(True, index=df.index)
|
| 307 |
+
if budget_min is not None:
|
| 308 |
+
m &= df["price_usd"].fillna(0) >= float(budget_min)
|
| 309 |
+
if budget_max is not None:
|
| 310 |
+
m &= df["price_usd"].fillna(1e9) <= float(budget_max)
|
| 311 |
+
if occasion_canon:
|
| 312 |
+
m &= _contains_ci(df["occasion_tags"], occasion_canon)
|
| 313 |
+
if age_range and age_range != "any":
|
| 314 |
+
m &= (df["age_range"].fillna("any").isin([age_range, "any"]))
|
| 315 |
+
return df[m]
|
| 316 |
+
|
| 317 |
def recommend_topk(profile: Dict, k: int=3) -> pd.DataFrame:
|
| 318 |
query = profile_to_query(profile)
|
| 319 |
+
sims, idxs = EMB_INDEX.search(query, topn=min(max(k*80, k), len(CATALOG)))
|
| 320 |
df_f = filter_business(
|
| 321 |
CATALOG,
|
| 322 |
budget_min=profile.get("budget_min"),
|
|
|
|
| 325 |
age_range=profile.get("age_range","any"),
|
| 326 |
)
|
| 327 |
if df_f.empty: df_f = CATALOG
|
| 328 |
+
df_f_idx = set(df_f.index.tolist())
|
| 329 |
|
| 330 |
# soft gender boost
|
| 331 |
def gender_tokens(g: str) -> List[str]:
|
|
|
|
| 339 |
cand = []
|
| 340 |
for i, sim in zip(idxs, sims):
|
| 341 |
i = int(i)
|
| 342 |
+
if i in df_f_idx:
|
| 343 |
blob = f"{CATALOG.loc[i,'tags']} {CATALOG.loc[i,'short_desc']}".lower()
|
| 344 |
boost = 0.08 if any(t in blob for t in gts) else 0.0
|
| 345 |
cand.append((i, float(sim) + boost))
|
|
|
|
| 363 |
res["similarity"] = [dict(picks).get(int(i), np.nan) for i in sel]
|
| 364 |
return res[["name","short_desc","price_usd","occasion_tags","persona_fit","age_range","image_url","similarity"]]
|
| 365 |
|
| 366 |
+
# --------------------- LLM plumbing (DIY + Greeting) ---------------------
|
| 367 |
+
|
| 368 |
+
def load_text_pipeline(model_id: str):
|
| 369 |
+
trust=True
|
| 370 |
+
if "flan" in model_id or "t5" in model_id:
|
| 371 |
+
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust)
|
| 372 |
+
mdl = AutoModelForSeq2SeqLM.from_pretrained(model_id, trust_remote_code=trust)
|
| 373 |
+
return pipeline("text2text-generation", model=mdl, tokenizer=tok, device_map="auto", trust_remote_code=trust)
|
| 374 |
+
else:
|
| 375 |
+
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust)
|
| 376 |
+
mdl = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=trust)
|
| 377 |
+
return pipeline("text-generation", model=mdl, tokenizer=tok, device_map="auto", trust_remote_code=trust)
|
| 378 |
+
|
| 379 |
try:
|
| 380 |
+
DIY_PIPE = load_text_pipeline(GEN_MODEL_ID)
|
|
|
|
|
|
|
| 381 |
except Exception as e:
|
| 382 |
+
DIY_PIPE = None
|
| 383 |
+
print("DIY LLM load failed:", e)
|
| 384 |
|
| 385 |
+
# Small greeting model (can reuse DIY_PIPE)
|
| 386 |
+
GREETING_PIPE = DIY_PIPE
|
|
|
|
|
|
|
| 387 |
|
| 388 |
+
# ---- JSON helpers ----
|
| 389 |
+
GENERIC_NAMES = {"diy gift","gift","personalized gift","handmade gift","custom gift","מתנה","מתנה אישית","עשה זאת בעצמך"}
|
| 390 |
+
|
| 391 |
+
def _f(x, fb=0.0):
|
| 392 |
+
try: return float(x)
|
| 393 |
+
except: return float(fb)
|
| 394 |
+
|
| 395 |
+
def try_parse_json(text: str):
|
| 396 |
+
if not text: return None
|
| 397 |
+
m = re.search(r"(\{[\s\S]*\})", text.strip())
|
| 398 |
+
if not m: return None
|
| 399 |
+
blob = m.group(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
try:
|
| 401 |
+
return json.loads(blob)
|
| 402 |
except Exception:
|
| 403 |
+
blob = re.sub(r",\s*}\s*$", "}", blob)
|
| 404 |
+
blob = re.sub(r",\s*\]", "]", blob)
|
| 405 |
+
try: return json.loads(blob)
|
| 406 |
+
except: return None
|
| 407 |
+
|
| 408 |
+
def truncate_prompt(pipe, text: str, max_tokens: int) -> str:
|
| 409 |
+
tok = pipe.tokenizer
|
| 410 |
+
ids = tok(text, truncation=True, max_length=max_tokens, return_tensors=None).get("input_ids", [])
|
| 411 |
+
return tok.decode(ids, skip_special_tokens=True) if ids else text
|
| 412 |
+
|
| 413 |
+
# ---- DIY prompt, validate & repair (no padding) ----
|
| 414 |
+
|
| 415 |
+
def diy_prompt(profile: Dict) -> str:
|
| 416 |
+
lang = "English" if OUTPUT_LANG == "en" else "Hebrew"
|
| 417 |
+
name = profile.get("recipient_name","Recipient")
|
| 418 |
+
rel = profile.get("relationship","Friend")
|
| 419 |
+
age = profile.get("age_range","any")
|
| 420 |
+
gen = profile.get("gender","any")
|
| 421 |
+
ints = ", ".join(profile.get("interests",[])) or "general"
|
| 422 |
+
occ = profile.get("occ_ui","Birthday")
|
| 423 |
+
lo, hi = int(profile.get("budget_min",10)), int(profile.get("budget_max",100))
|
| 424 |
+
|
| 425 |
+
return "\n".join([
|
| 426 |
+
f"Invent ONE original DIY gift idea from scratch for this recipient. Write all VALUES in {lang}.",
|
| 427 |
+
"Return JSON ONLY with exactly these keys (and nothing else):",
|
| 428 |
+
"gift_name, overview, materials_needed, step_by_step_instructions, estimated_cost_usd, estimated_time_minutes",
|
| 429 |
+
"",
|
| 430 |
+
"Hard requirements:",
|
| 431 |
+
"- Strongly reflect the recipient's interests and the occasion.",
|
| 432 |
+
"- overview MUST mention the recipient by NAME and include relationship, age_group, gender, and the occasion.",
|
| 433 |
+
"- gift_name must be SPECIFIC (not generic), 4–10 words, include at least one interest keyword.",
|
| 434 |
+
f"- estimated_cost_usd between ${lo}-${hi}; estimated_time_minutes 20–240.",
|
| 435 |
+
"- materials_needed: at least 5 concise items with quantities.",
|
| 436 |
+
"- step_by_step_instructions: at least 6 practical, ordered steps.",
|
| 437 |
+
"Forbidden gift_name terms: DIY Gift, Gift, Personalized Gift, Handmade Gift, Custom Gift.",
|
| 438 |
+
"",
|
| 439 |
+
f"Recipient: name={name}; relationship={rel}; age_group={age}; gender={gen}.",
|
| 440 |
+
f"Interests: {ints}. Occasion: {occ}.",
|
| 441 |
+
"JSON:"
|
| 442 |
+
])
|
| 443 |
|
| 444 |
+
def diy_validate(g: dict, profile: Dict) -> Tuple[bool, List[str]]:
|
| 445 |
+
errs=[]
|
| 446 |
+
# keys
|
| 447 |
+
req=["gift_name","overview","materials_needed","step_by_step_instructions","estimated_cost_usd","estimated_time_minutes"]
|
| 448 |
+
for k in req:
|
| 449 |
+
if k not in g: errs.append(f"missing key: {k}")
|
| 450 |
+
# name
|
| 451 |
+
n=str(g.get("gift_name",""))
|
| 452 |
+
if not n.strip(): errs.append("gift_name empty")
|
| 453 |
+
if any(b in n.strip().lower() for b in GENERIC_NAMES): errs.append("gift_name generic")
|
| 454 |
+
if len(n.split())<3: errs.append("gift_name too short")
|
| 455 |
+
# overview mentions
|
| 456 |
+
ov=str(g.get("overview",""))
|
| 457 |
+
if profile.get("recipient_name","") and profile.get("recipient_name") not in ov: errs.append("overview missing recipient name")
|
| 458 |
+
for field,label in [("relationship","relationship"),("age_range","age_group"),("gender","gender"),("occ_ui","occasion")]:
|
| 459 |
+
val=str(profile.get(field,""))
|
| 460 |
+
if val and (val.split()[0] not in ov): errs.append(f"overview missing {label}")
|
| 461 |
+
# lists
|
| 462 |
+
mats=g.get("materials_needed", [])
|
| 463 |
+
steps=g.get("step_by_step_instructions", [])
|
| 464 |
+
if not isinstance(mats, list) or len(mats)<5: errs.append("materials_needed len < 5")
|
| 465 |
+
if not isinstance(steps, list) or len(steps)<6: errs.append("steps len < 6")
|
| 466 |
+
# numbers
|
| 467 |
+
lo, hi = _f(profile.get("budget_min",10),10), _f(profile.get("budget_max",100),100)
|
| 468 |
+
cost=_f(g.get("estimated_cost_usd"), -1)
|
| 469 |
+
if not (lo <= cost <= hi): errs.append(f"cost not in budget [{lo},{hi}]")
|
| 470 |
+
mins=int(_f(g.get("estimated_time_minutes"), -1))
|
| 471 |
+
if not (20 <= mins <= 240): errs.append("time not in 20..240")
|
| 472 |
+
return (len(errs)==0), errs
|
| 473 |
+
|
| 474 |
+
def diy_repair_prompt(profile: Dict, last: dict, errors: List[str]) -> str:
|
| 475 |
+
lang = "English" if OUTPUT_LANG == "en" else "Hebrew"
|
| 476 |
+
return "\n".join([
|
| 477 |
+
f"Fix ONLY the following problems in this JSON. Keep the same idea and style. Return JSON ONLY. Write all VALUES in {lang}.",
|
| 478 |
+
"Errors:",
|
| 479 |
+
*[f"- {e}" for e in errors],
|
| 480 |
+
"JSON to fix:",
|
| 481 |
+
json.dumps(last, ensure_ascii=False)
|
| 482 |
+
])
|
| 483 |
+
|
| 484 |
+
def diy_generate(profile: Dict) -> Tuple[dict, str]:
|
| 485 |
+
if DIY_PIPE is None:
|
| 486 |
+
return {}, "DIY model not loaded"
|
| 487 |
+
# attempt 1: creative
|
| 488 |
+
prompt = diy_prompt(profile)
|
| 489 |
+
pr = truncate_prompt(DIY_PIPE, prompt, MAX_INPUT_TOKENS)
|
| 490 |
+
out = DIY_PIPE(pr, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=MAX_NEW_TOKENS, truncation=True)
|
| 491 |
+
if not isinstance(out, list): out=[out]
|
| 492 |
+
texts = [o.get("generated_text","") for o in out]
|
| 493 |
+
candidates = [try_parse_json(t) or {} for t in texts]
|
| 494 |
+
|
| 495 |
+
# pick first valid
|
| 496 |
+
for cand in candidates:
|
| 497 |
+
ok, errs = diy_validate(cand, profile)
|
| 498 |
+
if ok:
|
| 499 |
+
return cand, "ok"
|
| 500 |
+
last = cand
|
| 501 |
+
|
| 502 |
+
# repair loop (deterministic)
|
| 503 |
+
attempts = 1
|
| 504 |
+
while attempts < DIY_MAX_ATTEMPTS:
|
| 505 |
+
ok, errs = diy_validate(last, profile)
|
| 506 |
+
if ok:
|
| 507 |
+
return last, "ok"
|
| 508 |
+
fix_pr = diy_repair_prompt(profile, last, errs)
|
| 509 |
+
fix_pr = truncate_prompt(DIY_PIPE, fix_pr, MAX_INPUT_TOKENS)
|
| 510 |
+
fixed = DIY_PIPE(fix_pr, do_sample=False, max_new_tokens=MAX_NEW_TOKENS, truncation=True)
|
| 511 |
+
fixed = (fixed if isinstance(fixed, list) else [fixed])[0].get("generated_text","")
|
| 512 |
+
fixed = try_parse_json(fixed) or last
|
| 513 |
+
last = fixed
|
| 514 |
+
attempts += 1
|
| 515 |
+
return last, "partial"
|
| 516 |
+
|
| 517 |
+
# ---- Greeting generation ----
|
| 518 |
+
|
| 519 |
+
def greeting_prompt(profile: Dict) -> str:
|
| 520 |
+
tone = profile.get('tone','Heartfelt')
|
| 521 |
+
name = profile.get('recipient_name','Friend')
|
| 522 |
+
rel = profile.get('relationship','Friend')
|
| 523 |
+
occ = profile.get('occ_ui','Birthday')
|
| 524 |
+
ints = ", ".join(profile.get('interests', []))
|
| 525 |
+
age = profile.get('age_range','any')
|
| 526 |
+
gen = profile.get('gender','any')
|
| 527 |
+
return f"""
|
| 528 |
Write a short greeting (2–3 sentences) in English for a gift card.
|
| 529 |
+
Tone: {tone}
|
| 530 |
+
Recipient: {name} ({rel})
|
| 531 |
+
Occasion: {occ}
|
| 532 |
+
Interests: {ints}
|
| 533 |
+
Age group: {age}; Gender: {gen}
|
|
|
|
| 534 |
Avoid emojis.
|
| 535 |
"""
|
| 536 |
+
|
| 537 |
+
def llm_generate_message(profile: Dict) -> str:
|
| 538 |
+
if GREETING_PIPE is None:
|
| 539 |
+
return (f"Dear {profile.get('recipient_name','Friend')}, happy {profile.get('occ_ui','Birthday').lower()}! "
|
| 540 |
+
f"Wishing you joy and wonderful memories.")
|
| 541 |
+
pr = truncate_prompt(GREETING_PIPE, greeting_prompt(profile), MAX_INPUT_TOKENS)
|
| 542 |
+
out = GREETING_PIPE(pr, do_sample=False, max_new_tokens=90, truncation=True)
|
| 543 |
+
out = out if isinstance(out, list) else [out]
|
| 544 |
+
txt = out[0].get("generated_text","")
|
| 545 |
+
return txt.strip() or (f"Dear {profile.get('recipient_name','Friend')}, happy {profile.get('occ_ui','Birthday').lower()}!")
|
| 546 |
|
| 547 |
# --------------------- Image generation (SD-Turbo) ---------------------
|
| 548 |
+
|
| 549 |
def load_image_pipeline():
|
| 550 |
+
if not ENABLE_IMAGE:
|
| 551 |
+
return None
|
| 552 |
try:
|
| 553 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 554 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
|
| 561 |
|
| 562 |
IMG_PIPE = load_image_pipeline()
|
| 563 |
|
| 564 |
+
|
| 565 |
+
def generate_gift_image_from_diy(diy: Dict):
|
| 566 |
+
if IMG_PIPE is None or not diy:
|
| 567 |
return None
|
| 568 |
+
name = diy.get('gift_name','gift')
|
| 569 |
+
ov = diy.get('overview','product photo of handmade gift')
|
| 570 |
prompt = (
|
| 571 |
+
f"{name}: {ov}. Style: product photo, soft studio lighting, minimal background, realistic, high detail."
|
|
|
|
| 572 |
)
|
| 573 |
try:
|
| 574 |
img = IMG_PIPE(
|
|
|
|
| 583 |
return None
|
| 584 |
|
| 585 |
# --------------------- Rendering ---------------------
|
| 586 |
+
|
| 587 |
def md_escape(text: str) -> str:
|
| 588 |
return str(text).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
| 589 |
|
|
|
|
| 616 |
rows.append(card)
|
| 617 |
return "\n".join(rows)
|
| 618 |
|
| 619 |
+
|
| 620 |
+
def render_diy_md(d: Dict) -> str:
|
| 621 |
+
if not d:
|
| 622 |
+
return "<em>DIY generation failed.</em>"
|
| 623 |
+
name = md_escape(d.get("gift_name",""))
|
| 624 |
+
ov = md_escape(d.get("overview",""))
|
| 625 |
+
cost = d.get("estimated_cost_usd", "—")
|
| 626 |
+
mins = d.get("estimated_time_minutes", "—")
|
| 627 |
+
mats = d.get("materials_needed", [])
|
| 628 |
+
steps= d.get("step_by_step_instructions", [])
|
| 629 |
+
mats_md = "\n".join([f"- {md_escape(str(m))}" for m in mats]) if isinstance(mats, list) else "- —"
|
| 630 |
+
steps_md= "\n".join([f"{i+1}. {md_escape(str(s))}" for i,s in enumerate(steps)]) if isinstance(steps, list) else "1. —"
|
| 631 |
+
return f"""
|
| 632 |
+
### DIY Gift — {name}
|
| 633 |
+
|
| 634 |
+
{ov}
|
| 635 |
+
|
| 636 |
+
**Estimated cost:** ${cost} · **Estimated time:** {mins} min
|
| 637 |
+
|
| 638 |
+
**Materials needed:**
|
| 639 |
+
{mats_md}
|
| 640 |
+
|
| 641 |
+
**Step-by-step:**
|
| 642 |
+
{steps_md}
|
| 643 |
+
"""
|
| 644 |
+
|
| 645 |
# --------------------- Gradio UI ---------------------
|
| 646 |
CSS = """
|
| 647 |
#examples { order: 1; }
|
|
|
|
| 651 |
with gr.Blocks(css=CSS) as demo:
|
| 652 |
gr.Markdown(TITLE)
|
| 653 |
|
|
|
|
| 654 |
with gr.Column(elem_id="examples"):
|
| 655 |
gr.Markdown("### Quick examples")
|
| 656 |
|
| 657 |
with gr.Column(elem_id="form"):
|
| 658 |
with gr.Row():
|
| 659 |
+
recipient_name = gr.Textbox(label="Recipient name", value="Rotem")
|
| 660 |
+
relationship = gr.Dropdown(label="Relationship", choices=RECIPIENT_RELATIONSHIPS, value="Romantic partner")
|
| 661 |
|
| 662 |
with gr.Row():
|
| 663 |
interests = gr.CheckboxGroup(
|
| 664 |
label="Interests (select a few)", choices=INTEREST_OPTIONS,
|
| 665 |
+
value=["Reading","Fashion","Home decor"], interactive=True
|
| 666 |
)
|
| 667 |
|
| 668 |
with gr.Row():
|
| 669 |
+
occasion = gr.Dropdown(label="Occasion", choices=OCCASION_UI, value="Valentine’s Day")
|
| 670 |
age = gr.Dropdown(label="Age group", choices=list(AGE_OPTIONS.keys()), value="adult (18–64)")
|
| 671 |
+
gender = gr.Dropdown(label="Recipient gender", choices=GENDER_OPTIONS, value="female")
|
| 672 |
|
|
|
|
| 673 |
RangeSlider = getattr(gr, "RangeSlider", None)
|
| 674 |
if RangeSlider is not None:
|
| 675 |
+
budget_range = RangeSlider(label="Budget range (USD)", minimum=5, maximum=500, step=1, value=[30, 60])
|
| 676 |
budget_min, budget_max = None, None
|
| 677 |
else:
|
| 678 |
with gr.Row():
|
| 679 |
+
budget_min = gr.Slider(label="Min budget (USD)", minimum=5, maximum=500, step=1, value=30)
|
| 680 |
budget_max = gr.Slider(label="Max budget (USD)", minimum=5, maximum=500, step=1, value=60)
|
| 681 |
budget_range = gr.State(value=None)
|
| 682 |
|
| 683 |
+
tone = gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="Romantic")
|
| 684 |
|
| 685 |
go = gr.Button("Get GIfty 🎯")
|
| 686 |
|
| 687 |
out_top3 = gr.HTML(label="Top-3 recommendations")
|
| 688 |
+
out_diy_json = gr.JSON(label="DIY Gift (JSON)")
|
| 689 |
+
out_diy_md = gr.Markdown(label="DIY Gift (readable)")
|
| 690 |
+
out_gen_img = gr.Image(label="DIY Gift image", type="pil")
|
| 691 |
+
out_msg = gr.Markdown(label="Personalized message")
|
| 692 |
|
| 693 |
# examples (render on top via CSS)
|
| 694 |
if RangeSlider:
|
| 695 |
example_inputs = [interests, occasion, budget_range, recipient_name, relationship, age, gender, tone]
|
| 696 |
EXAMPLES = [
|
| 697 |
+
[["Reading","Fashion","Home decor"], "Valentine’s Day", [30,60], "Rotem", "Romantic partner", "adult (18–64)", "female", "Romantic"],
|
| 698 |
+
[["Technology","Movies"], "Birthday", [25,45], "Daniel", "Friend", "adult (18–64)", "male", "Funny"],
|
| 699 |
[["Gaming","Photography"], "Birthday", [30,120], "Omer", "Family - Sibling", "teen (13–17)", "male", "Playful"],
|
| 700 |
+
[["Home decor","Cooking"], "Housewarming", [25,45], "Noa", "Neighbor", "adult (18–64)", "any", "Appreciative"],
|
| 701 |
]
|
| 702 |
else:
|
| 703 |
example_inputs = [interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone]
|
| 704 |
EXAMPLES = [
|
| 705 |
+
[["Reading","Fashion","Home decor"], "Valentine’s Day", 30, 60, "Rotem", "Romantic partner", "adult (18–64)", "female", "Romantic"],
|
| 706 |
+
[["Technology","Movies"], "Birthday", 25, 45, "Daniel", "Friend", "adult (18–64)", "male", "Funny"],
|
| 707 |
[["Gaming","Photography"], "Birthday", 30, 120, "Omer", "Family - Sibling", "teen (13–17)", "male", "Playful"],
|
| 708 |
+
[["Home decor","Cooking"], "Housewarming", 25, 45, "Noa", "Neighbor", "adult (18–64)", "any", "Appreciative"],
|
| 709 |
]
|
| 710 |
|
| 711 |
with gr.Column(elem_id="examples"):
|
|
|
|
| 760 |
top3 = recommend_topk(profile, k=3)
|
| 761 |
top3_html = render_top3_html(top3)
|
| 762 |
|
| 763 |
+
# DIY gift (generate-from-scratch, JSON)
|
| 764 |
+
diy_json, diy_status = diy_generate(profile)
|
| 765 |
+
diy_md = render_diy_md(diy_json)
|
| 766 |
+
|
| 767 |
+
# DIY image (optional)
|
| 768 |
+
diy_img = generate_gift_image_from_diy(diy_json)
|
| 769 |
|
| 770 |
# greeting
|
| 771 |
msg = llm_generate_message(profile)
|
| 772 |
|
| 773 |
+
return top3_html, diy_json, diy_md, diy_img, msg
|
| 774 |
|
| 775 |
if RangeSlider:
|
| 776 |
go.click(
|
| 777 |
ui_predict,
|
| 778 |
[interests, occasion, budget_range, recipient_name, relationship, age, gender, tone],
|
| 779 |
+
[out_top3, out_diy_json, out_diy_md, out_gen_img, out_msg]
|
| 780 |
)
|
| 781 |
else:
|
| 782 |
go.click(
|
| 783 |
ui_predict,
|
| 784 |
[interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
| 785 |
+
[out_top3, out_diy_json, out_diy_md, out_gen_img, out_msg]
|
| 786 |
)
|
| 787 |
|
| 788 |
if __name__ == "__main__":
|