Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,25 @@
|
|
| 1 |
# app.py โ GIftyPlus (lean)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os, re, json, hashlib, pathlib, random
|
| 3 |
from typing import Dict, List, Tuple, Optional, Any
|
| 4 |
import numpy as np, pandas as pd, gradio as gr, torch
|
|
@@ -13,6 +34,8 @@ MAX_ROWS = int(os.getenv("MAX_ROWS", "12000"))
|
|
| 13 |
EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "sentence-transformers/all-MiniLM-L12-v2")
|
| 14 |
|
| 15 |
def resolve_cache_dir():
|
|
|
|
|
|
|
| 16 |
for p in [os.getenv("EMBED_CACHE_DIR"), os.path.join(os.getcwd(), ".gifty_cache"), "/tmp/.gifty_cache"]:
|
| 17 |
if not p: continue
|
| 18 |
pathlib.Path(p).mkdir(parents=True, exist_ok=True)
|
|
@@ -22,6 +45,7 @@ def resolve_cache_dir():
|
|
| 22 |
return os.getcwd()
|
| 23 |
EMBED_CACHE_DIR = resolve_cache_dir()
|
| 24 |
|
|
|
|
| 25 |
INTEREST_OPTIONS = ["Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion","Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food","Home decor","Science"]
|
| 26 |
OCCASION_UI = ["Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming","Retirement","Holidays","Valentineโs Day","Promotion / New job","Get well soon"]
|
| 27 |
OCCASION_CANON = {"Birthday":"birthday","Wedding / Engagement":"wedding","Anniversary":"anniversary","Graduation":"graduation","New baby":"new_baby","Housewarming":"housewarming","Retirement":"retirement","Holidays":"holidays","Valentineโs Day":"valentines","Promotion / New job":"promotion","Get well soon":"get_well"}
|
|
@@ -30,9 +54,11 @@ MESSAGE_TONES = ["Formal","Casual","Funny","Heartfelt","Inspirational","Playful"
|
|
| 30 |
AGE_OPTIONS = {"any":"any","kid (3โ12)":"kids","teen (13โ17)":"teens","adult (18โ64)":"adult","senior (65+)":"senior"}
|
| 31 |
GENDER_OPTIONS = ["any","female","male","nonbinary"]
|
| 32 |
|
|
|
|
| 33 |
SYNONYMS = {"sports":["fitness","outdoor","training","yoga","run"],"travel":["luggage","passport","map","trip","vacation"],"cooking":["kitchen","cookware","chef","baking"],"technology":["electronics","gadgets","device","smart","computer"],"music":["audio","headphones","earbuds","speaker","vinyl"],"art":["painting","drawing","sketch","canvas"],"reading":["book","novel","literature"],"gardening":["plants","planter","seeds","garden","indoor"],"fashion":["style","accessory","jewelry"],"gaming":["board game","puzzle","video game","controller"],"photography":["camera","lens","tripod","film"],"hiking":["outdoor","camping","backpack","trek"],"movies":["film","cinema","blu-ray","poster"],"crafts":["diy","handmade","kit","knitting"],"pets":["dog","cat","pet"],"wellness":["relaxation","spa","aromatherapy","self-care"],"collecting":["display","collector","limited edition"],"food":["gourmet","snack","treats","chocolate"],"home decor":["home","decor","wall art","candle"],"science":["lab","experiment","STEM","microscope"]}
|
| 34 |
REL_TO_TOKENS = {"Family - Parent":["parent","family"],"Family - Sibling":["sibling","family"],"Family - Child":["kids","play","family"],"Family - Other relative":["family","relative"],"Friend":["friendly"],"Colleague":["office","work","professional"],"Boss":["executive","professional","premium"],"Romantic partner":["romantic","couple"],"Teacher / Mentor":["teacher","mentor","thank_you"],"Neighbor":["neighbor","housewarming"],"Client / Business partner":["professional","thank_you","premium"]}
|
| 35 |
|
|
|
|
| 36 |
_CURRENCY_RE = re.compile(r"[^\d.,\-]+"); _NUM_RE = re.compile(r"(\d+(?:[.,]\d+)?)"); _RANGE_SEP = re.compile(r"\s*(?:-|โ|โ|to)\s*")
|
| 37 |
def _to_price_usd(x):
|
| 38 |
if pd.isna(x): return np.nan
|
|
@@ -42,6 +68,7 @@ def _to_price_usd(x):
|
|
| 42 |
return float(m.group(1)) if m else np.nan
|
| 43 |
|
| 44 |
def _first_present(df, cands):
|
|
|
|
| 45 |
lower = {c.lower(): c for c in df.columns}
|
| 46 |
for c in cands:
|
| 47 |
if c in df.columns: return c
|
|
@@ -49,6 +76,7 @@ def _first_present(df, cands):
|
|
| 49 |
return None
|
| 50 |
|
| 51 |
def _auto_price_col(df):
|
|
|
|
| 52 |
for c in df.columns:
|
| 53 |
s = df[c]
|
| 54 |
if pd.api.types.is_numeric_dtype(s) and not s.dropna().empty and (s.dropna().between(0.5, 10000)).mean() > .6: return c
|
|
@@ -57,21 +85,25 @@ def _auto_price_col(df):
|
|
| 57 |
return None
|
| 58 |
|
| 59 |
def map_amazon_to_schema(raw: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
| 60 |
name_c=_first_present(raw,["product name","title","name","product_title"]); desc_c=_first_present(raw,["description","product_description","feature","about"])
|
| 61 |
cat_c=_first_present(raw,["category","categories","main_cat","product_category"]); price_c=_first_present(raw,["selling price","price","current_price","list_price","price_amount","actual_price","price_usd"]) or _auto_price_col(raw)
|
| 62 |
img_c=_first_present(raw,["image","image_url","imageurl","imUrl","img","img_url"])
|
| 63 |
df=pd.DataFrame({"name":raw.get(name_c,""),"short_desc":raw.get(desc_c,""),"tags":raw.get(cat_c,""),"price_usd":raw.get(price_c,np.nan),"image_url":raw.get(img_c,"")})
|
|
|
|
| 64 |
df["price_usd"]=df["price_usd"].map(_to_price_usd); df["name"]=df["name"].astype(str).str.strip().str.slice(0,160)
|
| 65 |
df["short_desc"]=df["short_desc"].astype(str).str.strip().str.slice(0,600); df["tags"]=df["tags"].astype(str).str.replace("|",", ").str.lower()
|
| 66 |
return df
|
| 67 |
|
| 68 |
def extract_top_cat(tags:str)->str:
|
|
|
|
| 69 |
s=(tags or "").lower()
|
| 70 |
for sep in ["|",">"]:
|
| 71 |
if sep in s: return s.split(sep,1)[0].strip()
|
| 72 |
return s.strip().split(",")[0] if s else ""
|
| 73 |
|
| 74 |
def load_catalog()->pd.DataFrame:
|
|
|
|
| 75 |
df=map_amazon_to_schema(load_dataset(DATASET_ID, split=DATASET_SPLIT).to_pandas()).drop_duplicates(subset=["name","short_desc"])
|
| 76 |
df=df[pd.notna(df["price_usd"])]; df=df[(df["price_usd"]>0)&(df["price_usd"]<=500)].reset_index(drop=True)
|
| 77 |
if len(df)>MAX_ROWS: df=df.sample(n=MAX_ROWS,random_state=42).reset_index(drop=True)
|
|
@@ -81,6 +113,9 @@ def load_catalog()->pd.DataFrame:
|
|
| 81 |
return df
|
| 82 |
CATALOG=load_catalog()
|
| 83 |
|
|
|
|
|
|
|
|
|
|
| 84 |
class EmbeddingBank:
|
| 85 |
def __init__(s, docs, model_id, dataset_tag):
|
| 86 |
s.model_id=model_id; s.dataset_tag=dataset_tag; s.model=SentenceTransformer(model_id); s.embs=s._load_or_build(docs)
|
|
@@ -95,10 +130,12 @@ class EmbeddingBank:
|
|
| 95 |
def query_vec(s,text): return s.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)[0]
|
| 96 |
EMB=EmbeddingBank(CATALOG["doc"].tolist(), EMBED_MODEL_ID, DATASET_ID)
|
| 97 |
|
|
|
|
| 98 |
_tok_rx = re.compile(r"[a-z0-9][a-z0-9\-']*")
|
| 99 |
if "tok_set" not in CATALOG.columns:
|
| 100 |
CATALOG["tok_set"]=(CATALOG["name"].fillna("")+" "+CATALOG["tags"].fillna("")+" "+CATALOG["short_desc"].fillna("")).map(lambda t:set(_tok_rx.findall(str(t).lower())))
|
| 101 |
|
|
|
|
| 102 |
try:
|
| 103 |
from sentence_transformers import CrossEncoder
|
| 104 |
except:
|
|
@@ -111,6 +148,7 @@ def _load_cross_encoder():
|
|
| 111 |
_CE_MODEL=CrossEncoder(RERANK_MODEL_ID, device="cpu")
|
| 112 |
return _CE_MODEL
|
| 113 |
|
|
|
|
| 114 |
OCCASION_PRIORS={"valentines":[("jewelry",.12),("chocolate",.10),("candle",.08),("romantic",.08),("couple",.08),("heart",.06)],
|
| 115 |
"birthday":[("fun",.06),("game",.06),("personalized",.06),("gift set",.05),("surprise",.04)],
|
| 116 |
"anniversary":[("couple",.10),("jewelry",.10),("photo",.08),("frame",.06),("memory",.06),("candle",.06)],
|
|
@@ -123,6 +161,7 @@ OCCASION_PRIORS={"valentines":[("jewelry",.12),("chocolate",.10),("candle",.08),
|
|
| 123 |
"get_well":[("cozy",.10),("tea",.08),("soothing",.06),("care",.06)]}
|
| 124 |
|
| 125 |
def expand_with_synonyms(tokens: List[str])->List[str]:
|
|
|
|
| 126 |
out=[];
|
| 127 |
for t in tokens:
|
| 128 |
t=t.strip().lower()
|
|
@@ -130,12 +169,14 @@ def expand_with_synonyms(tokens: List[str])->List[str]:
|
|
| 130 |
return out
|
| 131 |
|
| 132 |
def profile_to_query(p:Dict)->str:
|
|
|
|
| 133 |
inter=[i.lower() for i in p.get("interests",[]) if i]; expanded=expand_with_synonyms(inter)*3
|
| 134 |
parts=[", ".join(expanded) if expanded else "", ", ".join(REL_TO_TOKENS.get(p.get("relationship","Friend"),[])), OCCASION_CANON.get(p.get("occ_ui","Birthday"),"birthday")]
|
| 135 |
tail=f"gift ideas for a {p.get('relationship','Friend')} for {parts[-1]}; likes {', '.join(inter) or 'general'}"
|
| 136 |
return " | ".join([x for x in parts if x])+" | "+tail
|
| 137 |
|
| 138 |
def _gender_ok_mask(g:str)->np.ndarray:
|
|
|
|
| 139 |
g=(g or "any").lower(); bl=CATALOG["blob"]
|
| 140 |
has_m=bl.str.contains(r"\b(men|man's|mens|male|for men)\b",regex=True,na=False)
|
| 141 |
has_f=bl.str.contains(r"\b(women|woman's|womens|female|for women|dress)\b",regex=True,na=False)
|
|
@@ -145,6 +186,7 @@ def _gender_ok_mask(g:str)->np.ndarray:
|
|
| 145 |
return np.ones(len(bl),bool)
|
| 146 |
|
| 147 |
def _mask_by_age(age:str, blob:pd.Series)->np.ndarray:
|
|
|
|
| 148 |
kids=blob.str.contains(r"\b(?:kid|kids|child|children|toddler|baby|boys?|girls?|kid's|children's)\b",regex=True,na=False)
|
| 149 |
teen=blob.str.contains(r"\b(?:teen|teens|young adult|ya)\b",regex=True,na=False)
|
| 150 |
if age in ("adult","senior"): return (~kids).to_numpy()
|
|
@@ -153,11 +195,13 @@ def _mask_by_age(age:str, blob:pd.Series)->np.ndarray:
|
|
| 153 |
return np.ones(len(blob),bool)
|
| 154 |
|
| 155 |
def _interest_bonus(p:Dict, idx:np.ndarray)->np.ndarray:
|
|
|
|
| 156 |
ints=[i.lower() for i in p.get("interests",[]) if i]; syns=[s for it in ints for s in SYNONYMS.get(it,[])]; vocab=set(ints+syns)
|
| 157 |
if not vocab or idx.size==0: return np.zeros(len(idx),"float32")
|
| 158 |
counts=np.array([len(CATALOG["tok_set"].iat[i] & vocab) for i in idx],"float32"); return .10*np.clip(counts,0,6)
|
| 159 |
|
| 160 |
def _occasion_bonus(idx:np.ndarray, occ_ui:str)->np.ndarray:
|
|
|
|
| 161 |
pri=OCCASION_PRIORS.get(OCCASION_CANON.get(occ_ui or "Birthday","birthday"),[])
|
| 162 |
if not pri or idx.size==0: return np.zeros(len(idx),"float32")
|
| 163 |
bl=CATALOG["blob"].to_numpy(); out=np.zeros(len(idx),"float32")
|
|
@@ -166,11 +210,13 @@ def _occasion_bonus(idx:np.ndarray, occ_ui:str)->np.ndarray:
|
|
| 166 |
return out
|
| 167 |
|
| 168 |
def _minmax(x:np.ndarray)->np.ndarray:
|
|
|
|
| 169 |
if x.size==0: return x
|
| 170 |
lo,hi=float(np.min(x)),float(np.max(x));
|
| 171 |
return np.zeros_like(x) if hi<=lo+1e-9 else (x-lo)/(hi-lo)
|
| 172 |
|
| 173 |
def _mmr_select(cand_idx:np.ndarray, scores:np.ndarray, k:int, lambda_:float=.7)->np.ndarray:
|
|
|
|
| 174 |
if cand_idx.size<=k: return cand_idx[np.argsort(-scores)][:k]
|
| 175 |
picked=[]; rest=list(range(len(cand_idx))); rel=_minmax(scores)
|
| 176 |
V=np.asarray(EMB.embs,"float32")[cand_idx]; V/=np.linalg.norm(V,axis=1,keepdims=True)+1e-8
|
|
@@ -180,44 +226,154 @@ def _mmr_select(cand_idx:np.ndarray, scores:np.ndarray, k:int, lambda_:float=.7)
|
|
| 180 |
j=int(np.argmax(lambda_*rel[rest]-(1-lambda_)*sim_to_sel)); picked.append(rest.pop(j))
|
| 181 |
return cand_idx[np.array(picked,int)]
|
| 182 |
|
| 183 |
-
def recommend_top3_budget_first(
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
q=profile_to_query(p); qv=EMB.query_vec(q).astype("float32")
|
| 195 |
emb_sims=np.asarray(EMB.embs,"float32")[idx]@qv
|
| 196 |
target=(lo+hi)/2.0 if hi>lo else hi; prices=CATALOG.iloc[idx]["price_usd"].to_numpy()
|
|
|
|
| 197 |
price_bonus=np.clip(.12-np.abs(prices-target)/max(target,1.0),0,.12).astype("float32")
|
| 198 |
int_bonus=_interest_bonus(p,idx); occ_bonus=_occasion_bonus(idx,p.get("occ_ui","Birthday"))
|
| 199 |
pre=emb_sims+price_bonus+int_bonus+occ_bonus
|
|
|
|
| 200 |
K1=min(48,idx.size); top_local=np.argpartition(-pre,K1-1)[:K1]; cand_idx=idx[top_local]
|
| 201 |
emb_n=_minmax(emb_sims[top_local]); price_n=_minmax(price_bonus[top_local]); int_n=_minmax(int_bonus[top_local]); occ_n=_minmax(occ_bonus[top_local])
|
| 202 |
ce=_load_cross_encoder();
|
| 203 |
if ce is not None:
|
|
|
|
| 204 |
docs=CATALOG.loc[cand_idx,"doc"].tolist(); pairs=[(q,d) for d in docs]
|
| 205 |
k_ce=min(24,len(pairs)); tl=np.argpartition(-emb_n,k_ce-1)[:k_ce]; ce_raw=np.array(ce.predict([pairs[i] for i in tl]),"float32"); ce_n=np.zeros_like(emb_n); ce_n[tl]=_minmax(ce_raw)
|
| 206 |
else:
|
| 207 |
ce_n=np.zeros_like(emb_n)
|
|
|
|
| 208 |
final=(.56*emb_n+.26*ce_n+.10*int_n+.05*occ_n+.03*price_n).astype("float32")
|
| 209 |
pick=_mmr_select(cand_idx,final,k=min(3,cand_idx.size))
|
| 210 |
res=CATALOG.loc[pick].copy(); pos={int(cand_idx[i]):i for i in range(len(cand_idx))}; res["similarity"]=[float(final[pos[int(i)]]) for i in pick]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
return res[["name","short_desc","price_usd","image_url","similarity"]].reset_index(drop=True)
|
| 212 |
|
| 213 |
# ===== DIY (FLAN-only) =====
|
| 214 |
DIY_MODEL_ID=os.getenv("DIY_MODEL_ID","google/flan-t5-small"); DIY_DEVICE=torch.device("cpu")
|
| 215 |
MAX_INPUT_TOKENS=int(os.getenv("MAX_INPUT_TOKENS","384")); DIY_MAX_NEW_TOKENS=int(os.getenv("DIY_MAX_NEW_TOKENS","120"))
|
|
|
|
| 216 |
INTEREST_ALIASES={"Reading":["book","novel","literary"],"Fashion":["style","chic","silk"],"Home decor":["candle","wall","jar"],"Technology":["tech","gadget","usb"],"Movies":["film","cinema","poster"]}
|
| 217 |
FALLBACK_NOUNS=["Kit","Set","Bundle","Box","Pack"]
|
| 218 |
|
| 219 |
_diy_cache_model={}
|
| 220 |
def _load_flan(mid:str):
|
|
|
|
| 221 |
if mid in _diy_cache_model: return _diy_cache_model[mid]
|
| 222 |
tok=AutoTokenizer.from_pretrained(mid, use_fast=True, trust_remote_code=True)
|
| 223 |
mdl=AutoModelForSeq2SeqLM.from_pretrained(mid, trust_remote_code=True, use_safetensors=True).to(DIY_DEVICE).eval()
|
|
@@ -225,6 +381,7 @@ def _load_flan(mid:str):
|
|
| 225 |
|
| 226 |
@torch.inference_mode()
|
| 227 |
def _gen(tok, mdl, prompt, max_new_tokens=64, do_sample=False, temperature=.9, top_p=.95, seed=None):
|
|
|
|
| 228 |
if seed is None: seed=random.randint(1,10_000_000)
|
| 229 |
random.seed(seed); torch.manual_seed(seed)
|
| 230 |
enc=tok(prompt, truncation=True, max_length=MAX_INPUT_TOKENS, return_tensors="pt"); enc={k:v.to(DIY_DEVICE) for k,v in enc.items()}
|
|
@@ -232,11 +389,13 @@ def _gen(tok, mdl, prompt, max_new_tokens=64, do_sample=False, temperature=.9, t
|
|
| 232 |
return tok.decode(out[0], skip_special_tokens=True).strip()
|
| 233 |
|
| 234 |
def _choose_interest_token(interests):
|
|
|
|
| 235 |
for it in interests:
|
| 236 |
if INTEREST_ALIASES.get(it): return random.choice(INTEREST_ALIASES[it])
|
| 237 |
return (interests[0].split()[0].lower() if interests else "gift")
|
| 238 |
def _title_case(s): s=re.sub(r'\s+',' ',s).strip(); s=re.sub(r'["โโโโ]+','',s); return " ".join([w.capitalize() for w in s.split()])
|
| 239 |
def _sanitize_name(name, interests):
|
|
|
|
| 240 |
for b in [r"^the name\b",r"\bmember of the family\b",r"^name\b",r"^title\b"]: name=re.sub(b,"",name,flags=re.I).strip()
|
| 241 |
name=re.sub(r'[:\-โโ]+$',"",name).strip(); alias=_choose_interest_token(interests)
|
| 242 |
if alias not in name.lower():
|
|
@@ -247,6 +406,7 @@ def _sanitize_name(name, interests):
|
|
| 247 |
return name
|
| 248 |
|
| 249 |
def _split_list_text(s,seps):
|
|
|
|
| 250 |
s=s.strip()
|
| 251 |
for sep in seps:
|
| 252 |
if sep in s:
|
|
@@ -255,6 +415,7 @@ def _split_list_text(s,seps):
|
|
| 255 |
return [p.strip(" -โข*.,;:") for p in re.split(r"[\n\r;]+", s) if p.strip(" -โข*.,;:")]
|
| 256 |
|
| 257 |
def _coerce_materials(items):
|
|
|
|
| 258 |
out=[]
|
| 259 |
for it in items:
|
| 260 |
it=re.sub(r'\s+',' ',it).strip(" -โข*.,;:");
|
|
@@ -271,6 +432,7 @@ def _coerce_materials(items):
|
|
| 271 |
return out[:8]
|
| 272 |
|
| 273 |
def _coerce_steps(items):
|
|
|
|
| 274 |
out=[]
|
| 275 |
for it in items:
|
| 276 |
it=it.strip(" -โข*.,;:");
|
|
@@ -284,10 +446,12 @@ def _coerce_steps(items):
|
|
| 284 |
|
| 285 |
def _only_int(s): m=re.search(r"-?\d+",s); return int(m.group()) if m else None
|
| 286 |
def _clamp_num(v,lo,hi,default):
|
|
|
|
| 287 |
try: x=float(v); return int(min(max(x,lo),hi))
|
| 288 |
except: return int((lo+hi)/2 if default is None else default)
|
| 289 |
|
| 290 |
def diy_generate(profile:Dict)->Tuple[dict,str]:
|
|
|
|
| 291 |
tok,mdl=_load_flan(DIY_MODEL_ID)
|
| 292 |
p={"recipient_name":profile.get("recipient_name","Recipient"),"relationship":profile.get("relationship","Friend"),
|
| 293 |
"occ_ui":profile.get("occ_ui","Birthday"),"occasion":profile.get("occ_ui","Birthday"),"interests":profile.get("interests",[]),
|
|
@@ -311,8 +475,33 @@ def diy_generate(profile:Dict)->Tuple[dict,str]:
|
|
| 311 |
"estimated_cost_usd":_clamp_num(cost,p["budget_min"],p["budget_max"],None),"estimated_time_minutes":_clamp_num(minutes,20,180,60)}
|
| 312 |
return idea,"ok"
|
| 313 |
|
| 314 |
-
|
| 315 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
MSG_MODEL_ID = "google/flan-t5-small"
|
| 317 |
MSG_DEVICE = "cpu"
|
| 318 |
TEMP_RANGE = (0.88, 1.10)
|
|
@@ -414,6 +603,7 @@ CLOSERS = [
|
|
| 414 |
]
|
| 415 |
|
| 416 |
def _msg_load():
|
|
|
|
| 417 |
global _msg_tok, _msg_mdl
|
| 418 |
if _msg_tok is None or _msg_mdl is None:
|
| 419 |
_msg_tok = AutoTokenizer.from_pretrained(MSG_MODEL_ID)
|
|
@@ -422,16 +612,20 @@ def _msg_load():
|
|
| 422 |
return _msg_tok, _msg_mdl
|
| 423 |
|
| 424 |
def _norm(s: str) -> str:
|
|
|
|
| 425 |
return re.sub(r"\s+", " ", s or "").strip()
|
| 426 |
|
| 427 |
def _sentences_n(s: str) -> int:
|
|
|
|
| 428 |
return len([p for p in re.split(r"(?<=[.!?])\s+", s.strip()) if p])
|
| 429 |
|
| 430 |
def _contains_any(text: str, terms: List[str]) -> bool:
|
|
|
|
| 431 |
t = text.lower()
|
| 432 |
return any(term for term in terms if term) and any((term or "").lower() in t for term in terms)
|
| 433 |
|
| 434 |
def _too_similar(a: str, b: str, n=3, thr=0.85) -> bool:
|
|
|
|
| 435 |
def ngrams(txt):
|
| 436 |
toks = re.findall(r"[a-zA-Z']+", txt.lower())
|
| 437 |
return set(tuple(toks[i:i+n]) for i in range(max(0, len(toks)-n+1)))
|
|
@@ -441,9 +635,11 @@ def _too_similar(a: str, b: str, n=3, thr=0.85) -> bool:
|
|
| 441 |
return j >= thr
|
| 442 |
|
| 443 |
def _clean_occasion(occ: str) -> str:
|
|
|
|
| 444 |
return (occ or "").replace("โ","'").strip()
|
| 445 |
|
| 446 |
def _build_prompt(profile: Dict[str, Any]) -> Tuple[str, Dict[str,str]]:
|
|
|
|
| 447 |
name = profile.get("recipient_name", "Friend")
|
| 448 |
rel = profile.get("relationship", "Friend")
|
| 449 |
occ = _clean_occasion(profile.get("occ_ui") or profile.get("occasion") or "Birthday")
|
|
@@ -476,6 +672,7 @@ def _build_prompt(profile: Dict[str, Any]) -> Tuple[str, Dict[str,str]]:
|
|
| 476 |
|
| 477 |
@torch.inference_mode()
|
| 478 |
def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None, previous_message: Optional[str]=None) -> Dict[str, Any]:
|
|
|
|
| 479 |
global _last_msg
|
| 480 |
tok, mdl = _msg_load()
|
| 481 |
if seed is None:
|
|
@@ -501,7 +698,7 @@ def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None,
|
|
| 501 |
)
|
| 502 |
text = _norm(tok.decode(out_ids[0], skip_special_tokens=True))
|
| 503 |
|
| 504 |
-
# ===== Validators (
|
| 505 |
ok_len = 1 <= _sentences_n(text) <= 3
|
| 506 |
name_ok = _contains_any(text, [need["name"].lower()])
|
| 507 |
occ_ok = _contains_any(text, [need["occ"].lower(), need["occ"].split()[0].lower()])
|
|
@@ -516,6 +713,7 @@ def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None,
|
|
| 516 |
"seed": seed, "attempt": attempt, "model": MSG_MODEL_ID}}
|
| 517 |
tried.append({"text": text}); seed += 17
|
| 518 |
|
|
|
|
| 519 |
fallback = tried[-1]["text"] if tried else f"Happy {(_clean_occasion(profile.get('occ_ui') or 'day')).lower()}, {profile.get('recipient_name','Friend')}!"
|
| 520 |
_last_msg = fallback
|
| 521 |
return {"message": fallback, "meta": {"failed": True, "model": MSG_MODEL_ID, "tone": profile.get("tone","Heartfelt")}}
|
|
@@ -524,23 +722,26 @@ def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None,
|
|
| 524 |
|
| 525 |
# ===== Rendering & UI =====
|
| 526 |
def first_sentence(s,max_chars=140):
|
|
|
|
| 527 |
s=(s or "").strip();
|
| 528 |
if not s: return ""
|
| 529 |
cut=s.split(". ")[0];
|
| 530 |
return cut if len(cut)<=max_chars else cut[:max_chars-1]+"โฆ"
|
| 531 |
|
| 532 |
def render_top3_html(df, age_label):
|
|
|
|
| 533 |
if df is None or df.empty: return "<em>No results found within the current filters.</em>"
|
| 534 |
rows=[]
|
| 535 |
-
for
|
| 536 |
name=str(r.get("name","")).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
| 537 |
desc=str(first_sentence(r.get("short_desc",""))).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
| 538 |
price=r.get("price_usd"); sim=r.get("similarity"); img=r.get("image_url","") or ""
|
| 539 |
price_str=f"${price:.0f}" if pd.notna(price) else "N/A"; sim_str=f"{sim:.3f}" if pd.notna(sim) else "โ"
|
| 540 |
img_html=f'<img src="{img}" alt="" style="width:84px;height:84px;object-fit:cover;border-radius:10px;margin-left:12px;" />' if img else ""
|
|
|
|
| 541 |
rows.append(f"""
|
| 542 |
<div style="display:flex;align-items:flex-start;justify-content:space-between;gap:10px;padding:10px;border:1px solid #eee;border-radius:12px;margin-bottom:8px;background:#fff;">
|
| 543 |
-
<div style="flex:1;min-width:0;"><div style="font-weight:700;">{name}</div>
|
| 544 |
<div style="font-size:0.95em;margin-top:4px;">{desc}</div>
|
| 545 |
<div style="font-size:0.9em;margin-top:6px;opacity:0.8;">Price: <b>{price_str}</b> ยท Age: <code>{age_label}</code> ยท Score: <code>{sim_str}</code></div>
|
| 546 |
</div>{img_html}
|
|
@@ -586,17 +787,20 @@ with gr.Blocks(title="๐ GIfty โ Recommender + DIY", css="""
|
|
| 586 |
tone=gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="Funny")
|
| 587 |
|
| 588 |
go=gr.Button("Get GIfty!")
|
|
|
|
| 589 |
gr.Markdown("### ๐ฏ Recommendations"); out_top3=gr.HTML()
|
| 590 |
gr.Markdown("### ๐ ๏ธ DIY Gift"); out_diy_md=gr.Markdown()
|
| 591 |
gr.Markdown("### ๐ Personalized Message"); out_msg=gr.Markdown()
|
| 592 |
run_token=gr.State(0)
|
| 593 |
|
| 594 |
def _on_example_select(evt: gr.SelectData):
|
|
|
|
| 595 |
r=int(evt.index[0] if isinstance(evt.index,(list,tuple)) else evt.index); row=EX_DF.iloc[r]; ints=[s.strip() for s in str(row["Interests"]).split("+")]
|
| 596 |
return (ints,row["Occasion"],int(row["Min $"]),int(row["Max $"]),row["Recipient"],row["Relationship"],row["Age group"],row["Gender"],row["Tone"])
|
| 597 |
ex_df.select(_on_example_select, outputs=[interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone])
|
| 598 |
|
| 599 |
def render_diy_md(j:dict)->str:
|
|
|
|
| 600 |
if not j: return "_DIY generation failed._"
|
| 601 |
steps=j.get('step_by_step_instructions', j.get('steps', []))
|
| 602 |
parts = [
|
|
@@ -607,26 +811,105 @@ with gr.Blocks(title="๐ GIfty โ Recommender + DIY", css="""
|
|
| 607 |
f"**Estimated cost:** ${j.get('estimated_cost_usd','?')} ยท **Time:** {j.get('estimated_time_minutes','?')} min"
|
| 608 |
]
|
| 609 |
return "\n".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
def _build_profile(ints, occ, bmin, bmax, name, rel, age_label, gender_val, tone_val):
|
|
|
|
| 612 |
try: bmin=float(bmin); bmax=float(bmax)
|
| 613 |
except: bmin,bmax=5.0,500.0
|
| 614 |
if bmin>bmax: bmin,bmax=bmax,bmin
|
| 615 |
return {"recipient_name":name or "Friend","relationship":rel or "Friend","interests":ints or [],"occ_ui":occ or "Birthday","budget_min":bmin,"budget_max":bmax,"age_range":AGE_OPTIONS.get(age_label,"any"),"gender":(gender_val or "any").lower(),"tone":tone_val or "Heartfelt"}
|
| 616 |
|
| 617 |
-
def start_run(curr):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
|
| 619 |
def predict_recs_only(rt, *args):
|
| 620 |
-
p=_build_profile(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
def predict_diy_only(rt, *args):
|
| 622 |
-
p=_build_profile(*args)
|
| 623 |
-
|
| 624 |
-
|
| 625 |
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
|
| 631 |
if __name__=="__main__":
|
| 632 |
demo.launch()
|
|
|
|
| 1 |
# app.py โ GIftyPlus (lean)
|
| 2 |
+
# -----------------------------------------------------------------------------
|
| 3 |
+
# High-level overview
|
| 4 |
+
# -----------------------------------------------------------------------------
|
| 5 |
+
# GIftyPlus is a lightweight gift recommender + DIY generator.
|
| 6 |
+
# Pipeline:
|
| 7 |
+
# 1) Load & normalize an Amazon-like product dataset (name/desc/tags/price/img).
|
| 8 |
+
# 2) Build sentence embeddings for semantic retrieval (cached to .npy).
|
| 9 |
+
# 3) Rank items with a weighted score (embeddings + optional cross-encoder +
|
| 10 |
+
# interest/occasion/price bonuses) and diversify with MMR.
|
| 11 |
+
# 4) Generate a DIY gift idea (FLAN-T5), then embed 10 candidates and append
|
| 12 |
+
# the best one as a "Generated" #4 result.
|
| 13 |
+
# 5) Generate a short personalized message (FLAN-T5) with basic validators.
|
| 14 |
+
# 6) Gradio UI: input form, input summary, top-3 + generated #4, DIY section,
|
| 15 |
+
# and personalized message section.
|
| 16 |
+
#
|
| 17 |
+
# Env vars you can override:
|
| 18 |
+
# DATASET_ID, DATASET_SPLIT, MAX_ROWS,
|
| 19 |
+
# EMBED_MODEL_ID, RERANK_MODEL_ID,
|
| 20 |
+
# DIY_MODEL_ID, MAX_INPUT_TOKENS, DIY_MAX_NEW_TOKENS.
|
| 21 |
+
# -----------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
import os, re, json, hashlib, pathlib, random
|
| 24 |
from typing import Dict, List, Tuple, Optional, Any
|
| 25 |
import numpy as np, pandas as pd, gradio as gr, torch
|
|
|
|
| 34 |
EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "sentence-transformers/all-MiniLM-L12-v2")
|
| 35 |
|
| 36 |
def resolve_cache_dir():
|
| 37 |
+
# Choose the first writable cache directory:
|
| 38 |
+
# 1) EMBED_CACHE_DIR env, 2) project .gifty_cache, 3) /tmp/.gifty_cache
|
| 39 |
for p in [os.getenv("EMBED_CACHE_DIR"), os.path.join(os.getcwd(), ".gifty_cache"), "/tmp/.gifty_cache"]:
|
| 40 |
if not p: continue
|
| 41 |
pathlib.Path(p).mkdir(parents=True, exist_ok=True)
|
|
|
|
| 45 |
return os.getcwd()
|
| 46 |
EMBED_CACHE_DIR = resolve_cache_dir()
|
| 47 |
|
| 48 |
+
# UI vocab / options
|
| 49 |
INTEREST_OPTIONS = ["Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion","Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food","Home decor","Science"]
|
| 50 |
OCCASION_UI = ["Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming","Retirement","Holidays","Valentineโs Day","Promotion / New job","Get well soon"]
|
| 51 |
OCCASION_CANON = {"Birthday":"birthday","Wedding / Engagement":"wedding","Anniversary":"anniversary","Graduation":"graduation","New baby":"new_baby","Housewarming":"housewarming","Retirement":"retirement","Holidays":"holidays","Valentineโs Day":"valentines","Promotion / New job":"promotion","Get well soon":"get_well"}
|
|
|
|
| 54 |
AGE_OPTIONS = {"any":"any","kid (3โ12)":"kids","teen (13โ17)":"teens","adult (18โ64)":"adult","senior (65+)":"senior"}
|
| 55 |
GENDER_OPTIONS = ["any","female","male","nonbinary"]
|
| 56 |
|
| 57 |
+
# Light synonym expansion for interests; used to enrich queries and "hit" checks
|
| 58 |
SYNONYMS = {"sports":["fitness","outdoor","training","yoga","run"],"travel":["luggage","passport","map","trip","vacation"],"cooking":["kitchen","cookware","chef","baking"],"technology":["electronics","gadgets","device","smart","computer"],"music":["audio","headphones","earbuds","speaker","vinyl"],"art":["painting","drawing","sketch","canvas"],"reading":["book","novel","literature"],"gardening":["plants","planter","seeds","garden","indoor"],"fashion":["style","accessory","jewelry"],"gaming":["board game","puzzle","video game","controller"],"photography":["camera","lens","tripod","film"],"hiking":["outdoor","camping","backpack","trek"],"movies":["film","cinema","blu-ray","poster"],"crafts":["diy","handmade","kit","knitting"],"pets":["dog","cat","pet"],"wellness":["relaxation","spa","aromatherapy","self-care"],"collecting":["display","collector","limited edition"],"food":["gourmet","snack","treats","chocolate"],"home decor":["home","decor","wall art","candle"],"science":["lab","experiment","STEM","microscope"]}
|
| 59 |
REL_TO_TOKENS = {"Family - Parent":["parent","family"],"Family - Sibling":["sibling","family"],"Family - Child":["kids","play","family"],"Family - Other relative":["family","relative"],"Friend":["friendly"],"Colleague":["office","work","professional"],"Boss":["executive","professional","premium"],"Romantic partner":["romantic","couple"],"Teacher / Mentor":["teacher","mentor","thank_you"],"Neighbor":["neighbor","housewarming"],"Client / Business partner":["professional","thank_you","premium"]}
|
| 60 |
|
| 61 |
+
# --- Price parsing helpers (robust to currency symbols and ranges) ---
|
| 62 |
_CURRENCY_RE = re.compile(r"[^\d.,\-]+"); _NUM_RE = re.compile(r"(\d+(?:[.,]\d+)?)"); _RANGE_SEP = re.compile(r"\s*(?:-|โ|โ|to)\s*")
|
| 63 |
def _to_price_usd(x):
|
| 64 |
if pd.isna(x): return np.nan
|
|
|
|
| 68 |
return float(m.group(1)) if m else np.nan
|
| 69 |
|
| 70 |
def _first_present(df, cands):
|
| 71 |
+
# Return the first column name that exists in df out of candidates (case-insensitive)
|
| 72 |
lower = {c.lower(): c for c in df.columns}
|
| 73 |
for c in cands:
|
| 74 |
if c in df.columns: return c
|
|
|
|
| 76 |
return None
|
| 77 |
|
| 78 |
def _auto_price_col(df):
|
| 79 |
+
# Heuristics for price column detection when column name is unknown
|
| 80 |
for c in df.columns:
|
| 81 |
s = df[c]
|
| 82 |
if pd.api.types.is_numeric_dtype(s) and not s.dropna().empty and (s.dropna().between(0.5, 10000)).mean() > .6: return c
|
|
|
|
| 85 |
return None
|
| 86 |
|
| 87 |
def map_amazon_to_schema(raw: pd.DataFrame) -> pd.DataFrame:
|
| 88 |
+
# Map arbitrary Amazon-like columns into a compact schema suitable for retrieval
|
| 89 |
name_c=_first_present(raw,["product name","title","name","product_title"]); desc_c=_first_present(raw,["description","product_description","feature","about"])
|
| 90 |
cat_c=_first_present(raw,["category","categories","main_cat","product_category"]); price_c=_first_present(raw,["selling price","price","current_price","list_price","price_amount","actual_price","price_usd"]) or _auto_price_col(raw)
|
| 91 |
img_c=_first_present(raw,["image","image_url","imageurl","imUrl","img","img_url"])
|
| 92 |
df=pd.DataFrame({"name":raw.get(name_c,""),"short_desc":raw.get(desc_c,""),"tags":raw.get(cat_c,""),"price_usd":raw.get(price_c,np.nan),"image_url":raw.get(img_c,"")})
|
| 93 |
+
# Light normalization / truncation to keep UI compact
|
| 94 |
df["price_usd"]=df["price_usd"].map(_to_price_usd); df["name"]=df["name"].astype(str).str.strip().str.slice(0,160)
|
| 95 |
df["short_desc"]=df["short_desc"].astype(str).str.strip().str.slice(0,600); df["tags"]=df["tags"].astype(str).str.replace("|",", ").str.lower()
|
| 96 |
return df
|
| 97 |
|
| 98 |
def extract_top_cat(tags:str)->str:
|
| 99 |
+
# Extract a "top-level" category token for quick grouping/labeling
|
| 100 |
s=(tags or "").lower()
|
| 101 |
for sep in ["|",">"]:
|
| 102 |
if sep in s: return s.split(sep,1)[0].strip()
|
| 103 |
return s.strip().split(",")[0] if s else ""
|
| 104 |
|
| 105 |
def load_catalog()->pd.DataFrame:
|
| 106 |
+
# Load dataset โ normalize schema โ filter โ light feature engineering
|
| 107 |
df=map_amazon_to_schema(load_dataset(DATASET_ID, split=DATASET_SPLIT).to_pandas()).drop_duplicates(subset=["name","short_desc"])
|
| 108 |
df=df[pd.notna(df["price_usd"])]; df=df[(df["price_usd"]>0)&(df["price_usd"]<=500)].reset_index(drop=True)
|
| 109 |
if len(df)>MAX_ROWS: df=df.sample(n=MAX_ROWS,random_state=42).reset_index(drop=True)
|
|
|
|
| 113 |
return df
|
| 114 |
CATALOG=load_catalog()
|
| 115 |
|
| 116 |
+
# -----------------------------------------------------------------------------
|
| 117 |
+
# Embedding bank with on-disk caching
|
| 118 |
+
# -----------------------------------------------------------------------------
|
| 119 |
class EmbeddingBank:
|
| 120 |
def __init__(s, docs, model_id, dataset_tag):
|
| 121 |
s.model_id=model_id; s.dataset_tag=dataset_tag; s.model=SentenceTransformer(model_id); s.embs=s._load_or_build(docs)
|
|
|
|
| 130 |
def query_vec(s,text): return s.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)[0]
|
| 131 |
EMB=EmbeddingBank(CATALOG["doc"].tolist(), EMBED_MODEL_ID, DATASET_ID)
|
| 132 |
|
| 133 |
+
# Token set for light lexical checks (used by interest Hit@k)
|
| 134 |
_tok_rx = re.compile(r"[a-z0-9][a-z0-9\-']*")
|
| 135 |
if "tok_set" not in CATALOG.columns:
|
| 136 |
CATALOG["tok_set"]=(CATALOG["name"].fillna("")+" "+CATALOG["tags"].fillna("")+" "+CATALOG["short_desc"].fillna("")).map(lambda t:set(_tok_rx.findall(str(t).lower())))
|
| 137 |
|
| 138 |
+
# Optional cross-encoder for re-ranking (small CPU-friendly model by default)
|
| 139 |
try:
|
| 140 |
from sentence_transformers import CrossEncoder
|
| 141 |
except:
|
|
|
|
| 148 |
_CE_MODEL=CrossEncoder(RERANK_MODEL_ID, device="cpu")
|
| 149 |
return _CE_MODEL
|
| 150 |
|
| 151 |
+
# Occasion-specific keyword priors (light bonus shaping)
|
| 152 |
OCCASION_PRIORS={"valentines":[("jewelry",.12),("chocolate",.10),("candle",.08),("romantic",.08),("couple",.08),("heart",.06)],
|
| 153 |
"birthday":[("fun",.06),("game",.06),("personalized",.06),("gift set",.05),("surprise",.04)],
|
| 154 |
"anniversary":[("couple",.10),("jewelry",.10),("photo",.08),("frame",.06),("memory",.06),("candle",.06)],
|
|
|
|
| 161 |
"get_well":[("cozy",.10),("tea",.08),("soothing",.06),("care",.06)]}
|
| 162 |
|
| 163 |
def expand_with_synonyms(tokens: List[str])->List[str]:
|
| 164 |
+
# Expand user-provided interests with synonyms to enrich the query
|
| 165 |
out=[];
|
| 166 |
for t in tokens:
|
| 167 |
t=t.strip().lower()
|
|
|
|
| 169 |
return out
|
| 170 |
|
| 171 |
def profile_to_query(p:Dict)->str:
|
| 172 |
+
# Construct a dense query string from profile information
|
| 173 |
inter=[i.lower() for i in p.get("interests",[]) if i]; expanded=expand_with_synonyms(inter)*3
|
| 174 |
parts=[", ".join(expanded) if expanded else "", ", ".join(REL_TO_TOKENS.get(p.get("relationship","Friend"),[])), OCCASION_CANON.get(p.get("occ_ui","Birthday"),"birthday")]
|
| 175 |
tail=f"gift ideas for a {p.get('relationship','Friend')} for {parts[-1]}; likes {', '.join(inter) or 'general'}"
|
| 176 |
return " | ".join([x for x in parts if x])+" | "+tail
|
| 177 |
|
| 178 |
def _gender_ok_mask(g:str)->np.ndarray:
|
| 179 |
+
# Gender-aware filter: exclude items explicitly labeled for the opposite gender unless unisex
|
| 180 |
g=(g or "any").lower(); bl=CATALOG["blob"]
|
| 181 |
has_m=bl.str.contains(r"\b(men|man's|mens|male|for men)\b",regex=True,na=False)
|
| 182 |
has_f=bl.str.contains(r"\b(women|woman's|womens|female|for women|dress)\b",regex=True,na=False)
|
|
|
|
| 186 |
return np.ones(len(bl),bool)
|
| 187 |
|
| 188 |
def _mask_by_age(age:str, blob:pd.Series)->np.ndarray:
|
| 189 |
+
# Age-aware filter: crude regex to separate kids/teens/adults
|
| 190 |
kids=blob.str.contains(r"\b(?:kid|kids|child|children|toddler|baby|boys?|girls?|kid's|children's)\b",regex=True,na=False)
|
| 191 |
teen=blob.str.contains(r"\b(?:teen|teens|young adult|ya)\b",regex=True,na=False)
|
| 192 |
if age in ("adult","senior"): return (~kids).to_numpy()
|
|
|
|
| 195 |
return np.ones(len(blob),bool)
|
| 196 |
|
| 197 |
def _interest_bonus(p:Dict, idx:np.ndarray)->np.ndarray:
|
| 198 |
+
# Soft bonus if catalog tokens overlap with interest vocabulary (synonyms included)
|
| 199 |
ints=[i.lower() for i in p.get("interests",[]) if i]; syns=[s for it in ints for s in SYNONYMS.get(it,[])]; vocab=set(ints+syns)
|
| 200 |
if not vocab or idx.size==0: return np.zeros(len(idx),"float32")
|
| 201 |
counts=np.array([len(CATALOG["tok_set"].iat[i] & vocab) for i in idx],"float32"); return .10*np.clip(counts,0,6)
|
| 202 |
|
| 203 |
def _occasion_bonus(idx:np.ndarray, occ_ui:str)->np.ndarray:
|
| 204 |
+
# Soft bonus based on occasion priors (keywords found in item blob)
|
| 205 |
pri=OCCASION_PRIORS.get(OCCASION_CANON.get(occ_ui or "Birthday","birthday"),[])
|
| 206 |
if not pri or idx.size==0: return np.zeros(len(idx),"float32")
|
| 207 |
bl=CATALOG["blob"].to_numpy(); out=np.zeros(len(idx),"float32")
|
|
|
|
| 210 |
return out
|
| 211 |
|
| 212 |
def _minmax(x:np.ndarray)->np.ndarray:
|
| 213 |
+
# Normalize to [0,1] with safe guard for constant vectors
|
| 214 |
if x.size==0: return x
|
| 215 |
lo,hi=float(np.min(x)),float(np.max(x));
|
| 216 |
return np.zeros_like(x) if hi<=lo+1e-9 else (x-lo)/(hi-lo)
|
| 217 |
|
| 218 |
def _mmr_select(cand_idx:np.ndarray, scores:np.ndarray, k:int, lambda_:float=.7)->np.ndarray:
|
| 219 |
+
# MMR selection to maintain diversity in the final top-k
|
| 220 |
if cand_idx.size<=k: return cand_idx[np.argsort(-scores)][:k]
|
| 221 |
picked=[]; rest=list(range(len(cand_idx))); rel=_minmax(scores)
|
| 222 |
V=np.asarray(EMB.embs,"float32")[cand_idx]; V/=np.linalg.norm(V,axis=1,keepdims=True)+1e-8
|
|
|
|
| 226 |
j=int(np.argmax(lambda_*rel[rest]-(1-lambda_)*sim_to_sel)); picked.append(rest.pop(j))
|
| 227 |
return cand_idx[np.array(picked,int)]
|
| 228 |
|
| 229 |
+
def recommend_top3_budget_first(
|
| 230 |
+
p: Dict,
|
| 231 |
+
include_synth: bool = True,
|
| 232 |
+
synth_n: int = 10,
|
| 233 |
+
widen_budget_frac: float = 0.5
|
| 234 |
+
) -> pd.DataFrame:
|
| 235 |
+
"""
|
| 236 |
+
Retrieve โ score โ diversify. Always returns semantically-ranked results
|
| 237 |
+
from the catalog (no โcheapest-3โ fallback). If strict filters empty the
|
| 238 |
+
pool, we progressively relax them but still rank by embeddings + bonuses.
|
| 239 |
+
Optionally appends a 4th 'Generated' item (DIY) when include_synth=True.
|
| 240 |
+
"""
|
| 241 |
+
# ---------- Filters (progressive relaxations) ----------
|
| 242 |
+
lo, hi = float(p.get("budget_min", 0)), float(p.get("budget_max", 1e9))
|
| 243 |
+
blob = CATALOG["blob"]
|
| 244 |
+
price = CATALOG["price_usd"].values
|
| 245 |
+
age_ok = _mask_by_age(p.get("age_range", "any"), blob)
|
| 246 |
+
gen_ok = _gender_ok_mask(p.get("gender", "any"))
|
| 247 |
+
price_ok_strict = (price >= lo) & (price <= hi)
|
| 248 |
+
price_ok_wide = (price >= max(0, lo * (1 - widen_budget_frac))) & \
|
| 249 |
+
(price <= (hi * (1 + widen_budget_frac) if hi < 1e8 else hi))
|
| 250 |
+
|
| 251 |
+
mask_chain = [
|
| 252 |
+
price_ok_strict & age_ok & gen_ok, # ืืื ืงืฉืื
|
| 253 |
+
price_ok_strict & gen_ok, # ืืื ืืื
|
| 254 |
+
price_ok_wide & gen_ok, # ืืจืืืช ืืืื ืชืงืฆืื
|
| 255 |
+
age_ok & gen_ok, # ืืื ืชืงืฆืื
|
| 256 |
+
gen_ok, # ืจืง ืืืืจ
|
| 257 |
+
np.ones(len(CATALOG), bool), # ืืื
|
| 258 |
+
]
|
| 259 |
+
idx = np.array([], dtype=int)
|
| 260 |
+
for m in mask_chain:
|
| 261 |
+
cand = np.where(m)[0]
|
| 262 |
+
if cand.size:
|
| 263 |
+
idx = cand
|
| 264 |
+
break
|
| 265 |
+
|
| 266 |
+
# ---------- Query & base similarities ----------
|
| 267 |
+
q = profile_to_query(p)
|
| 268 |
+
qv = EMB.query_vec(q).astype("float32")
|
| 269 |
+
embs = np.asarray(EMB.embs, "float32")
|
| 270 |
+
emb_sims = embs[idx] @ qv
|
| 271 |
+
|
| 272 |
+
# ---------- Bonuses (ืขืืืื ืืืืฉืืื ืขื ืืืืขืืืื ืฉื ืืืจื) ----------
|
| 273 |
+
target = (lo + hi) / 2.0 if hi > lo else hi
|
| 274 |
+
prices = CATALOG.iloc[idx]["price_usd"].to_numpy()
|
| 275 |
+
price_bonus = np.clip(.12 - np.abs(prices - target) / max(target, 1.0), 0, .12).astype("float32")
|
| 276 |
+
int_bonus = _interest_bonus(p, idx)
|
| 277 |
+
occ_bonus = _occasion_bonus(idx, p.get("occ_ui", "Birthday"))
|
| 278 |
+
|
| 279 |
+
# Pre-score ืขื ืืื ืืช ื-NaN/Inf
|
| 280 |
+
pre = np.nan_to_num(emb_sims + price_bonus + int_bonus + occ_bonus, nan=0.0, posinf=0.0, neginf=0.0)
|
| 281 |
+
|
| 282 |
+
# ---------- Local candidate pool ----------
|
| 283 |
+
K1 = max(1, min(48, idx.size))
|
| 284 |
+
try:
|
| 285 |
+
top_local = np.argpartition(-pre, K1 - 1)[:K1]
|
| 286 |
+
except Exception:
|
| 287 |
+
top_local = np.argsort(-pre)[:K1]
|
| 288 |
+
cand_idx = idx[top_local]
|
| 289 |
+
|
| 290 |
+
# ---------- Feature normalization ----------
|
| 291 |
+
emb_n = _minmax(np.nan_to_num(emb_sims[top_local], nan=0.0))
|
| 292 |
+
price_n = _minmax(np.nan_to_num(price_bonus[top_local],nan=0.0))
|
| 293 |
+
int_n = _minmax(np.nan_to_num(int_bonus[top_local], nan=0.0))
|
| 294 |
+
occ_n = _minmax(np.nan_to_num(occ_bonus[top_local], nan=0.0))
|
| 295 |
+
|
| 296 |
+
# ---------- Optional cross-encoder ----------
|
| 297 |
+
ce = _load_cross_encoder()
|
| 298 |
+
if ce is not None:
|
| 299 |
+
docs = CATALOG.loc[cand_idx, "doc"].tolist()
|
| 300 |
+
pairs = [(q, d) for d in docs]
|
| 301 |
+
k_ce = min(24, len(pairs))
|
| 302 |
+
tl = np.argpartition(-emb_n, k_ce - 1)[:k_ce]
|
| 303 |
+
ce_raw = np.array(ce.predict([pairs[i] for i in tl]), "float32")
|
| 304 |
+
ce_n = np.zeros_like(emb_n)
|
| 305 |
+
ce_n[tl] = _minmax(ce_raw)
|
| 306 |
+
else:
|
| 307 |
+
ce_n = np.zeros_like(emb_n)
|
| 308 |
+
|
| 309 |
+
# ---------- Final score ----------
|
| 310 |
+
final = np.nan_to_num(.56*emb_n + .26*ce_n + .10*int_n + .05*occ_n + .03*price_n, nan=0.0)
|
| 311 |
+
|
| 312 |
+
# ---------- Select top-3 with diversity ----------
|
| 313 |
+
k = int(min(3, cand_idx.size))
|
| 314 |
+
pick = _mmr_select(cand_idx, final, k=k) if k > 0 else np.array([], dtype=int)
|
| 315 |
+
if pick.size == 0:
|
| 316 |
+
pick = cand_idx[np.argsort(-final)[:min(3, cand_idx.size)]]
|
| 317 |
+
|
| 318 |
+
# ---------- Build result ----------
|
| 319 |
+
res = CATALOG.loc[pick].copy()
|
| 320 |
+
pos = {int(cand_idx[i]): i for i in range(len(cand_idx))}
|
| 321 |
+
res["similarity"] = [float(final[pos[int(i)]]) if int(i) in pos else np.nan for i in pick]
|
| 322 |
+
|
| 323 |
+
# ---------- Optional synthetic #4 ----------
|
| 324 |
+
if include_synth:
|
| 325 |
+
try:
|
| 326 |
+
synth = pick_best_synthetic(p, qv, generate_synthetic_candidates(p, n=int(max(1, synth_n))))
|
| 327 |
+
if synth is not None:
|
| 328 |
+
res = pd.concat(
|
| 329 |
+
[res, pd.DataFrame([synth])[["name","short_desc","price_usd","image_url","similarity"]]],
|
| 330 |
+
ignore_index=True
|
| 331 |
+
)
|
| 332 |
+
except Exception:
|
| 333 |
+
pass # ืื ืฉืืืจืื ืืช ื-UI ืื ื-DIY ื ืืฉื
|
| 334 |
+
|
| 335 |
+
return res[["name","short_desc","price_usd","image_url","similarity"]].reset_index(drop=True)
|
| 336 |
+
|
| 337 |
q=profile_to_query(p); qv=EMB.query_vec(q).astype("float32")
|
| 338 |
emb_sims=np.asarray(EMB.embs,"float32")[idx]@qv
|
| 339 |
target=(lo+hi)/2.0 if hi>lo else hi; prices=CATALOG.iloc[idx]["price_usd"].to_numpy()
|
| 340 |
+
# Small bonus for being close to the budget mid-point
|
| 341 |
price_bonus=np.clip(.12-np.abs(prices-target)/max(target,1.0),0,.12).astype("float32")
|
| 342 |
int_bonus=_interest_bonus(p,idx); occ_bonus=_occasion_bonus(idx,p.get("occ_ui","Birthday"))
|
| 343 |
pre=emb_sims+price_bonus+int_bonus+occ_bonus
|
| 344 |
+
# Keep a local candidate pool for cost/quality tradeoff
|
| 345 |
K1=min(48,idx.size); top_local=np.argpartition(-pre,K1-1)[:K1]; cand_idx=idx[top_local]
|
| 346 |
emb_n=_minmax(emb_sims[top_local]); price_n=_minmax(price_bonus[top_local]); int_n=_minmax(int_bonus[top_local]); occ_n=_minmax(occ_bonus[top_local])
|
| 347 |
ce=_load_cross_encoder();
|
| 348 |
if ce is not None:
|
| 349 |
+
# Optional cross-encoder re-ranking on a smaller slice
|
| 350 |
docs=CATALOG.loc[cand_idx,"doc"].tolist(); pairs=[(q,d) for d in docs]
|
| 351 |
k_ce=min(24,len(pairs)); tl=np.argpartition(-emb_n,k_ce-1)[:k_ce]; ce_raw=np.array(ce.predict([pairs[i] for i in tl]),"float32"); ce_n=np.zeros_like(emb_n); ce_n[tl]=_minmax(ce_raw)
|
| 352 |
else:
|
| 353 |
ce_n=np.zeros_like(emb_n)
|
| 354 |
+
# Final weighted score (tuned manually)
|
| 355 |
final=(.56*emb_n+.26*ce_n+.10*int_n+.05*occ_n+.03*price_n).astype("float32")
|
| 356 |
pick=_mmr_select(cand_idx,final,k=min(3,cand_idx.size))
|
| 357 |
res=CATALOG.loc[pick].copy(); pos={int(cand_idx[i]):i for i in range(len(cand_idx))}; res["similarity"]=[float(final[pos[int(i)]]) for i in pick]
|
| 358 |
+
# === NEW: synthetic #4 ===
|
| 359 |
+
synth = pick_best_synthetic(p, qv, generate_synthetic_candidates(p, n=10))
|
| 360 |
+
if synth is not None:
|
| 361 |
+
res = pd.concat(
|
| 362 |
+
[res, pd.DataFrame([synth])[["name","short_desc","price_usd","image_url","similarity"]]],
|
| 363 |
+
ignore_index=True
|
| 364 |
+
)
|
| 365 |
return res[["name","short_desc","price_usd","image_url","similarity"]].reset_index(drop=True)
|
| 366 |
|
| 367 |
# ===== DIY (FLAN-only) =====
|
| 368 |
DIY_MODEL_ID=os.getenv("DIY_MODEL_ID","google/flan-t5-small"); DIY_DEVICE=torch.device("cpu")
|
| 369 |
MAX_INPUT_TOKENS=int(os.getenv("MAX_INPUT_TOKENS","384")); DIY_MAX_NEW_TOKENS=int(os.getenv("DIY_MAX_NEW_TOKENS","120"))
|
| 370 |
+
# Light aliases to seed the DIY gift title with an interest token
|
| 371 |
INTEREST_ALIASES={"Reading":["book","novel","literary"],"Fashion":["style","chic","silk"],"Home decor":["candle","wall","jar"],"Technology":["tech","gadget","usb"],"Movies":["film","cinema","poster"]}
|
| 372 |
FALLBACK_NOUNS=["Kit","Set","Bundle","Box","Pack"]
|
| 373 |
|
| 374 |
_diy_cache_model={}
|
| 375 |
def _load_flan(mid:str):
|
| 376 |
+
# Lazy-load and cache FLAN-T5 on CPU
|
| 377 |
if mid in _diy_cache_model: return _diy_cache_model[mid]
|
| 378 |
tok=AutoTokenizer.from_pretrained(mid, use_fast=True, trust_remote_code=True)
|
| 379 |
mdl=AutoModelForSeq2SeqLM.from_pretrained(mid, trust_remote_code=True, use_safetensors=True).to(DIY_DEVICE).eval()
|
|
|
|
| 381 |
|
| 382 |
@torch.inference_mode()
|
| 383 |
def _gen(tok, mdl, prompt, max_new_tokens=64, do_sample=False, temperature=.9, top_p=.95, seed=None):
|
| 384 |
+
# Small wrapper for deterministic/non-deterministic generation
|
| 385 |
if seed is None: seed=random.randint(1,10_000_000)
|
| 386 |
random.seed(seed); torch.manual_seed(seed)
|
| 387 |
enc=tok(prompt, truncation=True, max_length=MAX_INPUT_TOKENS, return_tensors="pt"); enc={k:v.to(DIY_DEVICE) for k,v in enc.items()}
|
|
|
|
| 389 |
return tok.decode(out[0], skip_special_tokens=True).strip()
|
| 390 |
|
| 391 |
def _choose_interest_token(interests):
|
| 392 |
+
# Pick a representative token to inject into the DIY name
|
| 393 |
for it in interests:
|
| 394 |
if INTEREST_ALIASES.get(it): return random.choice(INTEREST_ALIASES[it])
|
| 395 |
return (interests[0].split()[0].lower() if interests else "gift")
|
| 396 |
def _title_case(s): s=re.sub(r'\s+',' ',s).strip(); s=re.sub(r'["โโโโ]+','',s); return " ".join([w.capitalize() for w in s.split()])
|
| 397 |
def _sanitize_name(name, interests):
|
| 398 |
+
# Clean LLM-proposed name and enforce a short, interest-infused title
|
| 399 |
for b in [r"^the name\b",r"\bmember of the family\b",r"^name\b",r"^title\b"]: name=re.sub(b,"",name,flags=re.I).strip()
|
| 400 |
name=re.sub(r'[:\-โโ]+$',"",name).strip(); alias=_choose_interest_token(interests)
|
| 401 |
if alias not in name.lower():
|
|
|
|
| 406 |
return name
|
| 407 |
|
| 408 |
def _split_list_text(s,seps):
|
| 409 |
+
# Parse list-like text returned by LLM into clean items (fallback across separators)
|
| 410 |
s=s.strip()
|
| 411 |
for sep in seps:
|
| 412 |
if sep in s:
|
|
|
|
| 415 |
return [p.strip(" -โข*.,;:") for p in re.split(r"[\n\r;]+", s) if p.strip(" -โข*.,;:")]
|
| 416 |
|
| 417 |
def _coerce_materials(items):
|
| 418 |
+
# Normalize materials list: dedupe, keep short, ensure quantities, pad with basics
|
| 419 |
out=[]
|
| 420 |
for it in items:
|
| 421 |
it=re.sub(r'\s+',' ',it).strip(" -โข*.,;:");
|
|
|
|
| 432 |
return out[:8]
|
| 433 |
|
| 434 |
def _coerce_steps(items):
|
| 435 |
+
# Normalize step list: trim, remove numbering, enforce sentence case, pad to 6+
|
| 436 |
out=[]
|
| 437 |
for it in items:
|
| 438 |
it=it.strip(" -โข*.,;:");
|
|
|
|
| 446 |
|
| 447 |
def _only_int(s): m=re.search(r"-?\d+",s); return int(m.group()) if m else None
|
| 448 |
def _clamp_num(v,lo,hi,default):
|
| 449 |
+
# Clamp numeric values into a valid range; fallback to default or midpoint
|
| 450 |
try: x=float(v); return int(min(max(x,lo),hi))
|
| 451 |
except: return int((lo+hi)/2 if default is None else default)
|
| 452 |
|
| 453 |
def diy_generate(profile:Dict)->Tuple[dict,str]:
|
| 454 |
+
# Generate a DIY gift object (name, overview, materials, steps, cost, time)
|
| 455 |
tok,mdl=_load_flan(DIY_MODEL_ID)
|
| 456 |
p={"recipient_name":profile.get("recipient_name","Recipient"),"relationship":profile.get("relationship","Friend"),
|
| 457 |
"occ_ui":profile.get("occ_ui","Birthday"),"occasion":profile.get("occ_ui","Birthday"),"interests":profile.get("interests",[]),
|
|
|
|
| 475 |
"estimated_cost_usd":_clamp_num(cost,p["budget_min"],p["budget_max"],None),"estimated_time_minutes":_clamp_num(minutes,20,180,60)}
|
| 476 |
return idea,"ok"
|
| 477 |
|
| 478 |
+
def generate_synthetic_candidates(profile, n=10):
|
| 479 |
+
# Use FLAN-based DIY generator to create N lightweight candidates (name/overview/price)
|
| 480 |
+
cands = []
|
| 481 |
+
lo, hi = int(float(profile.get("budget_min", 10))), int(float(profile.get("budget_max", 100)))
|
| 482 |
+
for _ in range(n):
|
| 483 |
+
idea, _ = diy_generate(profile) # Already returns name/overview/estimated_cost
|
| 484 |
+
price = int(idea.get("estimated_cost_usd") or random.randint(lo, hi))
|
| 485 |
+
name = idea.get("gift_name", "Custom DIY Gift")[:160]
|
| 486 |
+
desc = (idea.get("overview", "") or "").strip()[:300]
|
| 487 |
+
doc = f"{name} | custom | {desc}".lower()
|
| 488 |
+
cands.append({"name": name, "short_desc": desc, "price_usd": price, "image_url": "", "doc": doc})
|
| 489 |
+
return cands
|
| 490 |
+
|
| 491 |
+
def pick_best_synthetic(profile, qv, candidates):
|
| 492 |
+
# Embed synthetic candidates and pick the one most similar to the query vector
|
| 493 |
+
if not candidates: return None
|
| 494 |
+
docs = [c["doc"] for c in candidates]
|
| 495 |
+
vecs = EMB.model.encode(docs, convert_to_numpy=True, normalize_embeddings=True)
|
| 496 |
+
sims = vecs @ qv
|
| 497 |
+
j = int(np.argmax(sims))
|
| 498 |
+
best = candidates[j].copy()
|
| 499 |
+
best["similarity"] = float(sims[j])
|
| 500 |
+
return best
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
# --------------------- Personalized Message (FLAN + validation) ---------------------
|
| 504 |
+
# Implementation ported from the Colab; tone-specific constraints + simple checks.
|
| 505 |
MSG_MODEL_ID = "google/flan-t5-small"
|
| 506 |
MSG_DEVICE = "cpu"
|
| 507 |
TEMP_RANGE = (0.88, 1.10)
|
|
|
|
| 603 |
]
|
| 604 |
|
| 605 |
def _msg_load():
|
| 606 |
+
# Lazy-load FLAN for message generation (CPU)
|
| 607 |
global _msg_tok, _msg_mdl
|
| 608 |
if _msg_tok is None or _msg_mdl is None:
|
| 609 |
_msg_tok = AutoTokenizer.from_pretrained(MSG_MODEL_ID)
|
|
|
|
| 612 |
return _msg_tok, _msg_mdl
|
| 613 |
|
| 614 |
def _norm(s: str) -> str:
|
| 615 |
+
# Collapse whitespace for more reliable validators
|
| 616 |
return re.sub(r"\s+", " ", s or "").strip()
|
| 617 |
|
| 618 |
def _sentences_n(s: str) -> int:
|
| 619 |
+
# Count sentences via punctuation boundaries
|
| 620 |
return len([p for p in re.split(r"(?<=[.!?])\s+", s.strip()) if p])
|
| 621 |
|
| 622 |
def _contains_any(text: str, terms: List[str]) -> bool:
|
| 623 |
+
# Case-insensitive containment check for any of the given terms
|
| 624 |
t = text.lower()
|
| 625 |
return any(term for term in terms if term) and any((term or "").lower() in t for term in terms)
|
| 626 |
|
| 627 |
def _too_similar(a: str, b: str, n=3, thr=0.85) -> bool:
|
| 628 |
+
# Approximate de-duplication via n-gram Jaccard similarity
|
| 629 |
def ngrams(txt):
|
| 630 |
toks = re.findall(r"[a-zA-Z']+", txt.lower())
|
| 631 |
return set(tuple(toks[i:i+n]) for i in range(max(0, len(toks)-n+1)))
|
|
|
|
| 635 |
return j >= thr
|
| 636 |
|
| 637 |
def _clean_occasion(occ: str) -> str:
|
| 638 |
+
# Normalize typographic apostrophes to ASCII and trim
|
| 639 |
return (occ or "").replace("โ","'").strip()
|
| 640 |
|
| 641 |
def _build_prompt(profile: Dict[str, Any]) -> Tuple[str, Dict[str,str]]:
|
| 642 |
+
# Compose a guided prompt (tone + micro-rules) for the message LLM
|
| 643 |
name = profile.get("recipient_name", "Friend")
|
| 644 |
rel = profile.get("relationship", "Friend")
|
| 645 |
occ = _clean_occasion(profile.get("occ_ui") or profile.get("occasion") or "Birthday")
|
|
|
|
| 672 |
|
| 673 |
@torch.inference_mode()
|
| 674 |
def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None, previous_message: Optional[str]=None) -> Dict[str, Any]:
|
| 675 |
+
# Sample multiple generations with slight sampling variance, validate, and return best
|
| 676 |
global _last_msg
|
| 677 |
tok, mdl = _msg_load()
|
| 678 |
if seed is None:
|
|
|
|
| 698 |
)
|
| 699 |
text = _norm(tok.decode(out_ids[0], skip_special_tokens=True))
|
| 700 |
|
| 701 |
+
# ===== Validators (mirrors the Colab logic) =====
|
| 702 |
ok_len = 1 <= _sentences_n(text) <= 3
|
| 703 |
name_ok = _contains_any(text, [need["name"].lower()])
|
| 704 |
occ_ok = _contains_any(text, [need["occ"].lower(), need["occ"].split()[0].lower()])
|
|
|
|
| 713 |
"seed": seed, "attempt": attempt, "model": MSG_MODEL_ID}}
|
| 714 |
tried.append({"text": text}); seed += 17
|
| 715 |
|
| 716 |
+
# Fallback if all attempts failed validation
|
| 717 |
fallback = tried[-1]["text"] if tried else f"Happy {(_clean_occasion(profile.get('occ_ui') or 'day')).lower()}, {profile.get('recipient_name','Friend')}!"
|
| 718 |
_last_msg = fallback
|
| 719 |
return {"message": fallback, "meta": {"failed": True, "model": MSG_MODEL_ID, "tone": profile.get("tone","Heartfelt")}}
|
|
|
|
| 722 |
|
| 723 |
# ===== Rendering & UI =====
|
| 724 |
def first_sentence(s,max_chars=140):
|
| 725 |
+
# Extract the first sentence or truncate; keeps the HTML cards compact
|
| 726 |
s=(s or "").strip();
|
| 727 |
if not s: return ""
|
| 728 |
cut=s.split(". ")[0];
|
| 729 |
return cut if len(cut)<=max_chars else cut[:max_chars-1]+"โฆ"
|
| 730 |
|
| 731 |
def render_top3_html(df, age_label):
|
| 732 |
+
# Render the 3 catalog picks plus the optional 4th "Generated" item
|
| 733 |
if df is None or df.empty: return "<em>No results found within the current filters.</em>"
|
| 734 |
rows=[]
|
| 735 |
+
for i, r in df.iterrows():
|
| 736 |
name=str(r.get("name","")).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
| 737 |
desc=str(first_sentence(r.get("short_desc",""))).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
| 738 |
price=r.get("price_usd"); sim=r.get("similarity"); img=r.get("image_url","") or ""
|
| 739 |
price_str=f"${price:.0f}" if pd.notna(price) else "N/A"; sim_str=f"{sim:.3f}" if pd.notna(sim) else "โ"
|
| 740 |
img_html=f'<img src="{img}" alt="" style="width:84px;height:84px;object-fit:cover;border-radius:10px;margin-left:12px;" />' if img else ""
|
| 741 |
+
tag = "Generated" if i==3 else f"#{i+1}"
|
| 742 |
rows.append(f"""
|
| 743 |
<div style="display:flex;align-items:flex-start;justify-content:space-between;gap:10px;padding:10px;border:1px solid #eee;border-radius:12px;margin-bottom:8px;background:#fff;">
|
| 744 |
+
<div style="flex:1;min-width:0;"><div style="font-weight:700;">{name} <span style="font-size:.8em;opacity:.7;">({tag})</span></div>
|
| 745 |
<div style="font-size:0.95em;margin-top:4px;">{desc}</div>
|
| 746 |
<div style="font-size:0.9em;margin-top:6px;opacity:0.8;">Price: <b>{price_str}</b> ยท Age: <code>{age_label}</code> ยท Score: <code>{sim_str}</code></div>
|
| 747 |
</div>{img_html}
|
|
|
|
| 787 |
tone=gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="Funny")
|
| 788 |
|
| 789 |
go=gr.Button("Get GIfty!")
|
| 790 |
+
gr.Markdown("### ๐ Input summary"); out_summary = gr.HTML(visible=False)
|
| 791 |
gr.Markdown("### ๐ฏ Recommendations"); out_top3=gr.HTML()
|
| 792 |
gr.Markdown("### ๐ ๏ธ DIY Gift"); out_diy_md=gr.Markdown()
|
| 793 |
gr.Markdown("### ๐ Personalized Message"); out_msg=gr.Markdown()
|
| 794 |
run_token=gr.State(0)
|
| 795 |
|
| 796 |
def _on_example_select(evt: gr.SelectData):
|
| 797 |
+
# Clicking a row fills the input widgets with that example
|
| 798 |
r=int(evt.index[0] if isinstance(evt.index,(list,tuple)) else evt.index); row=EX_DF.iloc[r]; ints=[s.strip() for s in str(row["Interests"]).split("+")]
|
| 799 |
return (ints,row["Occasion"],int(row["Min $"]),int(row["Max $"]),row["Recipient"],row["Relationship"],row["Age group"],row["Gender"],row["Tone"])
|
| 800 |
ex_df.select(_on_example_select, outputs=[interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone])
|
| 801 |
|
| 802 |
def render_diy_md(j:dict)->str:
|
| 803 |
+
# Nicely format the DIY object as markdown
|
| 804 |
if not j: return "_DIY generation failed._"
|
| 805 |
steps=j.get('step_by_step_instructions', j.get('steps', []))
|
| 806 |
parts = [
|
|
|
|
| 811 |
f"**Estimated cost:** ${j.get('estimated_cost_usd','?')} ยท **Time:** {j.get('estimated_time_minutes','?')} min"
|
| 812 |
]
|
| 813 |
return "\n".join(parts)
|
| 814 |
+
def input_summary_html(p, age_label):
|
| 815 |
+
# Render a compact summary of the current input above the results
|
| 816 |
+
ints = ", ".join(p.get("interests", [])) or "โ"
|
| 817 |
+
budget = f"${int(float(p.get('budget_min',0)))}โ${int(float(p.get('budget_max',0)))}"
|
| 818 |
+
name = p.get("recipient_name","Friend"); rel = p.get("relationship","Friend")
|
| 819 |
+
occ = p.get("occ_ui", "Birthday"); gender = (p.get("gender","any") or "any").capitalize()
|
| 820 |
+
return f"""
|
| 821 |
+
<div style="padding:10px 12px;border:1px solid #e2e8f0;border-radius:12px;background:#f8fafc;margin-bottom:8px;">
|
| 822 |
+
<div style="display:flex;flex-wrap:wrap;gap:10px;align-items:center;">
|
| 823 |
+
<div><b>Recipient:</b> {name} ({rel})</div>
|
| 824 |
+
<div><b>Occasion:</b> {occ}</div>
|
| 825 |
+
<div><b>Age:</b> {age_label}</div>
|
| 826 |
+
<div><b>Gender:</b> {gender}</div>
|
| 827 |
+
<div><b>Budget:</b> {budget}</div>
|
| 828 |
+
<div style="flex-basis:100%;height:0;"></div>
|
| 829 |
+
<div><b>Interests:</b> {ints}</div>
|
| 830 |
+
</div>
|
| 831 |
+
</div>
|
| 832 |
+
"""
|
| 833 |
|
| 834 |
def _build_profile(ints, occ, bmin, bmax, name, rel, age_label, gender_val, tone_val):
|
| 835 |
+
# Convert UI widget values into an internal profile dict
|
| 836 |
try: bmin=float(bmin); bmax=float(bmax)
|
| 837 |
except: bmin,bmax=5.0,500.0
|
| 838 |
if bmin>bmax: bmin,bmax=bmax,bmin
|
| 839 |
return {"recipient_name":name or "Friend","relationship":rel or "Friend","interests":ints or [],"occ_ui":occ or "Birthday","budget_min":bmin,"budget_max":bmax,"age_range":AGE_OPTIONS.get(age_label,"any"),"gender":(gender_val or "any").lower(),"tone":tone_val or "Heartfelt"}
|
| 840 |
|
| 841 |
+
def start_run(curr):
|
| 842 |
+
# Simple monotonic counter to tie together chained events
|
| 843 |
+
return int(curr or 0) + 1
|
| 844 |
+
|
| 845 |
+
def predict_summary_only(rt, *args):
|
| 846 |
+
# args mapping:
|
| 847 |
+
# 0: interests, 1: occasion, 2: budget_min, 3: budget_max,
|
| 848 |
+
# 4: recipient_name, 5: relationship, 6: age_label, 7: gender, 8: tone
|
| 849 |
+
p = _build_profile(*args)
|
| 850 |
+
return gr.update(value=input_summary_html(p, args[6]), visible=True), rt
|
| 851 |
|
| 852 |
def predict_recs_only(rt, *args):
|
| 853 |
+
p = _build_profile(*args)
|
| 854 |
+
top3 = recommend_top3_budget_first(p, include_synth=False) # ืืืืจ
|
| 855 |
+
return gr.update(value=render_top3_html(top3, args[6]), visible=True), rt
|
| 856 |
+
|
| 857 |
+
def predict_recs_with_synth(rt, *args):
|
| 858 |
+
p = _build_profile(*args)
|
| 859 |
+
synth_n = int(os.getenv("SYNTH_N", "2"))
|
| 860 |
+
df = recommend_top3_budget_first(p, include_synth=True, synth_n=synth_n)
|
| 861 |
+
return gr.update(value=render_top3_html(df, args[6]), visible=True), rt
|
| 862 |
+
|
| 863 |
def predict_diy_only(rt, *args):
|
| 864 |
+
p = _build_profile(*args)
|
| 865 |
+
diy_json, _ = diy_generate(p)
|
| 866 |
+
return gr.update(value=render_diy_md(diy_json), visible=True), rt
|
| 867 |
|
| 868 |
+
def predict_msg_only(rt, *args):
|
| 869 |
+
p = _build_profile(*args)
|
| 870 |
+
msg_obj = generate_personal_message(p)
|
| 871 |
+
return gr.update(value=msg_obj["message"], visible=True), rt
|
| 872 |
+
|
| 873 |
+
ev_start = go.click(start_run, inputs=[run_token], outputs=[run_token], queue=True)
|
| 874 |
+
|
| 875 |
+
# 1) ืกืืืื ืงืื (ืืืืื)
|
| 876 |
+
ev_start.then(
|
| 877 |
+
predict_summary_only,
|
| 878 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
| 879 |
+
outputs=[out_summary, run_token],
|
| 880 |
+
queue=True,
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
# 2) ืืืืฆืืช ืืืืจืืช (Top-3 ืืื ืกืื ืชืื)
|
| 884 |
+
recs_fast = ev_start.then(
|
| 885 |
+
predict_recs_only,
|
| 886 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
| 887 |
+
outputs=[out_top3, run_token],
|
| 888 |
+
queue=True,
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
# 3) ืืืฉืื ืกืื ืชืื ืืฉืื ืืืฉื โ ืืจืขื ื ืืช ืืืชื out_top3 ืืฉืืืื
|
| 892 |
+
recs_fast.then(
|
| 893 |
+
predict_recs_with_synth,
|
| 894 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
| 895 |
+
outputs=[out_top3, run_token],
|
| 896 |
+
queue=True,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# 4) DIY ืึพMessage ืืืืืื ืืจืืฅ ืืืงืืื ืึพ(3)
|
| 900 |
+
ev_start.then(
|
| 901 |
+
predict_diy_only,
|
| 902 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
| 903 |
+
outputs=[out_diy_md, run_token],
|
| 904 |
+
queue=True,
|
| 905 |
+
)
|
| 906 |
+
ev_start.then(
|
| 907 |
+
predict_msg_only,
|
| 908 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
| 909 |
+
outputs=[out_msg, run_token],
|
| 910 |
+
queue=True,
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
|
| 914 |
if __name__=="__main__":
|
| 915 |
demo.launch()
|