barakb21 commited on
Commit
35bdccf
·
verified ·
1 Parent(s): 3dd33d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -203
app.py CHANGED
@@ -1,58 +1,53 @@
1
- # ========= Simple Interactive UI (Gradio) for: search unlike synthetic #4 =========
2
- # Requires that you already ran the embedding/index step and have:
3
- # ./emb_index_e5/faiss.index and ./emb_index_e5/data.parquet
4
- # Also expects you kept the improved generation functions we wrote earlier
5
- # (generate_names, generate_tagline_and_desc(name, query_context), pick_best_synthetic_name).
6
- # If not, this file includes compact versions below.
7
-
8
- import os, re, random, numpy as np, pandas as pd
9
  from pathlib import Path
10
  import gradio as gr
11
-
12
- # --- Imports for models ---
13
- import torch
14
- import faiss
15
  from sentence_transformers import SentenceTransformer
16
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
17
 
18
- # --- Paths / config ---
19
  OUT_DIR = Path("./emb_index_e5")
20
  FAISS_PATH = OUT_DIR / "faiss.index"
21
  DATA_PATH = OUT_DIR / "data.parquet"
 
 
22
 
23
- # Toggle: use flan-t5-large only for description richness (if GPU T4 available)
24
- USE_LARGE_FOR_DESCRIPTION = True
25
- MODEL_BASE = "google/flan-t5-base"
26
- MODEL_LARGE = "google/flan-t5-large"
27
- EMBED_MODEL = "intfloat/e5-base-v2"
28
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
-
30
- # --- Load artifacts once ---
31
- assert FAISS_PATH.exists(), f"Missing FAISS index at {FAISS_PATH}. Run the embedding/index step first."
32
- assert DATA_PATH.exists(), f"Missing dataset at {DATA_PATH}. Run the embedding/index step first."
33
 
 
34
  index = faiss.read_index(str(FAISS_PATH))
35
  df_local = pd.read_parquet(DATA_PATH)
36
- for col in ["name","tagline","description"]:
37
- if col in df_local.columns:
38
- df_local[col] = df_local[col].astype(str).fillna("")
39
 
40
- # --- Load embedding model once ---
41
- embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE)
 
 
 
 
 
42
 
43
- # --- Load FLAN models (base + optional large) once ---
44
  tok_base = AutoTokenizer.from_pretrained(MODEL_BASE)
45
- mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE).to(DEVICE)
 
 
46
  if USE_LARGE_FOR_DESCRIPTION:
47
  tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE)
48
- mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE).to(DEVICE)
 
49
  else:
50
  tok_large, mod_large = tok_base, mod_base
51
 
52
- # ===================== Small helpers reused from your project =====================
53
-
54
  def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1):
55
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
56
  outputs = model.generate(
57
  **inputs,
58
  max_new_tokens=max_new_tokens,
@@ -70,49 +65,20 @@ def _embed_passages(texts) -> np.ndarray:
70
  texts = [f"passage: {t}" for t in texts]
71
  return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
72
 
73
- # ======== 1-click Examples ========
74
-
75
- EXAMPLES = [
76
- "AI tool to analyze customer feedback",
77
- "Social network for jobs",
78
- "Mobile fintech app for cross-border payments",
79
- "AI learning tool for students",
80
- "Marketplace for eco-friendly products",
81
- ]
82
-
83
- # Helper that applies an example: fills the textbox, resets unlikes, runs search
84
- def _apply_example(example_text, state_unlikes):
85
- # Reuse your existing ui_search() so behavior stays identical
86
- results, state_unlikes, msg = ui_search(example_text, state_unlikes)
87
- # Return: set the query box text, update table & state & status
88
- return example_text, results, state_unlikes, f"Example selected: “{example_text}” — {msg}"
89
-
90
- # ---------- Search (with per-session unlikes) ----------
91
  def search_topk_filtered_session(query: str, k: int, unliked_ids: set):
92
- """Return top-K rows excluding row_idx values in unliked_ids."""
93
  qv = _embed_query(query)
94
  fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids)))
95
  scores, inds = index.search(qv[None, :], fetch)
96
- inds = inds[0].tolist()
97
- scores = scores[0].tolist()
98
-
99
  res = df_local.iloc[inds][["name","tagline","description"]].copy()
100
  res.insert(0, "row_idx", df_local.iloc[inds].index)
101
  res.insert(1, "score", [float(s) for s in scores])
102
-
103
- # filter-out unlikes
104
- mask = ~res["row_idx"].isin(unliked_ids)
105
- res = res[mask].head(k).reset_index(drop=True)
106
  res.insert(0, "rank", range(1, len(res)+1))
107
  return res
108
 
109
- # ---------- Synthetic generation (uses the improved functions) ----------
110
- # Length targets (from your EDA)
111
- TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6
112
- TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2
113
- DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230
114
- DESC_WORD_MIN, DESC_WORD_MAX = 27, 35
115
-
116
  _STOPWORDS = {
117
  "the","a","an","for","and","or","to","of","in","on","with","by","from",
118
  "my","our","your","their","at","as","about","into","over","under","this","that",
@@ -124,104 +90,90 @@ def _normalize_name(s: str) -> str: return re.sub(r"[^a-z0-9]+", "", str(s).lowe
124
  def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower()))
125
  def _overlap_ratio(name_tokens, banned):
126
  if not name_tokens or not banned: return 0.0
127
- inter = len(set(name_tokens) & set(banned))
128
- union = len(set(name_tokens) | set(banned))
129
  return inter / max(union, 1)
130
 
 
 
131
  def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int):
132
  c = len(text); w = len(text.split())
133
  return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words)
134
 
135
  def _theme_hints(query: str, k: int = 6):
136
- kws = _content_words(query)
137
- seen, hints = set(), []
138
  for t in kws:
139
- if t not in seen:
140
- hints.append(t); seen.add(t)
141
  return ", ".join(hints[:k]) if hints else "education, learning, students, AI"
142
 
143
- NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3
144
- NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3
145
-
146
  def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3):
147
  banned = sorted(set(_content_words(base_idea)))
148
  avoid_str = ", ".join(banned[:12]) if banned else "previous words"
149
  hints = _theme_hints(base_idea)
150
-
151
  def _prompt(osz):
152
  return (
153
  f"Create {osz} brandable startup names for this idea:\n"
154
  f"\"{base_idea}\"\n\n"
155
  f"Guidance:\n"
156
  f"- Evoke these themes (without literally using the words): {hints}\n"
157
- f"- 1 or 2 or 3 words; aim ~{NAME_CHAR_TARGET} characters total (±{NAME_CHAR_TOL})\n"
158
  f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n"
159
  f"- Do NOT use: {avoid_str}\n"
160
  f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n"
161
  f"- Output one name per line; no numbering, no quotes."
162
  )
163
-
164
- all_candidates = []
165
  for attempt in range(max_retries):
166
- raw = _generate_text(
167
- mod_base, tok_base, _prompt(oversample),
168
- num_return_sequences=1, max_new_tokens=240,
169
- temperature=1.0 + 0.05*attempt, top_p=0.95
170
- )[0]
171
-
172
- batch = []
173
  for line in raw.splitlines():
174
  nm = line.strip().lstrip("-•*0123456789. ").strip()
175
- if not nm: continue
176
- nm = re.sub(r"[^\w\s-]+$", "", nm).strip()
177
- batch.append(nm)
178
- all_candidates.extend(batch)
179
-
180
  # dedup
181
  uniq, seen = [], set()
182
  for nm in all_candidates:
183
  key = _normalize_name(nm)
184
- if not key or key in seen: continue
185
- seen.add(key); uniq.append(nm)
186
  all_candidates = uniq
187
-
188
- # progressive filtering
189
  def ok(nm: str, overlap_cap: float, tol_boost: int):
190
  if not _has_vowel(nm): return False
191
- if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX):
192
- return False
193
  toks = _content_words(nm)
194
  if _overlap_ratio(toks, banned) > overlap_cap: return False
195
  if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False
196
  return True
197
-
198
- overlap_caps = [0.25, 0.35, 0.5]
199
- tol_boosts = [0, 1, 2]
200
  filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])]
201
- if len(filtered) >= n:
202
- return filtered[:n]
203
-
204
  return all_candidates[:n] if all_candidates else []
205
 
 
 
 
 
 
 
206
  def _trim_to_words(text: str, max_words: int) -> str:
207
  toks = text.split()
208
- if len(toks) <= max_words: return text.strip()
209
- return " ".join(toks[:max_words]).rstrip(",;:") + "."
210
 
211
  def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int):
212
  text = text.strip()
213
  if len(text) <= max_chars and len(text) >= min_chars: return text
214
  cutoff = min(max_chars, len(text)); candidate = text[:cutoff]
215
  m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate)
216
- if m and (len(candidate[:m.end()].strip()) >= min_chars):
217
- return candidate[:m.end()].strip()
218
  return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "")
219
 
220
  def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool:
221
  c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax)
222
 
223
  def generate_tagline_and_desc(name: str, query_context: str):
224
- # Tagline (≈40 chars, ≈6 words)
225
  tag_prompt = (
226
  f"Write a short, benefit-driven tagline for a startup called '{name}'. "
227
  f"Audience & domain: {query_context}. "
@@ -236,10 +188,8 @@ def generate_tagline_and_desc(name: str, query_context: str):
236
  TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL):
237
  tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0]
238
  tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL)
239
- if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET):
240
- tagline = tagline2
241
 
242
- # Description (≈170–230 chars, 27–35 words)
243
  desc_prompt = (
244
  f"Write a concise product description for the startup '{name}'. "
245
  f"Context: {query_context}. "
@@ -250,154 +200,119 @@ def generate_tagline_and_desc(name: str, query_context: str):
250
  model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base)
251
  description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0]
252
  description = re.sub(r"\s+", " ", description).strip()
253
- if len(description.split()) > DESC_WORD_MAX:
254
- description = _trim_to_words(description, DESC_WORD_MAX)
255
  description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX)
256
-
257
  if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX):
258
  description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0]
259
  description2 = re.sub(r"\s+", " ", description2).strip()
260
- if len(description2.split()) > DESC_WORD_MAX:
261
- description2 = _trim_to_words(description2, DESC_WORD_MAX)
262
  description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX)
263
  target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2
264
- if abs(len(description2) - target_mid) < abs(len(description) - target_mid):
265
- description = description2
266
-
267
  return tagline, description
268
 
269
  def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False):
270
- # Generate candidates & score vs query (cosine)
271
  names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3)
272
  if len(names) == 0:
273
- # permissive retry
274
  names = generate_names(query, n=n_candidates, oversample=140, max_retries=1)
275
  if len(names) == 0:
276
  toks = _content_words(query) or ["nova","learn","edu","mento"]
277
- seeds = set()
278
- for t in toks:
279
- seeds.add((t[:4] + "ify"))
280
- seeds.add((t[:3] + "ora"))
281
- seeds.add((t[:4] + "io"))
282
- names = list(seeds)[:n_candidates]
283
-
284
- qv = _embed_query(query)
285
- embs = _embed_passages(names)
286
- cos = embs @ qv
287
-
288
  banned = sorted(set(_content_words(query)))
289
  final_scores = []
290
  for nm, s in zip(names, cos):
291
- toks = _content_words(nm)
292
- overlap = _overlap_ratio(toks, banned)
293
- length_pen = 0.0
294
- L = len(_normalize_name(nm))
295
- if L < 4: length_pen += 0.3
296
  if L > 16: length_pen += 0.2
297
- score = float(s) - 0.35*overlap - length_pen
298
- final_scores.append(score)
299
-
300
- best_idx = int(np.argmax(final_scores))
301
- best_name = names[best_idx]
302
- best_score = float(final_scores[best_idx])
303
-
304
  tagline, description = ("","")
305
- if include_copy:
306
- tagline, description = generate_tagline_and_desc(best_name, query_context=query)
307
-
308
- row = pd.DataFrame([{
309
- "rank": 4,
310
- "score": best_score,
311
- "name": best_name,
312
- "tagline": tagline,
313
- "description": description
314
- }])
315
  return row
316
 
317
- # ============================= Gradio UI logic =============================
 
 
 
 
 
 
 
318
 
319
  def ui_search(query, state_unlikes):
320
- """Start a new search, reset unlikes if it's a new query."""
321
  query = (query or "").strip()
322
- if not query:
323
- return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea/description."
324
- # For a new search, reset unlikes
325
- state_unlikes = set()
326
- results = search_topk_filtered_session(query, k=3, unliked_ids=state_unlikes)
327
- return results, list(state_unlikes), "Found 3 similar items. You can unlike by row_idx, then 'Refresh'."
328
 
329
  def ui_unlike(query, unlike_ids_csv, state_unlikes):
330
  query = (query or "").strip()
331
- if not query:
332
- return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first."
333
  add_ids = set()
334
  for tok in (unlike_ids_csv or "").split(","):
335
  tok = tok.strip()
336
- if tok.isdigit():
337
- add_ids.add(int(tok))
338
- state_unlikes = set(state_unlikes) | add_ids
339
- results = search_topk_filtered_session(query, k=3, unliked_ids=state_unlikes)
340
- msg = f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(state_unlikes)}"
341
- return results, list(state_unlikes), msg
342
 
343
  def ui_clear_unlikes(query):
344
  query = (query or "").strip()
345
- if not query:
346
- return gr.update(value=pd.DataFrame()), [], "Enter a query first."
347
- results = search_topk_filtered_session(query, k=3, unliked_ids=set())
348
- return results, [], "Cleared unlikes."
349
 
350
  def ui_generate_synth(query, include_copy):
351
  query = (query or "").strip()
352
- if not query:
353
- return gr.update(value=pd.DataFrame()), "Enter a query first."
354
  synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy)
355
  return synth, "Generated AI option as #4. Combine it with your top-3."
356
 
357
- with gr.Blocks(title="Startup Recommender + Synthetic Name") as app:
358
- gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline + description).")
 
 
 
 
 
 
359
 
360
- with gr.Row():
361
- query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback")
362
  with gr.Row():
363
  gr.Markdown("**Try an example:**")
364
- example_buttons = []
365
- for ex in EXAMPLES:
366
- example_buttons.append(gr.Button(ex, variant="secondary", scale=1))
367
  with gr.Row():
368
  btn_search = gr.Button("Search Top-3")
369
  unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456")
370
  btn_unlike = gr.Button("Refresh after Unlike")
371
  btn_clear = gr.Button("Clear Unlikes")
372
- with gr.Row():
373
- results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True)
374
 
375
- gr.Markdown("### AI-Generated Option (#4)")
376
- with gr.Row():
377
- include_copy = gr.Checkbox(label="Also generate tagline & description", value=True)
378
- btn_synth = gr.Button("Generate #4 (AI)")
379
 
 
 
 
380
  synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True)
381
- status = gr.Markdown("")
382
 
383
- # Session state: list of unliked row_idx
384
  state_unlikes = gr.State([])
385
 
386
- # Wiring
387
  btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status])
388
  btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status])
389
  btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status])
390
- btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status])
391
  for btn, ex in zip(example_buttons, EXAMPLES):
392
- # When clicked: fill the query box, run search, reset unlikes
393
- btn.click(
394
- lambda st, ex_=ex: _apply_example(ex_, st),
395
- inputs=[state_unlikes], # current unlike state
396
- outputs=[query, results_tbl, state_unlikes, status]
397
- )
398
 
399
-
400
 
401
- # For local dev; on Spaces this is ignored.
402
  if __name__ == "__main__":
403
- app.launch(server_name="0.0.0.0")
 
1
+ # app.py Startup recommender + Unlike + AI name (optional tagline/description)
2
+ import os, re, numpy as np, pandas as pd
 
 
 
 
 
 
3
  from pathlib import Path
4
  import gradio as gr
5
+ import torch, faiss
 
 
 
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
+ # ---------- Paths / artifacts ----------
10
  OUT_DIR = Path("./emb_index_e5")
11
  FAISS_PATH = OUT_DIR / "faiss.index"
12
  DATA_PATH = OUT_DIR / "data.parquet"
13
+ assert FAISS_PATH.exists(), f"Missing {FAISS_PATH}. Build & upload embeddings/index."
14
+ assert DATA_PATH.exists(), f"Missing {DATA_PATH}. Build & upload data parquet."
15
 
16
+ # ---------- Devices ----------
17
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
18
+ DEVICE_EMBED = "cuda" if torch.cuda.is_available() else "cpu" # e5 on GPU if available
19
+ DEVICE_GEN = "cpu" # FLAN on CPU (avoid OOM)
20
+ print(f"Embed device: {DEVICE_EMBED} | Gen device: {DEVICE_GEN}")
 
 
 
 
 
21
 
22
+ # ---------- Load artifacts ----------
23
  index = faiss.read_index(str(FAISS_PATH))
24
  df_local = pd.read_parquet(DATA_PATH)
25
+ for c in ["name","tagline","description"]:
26
+ if c in df_local.columns:
27
+ df_local[c] = df_local[c].astype(str).fillna("")
28
 
29
+ # ---------- Load models ----------
30
+ EMBED_MODEL = "intfloat/e5-base-v2"
31
+ embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE_EMBED)
32
+
33
+ MODEL_BASE = "google/flan-t5-base"
34
+ MODEL_LARGE = "google/flan-t5-large"
35
+ USE_LARGE_FOR_DESCRIPTION = False # keep False on Spaces unless you switch GEN to "cuda"
36
 
 
37
  tok_base = AutoTokenizer.from_pretrained(MODEL_BASE)
38
+ base_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {}
39
+ mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE, **base_kwargs).to(DEVICE_GEN)
40
+
41
  if USE_LARGE_FOR_DESCRIPTION:
42
  tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE)
43
+ large_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {}
44
+ mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE, **large_kwargs).to(DEVICE_GEN)
45
  else:
46
  tok_large, mod_large = tok_base, mod_base
47
 
48
+ # ---------- Helpers (embedding + generation) ----------
 
49
  def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1):
50
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE_GEN)
51
  outputs = model.generate(
52
  **inputs,
53
  max_new_tokens=max_new_tokens,
 
65
  texts = [f"passage: {t}" for t in texts]
66
  return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
67
 
68
+ # ---------- Search with per-session unlikes ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def search_topk_filtered_session(query: str, k: int, unliked_ids: set):
 
70
  qv = _embed_query(query)
71
  fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids)))
72
  scores, inds = index.search(qv[None, :], fetch)
73
+ inds = inds[0].tolist(); scores = scores[0].tolist()
 
 
74
  res = df_local.iloc[inds][["name","tagline","description"]].copy()
75
  res.insert(0, "row_idx", df_local.iloc[inds].index)
76
  res.insert(1, "score", [float(s) for s in scores])
77
+ res = res[~res["row_idx"].isin(unliked_ids)].head(k).reset_index(drop=True)
 
 
 
78
  res.insert(0, "rank", range(1, len(res)+1))
79
  return res
80
 
81
+ # ---------- Synthetic generation (length-aware) ----------
 
 
 
 
 
 
82
  _STOPWORDS = {
83
  "the","a","an","for","and","or","to","of","in","on","with","by","from",
84
  "my","our","your","their","at","as","about","into","over","under","this","that",
 
90
  def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower()))
91
  def _overlap_ratio(name_tokens, banned):
92
  if not name_tokens or not banned: return 0.0
93
+ inter = len(set(name_tokens) & set(banned)); union = len(set(name_tokens) | set(banned))
 
94
  return inter / max(union, 1)
95
 
96
+ NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3
97
+ NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3
98
  def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int):
99
  c = len(text); w = len(text.split())
100
  return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words)
101
 
102
  def _theme_hints(query: str, k: int = 6):
103
+ kws = _content_words(query); seen, hints = set(), []
 
104
  for t in kws:
105
+ if t not in seen: hints.append(t); seen.add(t)
 
106
  return ", ".join(hints[:k]) if hints else "education, learning, students, AI"
107
 
 
 
 
108
  def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3):
109
  banned = sorted(set(_content_words(base_idea)))
110
  avoid_str = ", ".join(banned[:12]) if banned else "previous words"
111
  hints = _theme_hints(base_idea)
112
+ all_candidates = []
113
  def _prompt(osz):
114
  return (
115
  f"Create {osz} brandable startup names for this idea:\n"
116
  f"\"{base_idea}\"\n\n"
117
  f"Guidance:\n"
118
  f"- Evoke these themes (without literally using the words): {hints}\n"
119
+ f"- 1 or 2 words; aim ~{NAME_CHAR_TARGET} characters (±{NAME_CHAR_TOL})\n"
120
  f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n"
121
  f"- Do NOT use: {avoid_str}\n"
122
  f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n"
123
  f"- Output one name per line; no numbering, no quotes."
124
  )
 
 
125
  for attempt in range(max_retries):
126
+ raw = _generate_text(mod_base, tok_base, _prompt(oversample),
127
+ num_return_sequences=1, max_new_tokens=240,
128
+ temperature=1.0 + 0.05*attempt, top_p=0.95)[0]
129
+ # collect
 
 
 
130
  for line in raw.splitlines():
131
  nm = line.strip().lstrip("-•*0123456789. ").strip()
132
+ if nm:
133
+ nm = re.sub(r"[^\w\s-]+$", "", nm).strip()
134
+ all_candidates.append(nm)
 
 
135
  # dedup
136
  uniq, seen = [], set()
137
  for nm in all_candidates:
138
  key = _normalize_name(nm)
139
+ if key and key not in seen:
140
+ seen.add(key); uniq.append(nm)
141
  all_candidates = uniq
142
+ # progressive filter
 
143
  def ok(nm: str, overlap_cap: float, tol_boost: int):
144
  if not _has_vowel(nm): return False
145
+ if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX): return False
 
146
  toks = _content_words(nm)
147
  if _overlap_ratio(toks, banned) > overlap_cap: return False
148
  if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False
149
  return True
150
+ overlap_caps = [0.25, 0.35, 0.5]; tol_boosts = [0, 1, 2]
 
 
151
  filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])]
152
+ if len(filtered) >= n: return filtered[:n]
 
 
153
  return all_candidates[:n] if all_candidates else []
154
 
155
+ # Tagline/description length targets (from your EDA)
156
+ TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6
157
+ TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2
158
+ DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230
159
+ DESC_WORD_MIN, DESC_WORD_MAX = 27, 35
160
+
161
  def _trim_to_words(text: str, max_words: int) -> str:
162
  toks = text.split()
163
+ return text.strip() if len(toks) <= max_words else " ".join(toks[:max_words]).rstrip(",;:") + "."
 
164
 
165
  def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int):
166
  text = text.strip()
167
  if len(text) <= max_chars and len(text) >= min_chars: return text
168
  cutoff = min(max_chars, len(text)); candidate = text[:cutoff]
169
  m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate)
170
+ if m and (len(candidate[:m.end()].strip()) >= min_chars): return candidate[:m.end()].strip()
 
171
  return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "")
172
 
173
  def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool:
174
  c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax)
175
 
176
  def generate_tagline_and_desc(name: str, query_context: str):
 
177
  tag_prompt = (
178
  f"Write a short, benefit-driven tagline for a startup called '{name}'. "
179
  f"Audience & domain: {query_context}. "
 
188
  TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL):
189
  tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0]
190
  tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL)
191
+ if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET): tagline = tagline2
 
192
 
 
193
  desc_prompt = (
194
  f"Write a concise product description for the startup '{name}'. "
195
  f"Context: {query_context}. "
 
200
  model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base)
201
  description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0]
202
  description = re.sub(r"\s+", " ", description).strip()
203
+ if len(description.split()) > DESC_WORD_MAX: description = _trim_to_words(description, DESC_WORD_MAX)
 
204
  description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX)
 
205
  if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX):
206
  description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0]
207
  description2 = re.sub(r"\s+", " ", description2).strip()
208
+ if len(description2.split()) > DESC_WORD_MAX: description2 = _trim_to_words(description2, DESC_WORD_MAX)
 
209
  description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX)
210
  target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2
211
+ if abs(len(description2) - target_mid) < abs(len(description) - target_mid): description = description2
 
 
212
  return tagline, description
213
 
214
  def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False):
 
215
  names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3)
216
  if len(names) == 0:
 
217
  names = generate_names(query, n=n_candidates, oversample=140, max_retries=1)
218
  if len(names) == 0:
219
  toks = _content_words(query) or ["nova","learn","edu","mento"]
220
+ seeds = list({t[:4]+"ify" for t in toks} | {t[:3]+"ora" for t in toks} | {t[:4]+"io" for t in toks})
221
+ names = seeds[:n_candidates]
222
+ qv = _embed_query(query); embs = _embed_passages(names); cos = embs @ qv
 
 
 
 
 
 
 
 
223
  banned = sorted(set(_content_words(query)))
224
  final_scores = []
225
  for nm, s in zip(names, cos):
226
+ toks = _content_words(nm); overlap = _overlap_ratio(toks, banned)
227
+ length_pen = 0.0; L = len(_normalize_name(nm))
228
+ if L < 4: length_pen += 0.3
 
 
229
  if L > 16: length_pen += 0.2
230
+ final_scores.append(float(s) - 0.35*overlap - length_pen)
231
+ best_idx = int(np.argmax(final_scores)); best_name = names[best_idx]; best_score = float(final_scores[best_idx])
 
 
 
 
 
232
  tagline, description = ("","")
233
+ if include_copy: tagline, description = generate_tagline_and_desc(best_name, query_context=query)
234
+ row = pd.DataFrame([{"rank":4,"score":best_score,"name":best_name,"tagline":tagline,"description":description}])
 
 
 
 
 
 
 
 
235
  return row
236
 
237
+ # ---------- UI glue ----------
238
+ EXAMPLES = [
239
+ "AI tool to analyze customer feedback",
240
+ "Social network for jobs",
241
+ "Mobile fintech app for cross-border payments",
242
+ "AI learning tool for students",
243
+ "Marketplace for eco-friendly products",
244
+ ]
245
 
246
  def ui_search(query, state_unlikes):
 
247
  query = (query or "").strip()
248
+ if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea."
249
+ state_unlikes = [] # reset for new query
250
+ res = search_topk_filtered_session(query, k=3, unliked_ids=set())
251
+ return res, state_unlikes, "Found 3 similar items. You can unlike by row_idx, then Refresh."
 
 
252
 
253
  def ui_unlike(query, unlike_ids_csv, state_unlikes):
254
  query = (query or "").strip()
255
+ if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first."
 
256
  add_ids = set()
257
  for tok in (unlike_ids_csv or "").split(","):
258
  tok = tok.strip()
259
+ if tok.isdigit(): add_ids.add(int(tok))
260
+ cur = set(state_unlikes) | add_ids
261
+ res = search_topk_filtered_session(query, k=3, unliked_ids=cur)
262
+ return res, list(cur), f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(cur)}"
 
 
263
 
264
  def ui_clear_unlikes(query):
265
  query = (query or "").strip()
266
+ if not query: return gr.update(value=pd.DataFrame()), [], "Enter a query first."
267
+ res = search_topk_filtered_session(query, k=3, unliked_ids=set())
268
+ return res, [], "Cleared unlikes."
 
269
 
270
  def ui_generate_synth(query, include_copy):
271
  query = (query or "").strip()
272
+ if not query: return gr.update(value=pd.DataFrame()), "Enter a query first."
 
273
  synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy)
274
  return synth, "Generated AI option as #4. Combine it with your top-3."
275
 
276
+ def _apply_example(example_text, state_unlikes):
277
+ results, state_unlikes, msg = ui_search(example_text, state_unlikes)
278
+ return example_text, results, state_unlikes, f"Example selected: “{example_text}”. {msg}"
279
+
280
+ with gr.Blocks(title="Startup Recommender + AI Name") as app:
281
+ gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline & description).")
282
+
283
+ query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback")
284
 
 
 
285
  with gr.Row():
286
  gr.Markdown("**Try an example:**")
287
+ example_buttons = [gr.Button(ex, variant="secondary") for ex in EXAMPLES]
288
+
 
289
  with gr.Row():
290
  btn_search = gr.Button("Search Top-3")
291
  unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456")
292
  btn_unlike = gr.Button("Refresh after Unlike")
293
  btn_clear = gr.Button("Clear Unlikes")
 
 
294
 
295
+ results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True)
 
 
 
296
 
297
+ gr.Markdown("### AI-Generated Option (#4)")
298
+ include_copy = gr.Checkbox(label="Also generate tagline & description", value=True)
299
+ btn_synth = gr.Button("Generate #4 (AI)")
300
  synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True)
 
301
 
302
+ status = gr.Markdown("")
303
  state_unlikes = gr.State([])
304
 
305
+ # wiring
306
  btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status])
307
  btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status])
308
  btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status])
309
+
310
  for btn, ex in zip(example_buttons, EXAMPLES):
311
+ btn.click(lambda st, ex_=ex: _apply_example(ex_, st),
312
+ inputs=[state_unlikes], outputs=[query, results_tbl, state_unlikes, status])
 
 
 
 
313
 
314
+ btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status])
315
 
316
+ # On Spaces, just calling launch() is fine; no explicit port.
317
  if __name__ == "__main__":
318
+ app.queue(concurrency_count=2).launch()