Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,58 +1,53 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
# ./emb_index_e5/faiss.index and ./emb_index_e5/data.parquet
|
| 4 |
-
# Also expects you kept the improved generation functions we wrote earlier
|
| 5 |
-
# (generate_names, generate_tagline_and_desc(name, query_context), pick_best_synthetic_name).
|
| 6 |
-
# If not, this file includes compact versions below.
|
| 7 |
-
|
| 8 |
-
import os, re, random, numpy as np, pandas as pd
|
| 9 |
from pathlib import Path
|
| 10 |
import gradio as gr
|
| 11 |
-
|
| 12 |
-
# --- Imports for models ---
|
| 13 |
-
import torch
|
| 14 |
-
import faiss
|
| 15 |
from sentence_transformers import SentenceTransformer
|
| 16 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 17 |
|
| 18 |
-
#
|
| 19 |
OUT_DIR = Path("./emb_index_e5")
|
| 20 |
FAISS_PATH = OUT_DIR / "faiss.index"
|
| 21 |
DATA_PATH = OUT_DIR / "data.parquet"
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
-
|
| 30 |
-
# --- Load artifacts once ---
|
| 31 |
-
assert FAISS_PATH.exists(), f"Missing FAISS index at {FAISS_PATH}. Run the embedding/index step first."
|
| 32 |
-
assert DATA_PATH.exists(), f"Missing dataset at {DATA_PATH}. Run the embedding/index step first."
|
| 33 |
|
|
|
|
| 34 |
index = faiss.read_index(str(FAISS_PATH))
|
| 35 |
df_local = pd.read_parquet(DATA_PATH)
|
| 36 |
-
for
|
| 37 |
-
if
|
| 38 |
-
df_local[
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
# --- Load FLAN models (base + optional large) once ---
|
| 44 |
tok_base = AutoTokenizer.from_pretrained(MODEL_BASE)
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
if USE_LARGE_FOR_DESCRIPTION:
|
| 47 |
tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE)
|
| 48 |
-
|
|
|
|
| 49 |
else:
|
| 50 |
tok_large, mod_large = tok_base, mod_base
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1):
|
| 55 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(
|
| 56 |
outputs = model.generate(
|
| 57 |
**inputs,
|
| 58 |
max_new_tokens=max_new_tokens,
|
|
@@ -70,49 +65,20 @@ def _embed_passages(texts) -> np.ndarray:
|
|
| 70 |
texts = [f"passage: {t}" for t in texts]
|
| 71 |
return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
|
| 72 |
|
| 73 |
-
#
|
| 74 |
-
|
| 75 |
-
EXAMPLES = [
|
| 76 |
-
"AI tool to analyze customer feedback",
|
| 77 |
-
"Social network for jobs",
|
| 78 |
-
"Mobile fintech app for cross-border payments",
|
| 79 |
-
"AI learning tool for students",
|
| 80 |
-
"Marketplace for eco-friendly products",
|
| 81 |
-
]
|
| 82 |
-
|
| 83 |
-
# Helper that applies an example: fills the textbox, resets unlikes, runs search
|
| 84 |
-
def _apply_example(example_text, state_unlikes):
|
| 85 |
-
# Reuse your existing ui_search() so behavior stays identical
|
| 86 |
-
results, state_unlikes, msg = ui_search(example_text, state_unlikes)
|
| 87 |
-
# Return: set the query box text, update table & state & status
|
| 88 |
-
return example_text, results, state_unlikes, f"Example selected: “{example_text}” — {msg}"
|
| 89 |
-
|
| 90 |
-
# ---------- Search (with per-session unlikes) ----------
|
| 91 |
def search_topk_filtered_session(query: str, k: int, unliked_ids: set):
|
| 92 |
-
"""Return top-K rows excluding row_idx values in unliked_ids."""
|
| 93 |
qv = _embed_query(query)
|
| 94 |
fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids)))
|
| 95 |
scores, inds = index.search(qv[None, :], fetch)
|
| 96 |
-
inds = inds[0].tolist()
|
| 97 |
-
scores = scores[0].tolist()
|
| 98 |
-
|
| 99 |
res = df_local.iloc[inds][["name","tagline","description"]].copy()
|
| 100 |
res.insert(0, "row_idx", df_local.iloc[inds].index)
|
| 101 |
res.insert(1, "score", [float(s) for s in scores])
|
| 102 |
-
|
| 103 |
-
# filter-out unlikes
|
| 104 |
-
mask = ~res["row_idx"].isin(unliked_ids)
|
| 105 |
-
res = res[mask].head(k).reset_index(drop=True)
|
| 106 |
res.insert(0, "rank", range(1, len(res)+1))
|
| 107 |
return res
|
| 108 |
|
| 109 |
-
# ---------- Synthetic generation (
|
| 110 |
-
# Length targets (from your EDA)
|
| 111 |
-
TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6
|
| 112 |
-
TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2
|
| 113 |
-
DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230
|
| 114 |
-
DESC_WORD_MIN, DESC_WORD_MAX = 27, 35
|
| 115 |
-
|
| 116 |
_STOPWORDS = {
|
| 117 |
"the","a","an","for","and","or","to","of","in","on","with","by","from",
|
| 118 |
"my","our","your","their","at","as","about","into","over","under","this","that",
|
|
@@ -124,104 +90,90 @@ def _normalize_name(s: str) -> str: return re.sub(r"[^a-z0-9]+", "", str(s).lowe
|
|
| 124 |
def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower()))
|
| 125 |
def _overlap_ratio(name_tokens, banned):
|
| 126 |
if not name_tokens or not banned: return 0.0
|
| 127 |
-
inter = len(set(name_tokens) & set(banned))
|
| 128 |
-
union = len(set(name_tokens) | set(banned))
|
| 129 |
return inter / max(union, 1)
|
| 130 |
|
|
|
|
|
|
|
| 131 |
def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int):
|
| 132 |
c = len(text); w = len(text.split())
|
| 133 |
return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words)
|
| 134 |
|
| 135 |
def _theme_hints(query: str, k: int = 6):
|
| 136 |
-
kws = _content_words(query)
|
| 137 |
-
seen, hints = set(), []
|
| 138 |
for t in kws:
|
| 139 |
-
if t not in seen:
|
| 140 |
-
hints.append(t); seen.add(t)
|
| 141 |
return ", ".join(hints[:k]) if hints else "education, learning, students, AI"
|
| 142 |
|
| 143 |
-
NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3
|
| 144 |
-
NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3
|
| 145 |
-
|
| 146 |
def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3):
|
| 147 |
banned = sorted(set(_content_words(base_idea)))
|
| 148 |
avoid_str = ", ".join(banned[:12]) if banned else "previous words"
|
| 149 |
hints = _theme_hints(base_idea)
|
| 150 |
-
|
| 151 |
def _prompt(osz):
|
| 152 |
return (
|
| 153 |
f"Create {osz} brandable startup names for this idea:\n"
|
| 154 |
f"\"{base_idea}\"\n\n"
|
| 155 |
f"Guidance:\n"
|
| 156 |
f"- Evoke these themes (without literally using the words): {hints}\n"
|
| 157 |
-
f"- 1 or 2
|
| 158 |
f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n"
|
| 159 |
f"- Do NOT use: {avoid_str}\n"
|
| 160 |
f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n"
|
| 161 |
f"- Output one name per line; no numbering, no quotes."
|
| 162 |
)
|
| 163 |
-
|
| 164 |
-
all_candidates = []
|
| 165 |
for attempt in range(max_retries):
|
| 166 |
-
raw = _generate_text(
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
)[0]
|
| 171 |
-
|
| 172 |
-
batch = []
|
| 173 |
for line in raw.splitlines():
|
| 174 |
nm = line.strip().lstrip("-•*0123456789. ").strip()
|
| 175 |
-
if
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
all_candidates.extend(batch)
|
| 179 |
-
|
| 180 |
# dedup
|
| 181 |
uniq, seen = [], set()
|
| 182 |
for nm in all_candidates:
|
| 183 |
key = _normalize_name(nm)
|
| 184 |
-
if
|
| 185 |
-
|
| 186 |
all_candidates = uniq
|
| 187 |
-
|
| 188 |
-
# progressive filtering
|
| 189 |
def ok(nm: str, overlap_cap: float, tol_boost: int):
|
| 190 |
if not _has_vowel(nm): return False
|
| 191 |
-
if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX):
|
| 192 |
-
return False
|
| 193 |
toks = _content_words(nm)
|
| 194 |
if _overlap_ratio(toks, banned) > overlap_cap: return False
|
| 195 |
if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False
|
| 196 |
return True
|
| 197 |
-
|
| 198 |
-
overlap_caps = [0.25, 0.35, 0.5]
|
| 199 |
-
tol_boosts = [0, 1, 2]
|
| 200 |
filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])]
|
| 201 |
-
if len(filtered) >= n:
|
| 202 |
-
return filtered[:n]
|
| 203 |
-
|
| 204 |
return all_candidates[:n] if all_candidates else []
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
def _trim_to_words(text: str, max_words: int) -> str:
|
| 207 |
toks = text.split()
|
| 208 |
-
if len(toks) <= max_words
|
| 209 |
-
return " ".join(toks[:max_words]).rstrip(",;:") + "."
|
| 210 |
|
| 211 |
def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int):
|
| 212 |
text = text.strip()
|
| 213 |
if len(text) <= max_chars and len(text) >= min_chars: return text
|
| 214 |
cutoff = min(max_chars, len(text)); candidate = text[:cutoff]
|
| 215 |
m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate)
|
| 216 |
-
if m and (len(candidate[:m.end()].strip()) >= min_chars):
|
| 217 |
-
return candidate[:m.end()].strip()
|
| 218 |
return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "")
|
| 219 |
|
| 220 |
def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool:
|
| 221 |
c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax)
|
| 222 |
|
| 223 |
def generate_tagline_and_desc(name: str, query_context: str):
|
| 224 |
-
# Tagline (≈40 chars, ≈6 words)
|
| 225 |
tag_prompt = (
|
| 226 |
f"Write a short, benefit-driven tagline for a startup called '{name}'. "
|
| 227 |
f"Audience & domain: {query_context}. "
|
|
@@ -236,10 +188,8 @@ def generate_tagline_and_desc(name: str, query_context: str):
|
|
| 236 |
TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL):
|
| 237 |
tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0]
|
| 238 |
tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL)
|
| 239 |
-
if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET):
|
| 240 |
-
tagline = tagline2
|
| 241 |
|
| 242 |
-
# Description (≈170–230 chars, 27–35 words)
|
| 243 |
desc_prompt = (
|
| 244 |
f"Write a concise product description for the startup '{name}'. "
|
| 245 |
f"Context: {query_context}. "
|
|
@@ -250,154 +200,119 @@ def generate_tagline_and_desc(name: str, query_context: str):
|
|
| 250 |
model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base)
|
| 251 |
description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0]
|
| 252 |
description = re.sub(r"\s+", " ", description).strip()
|
| 253 |
-
if len(description.split()) > DESC_WORD_MAX:
|
| 254 |
-
description = _trim_to_words(description, DESC_WORD_MAX)
|
| 255 |
description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX)
|
| 256 |
-
|
| 257 |
if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX):
|
| 258 |
description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0]
|
| 259 |
description2 = re.sub(r"\s+", " ", description2).strip()
|
| 260 |
-
if len(description2.split()) > DESC_WORD_MAX:
|
| 261 |
-
description2 = _trim_to_words(description2, DESC_WORD_MAX)
|
| 262 |
description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX)
|
| 263 |
target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2
|
| 264 |
-
if abs(len(description2) - target_mid) < abs(len(description) - target_mid):
|
| 265 |
-
description = description2
|
| 266 |
-
|
| 267 |
return tagline, description
|
| 268 |
|
| 269 |
def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False):
|
| 270 |
-
# Generate candidates & score vs query (cosine)
|
| 271 |
names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3)
|
| 272 |
if len(names) == 0:
|
| 273 |
-
# permissive retry
|
| 274 |
names = generate_names(query, n=n_candidates, oversample=140, max_retries=1)
|
| 275 |
if len(names) == 0:
|
| 276 |
toks = _content_words(query) or ["nova","learn","edu","mento"]
|
| 277 |
-
seeds =
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
seeds.add((t[:3] + "ora"))
|
| 281 |
-
seeds.add((t[:4] + "io"))
|
| 282 |
-
names = list(seeds)[:n_candidates]
|
| 283 |
-
|
| 284 |
-
qv = _embed_query(query)
|
| 285 |
-
embs = _embed_passages(names)
|
| 286 |
-
cos = embs @ qv
|
| 287 |
-
|
| 288 |
banned = sorted(set(_content_words(query)))
|
| 289 |
final_scores = []
|
| 290 |
for nm, s in zip(names, cos):
|
| 291 |
-
toks = _content_words(nm)
|
| 292 |
-
|
| 293 |
-
length_pen
|
| 294 |
-
L = len(_normalize_name(nm))
|
| 295 |
-
if L < 4: length_pen += 0.3
|
| 296 |
if L > 16: length_pen += 0.2
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
best_idx = int(np.argmax(final_scores))
|
| 301 |
-
best_name = names[best_idx]
|
| 302 |
-
best_score = float(final_scores[best_idx])
|
| 303 |
-
|
| 304 |
tagline, description = ("","")
|
| 305 |
-
if include_copy:
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
row = pd.DataFrame([{
|
| 309 |
-
"rank": 4,
|
| 310 |
-
"score": best_score,
|
| 311 |
-
"name": best_name,
|
| 312 |
-
"tagline": tagline,
|
| 313 |
-
"description": description
|
| 314 |
-
}])
|
| 315 |
return row
|
| 316 |
|
| 317 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
def ui_search(query, state_unlikes):
|
| 320 |
-
"""Start a new search, reset unlikes if it's a new query."""
|
| 321 |
query = (query or "").strip()
|
| 322 |
-
if not query:
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
state_unlikes
|
| 326 |
-
results = search_topk_filtered_session(query, k=3, unliked_ids=state_unlikes)
|
| 327 |
-
return results, list(state_unlikes), "Found 3 similar items. You can unlike by row_idx, then 'Refresh'."
|
| 328 |
|
| 329 |
def ui_unlike(query, unlike_ids_csv, state_unlikes):
|
| 330 |
query = (query or "").strip()
|
| 331 |
-
if not query:
|
| 332 |
-
return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first."
|
| 333 |
add_ids = set()
|
| 334 |
for tok in (unlike_ids_csv or "").split(","):
|
| 335 |
tok = tok.strip()
|
| 336 |
-
if tok.isdigit():
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
msg = f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(state_unlikes)}"
|
| 341 |
-
return results, list(state_unlikes), msg
|
| 342 |
|
| 343 |
def ui_clear_unlikes(query):
|
| 344 |
query = (query or "").strip()
|
| 345 |
-
if not query:
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
return results, [], "Cleared unlikes."
|
| 349 |
|
| 350 |
def ui_generate_synth(query, include_copy):
|
| 351 |
query = (query or "").strip()
|
| 352 |
-
if not query:
|
| 353 |
-
return gr.update(value=pd.DataFrame()), "Enter a query first."
|
| 354 |
synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy)
|
| 355 |
return synth, "Generated AI option as #4. Combine it with your top-3."
|
| 356 |
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
-
with gr.Row():
|
| 361 |
-
query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback")
|
| 362 |
with gr.Row():
|
| 363 |
gr.Markdown("**Try an example:**")
|
| 364 |
-
example_buttons = []
|
| 365 |
-
|
| 366 |
-
example_buttons.append(gr.Button(ex, variant="secondary", scale=1))
|
| 367 |
with gr.Row():
|
| 368 |
btn_search = gr.Button("Search Top-3")
|
| 369 |
unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456")
|
| 370 |
btn_unlike = gr.Button("Refresh after Unlike")
|
| 371 |
btn_clear = gr.Button("Clear Unlikes")
|
| 372 |
-
with gr.Row():
|
| 373 |
-
results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True)
|
| 374 |
|
| 375 |
-
gr.
|
| 376 |
-
with gr.Row():
|
| 377 |
-
include_copy = gr.Checkbox(label="Also generate tagline & description", value=True)
|
| 378 |
-
btn_synth = gr.Button("Generate #4 (AI)")
|
| 379 |
|
|
|
|
|
|
|
|
|
|
| 380 |
synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True)
|
| 381 |
-
status = gr.Markdown("")
|
| 382 |
|
| 383 |
-
|
| 384 |
state_unlikes = gr.State([])
|
| 385 |
|
| 386 |
-
#
|
| 387 |
btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status])
|
| 388 |
btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status])
|
| 389 |
btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status])
|
| 390 |
-
|
| 391 |
for btn, ex in zip(example_buttons, EXAMPLES):
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
lambda st, ex_=ex: _apply_example(ex_, st),
|
| 395 |
-
inputs=[state_unlikes], # current unlike state
|
| 396 |
-
outputs=[query, results_tbl, state_unlikes, status]
|
| 397 |
-
)
|
| 398 |
|
| 399 |
-
|
| 400 |
|
| 401 |
-
#
|
| 402 |
if __name__ == "__main__":
|
| 403 |
-
app.
|
|
|
|
| 1 |
+
# app.py — Startup recommender + Unlike + AI name (optional tagline/description)
|
| 2 |
+
import os, re, numpy as np, pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
import gradio as gr
|
| 5 |
+
import torch, faiss
|
|
|
|
|
|
|
|
|
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 8 |
|
| 9 |
+
# ---------- Paths / artifacts ----------
|
| 10 |
OUT_DIR = Path("./emb_index_e5")
|
| 11 |
FAISS_PATH = OUT_DIR / "faiss.index"
|
| 12 |
DATA_PATH = OUT_DIR / "data.parquet"
|
| 13 |
+
assert FAISS_PATH.exists(), f"Missing {FAISS_PATH}. Build & upload embeddings/index."
|
| 14 |
+
assert DATA_PATH.exists(), f"Missing {DATA_PATH}. Build & upload data parquet."
|
| 15 |
|
| 16 |
+
# ---------- Devices ----------
|
| 17 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 18 |
+
DEVICE_EMBED = "cuda" if torch.cuda.is_available() else "cpu" # e5 on GPU if available
|
| 19 |
+
DEVICE_GEN = "cpu" # FLAN on CPU (avoid OOM)
|
| 20 |
+
print(f"Embed device: {DEVICE_EMBED} | Gen device: {DEVICE_GEN}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
# ---------- Load artifacts ----------
|
| 23 |
index = faiss.read_index(str(FAISS_PATH))
|
| 24 |
df_local = pd.read_parquet(DATA_PATH)
|
| 25 |
+
for c in ["name","tagline","description"]:
|
| 26 |
+
if c in df_local.columns:
|
| 27 |
+
df_local[c] = df_local[c].astype(str).fillna("")
|
| 28 |
|
| 29 |
+
# ---------- Load models ----------
|
| 30 |
+
EMBED_MODEL = "intfloat/e5-base-v2"
|
| 31 |
+
embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE_EMBED)
|
| 32 |
+
|
| 33 |
+
MODEL_BASE = "google/flan-t5-base"
|
| 34 |
+
MODEL_LARGE = "google/flan-t5-large"
|
| 35 |
+
USE_LARGE_FOR_DESCRIPTION = False # keep False on Spaces unless you switch GEN to "cuda"
|
| 36 |
|
|
|
|
| 37 |
tok_base = AutoTokenizer.from_pretrained(MODEL_BASE)
|
| 38 |
+
base_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {}
|
| 39 |
+
mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE, **base_kwargs).to(DEVICE_GEN)
|
| 40 |
+
|
| 41 |
if USE_LARGE_FOR_DESCRIPTION:
|
| 42 |
tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE)
|
| 43 |
+
large_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {}
|
| 44 |
+
mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE, **large_kwargs).to(DEVICE_GEN)
|
| 45 |
else:
|
| 46 |
tok_large, mod_large = tok_base, mod_base
|
| 47 |
|
| 48 |
+
# ---------- Helpers (embedding + generation) ----------
|
|
|
|
| 49 |
def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1):
|
| 50 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE_GEN)
|
| 51 |
outputs = model.generate(
|
| 52 |
**inputs,
|
| 53 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 65 |
texts = [f"passage: {t}" for t in texts]
|
| 66 |
return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
|
| 67 |
|
| 68 |
+
# ---------- Search with per-session unlikes ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def search_topk_filtered_session(query: str, k: int, unliked_ids: set):
|
|
|
|
| 70 |
qv = _embed_query(query)
|
| 71 |
fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids)))
|
| 72 |
scores, inds = index.search(qv[None, :], fetch)
|
| 73 |
+
inds = inds[0].tolist(); scores = scores[0].tolist()
|
|
|
|
|
|
|
| 74 |
res = df_local.iloc[inds][["name","tagline","description"]].copy()
|
| 75 |
res.insert(0, "row_idx", df_local.iloc[inds].index)
|
| 76 |
res.insert(1, "score", [float(s) for s in scores])
|
| 77 |
+
res = res[~res["row_idx"].isin(unliked_ids)].head(k).reset_index(drop=True)
|
|
|
|
|
|
|
|
|
|
| 78 |
res.insert(0, "rank", range(1, len(res)+1))
|
| 79 |
return res
|
| 80 |
|
| 81 |
+
# ---------- Synthetic generation (length-aware) ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
_STOPWORDS = {
|
| 83 |
"the","a","an","for","and","or","to","of","in","on","with","by","from",
|
| 84 |
"my","our","your","their","at","as","about","into","over","under","this","that",
|
|
|
|
| 90 |
def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower()))
|
| 91 |
def _overlap_ratio(name_tokens, banned):
|
| 92 |
if not name_tokens or not banned: return 0.0
|
| 93 |
+
inter = len(set(name_tokens) & set(banned)); union = len(set(name_tokens) | set(banned))
|
|
|
|
| 94 |
return inter / max(union, 1)
|
| 95 |
|
| 96 |
+
NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3
|
| 97 |
+
NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3
|
| 98 |
def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int):
|
| 99 |
c = len(text); w = len(text.split())
|
| 100 |
return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words)
|
| 101 |
|
| 102 |
def _theme_hints(query: str, k: int = 6):
|
| 103 |
+
kws = _content_words(query); seen, hints = set(), []
|
|
|
|
| 104 |
for t in kws:
|
| 105 |
+
if t not in seen: hints.append(t); seen.add(t)
|
|
|
|
| 106 |
return ", ".join(hints[:k]) if hints else "education, learning, students, AI"
|
| 107 |
|
|
|
|
|
|
|
|
|
|
| 108 |
def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3):
|
| 109 |
banned = sorted(set(_content_words(base_idea)))
|
| 110 |
avoid_str = ", ".join(banned[:12]) if banned else "previous words"
|
| 111 |
hints = _theme_hints(base_idea)
|
| 112 |
+
all_candidates = []
|
| 113 |
def _prompt(osz):
|
| 114 |
return (
|
| 115 |
f"Create {osz} brandable startup names for this idea:\n"
|
| 116 |
f"\"{base_idea}\"\n\n"
|
| 117 |
f"Guidance:\n"
|
| 118 |
f"- Evoke these themes (without literally using the words): {hints}\n"
|
| 119 |
+
f"- 1 or 2 words; aim ~{NAME_CHAR_TARGET} characters (±{NAME_CHAR_TOL})\n"
|
| 120 |
f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n"
|
| 121 |
f"- Do NOT use: {avoid_str}\n"
|
| 122 |
f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n"
|
| 123 |
f"- Output one name per line; no numbering, no quotes."
|
| 124 |
)
|
|
|
|
|
|
|
| 125 |
for attempt in range(max_retries):
|
| 126 |
+
raw = _generate_text(mod_base, tok_base, _prompt(oversample),
|
| 127 |
+
num_return_sequences=1, max_new_tokens=240,
|
| 128 |
+
temperature=1.0 + 0.05*attempt, top_p=0.95)[0]
|
| 129 |
+
# collect
|
|
|
|
|
|
|
|
|
|
| 130 |
for line in raw.splitlines():
|
| 131 |
nm = line.strip().lstrip("-•*0123456789. ").strip()
|
| 132 |
+
if nm:
|
| 133 |
+
nm = re.sub(r"[^\w\s-]+$", "", nm).strip()
|
| 134 |
+
all_candidates.append(nm)
|
|
|
|
|
|
|
| 135 |
# dedup
|
| 136 |
uniq, seen = [], set()
|
| 137 |
for nm in all_candidates:
|
| 138 |
key = _normalize_name(nm)
|
| 139 |
+
if key and key not in seen:
|
| 140 |
+
seen.add(key); uniq.append(nm)
|
| 141 |
all_candidates = uniq
|
| 142 |
+
# progressive filter
|
|
|
|
| 143 |
def ok(nm: str, overlap_cap: float, tol_boost: int):
|
| 144 |
if not _has_vowel(nm): return False
|
| 145 |
+
if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX): return False
|
|
|
|
| 146 |
toks = _content_words(nm)
|
| 147 |
if _overlap_ratio(toks, banned) > overlap_cap: return False
|
| 148 |
if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False
|
| 149 |
return True
|
| 150 |
+
overlap_caps = [0.25, 0.35, 0.5]; tol_boosts = [0, 1, 2]
|
|
|
|
|
|
|
| 151 |
filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])]
|
| 152 |
+
if len(filtered) >= n: return filtered[:n]
|
|
|
|
|
|
|
| 153 |
return all_candidates[:n] if all_candidates else []
|
| 154 |
|
| 155 |
+
# Tagline/description length targets (from your EDA)
|
| 156 |
+
TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6
|
| 157 |
+
TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2
|
| 158 |
+
DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230
|
| 159 |
+
DESC_WORD_MIN, DESC_WORD_MAX = 27, 35
|
| 160 |
+
|
| 161 |
def _trim_to_words(text: str, max_words: int) -> str:
|
| 162 |
toks = text.split()
|
| 163 |
+
return text.strip() if len(toks) <= max_words else " ".join(toks[:max_words]).rstrip(",;:") + "."
|
|
|
|
| 164 |
|
| 165 |
def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int):
|
| 166 |
text = text.strip()
|
| 167 |
if len(text) <= max_chars and len(text) >= min_chars: return text
|
| 168 |
cutoff = min(max_chars, len(text)); candidate = text[:cutoff]
|
| 169 |
m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate)
|
| 170 |
+
if m and (len(candidate[:m.end()].strip()) >= min_chars): return candidate[:m.end()].strip()
|
|
|
|
| 171 |
return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "")
|
| 172 |
|
| 173 |
def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool:
|
| 174 |
c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax)
|
| 175 |
|
| 176 |
def generate_tagline_and_desc(name: str, query_context: str):
|
|
|
|
| 177 |
tag_prompt = (
|
| 178 |
f"Write a short, benefit-driven tagline for a startup called '{name}'. "
|
| 179 |
f"Audience & domain: {query_context}. "
|
|
|
|
| 188 |
TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL):
|
| 189 |
tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0]
|
| 190 |
tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL)
|
| 191 |
+
if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET): tagline = tagline2
|
|
|
|
| 192 |
|
|
|
|
| 193 |
desc_prompt = (
|
| 194 |
f"Write a concise product description for the startup '{name}'. "
|
| 195 |
f"Context: {query_context}. "
|
|
|
|
| 200 |
model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base)
|
| 201 |
description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0]
|
| 202 |
description = re.sub(r"\s+", " ", description).strip()
|
| 203 |
+
if len(description.split()) > DESC_WORD_MAX: description = _trim_to_words(description, DESC_WORD_MAX)
|
|
|
|
| 204 |
description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX)
|
|
|
|
| 205 |
if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX):
|
| 206 |
description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0]
|
| 207 |
description2 = re.sub(r"\s+", " ", description2).strip()
|
| 208 |
+
if len(description2.split()) > DESC_WORD_MAX: description2 = _trim_to_words(description2, DESC_WORD_MAX)
|
|
|
|
| 209 |
description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX)
|
| 210 |
target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2
|
| 211 |
+
if abs(len(description2) - target_mid) < abs(len(description) - target_mid): description = description2
|
|
|
|
|
|
|
| 212 |
return tagline, description
|
| 213 |
|
| 214 |
def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False):
|
|
|
|
| 215 |
names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3)
|
| 216 |
if len(names) == 0:
|
|
|
|
| 217 |
names = generate_names(query, n=n_candidates, oversample=140, max_retries=1)
|
| 218 |
if len(names) == 0:
|
| 219 |
toks = _content_words(query) or ["nova","learn","edu","mento"]
|
| 220 |
+
seeds = list({t[:4]+"ify" for t in toks} | {t[:3]+"ora" for t in toks} | {t[:4]+"io" for t in toks})
|
| 221 |
+
names = seeds[:n_candidates]
|
| 222 |
+
qv = _embed_query(query); embs = _embed_passages(names); cos = embs @ qv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
banned = sorted(set(_content_words(query)))
|
| 224 |
final_scores = []
|
| 225 |
for nm, s in zip(names, cos):
|
| 226 |
+
toks = _content_words(nm); overlap = _overlap_ratio(toks, banned)
|
| 227 |
+
length_pen = 0.0; L = len(_normalize_name(nm))
|
| 228 |
+
if L < 4: length_pen += 0.3
|
|
|
|
|
|
|
| 229 |
if L > 16: length_pen += 0.2
|
| 230 |
+
final_scores.append(float(s) - 0.35*overlap - length_pen)
|
| 231 |
+
best_idx = int(np.argmax(final_scores)); best_name = names[best_idx]; best_score = float(final_scores[best_idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
tagline, description = ("","")
|
| 233 |
+
if include_copy: tagline, description = generate_tagline_and_desc(best_name, query_context=query)
|
| 234 |
+
row = pd.DataFrame([{"rank":4,"score":best_score,"name":best_name,"tagline":tagline,"description":description}])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
return row
|
| 236 |
|
| 237 |
+
# ---------- UI glue ----------
|
| 238 |
+
EXAMPLES = [
|
| 239 |
+
"AI tool to analyze customer feedback",
|
| 240 |
+
"Social network for jobs",
|
| 241 |
+
"Mobile fintech app for cross-border payments",
|
| 242 |
+
"AI learning tool for students",
|
| 243 |
+
"Marketplace for eco-friendly products",
|
| 244 |
+
]
|
| 245 |
|
| 246 |
def ui_search(query, state_unlikes):
|
|
|
|
| 247 |
query = (query or "").strip()
|
| 248 |
+
if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea."
|
| 249 |
+
state_unlikes = [] # reset for new query
|
| 250 |
+
res = search_topk_filtered_session(query, k=3, unliked_ids=set())
|
| 251 |
+
return res, state_unlikes, "Found 3 similar items. You can unlike by row_idx, then Refresh."
|
|
|
|
|
|
|
| 252 |
|
| 253 |
def ui_unlike(query, unlike_ids_csv, state_unlikes):
|
| 254 |
query = (query or "").strip()
|
| 255 |
+
if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first."
|
|
|
|
| 256 |
add_ids = set()
|
| 257 |
for tok in (unlike_ids_csv or "").split(","):
|
| 258 |
tok = tok.strip()
|
| 259 |
+
if tok.isdigit(): add_ids.add(int(tok))
|
| 260 |
+
cur = set(state_unlikes) | add_ids
|
| 261 |
+
res = search_topk_filtered_session(query, k=3, unliked_ids=cur)
|
| 262 |
+
return res, list(cur), f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(cur)}"
|
|
|
|
|
|
|
| 263 |
|
| 264 |
def ui_clear_unlikes(query):
|
| 265 |
query = (query or "").strip()
|
| 266 |
+
if not query: return gr.update(value=pd.DataFrame()), [], "Enter a query first."
|
| 267 |
+
res = search_topk_filtered_session(query, k=3, unliked_ids=set())
|
| 268 |
+
return res, [], "Cleared unlikes."
|
|
|
|
| 269 |
|
| 270 |
def ui_generate_synth(query, include_copy):
|
| 271 |
query = (query or "").strip()
|
| 272 |
+
if not query: return gr.update(value=pd.DataFrame()), "Enter a query first."
|
|
|
|
| 273 |
synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy)
|
| 274 |
return synth, "Generated AI option as #4. Combine it with your top-3."
|
| 275 |
|
| 276 |
+
def _apply_example(example_text, state_unlikes):
|
| 277 |
+
results, state_unlikes, msg = ui_search(example_text, state_unlikes)
|
| 278 |
+
return example_text, results, state_unlikes, f"Example selected: “{example_text}”. {msg}"
|
| 279 |
+
|
| 280 |
+
with gr.Blocks(title="Startup Recommender + AI Name") as app:
|
| 281 |
+
gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline & description).")
|
| 282 |
+
|
| 283 |
+
query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback")
|
| 284 |
|
|
|
|
|
|
|
| 285 |
with gr.Row():
|
| 286 |
gr.Markdown("**Try an example:**")
|
| 287 |
+
example_buttons = [gr.Button(ex, variant="secondary") for ex in EXAMPLES]
|
| 288 |
+
|
|
|
|
| 289 |
with gr.Row():
|
| 290 |
btn_search = gr.Button("Search Top-3")
|
| 291 |
unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456")
|
| 292 |
btn_unlike = gr.Button("Refresh after Unlike")
|
| 293 |
btn_clear = gr.Button("Clear Unlikes")
|
|
|
|
|
|
|
| 294 |
|
| 295 |
+
results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True)
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
+
gr.Markdown("### AI-Generated Option (#4)")
|
| 298 |
+
include_copy = gr.Checkbox(label="Also generate tagline & description", value=True)
|
| 299 |
+
btn_synth = gr.Button("Generate #4 (AI)")
|
| 300 |
synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True)
|
|
|
|
| 301 |
|
| 302 |
+
status = gr.Markdown("")
|
| 303 |
state_unlikes = gr.State([])
|
| 304 |
|
| 305 |
+
# wiring
|
| 306 |
btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status])
|
| 307 |
btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status])
|
| 308 |
btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status])
|
| 309 |
+
|
| 310 |
for btn, ex in zip(example_buttons, EXAMPLES):
|
| 311 |
+
btn.click(lambda st, ex_=ex: _apply_example(ex_, st),
|
| 312 |
+
inputs=[state_unlikes], outputs=[query, results_tbl, state_unlikes, status])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
+
btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status])
|
| 315 |
|
| 316 |
+
# On Spaces, just calling launch() is fine; no explicit port.
|
| 317 |
if __name__ == "__main__":
|
| 318 |
+
app.queue(concurrency_count=2).launch()
|