barakb21 commited on
Commit
83bd1fa
·
verified ·
1 Parent(s): 819022a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +403 -0
app.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Simple Interactive UI (Gradio) for: search → unlike → synthetic #4 =========
2
+ # Requires that you already ran the embedding/index step and have:
3
+ # ./emb_index_e5/faiss.index and ./emb_index_e5/data.parquet
4
+ # Also expects you kept the improved generation functions we wrote earlier
5
+ # (generate_names, generate_tagline_and_desc(name, query_context), pick_best_synthetic_name).
6
+ # If not, this file includes compact versions below.
7
+
8
+ import os, re, random, numpy as np, pandas as pd
9
+ from pathlib import Path
10
+ import gradio as gr
11
+
12
+ # --- Imports for models ---
13
+ import torch
14
+ import faiss
15
+ from sentence_transformers import SentenceTransformer
16
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
17
+
18
+ # --- Paths / config ---
19
+ OUT_DIR = Path("./emb_index_e5")
20
+ FAISS_PATH = OUT_DIR / "faiss.index"
21
+ DATA_PATH = OUT_DIR / "data.parquet"
22
+
23
+ # Toggle: use flan-t5-large only for description richness (if GPU T4 available)
24
+ USE_LARGE_FOR_DESCRIPTION = True
25
+ MODEL_BASE = "google/flan-t5-base"
26
+ MODEL_LARGE = "google/flan-t5-large"
27
+ EMBED_MODEL = "intfloat/e5-base-v2"
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ # --- Load artifacts once ---
31
+ assert FAISS_PATH.exists(), f"Missing FAISS index at {FAISS_PATH}. Run the embedding/index step first."
32
+ assert DATA_PATH.exists(), f"Missing dataset at {DATA_PATH}. Run the embedding/index step first."
33
+
34
+ index = faiss.read_index(str(FAISS_PATH))
35
+ df_local = pd.read_parquet(DATA_PATH)
36
+ for col in ["name","tagline","description"]:
37
+ if col in df_local.columns:
38
+ df_local[col] = df_local[col].astype(str).fillna("")
39
+
40
+ # --- Load embedding model once ---
41
+ embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE)
42
+
43
+ # --- Load FLAN models (base + optional large) once ---
44
+ tok_base = AutoTokenizer.from_pretrained(MODEL_BASE)
45
+ mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE).to(DEVICE)
46
+ if USE_LARGE_FOR_DESCRIPTION:
47
+ tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE)
48
+ mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE).to(DEVICE)
49
+ else:
50
+ tok_large, mod_large = tok_base, mod_base
51
+
52
+ # ===================== Small helpers reused from your project =====================
53
+
54
+ def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1):
55
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=max_new_tokens,
59
+ temperature=temperature,
60
+ top_p=top_p,
61
+ do_sample=True,
62
+ num_return_sequences=num_return_sequences
63
+ )
64
+ return [tokenizer.decode(o, skip_special_tokens=True).strip() for o in outputs]
65
+
66
+ def _embed_query(q: str) -> np.ndarray:
67
+ return embed_model.encode([f"query: {q}"], convert_to_numpy=True, normalize_embeddings=True).astype("float32")[0]
68
+
69
+ def _embed_passages(texts) -> np.ndarray:
70
+ texts = [f"passage: {t}" for t in texts]
71
+ return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
72
+
73
+ # ======== 1-click Examples ========
74
+
75
+ EXAMPLES = [
76
+ "AI tool to analyze customer feedback",
77
+ "Social network for jobs",
78
+ "Mobile fintech app for cross-border payments",
79
+ "AI learning tool for students",
80
+ "Marketplace for eco-friendly products",
81
+ ]
82
+
83
+ # Helper that applies an example: fills the textbox, resets unlikes, runs search
84
+ def _apply_example(example_text, state_unlikes):
85
+ # Reuse your existing ui_search() so behavior stays identical
86
+ results, state_unlikes, msg = ui_search(example_text, state_unlikes)
87
+ # Return: set the query box text, update table & state & status
88
+ return example_text, results, state_unlikes, f"Example selected: “{example_text}” — {msg}"
89
+
90
+ # ---------- Search (with per-session unlikes) ----------
91
+ def search_topk_filtered_session(query: str, k: int, unliked_ids: set):
92
+ """Return top-K rows excluding row_idx values in unliked_ids."""
93
+ qv = _embed_query(query)
94
+ fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids)))
95
+ scores, inds = index.search(qv[None, :], fetch)
96
+ inds = inds[0].tolist()
97
+ scores = scores[0].tolist()
98
+
99
+ res = df_local.iloc[inds][["name","tagline","description"]].copy()
100
+ res.insert(0, "row_idx", df_local.iloc[inds].index)
101
+ res.insert(1, "score", [float(s) for s in scores])
102
+
103
+ # filter-out unlikes
104
+ mask = ~res["row_idx"].isin(unliked_ids)
105
+ res = res[mask].head(k).reset_index(drop=True)
106
+ res.insert(0, "rank", range(1, len(res)+1))
107
+ return res
108
+
109
+ # ---------- Synthetic generation (uses the improved functions) ----------
110
+ # Length targets (from your EDA)
111
+ TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6
112
+ TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2
113
+ DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230
114
+ DESC_WORD_MIN, DESC_WORD_MAX = 27, 35
115
+
116
+ _STOPWORDS = {
117
+ "the","a","an","for","and","or","to","of","in","on","with","by","from",
118
+ "my","our","your","their","at","as","about","into","over","under","this","that",
119
+ "idea","startup","company","product","service","app","platform","factory","labs","tech"
120
+ }
121
+ def _words(s: str): return [w for w in re.findall(r"[a-z]+", str(s).lower()) if w]
122
+ def _content_words(s: str): return [w for w in _words(s) if len(w) >= 3 and w not in _STOPWORDS]
123
+ def _normalize_name(s: str) -> str: return re.sub(r"[^a-z0-9]+", "", str(s).lower())
124
+ def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower()))
125
+ def _overlap_ratio(name_tokens, banned):
126
+ if not name_tokens or not banned: return 0.0
127
+ inter = len(set(name_tokens) & set(banned))
128
+ union = len(set(name_tokens) | set(banned))
129
+ return inter / max(union, 1)
130
+
131
+ def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int):
132
+ c = len(text); w = len(text.split())
133
+ return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words)
134
+
135
+ def _theme_hints(query: str, k: int = 6):
136
+ kws = _content_words(query)
137
+ seen, hints = set(), []
138
+ for t in kws:
139
+ if t not in seen:
140
+ hints.append(t); seen.add(t)
141
+ return ", ".join(hints[:k]) if hints else "education, learning, students, AI"
142
+
143
+ NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3
144
+ NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3
145
+
146
+ def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3):
147
+ banned = sorted(set(_content_words(base_idea)))
148
+ avoid_str = ", ".join(banned[:12]) if banned else "previous words"
149
+ hints = _theme_hints(base_idea)
150
+
151
+ def _prompt(osz):
152
+ return (
153
+ f"Create {osz} brandable startup names for this idea:\n"
154
+ f"\"{base_idea}\"\n\n"
155
+ f"Guidance:\n"
156
+ f"- Evoke these themes (without literally using the words): {hints}\n"
157
+ f"- 1 or 2 or 3 words; aim ~{NAME_CHAR_TARGET} characters total (±{NAME_CHAR_TOL})\n"
158
+ f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n"
159
+ f"- Do NOT use: {avoid_str}\n"
160
+ f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n"
161
+ f"- Output one name per line; no numbering, no quotes."
162
+ )
163
+
164
+ all_candidates = []
165
+ for attempt in range(max_retries):
166
+ raw = _generate_text(
167
+ mod_base, tok_base, _prompt(oversample),
168
+ num_return_sequences=1, max_new_tokens=240,
169
+ temperature=1.0 + 0.05*attempt, top_p=0.95
170
+ )[0]
171
+
172
+ batch = []
173
+ for line in raw.splitlines():
174
+ nm = line.strip().lstrip("-•*0123456789. ").strip()
175
+ if not nm: continue
176
+ nm = re.sub(r"[^\w\s-]+$", "", nm).strip()
177
+ batch.append(nm)
178
+ all_candidates.extend(batch)
179
+
180
+ # dedup
181
+ uniq, seen = [], set()
182
+ for nm in all_candidates:
183
+ key = _normalize_name(nm)
184
+ if not key or key in seen: continue
185
+ seen.add(key); uniq.append(nm)
186
+ all_candidates = uniq
187
+
188
+ # progressive filtering
189
+ def ok(nm: str, overlap_cap: float, tol_boost: int):
190
+ if not _has_vowel(nm): return False
191
+ if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX):
192
+ return False
193
+ toks = _content_words(nm)
194
+ if _overlap_ratio(toks, banned) > overlap_cap: return False
195
+ if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False
196
+ return True
197
+
198
+ overlap_caps = [0.25, 0.35, 0.5]
199
+ tol_boosts = [0, 1, 2]
200
+ filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])]
201
+ if len(filtered) >= n:
202
+ return filtered[:n]
203
+
204
+ return all_candidates[:n] if all_candidates else []
205
+
206
+ def _trim_to_words(text: str, max_words: int) -> str:
207
+ toks = text.split()
208
+ if len(toks) <= max_words: return text.strip()
209
+ return " ".join(toks[:max_words]).rstrip(",;:") + "."
210
+
211
+ def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int):
212
+ text = text.strip()
213
+ if len(text) <= max_chars and len(text) >= min_chars: return text
214
+ cutoff = min(max_chars, len(text)); candidate = text[:cutoff]
215
+ m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate)
216
+ if m and (len(candidate[:m.end()].strip()) >= min_chars):
217
+ return candidate[:m.end()].strip()
218
+ return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "")
219
+
220
+ def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool:
221
+ c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax)
222
+
223
+ def generate_tagline_and_desc(name: str, query_context: str):
224
+ # Tagline (≈40 chars, ≈6 words)
225
+ tag_prompt = (
226
+ f"Write a short, benefit-driven tagline for a startup called '{name}'. "
227
+ f"Audience & domain: {query_context}. "
228
+ f"Target ~{TAG_CHAR_TARGET} characters and ~{TAG_WORD_TARGET} words. Avoid clichés."
229
+ )
230
+ tagline = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=28, temperature=0.9, top_p=0.95)[0]
231
+ tagline = re.sub(r"\s+", " ", tagline).strip()
232
+ tagline = _trim_to_words(tagline, TAG_WORD_TARGET + TAG_WORD_TOL)
233
+ if len(tagline) > TAG_CHAR_TARGET + TAG_CHAR_TOL:
234
+ tagline = tagline[:TAG_CHAR_TARGET + TAG_CHAR_TOL].rstrip(",;: -") + "…"
235
+ if not _within_ranges(tagline, TAG_CHAR_TARGET - TAG_CHAR_TOL, TAG_CHAR_TARGET + TAG_CHAR_TOL,
236
+ TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL):
237
+ tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0]
238
+ tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL)
239
+ if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET):
240
+ tagline = tagline2
241
+
242
+ # Description (≈170–230 chars, 27–35 words)
243
+ desc_prompt = (
244
+ f"Write a concise product description for the startup '{name}'. "
245
+ f"Context: {query_context}. "
246
+ f"Explain who it's for, what it does, and the main benefit. "
247
+ f"Target {DESC_CHAR_MIN}–{DESC_CHAR_MAX} characters and {DESC_WORD_MIN}–{DESC_WORD_MAX} words. "
248
+ f"Avoid fluff; keep it clear."
249
+ )
250
+ model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base)
251
+ description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0]
252
+ description = re.sub(r"\s+", " ", description).strip()
253
+ if len(description.split()) > DESC_WORD_MAX:
254
+ description = _trim_to_words(description, DESC_WORD_MAX)
255
+ description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX)
256
+
257
+ if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX):
258
+ description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0]
259
+ description2 = re.sub(r"\s+", " ", description2).strip()
260
+ if len(description2.split()) > DESC_WORD_MAX:
261
+ description2 = _trim_to_words(description2, DESC_WORD_MAX)
262
+ description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX)
263
+ target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2
264
+ if abs(len(description2) - target_mid) < abs(len(description) - target_mid):
265
+ description = description2
266
+
267
+ return tagline, description
268
+
269
+ def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False):
270
+ # Generate candidates & score vs query (cosine)
271
+ names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3)
272
+ if len(names) == 0:
273
+ # permissive retry
274
+ names = generate_names(query, n=n_candidates, oversample=140, max_retries=1)
275
+ if len(names) == 0:
276
+ toks = _content_words(query) or ["nova","learn","edu","mento"]
277
+ seeds = set()
278
+ for t in toks:
279
+ seeds.add((t[:4] + "ify"))
280
+ seeds.add((t[:3] + "ora"))
281
+ seeds.add((t[:4] + "io"))
282
+ names = list(seeds)[:n_candidates]
283
+
284
+ qv = _embed_query(query)
285
+ embs = _embed_passages(names)
286
+ cos = embs @ qv
287
+
288
+ banned = sorted(set(_content_words(query)))
289
+ final_scores = []
290
+ for nm, s in zip(names, cos):
291
+ toks = _content_words(nm)
292
+ overlap = _overlap_ratio(toks, banned)
293
+ length_pen = 0.0
294
+ L = len(_normalize_name(nm))
295
+ if L < 4: length_pen += 0.3
296
+ if L > 16: length_pen += 0.2
297
+ score = float(s) - 0.35*overlap - length_pen
298
+ final_scores.append(score)
299
+
300
+ best_idx = int(np.argmax(final_scores))
301
+ best_name = names[best_idx]
302
+ best_score = float(final_scores[best_idx])
303
+
304
+ tagline, description = ("","")
305
+ if include_copy:
306
+ tagline, description = generate_tagline_and_desc(best_name, query_context=query)
307
+
308
+ row = pd.DataFrame([{
309
+ "rank": 4,
310
+ "score": best_score,
311
+ "name": best_name,
312
+ "tagline": tagline,
313
+ "description": description
314
+ }])
315
+ return row
316
+
317
+ # ============================= Gradio UI logic =============================
318
+
319
+ def ui_search(query, state_unlikes):
320
+ """Start a new search, reset unlikes if it's a new query."""
321
+ query = (query or "").strip()
322
+ if not query:
323
+ return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea/description."
324
+ # For a new search, reset unlikes
325
+ state_unlikes = set()
326
+ results = search_topk_filtered_session(query, k=3, unliked_ids=state_unlikes)
327
+ return results, list(state_unlikes), "Found 3 similar items. You can unlike by row_idx, then 'Refresh'."
328
+
329
+ def ui_unlike(query, unlike_ids_csv, state_unlikes):
330
+ query = (query or "").strip()
331
+ if not query:
332
+ return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first."
333
+ add_ids = set()
334
+ for tok in (unlike_ids_csv or "").split(","):
335
+ tok = tok.strip()
336
+ if tok.isdigit():
337
+ add_ids.add(int(tok))
338
+ state_unlikes = set(state_unlikes) | add_ids
339
+ results = search_topk_filtered_session(query, k=3, unliked_ids=state_unlikes)
340
+ msg = f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(state_unlikes)}"
341
+ return results, list(state_unlikes), msg
342
+
343
+ def ui_clear_unlikes(query):
344
+ query = (query or "").strip()
345
+ if not query:
346
+ return gr.update(value=pd.DataFrame()), [], "Enter a query first."
347
+ results = search_topk_filtered_session(query, k=3, unliked_ids=set())
348
+ return results, [], "Cleared unlikes."
349
+
350
+ def ui_generate_synth(query, include_copy):
351
+ query = (query or "").strip()
352
+ if not query:
353
+ return gr.update(value=pd.DataFrame()), "Enter a query first."
354
+ synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy)
355
+ return synth, "Generated AI option as #4. Combine it with your top-3."
356
+
357
+ with gr.Blocks(title="Startup Recommender + Synthetic Name") as app:
358
+ gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline + description).")
359
+
360
+ with gr.Row():
361
+ query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback")
362
+ with gr.Row():
363
+ gr.Markdown("**Try an example:**")
364
+ example_buttons = []
365
+ for ex in EXAMPLES:
366
+ example_buttons.append(gr.Button(ex, variant="secondary", scale=1))
367
+ with gr.Row():
368
+ btn_search = gr.Button("Search Top-3")
369
+ unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456")
370
+ btn_unlike = gr.Button("Refresh after Unlike")
371
+ btn_clear = gr.Button("Clear Unlikes")
372
+ with gr.Row():
373
+ results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True)
374
+
375
+ gr.Markdown("### AI-Generated Option (#4)")
376
+ with gr.Row():
377
+ include_copy = gr.Checkbox(label="Also generate tagline & description", value=True)
378
+ btn_synth = gr.Button("Generate #4 (AI)")
379
+
380
+ synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True)
381
+ status = gr.Markdown("")
382
+
383
+ # Session state: list of unliked row_idx
384
+ state_unlikes = gr.State([])
385
+
386
+ # Wiring
387
+ btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status])
388
+ btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status])
389
+ btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status])
390
+ btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status])
391
+ for btn, ex in zip(example_buttons, EXAMPLES):
392
+ # When clicked: fill the query box, run search, reset unlikes
393
+ btn.click(
394
+ lambda st, ex_=ex: _apply_example(ex_, st),
395
+ inputs=[state_unlikes], # current unlike state
396
+ outputs=[query, results_tbl, state_unlikes, status]
397
+ )
398
+
399
+
400
+
401
+ # For local dev; on Spaces this is ignored.
402
+ if __name__ == "__main__":
403
+ app.launch(server_name="0.0.0.0")