barakb21 commited on
Commit
3dd33d9
·
verified ·
1 Parent(s): 680855f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -119
app.py CHANGED
@@ -1,53 +1,58 @@
1
- # app.py Startup recommender + Unlike + AI name (optional tagline/description)
2
- import os, re, numpy as np, pandas as pd
 
 
 
 
 
 
3
  from pathlib import Path
4
  import gradio as gr
5
- import torch, faiss
 
 
 
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
- # ---------- Paths / artifacts ----------
10
  OUT_DIR = Path("./emb_index_e5")
11
  FAISS_PATH = OUT_DIR / "faiss.index"
12
  DATA_PATH = OUT_DIR / "data.parquet"
13
- assert FAISS_PATH.exists(), f"Missing {FAISS_PATH}. Build & upload embeddings/index."
14
- assert DATA_PATH.exists(), f"Missing {DATA_PATH}. Build & upload data parquet."
15
 
16
- # ---------- Devices ----------
17
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
18
- DEVICE_EMBED = "cuda" if torch.cuda.is_available() else "cpu" # e5 on GPU if available
19
- DEVICE_GEN = "cpu" # FLAN on CPU (avoid OOM)
20
- print(f"Embed device: {DEVICE_EMBED} | Gen device: {DEVICE_GEN}")
 
 
 
 
 
21
 
22
- # ---------- Load artifacts ----------
23
  index = faiss.read_index(str(FAISS_PATH))
24
  df_local = pd.read_parquet(DATA_PATH)
25
- for c in ["name","tagline","description"]:
26
- if c in df_local.columns:
27
- df_local[c] = df_local[c].astype(str).fillna("")
28
 
29
- # ---------- Load models ----------
30
- EMBED_MODEL = "intfloat/e5-base-v2"
31
- embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE_EMBED)
32
-
33
- MODEL_BASE = "google/flan-t5-base"
34
- MODEL_LARGE = "google/flan-t5-large"
35
- USE_LARGE_FOR_DESCRIPTION = False # keep False on Spaces unless you switch GEN to "cuda"
36
 
 
37
  tok_base = AutoTokenizer.from_pretrained(MODEL_BASE)
38
- base_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {}
39
- mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE, **base_kwargs).to(DEVICE_GEN)
40
-
41
  if USE_LARGE_FOR_DESCRIPTION:
42
  tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE)
43
- large_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {}
44
- mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE, **large_kwargs).to(DEVICE_GEN)
45
  else:
46
  tok_large, mod_large = tok_base, mod_base
47
 
48
- # ---------- Helpers (embedding + generation) ----------
 
49
  def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1):
50
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE_GEN)
51
  outputs = model.generate(
52
  **inputs,
53
  max_new_tokens=max_new_tokens,
@@ -65,20 +70,49 @@ def _embed_passages(texts) -> np.ndarray:
65
  texts = [f"passage: {t}" for t in texts]
66
  return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
67
 
68
- # ---------- Search with per-session unlikes ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def search_topk_filtered_session(query: str, k: int, unliked_ids: set):
 
70
  qv = _embed_query(query)
71
  fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids)))
72
  scores, inds = index.search(qv[None, :], fetch)
73
- inds = inds[0].tolist(); scores = scores[0].tolist()
 
 
74
  res = df_local.iloc[inds][["name","tagline","description"]].copy()
75
  res.insert(0, "row_idx", df_local.iloc[inds].index)
76
  res.insert(1, "score", [float(s) for s in scores])
77
- res = res[~res["row_idx"].isin(unliked_ids)].head(k).reset_index(drop=True)
 
 
 
78
  res.insert(0, "rank", range(1, len(res)+1))
79
  return res
80
 
81
- # ---------- Synthetic generation (length-aware) ----------
 
 
 
 
 
 
82
  _STOPWORDS = {
83
  "the","a","an","for","and","or","to","of","in","on","with","by","from",
84
  "my","our","your","their","at","as","about","into","over","under","this","that",
@@ -90,90 +124,104 @@ def _normalize_name(s: str) -> str: return re.sub(r"[^a-z0-9]+", "", str(s).lowe
90
  def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower()))
91
  def _overlap_ratio(name_tokens, banned):
92
  if not name_tokens or not banned: return 0.0
93
- inter = len(set(name_tokens) & set(banned)); union = len(set(name_tokens) | set(banned))
 
94
  return inter / max(union, 1)
95
 
96
- NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3
97
- NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 2
98
  def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int):
99
  c = len(text); w = len(text.split())
100
  return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words)
101
 
102
  def _theme_hints(query: str, k: int = 6):
103
- kws = _content_words(query); seen, hints = set(), []
 
104
  for t in kws:
105
- if t not in seen: hints.append(t); seen.add(t)
 
106
  return ", ".join(hints[:k]) if hints else "education, learning, students, AI"
107
 
 
 
 
108
  def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3):
109
  banned = sorted(set(_content_words(base_idea)))
110
  avoid_str = ", ".join(banned[:12]) if banned else "previous words"
111
  hints = _theme_hints(base_idea)
112
- all_candidates = []
113
  def _prompt(osz):
114
  return (
115
  f"Create {osz} brandable startup names for this idea:\n"
116
  f"\"{base_idea}\"\n\n"
117
  f"Guidance:\n"
118
  f"- Evoke these themes (without literally using the words): {hints}\n"
119
- f"- 1 or 2 words; aim ~{NAME_CHAR_TARGET} characters (±{NAME_CHAR_TOL})\n"
120
  f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n"
121
  f"- Do NOT use: {avoid_str}\n"
122
  f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n"
123
  f"- Output one name per line; no numbering, no quotes."
124
  )
 
 
125
  for attempt in range(max_retries):
126
- raw = _generate_text(mod_base, tok_base, _prompt(oversample),
127
- num_return_sequences=1, max_new_tokens=240,
128
- temperature=1.0 + 0.05*attempt, top_p=0.95)[0]
129
- # collect
 
 
 
130
  for line in raw.splitlines():
131
  nm = line.strip().lstrip("-•*0123456789. ").strip()
132
- if nm:
133
- nm = re.sub(r"[^\w\s-]+$", "", nm).strip()
134
- all_candidates.append(nm)
 
 
135
  # dedup
136
  uniq, seen = [], set()
137
  for nm in all_candidates:
138
  key = _normalize_name(nm)
139
- if key and key not in seen:
140
- seen.add(key); uniq.append(nm)
141
  all_candidates = uniq
142
- # progressive filter
 
143
  def ok(nm: str, overlap_cap: float, tol_boost: int):
144
  if not _has_vowel(nm): return False
145
- if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX): return False
 
146
  toks = _content_words(nm)
147
  if _overlap_ratio(toks, banned) > overlap_cap: return False
148
  if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False
149
  return True
150
- overlap_caps = [0.25, 0.35, 0.5]; tol_boosts = [0, 1, 2]
 
 
151
  filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])]
152
- if len(filtered) >= n: return filtered[:n]
153
- return all_candidates[:n] if all_candidates else []
154
 
155
- # Tagline/description length targets (from your EDA)
156
- TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6
157
- TAG_WORD_TARGET, TAG_WORD_TOL = 6, 1
158
- DESC_CHAR_MIN, DESC_CHAR_MAX = 180, 215
159
- DESC_WORD_MIN, DESC_WORD_MAX = 29, 33
160
 
161
  def _trim_to_words(text: str, max_words: int) -> str:
162
  toks = text.split()
163
- return text.strip() if len(toks) <= max_words else " ".join(toks[:max_words]).rstrip(",;:") + "."
 
164
 
165
  def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int):
166
  text = text.strip()
167
  if len(text) <= max_chars and len(text) >= min_chars: return text
168
  cutoff = min(max_chars, len(text)); candidate = text[:cutoff]
169
  m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate)
170
- if m and (len(candidate[:m.end()].strip()) >= min_chars): return candidate[:m.end()].strip()
 
171
  return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "")
172
 
173
  def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool:
174
  c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax)
175
 
176
  def generate_tagline_and_desc(name: str, query_context: str):
 
177
  tag_prompt = (
178
  f"Write a short, benefit-driven tagline for a startup called '{name}'. "
179
  f"Audience & domain: {query_context}. "
@@ -188,8 +236,10 @@ def generate_tagline_and_desc(name: str, query_context: str):
188
  TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL):
189
  tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0]
190
  tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL)
191
- if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET): tagline = tagline2
 
192
 
 
193
  desc_prompt = (
194
  f"Write a concise product description for the startup '{name}'. "
195
  f"Context: {query_context}. "
@@ -198,121 +248,156 @@ def generate_tagline_and_desc(name: str, query_context: str):
198
  f"Avoid fluff; keep it clear."
199
  )
200
  model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base)
201
- description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=0.95, top_p=0.95)[0]
202
  description = re.sub(r"\s+", " ", description).strip()
203
- if len(description.split()) > DESC_WORD_MAX: description = _trim_to_words(description, DESC_WORD_MAX)
 
204
  description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX)
 
205
  if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX):
206
  description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0]
207
  description2 = re.sub(r"\s+", " ", description2).strip()
208
- if len(description2.split()) > DESC_WORD_MAX: description2 = _trim_to_words(description2, DESC_WORD_MAX)
 
209
  description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX)
210
  target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2
211
- if abs(len(description2) - target_mid) < abs(len(description) - target_mid): description = description2
 
 
212
  return tagline, description
213
 
214
  def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False):
 
215
  names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3)
216
  if len(names) == 0:
 
217
  names = generate_names(query, n=n_candidates, oversample=140, max_retries=1)
218
  if len(names) == 0:
219
  toks = _content_words(query) or ["nova","learn","edu","mento"]
220
- seeds = list({t[:4]+"ify" for t in toks} | {t[:3]+"ora" for t in toks} | {t[:4]+"io" for t in toks})
221
- names = seeds[:n_candidates]
222
- qv = _embed_query(query); embs = _embed_passages(names); cos = embs @ qv
 
 
 
 
 
 
 
 
223
  banned = sorted(set(_content_words(query)))
224
  final_scores = []
225
  for nm, s in zip(names, cos):
226
- toks = _content_words(nm); overlap = _overlap_ratio(toks, banned)
227
- length_pen = 0.0; L = len(_normalize_name(nm))
228
- if L < 4: length_pen += 0.3
 
 
229
  if L > 16: length_pen += 0.2
230
- final_scores.append(float(s) - 0.35*overlap - length_pen)
231
- best_idx = int(np.argmax(final_scores)); best_name = names[best_idx]; best_score = float(final_scores[best_idx])
 
 
 
 
 
232
  tagline, description = ("","")
233
- if include_copy: tagline, description = generate_tagline_and_desc(best_name, query_context=query)
234
- row = pd.DataFrame([{"rank":4,"score":best_score,"name":best_name,"tagline":tagline,"description":description}])
 
 
 
 
 
 
 
 
235
  return row
236
 
237
- # ---------- UI glue ----------
238
- EXAMPLES = [
239
- "AI tool to analyze customer feedback",
240
- "Social network for jobs",
241
- "Mobile fintech app for cross-border payments",
242
- "AI learning tool for students",
243
- "Marketplace for eco-friendly products",
244
- ]
245
 
246
  def ui_search(query, state_unlikes):
 
247
  query = (query or "").strip()
248
- if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea."
249
- state_unlikes = [] # reset for new query
250
- res = search_topk_filtered_session(query, k=3, unliked_ids=set())
251
- return res, state_unlikes, "Found 3 similar items. You can unlike by row_idx, then Refresh."
 
 
252
 
253
  def ui_unlike(query, unlike_ids_csv, state_unlikes):
254
  query = (query or "").strip()
255
- if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first."
 
256
  add_ids = set()
257
  for tok in (unlike_ids_csv or "").split(","):
258
  tok = tok.strip()
259
- if tok.isdigit(): add_ids.add(int(tok))
260
- cur = set(state_unlikes) | add_ids
261
- res = search_topk_filtered_session(query, k=3, unliked_ids=cur)
262
- return res, list(cur), f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(cur)}"
 
 
263
 
264
  def ui_clear_unlikes(query):
265
  query = (query or "").strip()
266
- if not query: return gr.update(value=pd.DataFrame()), [], "Enter a query first."
267
- res = search_topk_filtered_session(query, k=3, unliked_ids=set())
268
- return res, [], "Cleared unlikes."
 
269
 
270
  def ui_generate_synth(query, include_copy):
271
  query = (query or "").strip()
272
- if not query: return gr.update(value=pd.DataFrame()), "Enter a query first."
 
273
  synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy)
274
  return synth, "Generated AI option as #4. Combine it with your top-3."
275
 
276
- def _apply_example(example_text, state_unlikes):
277
- results, state_unlikes, msg = ui_search(example_text, state_unlikes)
278
- return example_text, results, state_unlikes, f"Example selected: “{example_text}”. {msg}"
279
-
280
- with gr.Blocks(title="Startup Recommender + AI Name") as app:
281
- gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline & description).")
282
-
283
- query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback")
284
 
 
 
285
  with gr.Row():
286
  gr.Markdown("**Try an example:**")
287
- example_buttons = [gr.Button(ex, variant="secondary") for ex in EXAMPLES]
288
-
 
289
  with gr.Row():
290
  btn_search = gr.Button("Search Top-3")
291
  unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456")
292
  btn_unlike = gr.Button("Refresh after Unlike")
293
  btn_clear = gr.Button("Clear Unlikes")
294
-
295
- results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True)
296
 
297
  gr.Markdown("### AI-Generated Option (#4)")
298
- include_copy = gr.Checkbox(label="Also generate tagline & description", value=True)
299
- btn_synth = gr.Button("Generate #4 (AI)")
300
- synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True)
301
 
 
302
  status = gr.Markdown("")
 
 
303
  state_unlikes = gr.State([])
304
 
305
- # wiring
306
  btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status])
307
  btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status])
308
  btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status])
309
-
310
  for btn, ex in zip(example_buttons, EXAMPLES):
311
- btn.click(lambda st, ex_=ex: _apply_example(ex_, st),
312
- inputs=[state_unlikes], outputs=[query, results_tbl, state_unlikes, status])
 
 
 
 
313
 
314
- btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status])
315
 
316
- # On Spaces, just calling launch() is fine; no explicit port.
317
  if __name__ == "__main__":
318
- app.queue(concurrency_count=2).launch()
 
1
+ # ========= Simple Interactive UI (Gradio) for: search unlike synthetic #4 =========
2
+ # Requires that you already ran the embedding/index step and have:
3
+ # ./emb_index_e5/faiss.index and ./emb_index_e5/data.parquet
4
+ # Also expects you kept the improved generation functions we wrote earlier
5
+ # (generate_names, generate_tagline_and_desc(name, query_context), pick_best_synthetic_name).
6
+ # If not, this file includes compact versions below.
7
+
8
+ import os, re, random, numpy as np, pandas as pd
9
  from pathlib import Path
10
  import gradio as gr
11
+
12
+ # --- Imports for models ---
13
+ import torch
14
+ import faiss
15
  from sentence_transformers import SentenceTransformer
16
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
17
 
18
+ # --- Paths / config ---
19
  OUT_DIR = Path("./emb_index_e5")
20
  FAISS_PATH = OUT_DIR / "faiss.index"
21
  DATA_PATH = OUT_DIR / "data.parquet"
 
 
22
 
23
+ # Toggle: use flan-t5-large only for description richness (if GPU T4 available)
24
+ USE_LARGE_FOR_DESCRIPTION = True
25
+ MODEL_BASE = "google/flan-t5-base"
26
+ MODEL_LARGE = "google/flan-t5-large"
27
+ EMBED_MODEL = "intfloat/e5-base-v2"
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ # --- Load artifacts once ---
31
+ assert FAISS_PATH.exists(), f"Missing FAISS index at {FAISS_PATH}. Run the embedding/index step first."
32
+ assert DATA_PATH.exists(), f"Missing dataset at {DATA_PATH}. Run the embedding/index step first."
33
 
 
34
  index = faiss.read_index(str(FAISS_PATH))
35
  df_local = pd.read_parquet(DATA_PATH)
36
+ for col in ["name","tagline","description"]:
37
+ if col in df_local.columns:
38
+ df_local[col] = df_local[col].astype(str).fillna("")
39
 
40
+ # --- Load embedding model once ---
41
+ embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE)
 
 
 
 
 
42
 
43
+ # --- Load FLAN models (base + optional large) once ---
44
  tok_base = AutoTokenizer.from_pretrained(MODEL_BASE)
45
+ mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE).to(DEVICE)
 
 
46
  if USE_LARGE_FOR_DESCRIPTION:
47
  tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE)
48
+ mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE).to(DEVICE)
 
49
  else:
50
  tok_large, mod_large = tok_base, mod_base
51
 
52
+ # ===================== Small helpers reused from your project =====================
53
+
54
  def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1):
55
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
56
  outputs = model.generate(
57
  **inputs,
58
  max_new_tokens=max_new_tokens,
 
70
  texts = [f"passage: {t}" for t in texts]
71
  return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
72
 
73
+ # ======== 1-click Examples ========
74
+
75
+ EXAMPLES = [
76
+ "AI tool to analyze customer feedback",
77
+ "Social network for jobs",
78
+ "Mobile fintech app for cross-border payments",
79
+ "AI learning tool for students",
80
+ "Marketplace for eco-friendly products",
81
+ ]
82
+
83
+ # Helper that applies an example: fills the textbox, resets unlikes, runs search
84
+ def _apply_example(example_text, state_unlikes):
85
+ # Reuse your existing ui_search() so behavior stays identical
86
+ results, state_unlikes, msg = ui_search(example_text, state_unlikes)
87
+ # Return: set the query box text, update table & state & status
88
+ return example_text, results, state_unlikes, f"Example selected: “{example_text}” — {msg}"
89
+
90
+ # ---------- Search (with per-session unlikes) ----------
91
  def search_topk_filtered_session(query: str, k: int, unliked_ids: set):
92
+ """Return top-K rows excluding row_idx values in unliked_ids."""
93
  qv = _embed_query(query)
94
  fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids)))
95
  scores, inds = index.search(qv[None, :], fetch)
96
+ inds = inds[0].tolist()
97
+ scores = scores[0].tolist()
98
+
99
  res = df_local.iloc[inds][["name","tagline","description"]].copy()
100
  res.insert(0, "row_idx", df_local.iloc[inds].index)
101
  res.insert(1, "score", [float(s) for s in scores])
102
+
103
+ # filter-out unlikes
104
+ mask = ~res["row_idx"].isin(unliked_ids)
105
+ res = res[mask].head(k).reset_index(drop=True)
106
  res.insert(0, "rank", range(1, len(res)+1))
107
  return res
108
 
109
+ # ---------- Synthetic generation (uses the improved functions) ----------
110
+ # Length targets (from your EDA)
111
+ TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6
112
+ TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2
113
+ DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230
114
+ DESC_WORD_MIN, DESC_WORD_MAX = 27, 35
115
+
116
  _STOPWORDS = {
117
  "the","a","an","for","and","or","to","of","in","on","with","by","from",
118
  "my","our","your","their","at","as","about","into","over","under","this","that",
 
124
  def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower()))
125
  def _overlap_ratio(name_tokens, banned):
126
  if not name_tokens or not banned: return 0.0
127
+ inter = len(set(name_tokens) & set(banned))
128
+ union = len(set(name_tokens) | set(banned))
129
  return inter / max(union, 1)
130
 
 
 
131
  def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int):
132
  c = len(text); w = len(text.split())
133
  return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words)
134
 
135
  def _theme_hints(query: str, k: int = 6):
136
+ kws = _content_words(query)
137
+ seen, hints = set(), []
138
  for t in kws:
139
+ if t not in seen:
140
+ hints.append(t); seen.add(t)
141
  return ", ".join(hints[:k]) if hints else "education, learning, students, AI"
142
 
143
+ NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3
144
+ NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3
145
+
146
  def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3):
147
  banned = sorted(set(_content_words(base_idea)))
148
  avoid_str = ", ".join(banned[:12]) if banned else "previous words"
149
  hints = _theme_hints(base_idea)
150
+
151
  def _prompt(osz):
152
  return (
153
  f"Create {osz} brandable startup names for this idea:\n"
154
  f"\"{base_idea}\"\n\n"
155
  f"Guidance:\n"
156
  f"- Evoke these themes (without literally using the words): {hints}\n"
157
+ f"- 1 or 2 or 3 words; aim ~{NAME_CHAR_TARGET} characters total (±{NAME_CHAR_TOL})\n"
158
  f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n"
159
  f"- Do NOT use: {avoid_str}\n"
160
  f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n"
161
  f"- Output one name per line; no numbering, no quotes."
162
  )
163
+
164
+ all_candidates = []
165
  for attempt in range(max_retries):
166
+ raw = _generate_text(
167
+ mod_base, tok_base, _prompt(oversample),
168
+ num_return_sequences=1, max_new_tokens=240,
169
+ temperature=1.0 + 0.05*attempt, top_p=0.95
170
+ )[0]
171
+
172
+ batch = []
173
  for line in raw.splitlines():
174
  nm = line.strip().lstrip("-•*0123456789. ").strip()
175
+ if not nm: continue
176
+ nm = re.sub(r"[^\w\s-]+$", "", nm).strip()
177
+ batch.append(nm)
178
+ all_candidates.extend(batch)
179
+
180
  # dedup
181
  uniq, seen = [], set()
182
  for nm in all_candidates:
183
  key = _normalize_name(nm)
184
+ if not key or key in seen: continue
185
+ seen.add(key); uniq.append(nm)
186
  all_candidates = uniq
187
+
188
+ # progressive filtering
189
  def ok(nm: str, overlap_cap: float, tol_boost: int):
190
  if not _has_vowel(nm): return False
191
+ if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX):
192
+ return False
193
  toks = _content_words(nm)
194
  if _overlap_ratio(toks, banned) > overlap_cap: return False
195
  if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False
196
  return True
197
+
198
+ overlap_caps = [0.25, 0.35, 0.5]
199
+ tol_boosts = [0, 1, 2]
200
  filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])]
201
+ if len(filtered) >= n:
202
+ return filtered[:n]
203
 
204
+ return all_candidates[:n] if all_candidates else []
 
 
 
 
205
 
206
  def _trim_to_words(text: str, max_words: int) -> str:
207
  toks = text.split()
208
+ if len(toks) <= max_words: return text.strip()
209
+ return " ".join(toks[:max_words]).rstrip(",;:") + "."
210
 
211
  def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int):
212
  text = text.strip()
213
  if len(text) <= max_chars and len(text) >= min_chars: return text
214
  cutoff = min(max_chars, len(text)); candidate = text[:cutoff]
215
  m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate)
216
+ if m and (len(candidate[:m.end()].strip()) >= min_chars):
217
+ return candidate[:m.end()].strip()
218
  return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "")
219
 
220
  def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool:
221
  c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax)
222
 
223
  def generate_tagline_and_desc(name: str, query_context: str):
224
+ # Tagline (≈40 chars, ≈6 words)
225
  tag_prompt = (
226
  f"Write a short, benefit-driven tagline for a startup called '{name}'. "
227
  f"Audience & domain: {query_context}. "
 
236
  TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL):
237
  tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0]
238
  tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL)
239
+ if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET):
240
+ tagline = tagline2
241
 
242
+ # Description (≈170–230 chars, 27–35 words)
243
  desc_prompt = (
244
  f"Write a concise product description for the startup '{name}'. "
245
  f"Context: {query_context}. "
 
248
  f"Avoid fluff; keep it clear."
249
  )
250
  model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base)
251
+ description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0]
252
  description = re.sub(r"\s+", " ", description).strip()
253
+ if len(description.split()) > DESC_WORD_MAX:
254
+ description = _trim_to_words(description, DESC_WORD_MAX)
255
  description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX)
256
+
257
  if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX):
258
  description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0]
259
  description2 = re.sub(r"\s+", " ", description2).strip()
260
+ if len(description2.split()) > DESC_WORD_MAX:
261
+ description2 = _trim_to_words(description2, DESC_WORD_MAX)
262
  description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX)
263
  target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2
264
+ if abs(len(description2) - target_mid) < abs(len(description) - target_mid):
265
+ description = description2
266
+
267
  return tagline, description
268
 
269
  def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False):
270
+ # Generate candidates & score vs query (cosine)
271
  names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3)
272
  if len(names) == 0:
273
+ # permissive retry
274
  names = generate_names(query, n=n_candidates, oversample=140, max_retries=1)
275
  if len(names) == 0:
276
  toks = _content_words(query) or ["nova","learn","edu","mento"]
277
+ seeds = set()
278
+ for t in toks:
279
+ seeds.add((t[:4] + "ify"))
280
+ seeds.add((t[:3] + "ora"))
281
+ seeds.add((t[:4] + "io"))
282
+ names = list(seeds)[:n_candidates]
283
+
284
+ qv = _embed_query(query)
285
+ embs = _embed_passages(names)
286
+ cos = embs @ qv
287
+
288
  banned = sorted(set(_content_words(query)))
289
  final_scores = []
290
  for nm, s in zip(names, cos):
291
+ toks = _content_words(nm)
292
+ overlap = _overlap_ratio(toks, banned)
293
+ length_pen = 0.0
294
+ L = len(_normalize_name(nm))
295
+ if L < 4: length_pen += 0.3
296
  if L > 16: length_pen += 0.2
297
+ score = float(s) - 0.35*overlap - length_pen
298
+ final_scores.append(score)
299
+
300
+ best_idx = int(np.argmax(final_scores))
301
+ best_name = names[best_idx]
302
+ best_score = float(final_scores[best_idx])
303
+
304
  tagline, description = ("","")
305
+ if include_copy:
306
+ tagline, description = generate_tagline_and_desc(best_name, query_context=query)
307
+
308
+ row = pd.DataFrame([{
309
+ "rank": 4,
310
+ "score": best_score,
311
+ "name": best_name,
312
+ "tagline": tagline,
313
+ "description": description
314
+ }])
315
  return row
316
 
317
+ # ============================= Gradio UI logic =============================
 
 
 
 
 
 
 
318
 
319
  def ui_search(query, state_unlikes):
320
+ """Start a new search, reset unlikes if it's a new query."""
321
  query = (query or "").strip()
322
+ if not query:
323
+ return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea/description."
324
+ # For a new search, reset unlikes
325
+ state_unlikes = set()
326
+ results = search_topk_filtered_session(query, k=3, unliked_ids=state_unlikes)
327
+ return results, list(state_unlikes), "Found 3 similar items. You can unlike by row_idx, then 'Refresh'."
328
 
329
  def ui_unlike(query, unlike_ids_csv, state_unlikes):
330
  query = (query or "").strip()
331
+ if not query:
332
+ return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first."
333
  add_ids = set()
334
  for tok in (unlike_ids_csv or "").split(","):
335
  tok = tok.strip()
336
+ if tok.isdigit():
337
+ add_ids.add(int(tok))
338
+ state_unlikes = set(state_unlikes) | add_ids
339
+ results = search_topk_filtered_session(query, k=3, unliked_ids=state_unlikes)
340
+ msg = f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(state_unlikes)}"
341
+ return results, list(state_unlikes), msg
342
 
343
  def ui_clear_unlikes(query):
344
  query = (query or "").strip()
345
+ if not query:
346
+ return gr.update(value=pd.DataFrame()), [], "Enter a query first."
347
+ results = search_topk_filtered_session(query, k=3, unliked_ids=set())
348
+ return results, [], "Cleared unlikes."
349
 
350
  def ui_generate_synth(query, include_copy):
351
  query = (query or "").strip()
352
+ if not query:
353
+ return gr.update(value=pd.DataFrame()), "Enter a query first."
354
  synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy)
355
  return synth, "Generated AI option as #4. Combine it with your top-3."
356
 
357
+ with gr.Blocks(title="Startup Recommender + Synthetic Name") as app:
358
+ gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline + description).")
 
 
 
 
 
 
359
 
360
+ with gr.Row():
361
+ query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback")
362
  with gr.Row():
363
  gr.Markdown("**Try an example:**")
364
+ example_buttons = []
365
+ for ex in EXAMPLES:
366
+ example_buttons.append(gr.Button(ex, variant="secondary", scale=1))
367
  with gr.Row():
368
  btn_search = gr.Button("Search Top-3")
369
  unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456")
370
  btn_unlike = gr.Button("Refresh after Unlike")
371
  btn_clear = gr.Button("Clear Unlikes")
372
+ with gr.Row():
373
+ results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True)
374
 
375
  gr.Markdown("### AI-Generated Option (#4)")
376
+ with gr.Row():
377
+ include_copy = gr.Checkbox(label="Also generate tagline & description", value=True)
378
+ btn_synth = gr.Button("Generate #4 (AI)")
379
 
380
+ synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True)
381
  status = gr.Markdown("")
382
+
383
+ # Session state: list of unliked row_idx
384
  state_unlikes = gr.State([])
385
 
386
+ # Wiring
387
  btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status])
388
  btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status])
389
  btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status])
390
+ btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status])
391
  for btn, ex in zip(example_buttons, EXAMPLES):
392
+ # When clicked: fill the query box, run search, reset unlikes
393
+ btn.click(
394
+ lambda st, ex_=ex: _apply_example(ex_, st),
395
+ inputs=[state_unlikes], # current unlike state
396
+ outputs=[query, results_tbl, state_unlikes, status]
397
+ )
398
 
399
+
400
 
401
+ # For local dev; on Spaces this is ignored.
402
  if __name__ == "__main__":
403
+ app.launch(server_name="0.0.0.0")