barakb21's picture
Update app.py
0dac04c verified
# app.py — Startup recommender + Unlike + AI name (optional tagline/description)
import os, re, numpy as np, pandas as pd
from pathlib import Path
import gradio as gr
import torch, faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ---------- Paths / artifacts ----------
OUT_DIR = Path("./emb_index_e5")
FAISS_PATH = OUT_DIR / "faiss.index"
DATA_PATH = OUT_DIR / "data.parquet"
assert FAISS_PATH.exists(), f"Missing {FAISS_PATH}. Build & upload embeddings/index."
assert DATA_PATH.exists(), f"Missing {DATA_PATH}. Build & upload data parquet."
# ---------- Devices ----------
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
DEVICE_EMBED = "cuda" if torch.cuda.is_available() else "cpu" # e5 on GPU if available
DEVICE_GEN = "cpu" # FLAN on CPU (avoid OOM)
print(f"Embed device: {DEVICE_EMBED} | Gen device: {DEVICE_GEN}")
# ---------- Load artifacts ----------
index = faiss.read_index(str(FAISS_PATH))
df_local = pd.read_parquet(DATA_PATH)
for c in ["name","tagline","description"]:
if c in df_local.columns:
df_local[c] = df_local[c].astype(str).fillna("")
# ---------- Load models ----------
EMBED_MODEL = "intfloat/e5-base-v2"
embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE_EMBED)
MODEL_BASE = "google/flan-t5-base"
MODEL_LARGE = "google/flan-t5-large"
USE_LARGE_FOR_DESCRIPTION = False # keep False on Spaces unless you switch GEN to "cuda"
tok_base = AutoTokenizer.from_pretrained(MODEL_BASE)
base_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {}
mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE, **base_kwargs).to(DEVICE_GEN)
if USE_LARGE_FOR_DESCRIPTION:
tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE)
large_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {}
mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE, **large_kwargs).to(DEVICE_GEN)
else:
tok_large, mod_large = tok_base, mod_base
# ---------- Helpers (embedding + generation) ----------
def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1):
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE_GEN)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences
)
return [tokenizer.decode(o, skip_special_tokens=True).strip() for o in outputs]
def _embed_query(q: str) -> np.ndarray:
return embed_model.encode([f"query: {q}"], convert_to_numpy=True, normalize_embeddings=True).astype("float32")[0]
def _embed_passages(texts) -> np.ndarray:
texts = [f"passage: {t}" for t in texts]
return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
# ---------- Search with per-session unlikes ----------
def search_topk_filtered_session(query: str, k: int, unliked_ids: set):
qv = _embed_query(query)
fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids)))
scores, inds = index.search(qv[None, :], fetch)
inds = inds[0].tolist(); scores = scores[0].tolist()
res = df_local.iloc[inds][["name","tagline","description"]].copy()
res.insert(0, "row_idx", df_local.iloc[inds].index)
res.insert(1, "score", [float(s) for s in scores])
res = res[~res["row_idx"].isin(unliked_ids)].head(k).reset_index(drop=True)
res.insert(0, "rank", range(1, len(res)+1))
return res
# ---------- Synthetic generation (length-aware) ----------
_STOPWORDS = {
"the","a","an","for","and","or","to","of","in","on","with","by","from",
"my","our","your","their","at","as","about","into","over","under","this","that",
"idea","startup","company","product","service","app","platform","factory","labs","tech"
}
def _words(s: str): return [w for w in re.findall(r"[a-z]+", str(s).lower()) if w]
def _content_words(s: str): return [w for w in _words(s) if len(w) >= 3 and w not in _STOPWORDS]
def _normalize_name(s: str) -> str: return re.sub(r"[^a-z0-9]+", "", str(s).lower())
def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower()))
def _overlap_ratio(name_tokens, banned):
if not name_tokens or not banned: return 0.0
inter = len(set(name_tokens) & set(banned)); union = len(set(name_tokens) | set(banned))
return inter / max(union, 1)
NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3
NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3
def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int):
c = len(text); w = len(text.split())
return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words)
def _theme_hints(query: str, k: int = 6):
kws = _content_words(query); seen, hints = set(), []
for t in kws:
if t not in seen: hints.append(t); seen.add(t)
return ", ".join(hints[:k]) if hints else "education, learning, students, AI"
def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3):
banned = sorted(set(_content_words(base_idea)))
avoid_str = ", ".join(banned[:12]) if banned else "previous words"
hints = _theme_hints(base_idea)
all_candidates = []
def _prompt(osz):
return (
f"Create {osz} brandable startup names for this idea:\n"
f"\"{base_idea}\"\n\n"
f"Guidance:\n"
f"- Evoke these themes (without literally using the words): {hints}\n"
f"- 1 or 2 words; aim ~{NAME_CHAR_TARGET} characters (±{NAME_CHAR_TOL})\n"
f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n"
f"- Do NOT use: {avoid_str}\n"
f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n"
f"- Output one name per line; no numbering, no quotes."
)
for attempt in range(max_retries):
raw = _generate_text(mod_base, tok_base, _prompt(oversample),
num_return_sequences=1, max_new_tokens=240,
temperature=1.0 + 0.05*attempt, top_p=0.95)[0]
# collect
for line in raw.splitlines():
nm = line.strip().lstrip("-•*0123456789. ").strip()
if nm:
nm = re.sub(r"[^\w\s-]+$", "", nm).strip()
all_candidates.append(nm)
# dedup
uniq, seen = [], set()
for nm in all_candidates:
key = _normalize_name(nm)
if key and key not in seen:
seen.add(key); uniq.append(nm)
all_candidates = uniq
# progressive filter
def ok(nm: str, overlap_cap: float, tol_boost: int):
if not _has_vowel(nm): return False
if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX): return False
toks = _content_words(nm)
if _overlap_ratio(toks, banned) > overlap_cap: return False
if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False
return True
overlap_caps = [0.25, 0.35, 0.5]; tol_boosts = [0, 1, 2]
filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])]
if len(filtered) >= n: return filtered[:n]
return all_candidates[:n] if all_candidates else []
# Tagline/description length targets (from your EDA)
TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6
TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2
DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230
DESC_WORD_MIN, DESC_WORD_MAX = 27, 35
def _trim_to_words(text: str, max_words: int) -> str:
toks = text.split()
return text.strip() if len(toks) <= max_words else " ".join(toks[:max_words]).rstrip(",;:") + "."
def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int):
text = text.strip()
if len(text) <= max_chars and len(text) >= min_chars: return text
cutoff = min(max_chars, len(text)); candidate = text[:cutoff]
m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate)
if m and (len(candidate[:m.end()].strip()) >= min_chars): return candidate[:m.end()].strip()
return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "")
def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool:
c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax)
def generate_tagline_and_desc(name: str, query_context: str):
tag_prompt = (
f"Write a short, benefit-driven tagline for a startup called '{name}'. "
f"Audience & domain: {query_context}. "
f"Target ~{TAG_CHAR_TARGET} characters and ~{TAG_WORD_TARGET} words. Avoid clichés."
)
tagline = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=28, temperature=0.9, top_p=0.95)[0]
tagline = re.sub(r"\s+", " ", tagline).strip()
tagline = _trim_to_words(tagline, TAG_WORD_TARGET + TAG_WORD_TOL)
if len(tagline) > TAG_CHAR_TARGET + TAG_CHAR_TOL:
tagline = tagline[:TAG_CHAR_TARGET + TAG_CHAR_TOL].rstrip(",;: -") + "…"
if not _within_ranges(tagline, TAG_CHAR_TARGET - TAG_CHAR_TOL, TAG_CHAR_TARGET + TAG_CHAR_TOL,
TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL):
tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0]
tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL)
if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET): tagline = tagline2
desc_prompt = (
f"Write a concise product description for the startup '{name}'. "
f"Context: {query_context}. "
f"Explain who it's for, what it does, and the main benefit. "
f"Target {DESC_CHAR_MIN}{DESC_CHAR_MAX} characters and {DESC_WORD_MIN}{DESC_WORD_MAX} words. "
f"Avoid fluff; keep it clear."
)
model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base)
description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0]
description = re.sub(r"\s+", " ", description).strip()
if len(description.split()) > DESC_WORD_MAX: description = _trim_to_words(description, DESC_WORD_MAX)
description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX)
if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX):
description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0]
description2 = re.sub(r"\s+", " ", description2).strip()
if len(description2.split()) > DESC_WORD_MAX: description2 = _trim_to_words(description2, DESC_WORD_MAX)
description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX)
target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2
if abs(len(description2) - target_mid) < abs(len(description) - target_mid): description = description2
return tagline, description
def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False):
names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3)
if len(names) == 0:
names = generate_names(query, n=n_candidates, oversample=140, max_retries=1)
if len(names) == 0:
toks = _content_words(query) or ["nova","learn","edu","mento"]
seeds = list({t[:4]+"ify" for t in toks} | {t[:3]+"ora" for t in toks} | {t[:4]+"io" for t in toks})
names = seeds[:n_candidates]
qv = _embed_query(query); embs = _embed_passages(names); cos = embs @ qv
banned = sorted(set(_content_words(query)))
final_scores = []
for nm, s in zip(names, cos):
toks = _content_words(nm); overlap = _overlap_ratio(toks, banned)
length_pen = 0.0; L = len(_normalize_name(nm))
if L < 4: length_pen += 0.3
if L > 16: length_pen += 0.2
final_scores.append(float(s) - 0.35*overlap - length_pen)
best_idx = int(np.argmax(final_scores)); best_name = names[best_idx]; best_score = float(final_scores[best_idx])
tagline, description = ("","")
if include_copy: tagline, description = generate_tagline_and_desc(best_name, query_context=query)
row = pd.DataFrame([{"rank":4,"score":best_score,"name":best_name,"tagline":tagline,"description":description}])
return row
# ---------- UI glue ----------
EXAMPLES = [
"AI tool to analyze customer feedback",
"Social network for jobs",
"Mobile fintech app for cross-border payments",
"AI learning tool for students",
"Marketplace for eco-friendly products",
]
def ui_search(query, state_unlikes):
query = (query or "").strip()
if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea."
state_unlikes = [] # reset for new query
res = search_topk_filtered_session(query, k=3, unliked_ids=set())
return res, state_unlikes, "Found 3 similar items. You can unlike by row_idx, then Refresh."
def ui_unlike(query, unlike_ids_csv, state_unlikes):
query = (query or "").strip()
if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first."
add_ids = set()
for tok in (unlike_ids_csv or "").split(","):
tok = tok.strip()
if tok.isdigit(): add_ids.add(int(tok))
cur = set(state_unlikes) | add_ids
res = search_topk_filtered_session(query, k=3, unliked_ids=cur)
return res, list(cur), f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(cur)}"
def ui_clear_unlikes(query):
query = (query or "").strip()
if not query: return gr.update(value=pd.DataFrame()), [], "Enter a query first."
res = search_topk_filtered_session(query, k=3, unliked_ids=set())
return res, [], "Cleared unlikes."
def ui_generate_synth(query, include_copy):
query = (query or "").strip()
if not query: return gr.update(value=pd.DataFrame()), "Enter a query first."
synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy)
return synth, "Generated AI option as #4. Combine it with your top-3."
def _apply_example(example_text, state_unlikes):
results, state_unlikes, msg = ui_search(example_text, state_unlikes)
return example_text, results, state_unlikes, f"Example selected: “{example_text}”. {msg}"
with gr.Blocks(title="Startup Recommender + AI Name") as app:
gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline & description).")
query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback")
with gr.Row():
gr.Markdown("**Try an example:**")
example_buttons = [gr.Button(ex, variant="secondary") for ex in EXAMPLES]
with gr.Row():
btn_search = gr.Button("Search Top-3")
unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456")
btn_unlike = gr.Button("Refresh after Unlike")
btn_clear = gr.Button("Clear Unlikes")
results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True)
gr.Markdown("### AI-Generated Option (#4)")
include_copy = gr.Checkbox(label="Also generate tagline & description", value=True)
btn_synth = gr.Button("Generate #4 (AI)")
synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True)
status = gr.Markdown("")
state_unlikes = gr.State([])
# wiring
btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status])
btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status])
btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status])
for btn, ex in zip(example_buttons, EXAMPLES):
btn.click(lambda st, ex_=ex: _apply_example(ex_, st),
inputs=[state_unlikes], outputs=[query, results_tbl, state_unlikes, status])
btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status])
# On Spaces, just calling launch() is fine; no explicit port.
if __name__ == "__main__":
app.launch()