Spaces:
Sleeping
Sleeping
| # app.py — Startup recommender + Unlike + AI name (optional tagline/description) | |
| import os, re, numpy as np, pandas as pd | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch, faiss | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ---------- Paths / artifacts ---------- | |
| OUT_DIR = Path("./emb_index_e5") | |
| FAISS_PATH = OUT_DIR / "faiss.index" | |
| DATA_PATH = OUT_DIR / "data.parquet" | |
| assert FAISS_PATH.exists(), f"Missing {FAISS_PATH}. Build & upload embeddings/index." | |
| assert DATA_PATH.exists(), f"Missing {DATA_PATH}. Build & upload data parquet." | |
| # ---------- Devices ---------- | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| DEVICE_EMBED = "cuda" if torch.cuda.is_available() else "cpu" # e5 on GPU if available | |
| DEVICE_GEN = "cpu" # FLAN on CPU (avoid OOM) | |
| print(f"Embed device: {DEVICE_EMBED} | Gen device: {DEVICE_GEN}") | |
| # ---------- Load artifacts ---------- | |
| index = faiss.read_index(str(FAISS_PATH)) | |
| df_local = pd.read_parquet(DATA_PATH) | |
| for c in ["name","tagline","description"]: | |
| if c in df_local.columns: | |
| df_local[c] = df_local[c].astype(str).fillna("") | |
| # ---------- Load models ---------- | |
| EMBED_MODEL = "intfloat/e5-base-v2" | |
| embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE_EMBED) | |
| MODEL_BASE = "google/flan-t5-base" | |
| MODEL_LARGE = "google/flan-t5-large" | |
| USE_LARGE_FOR_DESCRIPTION = False # keep False on Spaces unless you switch GEN to "cuda" | |
| tok_base = AutoTokenizer.from_pretrained(MODEL_BASE) | |
| base_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {} | |
| mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE, **base_kwargs).to(DEVICE_GEN) | |
| if USE_LARGE_FOR_DESCRIPTION: | |
| tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE) | |
| large_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {} | |
| mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE, **large_kwargs).to(DEVICE_GEN) | |
| else: | |
| tok_large, mod_large = tok_base, mod_base | |
| # ---------- Helpers (embedding + generation) ---------- | |
| def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1): | |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE_GEN) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| num_return_sequences=num_return_sequences | |
| ) | |
| return [tokenizer.decode(o, skip_special_tokens=True).strip() for o in outputs] | |
| def _embed_query(q: str) -> np.ndarray: | |
| return embed_model.encode([f"query: {q}"], convert_to_numpy=True, normalize_embeddings=True).astype("float32")[0] | |
| def _embed_passages(texts) -> np.ndarray: | |
| texts = [f"passage: {t}" for t in texts] | |
| return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32") | |
| # ---------- Search with per-session unlikes ---------- | |
| def search_topk_filtered_session(query: str, k: int, unliked_ids: set): | |
| qv = _embed_query(query) | |
| fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids))) | |
| scores, inds = index.search(qv[None, :], fetch) | |
| inds = inds[0].tolist(); scores = scores[0].tolist() | |
| res = df_local.iloc[inds][["name","tagline","description"]].copy() | |
| res.insert(0, "row_idx", df_local.iloc[inds].index) | |
| res.insert(1, "score", [float(s) for s in scores]) | |
| res = res[~res["row_idx"].isin(unliked_ids)].head(k).reset_index(drop=True) | |
| res.insert(0, "rank", range(1, len(res)+1)) | |
| return res | |
| # ---------- Synthetic generation (length-aware) ---------- | |
| _STOPWORDS = { | |
| "the","a","an","for","and","or","to","of","in","on","with","by","from", | |
| "my","our","your","their","at","as","about","into","over","under","this","that", | |
| "idea","startup","company","product","service","app","platform","factory","labs","tech" | |
| } | |
| def _words(s: str): return [w for w in re.findall(r"[a-z]+", str(s).lower()) if w] | |
| def _content_words(s: str): return [w for w in _words(s) if len(w) >= 3 and w not in _STOPWORDS] | |
| def _normalize_name(s: str) -> str: return re.sub(r"[^a-z0-9]+", "", str(s).lower()) | |
| def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower())) | |
| def _overlap_ratio(name_tokens, banned): | |
| if not name_tokens or not banned: return 0.0 | |
| inter = len(set(name_tokens) & set(banned)); union = len(set(name_tokens) | set(banned)) | |
| return inter / max(union, 1) | |
| NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3 | |
| NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3 | |
| def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int): | |
| c = len(text); w = len(text.split()) | |
| return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words) | |
| def _theme_hints(query: str, k: int = 6): | |
| kws = _content_words(query); seen, hints = set(), [] | |
| for t in kws: | |
| if t not in seen: hints.append(t); seen.add(t) | |
| return ", ".join(hints[:k]) if hints else "education, learning, students, AI" | |
| def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3): | |
| banned = sorted(set(_content_words(base_idea))) | |
| avoid_str = ", ".join(banned[:12]) if banned else "previous words" | |
| hints = _theme_hints(base_idea) | |
| all_candidates = [] | |
| def _prompt(osz): | |
| return ( | |
| f"Create {osz} brandable startup names for this idea:\n" | |
| f"\"{base_idea}\"\n\n" | |
| f"Guidance:\n" | |
| f"- Evoke these themes (without literally using the words): {hints}\n" | |
| f"- 1 or 2 words; aim ~{NAME_CHAR_TARGET} characters (±{NAME_CHAR_TOL})\n" | |
| f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n" | |
| f"- Do NOT use: {avoid_str}\n" | |
| f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n" | |
| f"- Output one name per line; no numbering, no quotes." | |
| ) | |
| for attempt in range(max_retries): | |
| raw = _generate_text(mod_base, tok_base, _prompt(oversample), | |
| num_return_sequences=1, max_new_tokens=240, | |
| temperature=1.0 + 0.05*attempt, top_p=0.95)[0] | |
| # collect | |
| for line in raw.splitlines(): | |
| nm = line.strip().lstrip("-•*0123456789. ").strip() | |
| if nm: | |
| nm = re.sub(r"[^\w\s-]+$", "", nm).strip() | |
| all_candidates.append(nm) | |
| # dedup | |
| uniq, seen = [], set() | |
| for nm in all_candidates: | |
| key = _normalize_name(nm) | |
| if key and key not in seen: | |
| seen.add(key); uniq.append(nm) | |
| all_candidates = uniq | |
| # progressive filter | |
| def ok(nm: str, overlap_cap: float, tol_boost: int): | |
| if not _has_vowel(nm): return False | |
| if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX): return False | |
| toks = _content_words(nm) | |
| if _overlap_ratio(toks, banned) > overlap_cap: return False | |
| if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False | |
| return True | |
| overlap_caps = [0.25, 0.35, 0.5]; tol_boosts = [0, 1, 2] | |
| filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])] | |
| if len(filtered) >= n: return filtered[:n] | |
| return all_candidates[:n] if all_candidates else [] | |
| # Tagline/description length targets (from your EDA) | |
| TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6 | |
| TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2 | |
| DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230 | |
| DESC_WORD_MIN, DESC_WORD_MAX = 27, 35 | |
| def _trim_to_words(text: str, max_words: int) -> str: | |
| toks = text.split() | |
| return text.strip() if len(toks) <= max_words else " ".join(toks[:max_words]).rstrip(",;:") + "." | |
| def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int): | |
| text = text.strip() | |
| if len(text) <= max_chars and len(text) >= min_chars: return text | |
| cutoff = min(max_chars, len(text)); candidate = text[:cutoff] | |
| m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate) | |
| if m and (len(candidate[:m.end()].strip()) >= min_chars): return candidate[:m.end()].strip() | |
| return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "") | |
| def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool: | |
| c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax) | |
| def generate_tagline_and_desc(name: str, query_context: str): | |
| tag_prompt = ( | |
| f"Write a short, benefit-driven tagline for a startup called '{name}'. " | |
| f"Audience & domain: {query_context}. " | |
| f"Target ~{TAG_CHAR_TARGET} characters and ~{TAG_WORD_TARGET} words. Avoid clichés." | |
| ) | |
| tagline = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=28, temperature=0.9, top_p=0.95)[0] | |
| tagline = re.sub(r"\s+", " ", tagline).strip() | |
| tagline = _trim_to_words(tagline, TAG_WORD_TARGET + TAG_WORD_TOL) | |
| if len(tagline) > TAG_CHAR_TARGET + TAG_CHAR_TOL: | |
| tagline = tagline[:TAG_CHAR_TARGET + TAG_CHAR_TOL].rstrip(",;: -") + "…" | |
| if not _within_ranges(tagline, TAG_CHAR_TARGET - TAG_CHAR_TOL, TAG_CHAR_TARGET + TAG_CHAR_TOL, | |
| TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL): | |
| tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0] | |
| tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL) | |
| if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET): tagline = tagline2 | |
| desc_prompt = ( | |
| f"Write a concise product description for the startup '{name}'. " | |
| f"Context: {query_context}. " | |
| f"Explain who it's for, what it does, and the main benefit. " | |
| f"Target {DESC_CHAR_MIN}–{DESC_CHAR_MAX} characters and {DESC_WORD_MIN}–{DESC_WORD_MAX} words. " | |
| f"Avoid fluff; keep it clear." | |
| ) | |
| model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base) | |
| description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0] | |
| description = re.sub(r"\s+", " ", description).strip() | |
| if len(description.split()) > DESC_WORD_MAX: description = _trim_to_words(description, DESC_WORD_MAX) | |
| description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX) | |
| if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX): | |
| description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0] | |
| description2 = re.sub(r"\s+", " ", description2).strip() | |
| if len(description2.split()) > DESC_WORD_MAX: description2 = _trim_to_words(description2, DESC_WORD_MAX) | |
| description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX) | |
| target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2 | |
| if abs(len(description2) - target_mid) < abs(len(description) - target_mid): description = description2 | |
| return tagline, description | |
| def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False): | |
| names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3) | |
| if len(names) == 0: | |
| names = generate_names(query, n=n_candidates, oversample=140, max_retries=1) | |
| if len(names) == 0: | |
| toks = _content_words(query) or ["nova","learn","edu","mento"] | |
| seeds = list({t[:4]+"ify" for t in toks} | {t[:3]+"ora" for t in toks} | {t[:4]+"io" for t in toks}) | |
| names = seeds[:n_candidates] | |
| qv = _embed_query(query); embs = _embed_passages(names); cos = embs @ qv | |
| banned = sorted(set(_content_words(query))) | |
| final_scores = [] | |
| for nm, s in zip(names, cos): | |
| toks = _content_words(nm); overlap = _overlap_ratio(toks, banned) | |
| length_pen = 0.0; L = len(_normalize_name(nm)) | |
| if L < 4: length_pen += 0.3 | |
| if L > 16: length_pen += 0.2 | |
| final_scores.append(float(s) - 0.35*overlap - length_pen) | |
| best_idx = int(np.argmax(final_scores)); best_name = names[best_idx]; best_score = float(final_scores[best_idx]) | |
| tagline, description = ("","") | |
| if include_copy: tagline, description = generate_tagline_and_desc(best_name, query_context=query) | |
| row = pd.DataFrame([{"rank":4,"score":best_score,"name":best_name,"tagline":tagline,"description":description}]) | |
| return row | |
| # ---------- UI glue ---------- | |
| EXAMPLES = [ | |
| "AI tool to analyze customer feedback", | |
| "Social network for jobs", | |
| "Mobile fintech app for cross-border payments", | |
| "AI learning tool for students", | |
| "Marketplace for eco-friendly products", | |
| ] | |
| def ui_search(query, state_unlikes): | |
| query = (query or "").strip() | |
| if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea." | |
| state_unlikes = [] # reset for new query | |
| res = search_topk_filtered_session(query, k=3, unliked_ids=set()) | |
| return res, state_unlikes, "Found 3 similar items. You can unlike by row_idx, then Refresh." | |
| def ui_unlike(query, unlike_ids_csv, state_unlikes): | |
| query = (query or "").strip() | |
| if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first." | |
| add_ids = set() | |
| for tok in (unlike_ids_csv or "").split(","): | |
| tok = tok.strip() | |
| if tok.isdigit(): add_ids.add(int(tok)) | |
| cur = set(state_unlikes) | add_ids | |
| res = search_topk_filtered_session(query, k=3, unliked_ids=cur) | |
| return res, list(cur), f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(cur)}" | |
| def ui_clear_unlikes(query): | |
| query = (query or "").strip() | |
| if not query: return gr.update(value=pd.DataFrame()), [], "Enter a query first." | |
| res = search_topk_filtered_session(query, k=3, unliked_ids=set()) | |
| return res, [], "Cleared unlikes." | |
| def ui_generate_synth(query, include_copy): | |
| query = (query or "").strip() | |
| if not query: return gr.update(value=pd.DataFrame()), "Enter a query first." | |
| synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy) | |
| return synth, "Generated AI option as #4. Combine it with your top-3." | |
| def _apply_example(example_text, state_unlikes): | |
| results, state_unlikes, msg = ui_search(example_text, state_unlikes) | |
| return example_text, results, state_unlikes, f"Example selected: “{example_text}”. {msg}" | |
| with gr.Blocks(title="Startup Recommender + AI Name") as app: | |
| gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline & description).") | |
| query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback") | |
| with gr.Row(): | |
| gr.Markdown("**Try an example:**") | |
| example_buttons = [gr.Button(ex, variant="secondary") for ex in EXAMPLES] | |
| with gr.Row(): | |
| btn_search = gr.Button("Search Top-3") | |
| unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456") | |
| btn_unlike = gr.Button("Refresh after Unlike") | |
| btn_clear = gr.Button("Clear Unlikes") | |
| results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True) | |
| gr.Markdown("### AI-Generated Option (#4)") | |
| include_copy = gr.Checkbox(label="Also generate tagline & description", value=True) | |
| btn_synth = gr.Button("Generate #4 (AI)") | |
| synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True) | |
| status = gr.Markdown("") | |
| state_unlikes = gr.State([]) | |
| # wiring | |
| btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status]) | |
| btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status]) | |
| btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status]) | |
| for btn, ex in zip(example_buttons, EXAMPLES): | |
| btn.click(lambda st, ex_=ex: _apply_example(ex_, st), | |
| inputs=[state_unlikes], outputs=[query, results_tbl, state_unlikes, status]) | |
| btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status]) | |
| # On Spaces, just calling launch() is fine; no explicit port. | |
| if __name__ == "__main__": | |
| app.launch() | |