Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import logging | |
| import time | |
| import json | |
| import csv | |
| import re | |
| import base64 | |
| from datetime import datetime | |
| from functools import lru_cache | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Set, Tuple | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError | |
| from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words | |
| from psq_rag.llm.rewrite import llm_rewrite_prompt | |
| from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup | |
| from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags | |
| from psq_rag.retrieval.state import ( | |
| expand_tags_via_implications, | |
| get_tag_type_name, | |
| get_tag_implications, | |
| get_tag_counts, | |
| ) | |
| from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups | |
| APP_DIR = Path(__file__).parent | |
| DOCS_DIR = APP_DIR / "docs" | |
| ARCH_DIAGRAM_FILE = DOCS_DIR / "assets" / "architecture_overview.png" | |
| ARCH_DIAGRAM_MARKER = "{{ARCHITECTURE_DIAGRAM}}" | |
| ARCH_DIAGRAM_INSERT_BEFORE_HEADING = "## What Each Step Does" | |
| _CORPORATE_HARDBLOCK_PATTERNS = [ | |
| # Rating-like explicitness markers. | |
| re.compile(r"(^|_)(nsfw|explicit|questionable)(_|$)", re.IGNORECASE), | |
| # Unambiguous sexual anatomy. | |
| re.compile( | |
| r"(^|_)(breast|breasts|boob|boobs|nipple|nipples|penis|vagina|pussy|clit|testicle|scrotum|genital|crotch|anus|anal|areola)(_|$)", | |
| re.IGNORECASE, | |
| ), | |
| # Unambiguous sexual activity. | |
| re.compile( | |
| r"(^|_)(sex|sexual|fucking|fuck|blowjob|handjob|masturbat|penetrat|thrust|orgasm|cum|ejaculat|creampie|nude|naked|topless|bottomless|moan|sexy)(_|$)", | |
| re.IGNORECASE, | |
| ), | |
| # Common kink/fetish markers. | |
| re.compile(r"(^|_)(fetish|bdsm|bondage|dominatrix|submission|vore|inflation|watersports)(_|$)", re.IGNORECASE), | |
| ] | |
| def _split_prompt_commas(s: str) -> List[str]: | |
| return [p.strip() for p in (s or "").split(",") if p.strip()] | |
| def _norm_for_dedupe(tag: str) -> str: | |
| # your canonical form for lookup/dedupe | |
| return _norm_tag_for_lookup(tag.lower()) | |
| def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str: | |
| parts = _split_prompt_commas(rewritten_prompt) | |
| parts.extend(selected_tags) | |
| seen = set() | |
| out = [] | |
| for p in parts: | |
| key = _norm_for_dedupe(p) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| out.append(p) | |
| return ", ".join(out) | |
| def _display_tag_text(tag: str) -> str: | |
| return tag.replace("_", " ") | |
| def _display_row_label(name: str) -> str: | |
| n = (name or "").strip() | |
| if not n: | |
| return "" | |
| if n == "selected_other": | |
| return "Selected (Other)" | |
| return n.replace("_", " ").title() | |
| def _normalize_selection_origin(origin: str) -> str: | |
| o = (origin or "").strip().lower() | |
| if o in {"rewrite", "selection", "probe", "structural", "user", "candidate"}: | |
| return o | |
| return "selection" | |
| def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str: | |
| # Keep labels plain to avoid frontend text/value desynchronization. | |
| return _display_tag_text(tag) | |
| def _load_tag_wiki_defs() -> Dict[str, str]: | |
| p = Path("data/tag_wiki_defs.json") | |
| if not p.exists(): | |
| return {} | |
| try: | |
| with p.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| out: Dict[str, str] = {} | |
| if isinstance(data, dict): | |
| for k, v in data.items(): | |
| tag = _norm_tag_for_lookup(str(k)) | |
| text = " ".join(str(v or "").split()) | |
| if tag and text: | |
| out[tag] = text | |
| return out | |
| except Exception: | |
| return {} | |
| def _load_about_docs_markdown() -> str: | |
| candidates = [ | |
| DOCS_DIR / "space_overview.md", | |
| APP_DIR / "PROJECT_SUMMARY.md", | |
| ] | |
| for p in candidates: | |
| if not p.exists(): | |
| continue | |
| try: | |
| raw = p.read_text(encoding="utf-8") | |
| except Exception: | |
| continue | |
| text = raw.strip() | |
| if not text: | |
| continue | |
| # Strip YAML front matter if present. | |
| if text.startswith("---"): | |
| parts = text.split("---", 2) | |
| if len(parts) >= 3: | |
| text = parts[2].strip() | |
| if text: | |
| return text | |
| return ( | |
| "Documentation is unavailable.\n\n" | |
| "Expected file: `docs/space_overview.md`" | |
| ) | |
| def _tooltip_text_for_tag(tag: str) -> str: | |
| t = _norm_tag_for_lookup(tag) | |
| parts: List[str] = [] | |
| try: | |
| count = get_tag_counts().get(t) | |
| except Exception: | |
| count = None | |
| if isinstance(count, int): | |
| parts.append(f"Count: {count:,}") | |
| d = _load_tag_wiki_defs().get(t, "") | |
| if d: | |
| parts.append(d) | |
| return "\n".join(parts).strip() | |
| def _load_arch_diagram_data_uri() -> str: | |
| if not ARCH_DIAGRAM_FILE.exists(): | |
| return "" | |
| try: | |
| raw = ARCH_DIAGRAM_FILE.read_bytes() | |
| except Exception: | |
| return "" | |
| if not raw: | |
| return "" | |
| b64 = base64.b64encode(raw).decode("ascii") | |
| return f"data:image/png;base64,{b64}" | |
| def _split_about_docs_for_diagram(md: str) -> Tuple[str, str, bool]: | |
| text = (md or "").strip() | |
| if ARCH_DIAGRAM_MARKER in text: | |
| before, after = text.rsplit(ARCH_DIAGRAM_MARKER, 1) | |
| return before.strip(), after.strip(), True | |
| # Backward compatibility if an explicit architecture heading exists in docs. | |
| m_arch = re.search(r"(?m)^##\s+Architecture At A Glance\s*$", text) | |
| if m_arch: | |
| before = text[: m_arch.start()].strip() | |
| after = text[m_arch.end() :].strip() | |
| return before, after, True | |
| # Preferred insertion point: inject diagram right before "What Each Step Does". | |
| m_steps = re.search(r"(?m)^##\s+What Each Step Does\s*$", text) | |
| if m_steps: | |
| before = text[: m_steps.start()].strip() | |
| after = text[m_steps.start() :].strip() | |
| return before, after, True | |
| return text, "", False | |
| def _build_arch_diagram_html() -> str: | |
| uri = _load_arch_diagram_data_uri() | |
| if not uri: | |
| return "<p><code>(architecture diagram unavailable)</code></p>" | |
| return f""" | |
| <div class="arch-diagram-wrap"> | |
| <h2>Architecture At A Glance</h2> | |
| <img src="{uri}" alt="Architecture diagram" /> | |
| </div> | |
| """ | |
| def _selection_source_rank(origin: str) -> int: | |
| o = _normalize_selection_origin(origin) | |
| if o == "structural": | |
| return 0 | |
| if o == "probe": | |
| return 1 | |
| # Keep rewrite/user in the same priority band as general selection for row ordering. | |
| return 2 | |
| def _build_implied_parent_map( | |
| direct_tags_ordered: List[str], | |
| implied_tags: List[str], | |
| ) -> Dict[str, str]: | |
| implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t} | |
| if not implied_set or not direct_tags_ordered: | |
| return {} | |
| impl = get_tag_implications() | |
| parent_by_implied: Dict[str, str] = {} | |
| for direct in direct_tags_ordered: | |
| d = _norm_tag_for_lookup(direct) | |
| if not d: | |
| continue | |
| queue = [d] | |
| seen = {d} | |
| while queue: | |
| t = queue.pop() | |
| for parent in impl.get(t, ()): | |
| p = _norm_tag_for_lookup(parent) | |
| if not p or p in seen: | |
| continue | |
| seen.add(p) | |
| if p in implied_set and p not in parent_by_implied: | |
| parent_by_implied[p] = d | |
| queue.append(p) | |
| return parent_by_implied | |
| def _order_selected_tags_for_row( | |
| *, | |
| row_selected_tags: List[str], | |
| selected_index: Dict[str, int], | |
| tag_selection_origins: Dict[str, str], | |
| implied_parent_map: Dict[str, str], | |
| ) -> List[str]: | |
| row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t] | |
| implied_in_row = {t for t in row_selected_norm if t in implied_parent_map} | |
| base_tags = [t for t in row_selected_norm if t not in implied_in_row] | |
| base_tags.sort( | |
| key=lambda t: ( | |
| _selection_source_rank(tag_selection_origins.get(t, "selection")), | |
| selected_index.get(t, 10**9), | |
| t, | |
| ) | |
| ) | |
| children_by_parent: Dict[str, List[str]] = {} | |
| for implied in implied_in_row: | |
| parent = implied_parent_map.get(implied) | |
| if parent: | |
| children_by_parent.setdefault(parent, []).append(implied) | |
| for parent, children in children_by_parent.items(): | |
| children.sort(key=lambda t: (selected_index.get(t, 10**9), t)) | |
| ordered: List[str] = [] | |
| emitted: Set[str] = set() | |
| for tag in base_tags: | |
| if tag in emitted: | |
| continue | |
| ordered.append(tag) | |
| emitted.add(tag) | |
| for child in children_by_parent.get(tag, []): | |
| if child not in emitted: | |
| ordered.append(child) | |
| emitted.add(child) | |
| remaining_implied = [t for t in row_selected_norm if t not in emitted] | |
| remaining_implied.sort( | |
| key=lambda t: ( | |
| _selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")), | |
| selected_index.get(implied_parent_map.get(t, ""), 10**9), | |
| selected_index.get(t, 10**9), | |
| t, | |
| ) | |
| ) | |
| for t in remaining_implied: | |
| if t not in emitted: | |
| ordered.append(t) | |
| emitted.add(t) | |
| return ordered | |
| def _escape_prompt_tag(tag: str) -> str: | |
| return ( | |
| tag.replace("_", " ") | |
| .replace("(", "\\(") | |
| .replace(")", "\\)") | |
| ) | |
| def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]: | |
| out: List[str] = [] | |
| seen: Set[str] = set() | |
| for row in row_defs: | |
| for tag in row.get("tags", []): | |
| if tag in selected and tag not in seen: | |
| out.append(tag) | |
| seen.add(tag) | |
| return out | |
| def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str: | |
| selected = {t for t in (selected_tags or []) if t} | |
| ordered = _ordered_selected_for_prompt(selected, row_defs or []) | |
| return ", ".join(_escape_prompt_tag(t) for t in ordered) | |
| def _is_artist_tag(tag: str) -> bool: | |
| t = _norm_tag_for_lookup(str(tag)) | |
| if not t: | |
| return False | |
| # Keep a resilient fallback for malformed/missing tag typing metadata. | |
| return get_tag_type_name(t) == "artist" or t.startswith("by_") | |
| def _load_excluded_recommendation_tags() -> Set[str]: | |
| out: Set[str] = set() | |
| # Existing category-registry driven exclusions. | |
| csv_path = Path("data/category_registry.csv") | |
| if not csv_path.exists(): | |
| csv_path = Path("data/analysis/category_registry.csv") | |
| if csv_path.exists(): | |
| try: | |
| with csv_path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| tag = _norm_tag_for_lookup(str(row.get("tag") or "")) | |
| if not tag: | |
| continue | |
| status = str(row.get("category_status") or "").strip().lower() | |
| if status == "excluded": | |
| out.add(tag) | |
| except Exception: | |
| pass | |
| # Corporate-safety exclusions (editable runtime list). | |
| corp_path = Path("data/corporate_excluded_tags.csv") | |
| if corp_path.exists(): | |
| try: | |
| with corp_path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| tag = _norm_tag_for_lookup(str(row.get("tag") or "")) | |
| if not tag: | |
| continue | |
| enabled_raw = str(row.get("enabled", "1")).strip().lower() | |
| enabled = enabled_raw not in {"0", "false", "no", "off"} | |
| if enabled: | |
| out.add(tag) | |
| except Exception: | |
| pass | |
| return out | |
| def _is_hardblocked_corporate_tag(tag: str) -> bool: | |
| t = _norm_tag_for_lookup(str(tag)) | |
| if not t: | |
| return False | |
| return any(rx.search(t) for rx in _CORPORATE_HARDBLOCK_PATTERNS) | |
| def _is_excluded_recommendation_tag(tag: str) -> bool: | |
| t = _norm_tag_for_lookup(str(tag)) | |
| if not t: | |
| return False | |
| if _is_hardblocked_corporate_tag(t): | |
| return True | |
| return t in _load_excluded_recommendation_tags() | |
| def _get_min_tag_count() -> int: | |
| try: | |
| return max(0, int(os.environ.get("PSQ_MIN_TAG_COUNT", "100"))) | |
| except Exception: | |
| return 100 | |
| def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str], List[str]]: | |
| if min_count <= 0: | |
| return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] | |
| tag_counts = get_tag_counts() | |
| keep: List[str] = [] | |
| removed: List[str] = [] | |
| seen: Set[str] = set() | |
| for raw in (tags or []): | |
| t = _norm_tag_for_lookup(str(raw)) | |
| if not t: | |
| continue | |
| c = int(tag_counts.get(t, 0) or 0) | |
| if c < min_count: | |
| removed.append(t) | |
| continue | |
| if t in seen: | |
| continue | |
| seen.add(t) | |
| keep.append(t) | |
| return keep, sorted(set(removed)) | |
| def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]: | |
| excluded = _load_excluded_recommendation_tags() | |
| if not excluded: | |
| return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] | |
| keep: List[str] = [] | |
| removed: List[str] = [] | |
| seen: Set[str] = set() | |
| for raw in (tags or []): | |
| t = _norm_tag_for_lookup(str(raw)) | |
| if not t: | |
| continue | |
| if t in excluded: | |
| removed.append(t) | |
| continue | |
| if t in seen: | |
| continue | |
| seen.add(t) | |
| keep.append(t) | |
| return keep, sorted(set(removed)) | |
| def _filter_excluded_candidates(candidates: List[Any]) -> Tuple[List[Any], List[str]]: | |
| excluded = _load_excluded_recommendation_tags() | |
| if not excluded: | |
| return list(candidates or []), [] | |
| keep: List[Any] = [] | |
| removed: List[str] = [] | |
| for c in (candidates or []): | |
| tag = _norm_tag_for_lookup(str(getattr(c, "tag", "") or "")) | |
| if tag and tag in excluded: | |
| removed.append(tag) | |
| continue | |
| keep.append(c) | |
| return keep, sorted(set(removed)) | |
| def _dedupe_norm_tags(tags: List[str]) -> List[str]: | |
| out: List[str] = [] | |
| seen: Set[str] = set() | |
| for raw in (tags or []): | |
| t = _norm_tag_for_lookup(str(raw)) | |
| if not t or t in seen: | |
| continue | |
| seen.add(t) | |
| out.append(t) | |
| return out | |
| def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]: | |
| out: Set[str] = set() | |
| for row in (row_defs or []): | |
| for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []): | |
| out.add(t) | |
| return out | |
| def _collect_selected_from_state( | |
| selected_tags_state: List[str], | |
| row_defs: List[Dict[str, Any]], | |
| ) -> List[str]: | |
| visible_tags = _collect_visible_tags(row_defs) | |
| if not visible_tags: | |
| return [] | |
| selected: List[str] = [] | |
| seen: Set[str] = set() | |
| visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags} | |
| for raw in (selected_tags_state or []): | |
| t = _norm_tag_for_lookup(str(raw)) | |
| if not t: | |
| continue | |
| mapped = t if t in visible_tags else visible_by_norm.get(t) | |
| if not mapped or mapped in seen: | |
| continue | |
| seen.add(mapped) | |
| selected.append(mapped) | |
| return selected | |
| def _collect_selected_from_row_values( | |
| row_defs: List[Dict[str, Any]], | |
| row_values_state: List[List[str]], | |
| ) -> List[str]: | |
| selected: List[str] = [] | |
| seen: Set[str] = set() | |
| values = list(row_values_state or []) | |
| for idx, row in enumerate(row_defs or []): | |
| row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []) | |
| if not row_tags: | |
| continue | |
| row_tag_set = set(row_tags) | |
| row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} | |
| raw_vals = values[idx] if 0 <= idx < len(values) else [] | |
| for raw in (raw_vals or []): | |
| if raw in row_tag_set: | |
| if raw not in seen: | |
| seen.add(raw) | |
| selected.append(raw) | |
| continue | |
| raw_norm = _norm_tag_for_lookup(str(raw)) | |
| mapped = row_tag_by_norm.get(raw_norm) | |
| if mapped and mapped not in seen: | |
| seen.add(mapped) | |
| selected.append(mapped) | |
| return selected | |
| def _build_toggle_rows( | |
| *, | |
| seed_terms: List[str], | |
| selected_tags: List[str], | |
| retrieved_candidate_tags: List[str], | |
| tag_selection_origins: Dict[str, str], | |
| implied_parent_map: Dict[str, str], | |
| top_groups: int, | |
| top_tags_per_group: int, | |
| group_rank_top_k: int, | |
| ) -> List[Dict[str, Any]]: | |
| ranked_rows = rank_groups_from_tfidf( | |
| seed_terms=seed_terms, | |
| top_groups=max(1, int(top_groups)), | |
| top_tags_per_group=max(1, int(top_tags_per_group)), | |
| group_rank_top_k=max(1, int(group_rank_top_k)), | |
| ) | |
| groups_map = _load_enabled_groups() | |
| selected_active = list( | |
| dict.fromkeys( | |
| _norm_tag_for_lookup(t) | |
| for t in selected_tags | |
| if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) | |
| ) | |
| ) | |
| selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)} | |
| row_defs: List[Dict[str, Any]] = [] | |
| enabled_group_tag_sets: Dict[str, Set[str]] = { | |
| name: {t for t in tags if not _is_artist_tag(t)} | |
| for name, tags in groups_map.items() | |
| } | |
| tags_in_any_enabled_group: Set[str] = set() | |
| for tag_set in enabled_group_tag_sets.values(): | |
| tags_in_any_enabled_group.update(tag_set) | |
| displayed_group_names = [r.group_name for r in ranked_rows] | |
| displayed_group_tag_sets: Dict[str, Set[str]] = { | |
| name: enabled_group_tag_sets.get(name, set()) | |
| for name in displayed_group_names | |
| } | |
| tags_in_any_displayed_group: Set[str] = set() | |
| for tag_set in displayed_group_tag_sets.values(): | |
| tags_in_any_displayed_group.update(tag_set) | |
| retrieved_uncategorized_ranked = list( | |
| dict.fromkeys( | |
| _norm_tag_for_lookup(t) | |
| for t in (retrieved_candidate_tags or []) | |
| if t | |
| and not _is_artist_tag(t) | |
| and not _is_excluded_recommendation_tag(t) | |
| and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group | |
| ) | |
| ) | |
| retrieved_other_row: Dict[str, Any] | None = None | |
| if retrieved_uncategorized_ranked: | |
| retrieved_uncategorized_set = set(retrieved_uncategorized_ranked) | |
| selected_in_retrieved_other_raw = [ | |
| t for t in selected_active if t in retrieved_uncategorized_set | |
| ] | |
| selected_in_retrieved_other = _order_selected_tags_for_row( | |
| row_selected_tags=selected_in_retrieved_other_raw, | |
| selected_index=selected_index, | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| ) | |
| merged_retrieved_other = selected_in_retrieved_other + [ | |
| t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other | |
| ] | |
| merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other) | |
| keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other)) | |
| merged_retrieved_other = merged_retrieved_other[:keep_n] | |
| retrieved_other_meta = { | |
| t: { | |
| "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), | |
| "preselected": t in selected_active, | |
| } | |
| for t in merged_retrieved_other | |
| } | |
| retrieved_other_row = { | |
| "name": "other_retrieved", | |
| "label": "Other (Retrieved)", | |
| "tags": merged_retrieved_other, | |
| "tag_meta": retrieved_other_meta, | |
| } | |
| # "Selected (Other)" should contain selected tags not already shown in any displayed row. | |
| # Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows. | |
| tags_in_displayed_rows = set(tags_in_any_displayed_group) | |
| if retrieved_other_row: | |
| tags_in_displayed_rows.update(retrieved_other_row.get("tags", [])) | |
| selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows] | |
| selected_other = _order_selected_tags_for_row( | |
| row_selected_tags=selected_other_raw, | |
| selected_index=selected_index, | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| ) | |
| selected_other = _dedupe_norm_tags(selected_other) | |
| selected_other_meta = { | |
| t: { | |
| "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), | |
| "preselected": True, | |
| } | |
| for t in selected_other | |
| } | |
| row_defs.append( | |
| { | |
| "name": "selected_other", | |
| "label": _display_row_label("selected_other"), | |
| "tags": selected_other, | |
| "tag_meta": selected_other_meta, | |
| } | |
| ) | |
| for row in ranked_rows: | |
| group_name = row.group_name | |
| group_tag_set = displayed_group_tag_sets.get(group_name, set()) | |
| selected_in_group_raw = [t for t in selected_active if t in group_tag_set] | |
| selected_in_group = _order_selected_tags_for_row( | |
| row_selected_tags=selected_in_group_raw, | |
| selected_index=selected_index, | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| ) | |
| ranked_tags = [ | |
| _norm_tag_for_lookup(t) | |
| for t, _ in row.tags | |
| if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) | |
| ] | |
| ranked_tags = _dedupe_norm_tags(ranked_tags) | |
| merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group] | |
| merged = _dedupe_norm_tags(merged) | |
| keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group)) | |
| merged = merged[:keep_n] | |
| tag_meta = { | |
| t: { | |
| "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), | |
| "preselected": t in selected_active, | |
| } | |
| for t in merged | |
| } | |
| row_defs.append( | |
| { | |
| "name": group_name, | |
| "label": _display_row_label(group_name), | |
| "tags": merged, | |
| "tag_meta": tag_meta, | |
| } | |
| ) | |
| # Keep this row at the bottom so category/group rows remain contiguous. | |
| if retrieved_other_row: | |
| row_defs.append(retrieved_other_row) | |
| return row_defs | |
| def _build_display_audit_line( | |
| row_defs: List[Dict[str, Any]], | |
| *, | |
| active_selected_tags: List[str], | |
| direct_selected_tags: List[str], | |
| implied_selected_tags: List[str], | |
| ) -> str: | |
| active_set = { | |
| _norm_tag_for_lookup(t) | |
| for t in (active_selected_tags or []) | |
| if t and not _is_artist_tag(t) | |
| } | |
| direct_set = { | |
| _norm_tag_for_lookup(t) | |
| for t in (direct_selected_tags or []) | |
| if t and not _is_artist_tag(t) | |
| } | |
| implied_set = { | |
| _norm_tag_for_lookup(t) | |
| for t in (implied_selected_tags or []) | |
| if t and not _is_artist_tag(t) | |
| } | |
| info_by_tag: Dict[str, Dict[str, Any]] = {} | |
| for row in row_defs or []: | |
| row_name = row.get("name", "") | |
| row_label = row.get("label", row_name) | |
| for tag in row.get("tags", []): | |
| rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()}) | |
| rec["rows"].append(row_label) | |
| if row_name == "selected_other": | |
| rec["sources"].add("selected_other_row") | |
| elif row_name == "other_retrieved": | |
| rec["sources"].add("other_retrieved_row") | |
| else: | |
| rec["sources"].add("ranked_group_row") | |
| if tag in active_set: | |
| rec["sources"].add("selected_active") | |
| if tag in direct_set: | |
| rec["sources"].add("selected_direct") | |
| if tag in implied_set: | |
| rec["sources"].add("selected_implied") | |
| payload = { | |
| "n_tags": len(info_by_tag), | |
| "tags": [ | |
| { | |
| "tag": tag, | |
| "rows": rec["rows"], | |
| "sources": sorted(rec["sources"]), | |
| } | |
| for tag, rec in sorted(info_by_tag.items()) | |
| ], | |
| } | |
| return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True) | |
| def _build_tooltip_payload(row_defs: List[Dict[str, Any]], max_rows: int) -> str: | |
| row_defs_ui = (row_defs or [])[: max(0, int(max_rows))] | |
| tips: Dict[str, str] = {} | |
| rows: List[List[str]] = [] | |
| for row in row_defs_ui: | |
| tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []) | |
| rows.append(tags) | |
| for t in tags: | |
| if t not in tips: | |
| tips[t] = _tooltip_text_for_tag(t) | |
| return json.dumps({"rows": rows, "tips": tips}, ensure_ascii=True) | |
| def _build_row_component_updates( | |
| row_defs: List[Dict[str, Any]], | |
| selected_tags: List[str], | |
| max_rows: int, | |
| ): | |
| selected = {t for t in (selected_tags or []) if t} | |
| row_defs_ui = (row_defs or [])[: max(0, int(max_rows))] | |
| row_values_state: List[List[str]] = [] | |
| header_updates = [] | |
| checkbox_updates = [] | |
| for idx in range(max_rows): | |
| if idx < len(row_defs_ui): | |
| row = row_defs_ui[idx] | |
| tags = _dedupe_norm_tags(row.get("tags", [])) | |
| values = [t for t in tags if t in selected] | |
| row_values_state.append(values) | |
| visible = bool(tags) | |
| header_updates.append(gr.update(value=row.get("label", ""), visible=visible)) | |
| tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {} | |
| choices = [] | |
| for t in tags: | |
| meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {} | |
| origin = _normalize_selection_origin(str(meta.get("origin", "selection"))) | |
| preselected = bool(meta.get("preselected", False)) | |
| choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t)) | |
| checkbox_updates.append( | |
| gr.update( | |
| choices=choices, | |
| value=values, | |
| visible=visible, | |
| ) | |
| ) | |
| else: | |
| header_updates.append(gr.update(value="", visible=False)) | |
| checkbox_updates.append(gr.update(choices=[], value=[], visible=False)) | |
| prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui) | |
| return prompt_text, row_values_state, header_updates, checkbox_updates | |
| def _on_toggle_row( | |
| row_idx: int, | |
| changed_values: List[str], | |
| selected_tags_state: List[str], | |
| rows_dirty_state: bool, | |
| row_defs_state: List[Dict[str, Any]], | |
| row_values_state: List[List[str]], | |
| max_rows: int, | |
| ): | |
| row_defs = row_defs_state or [] | |
| row_defs_ui = row_defs[: max(0, int(max_rows))] | |
| prev_values = list(row_values_state or []) | |
| selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui) | |
| selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values) | |
| # Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback. | |
| selected: Set[str] = set(selected_from_rows or selected_from_state) | |
| row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {} | |
| row_tags = _dedupe_norm_tags(row.get("tags", [])) | |
| row_label = str(row.get("label", "")) | |
| row_tag_set = set(row_tags) | |
| row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} | |
| # Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants, | |
| # and occasional single-string payloads from frontend events. | |
| if changed_values is None: | |
| changed_iter: List[Any] = [] | |
| elif isinstance(changed_values, str): | |
| changed_iter = [changed_values] | |
| elif isinstance(changed_values, (list, tuple, set)): | |
| changed_iter = list(changed_values) | |
| else: | |
| changed_iter = [changed_values] | |
| # Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants. | |
| new_set: Set[str] = set() | |
| for raw in changed_iter: | |
| if raw in row_tag_set: | |
| new_set.add(raw) | |
| continue | |
| raw_norm = _norm_tag_for_lookup(str(raw)) | |
| mapped = row_tag_by_norm.get(raw_norm) | |
| if mapped: | |
| new_set.add(mapped) | |
| prev_row_selected = {t for t in row_tags if t in selected} | |
| # Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically. | |
| if new_set == prev_row_selected: | |
| prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) | |
| checkbox_updates = [gr.skip() for _ in range(max_rows)] | |
| return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates] | |
| selected.difference_update(row_tag_set) | |
| selected.update(new_set) | |
| toggled_tags = prev_row_selected ^ new_set | |
| new_row_values_state: List[List[str]] = [] | |
| affected_rows: Set[int] = {row_idx} | |
| for idx, row_item in enumerate(row_defs_ui): | |
| tags = _dedupe_norm_tags(row_item.get("tags", [])) | |
| values = [t for t in tags if t in selected] | |
| new_row_values_state.append(values) | |
| if toggled_tags and any(t in toggled_tags for t in tags): | |
| affected_rows.add(idx) | |
| checkbox_updates = [] | |
| for idx in range(max_rows): | |
| if idx >= len(row_defs_ui): | |
| checkbox_updates.append(gr.skip()) | |
| continue | |
| if idx in affected_rows: | |
| checkbox_updates.append(gr.update(value=new_row_values_state[idx])) | |
| else: | |
| checkbox_updates.append(gr.skip()) | |
| prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) | |
| return [ | |
| sorted(selected), | |
| True, | |
| gr.update(visible=True, interactive=True), | |
| new_row_values_state, | |
| prompt_text, | |
| *checkbox_updates, | |
| ] | |
| def _build_ui_payload( | |
| *, | |
| console_text: str, | |
| row_defs: List[Dict[str, Any]], | |
| selected_tags: List[str], | |
| suggested_prompt_text: str | None = None, | |
| ): | |
| prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( | |
| row_defs=row_defs, | |
| selected_tags=selected_tags, | |
| max_rows=display_max_rows_default, | |
| ) | |
| if suggested_prompt_text is not None: | |
| prompt_text = str(suggested_prompt_text) | |
| selected_ui: List[str] = [] | |
| selected_ui_seen: Set[str] = set() | |
| for vals in row_values_state: | |
| for t in vals: | |
| if t in selected_ui_seen: | |
| continue | |
| selected_ui_seen.add(t) | |
| selected_ui.append(t) | |
| tooltip_payload = _build_tooltip_payload(row_defs, display_max_rows_default) | |
| return [ | |
| console_text, | |
| gr.update(visible=bool(row_defs)), | |
| tooltip_payload, | |
| prompt_text, | |
| selected_ui, | |
| False, | |
| gr.update(visible=False, interactive=False), | |
| row_defs, | |
| row_values_state, | |
| *header_updates, | |
| *checkbox_updates, | |
| ] | |
| def _format_user_facing_error(exc: Exception) -> str: | |
| msg = str(exc or "").strip() | |
| msg_l = msg.lower() | |
| if "rewrite: empty output" in msg_l: | |
| return ( | |
| "Could not rewrite that prompt. Try simpler, neutral wording and remove sensitive phrasing, " | |
| "then click Run again." | |
| ) | |
| if "openrouter_api_key" in msg_l: | |
| return "Service configuration is missing. Please contact the app owner." | |
| if "timed out" in msg_l: | |
| return "The model request timed out. Please try again with a shorter or simpler prompt." | |
| if "index selection failed" in msg_l: | |
| return "Tag selection failed for this request. Please try again." | |
| if "startup preflight failed" in msg_l: | |
| return "App startup checks failed. Please contact the app owner." | |
| return "Something went wrong while processing the prompt. Please try again." | |
| def _prepare_run_ui() -> List[Any]: | |
| header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)] | |
| checkbox_updates = [ | |
| gr.update(choices=[], value=[], visible=False) | |
| for _ in range(display_max_rows_default) | |
| ] | |
| return [ | |
| "Running...", | |
| gr.skip(), | |
| "{}", | |
| "Running... usually completes in about 20 seconds.", | |
| [], | |
| False, | |
| gr.update(visible=False, interactive=False), | |
| [], | |
| [], | |
| *header_updates, | |
| *checkbox_updates, | |
| ] | |
| def _update_run_button_visibility(prompt_text: str, last_run_prompt: str): | |
| curr = (prompt_text or "").strip() | |
| last = (last_run_prompt or "").strip() | |
| can_run = bool(curr) and curr != last | |
| return gr.update(visible=can_run, interactive=can_run) | |
| def _mark_run_triggered(prompt_text: str): | |
| curr = (prompt_text or "").strip() | |
| return gr.update(visible=False, interactive=False), curr | |
| def _rebuild_rows_from_selected( | |
| selected_tags_state: List[str], | |
| row_defs_state: List[Dict[str, Any]], | |
| row_values_state: List[List[str]], | |
| display_top_groups: float, | |
| display_top_tags_per_group: float, | |
| display_rank_top_k: float, | |
| ): | |
| existing_rows = row_defs_state or [] | |
| existing_values = list(row_values_state or []) | |
| selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows) | |
| selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values) | |
| # Rebuild source-of-truth is current row checkbox values; fall back only when unavailable. | |
| selected_seed = selected_from_rows if existing_values else selected_from_state | |
| selected_active = list( | |
| dict.fromkeys( | |
| _norm_tag_for_lookup(t) | |
| for t in selected_seed | |
| if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) | |
| ) | |
| ) | |
| retrieved_candidate_tags: List[str] = [] | |
| tag_selection_origins: Dict[str, str] = {} | |
| for row in existing_rows: | |
| row_tags = row.get("tags", []) if isinstance(row, dict) else [] | |
| row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {} | |
| if not isinstance(row_meta, dict): | |
| row_meta = {} | |
| for t in row_tags: | |
| tn = _norm_tag_for_lookup(t) | |
| if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn): | |
| continue | |
| retrieved_candidate_tags.append(tn) | |
| if tn not in tag_selection_origins: | |
| meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {} | |
| tag_selection_origins[tn] = _normalize_selection_origin(str(meta.get("origin", "selection"))) | |
| for t in selected_active: | |
| tag_selection_origins.setdefault(t, "user") | |
| retrieved_candidate_tags.append(t) | |
| implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"] | |
| implied_set = set(implied_selected_tags) | |
| direct_selected_tags = [t for t in selected_active if t not in implied_set] | |
| direct_idx = {t: i for i, t in enumerate(direct_selected_tags)} | |
| direct_selected_tags.sort( | |
| key=lambda t: ( | |
| _selection_source_rank(tag_selection_origins.get(t, "selection")), | |
| direct_idx.get(t, 10**9), | |
| ) | |
| ) | |
| implied_parent_map = _build_implied_parent_map( | |
| direct_tags_ordered=direct_selected_tags, | |
| implied_tags=implied_selected_tags, | |
| ) | |
| toggle_rows = _build_toggle_rows( | |
| seed_terms=list(selected_active), | |
| selected_tags=selected_active, | |
| retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)), | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| top_groups=max(1, int(display_top_groups)), | |
| top_tags_per_group=max(1, int(display_top_tags_per_group)), | |
| group_rank_top_k=max(1, int(display_rank_top_k)), | |
| ) | |
| prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( | |
| row_defs=toggle_rows, | |
| selected_tags=selected_active, | |
| max_rows=display_max_rows_default, | |
| ) | |
| tooltip_payload = _build_tooltip_payload(toggle_rows, display_max_rows_default) | |
| return [ | |
| gr.update(visible=bool(toggle_rows)), | |
| tooltip_payload, | |
| prompt_text, | |
| sorted(selected_active), | |
| False, | |
| gr.update(visible=False, interactive=False), | |
| toggle_rows, | |
| row_values_state, | |
| *header_updates, | |
| *checkbox_updates, | |
| ] | |
| def _build_selection_query( | |
| prompt_in: str, | |
| rewritten: str, | |
| structural_tags: List[str], | |
| probe_tags: List[str], | |
| ) -> str: | |
| lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"] | |
| if rewritten and rewritten.strip(): | |
| lines.append(f"REWRITE PHRASES: {rewritten.strip()}") | |
| hint_tags = [] | |
| if structural_tags: | |
| hint_tags.extend(structural_tags) | |
| if probe_tags: | |
| hint_tags.extend(probe_tags) | |
| if hint_tags: | |
| # Keep hints as context only; selection still must choose by candidate indices. | |
| lines.append( | |
| "INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags))) | |
| ) | |
| return "\n".join(lines) | |
| # Set up logging | |
| # Minimal prod logging: warnings+ to stderr, no file by default | |
| import os, logging | |
| LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper() | |
| logging.basicConfig( | |
| level=getattr(logging, LOG_LEVEL, logging.WARNING), | |
| format="%(asctime)s %(levelname)s:%(message)s", | |
| handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces | |
| ) | |
| # Quiet down common noisy libs (optional) | |
| for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"): | |
| logging.getLogger(_name).setLevel(logging.ERROR) | |
| # Turn off Gradio analytics phone-home to avoid those background thread errors (optional) | |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "0" | |
| MASCOT_DIR = Path(__file__).parent / "mascotimages" | |
| MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png" | |
| def _load_mascot_image(): | |
| """Load mascot image if available; return None when missing/unreadable.""" | |
| if not MASCOT_FILE.exists(): | |
| logging.warning("Mascot image missing: %s", MASCOT_FILE) | |
| return None | |
| try: | |
| return Image.open(MASCOT_FILE).convert("RGBA") | |
| except Exception as e: | |
| logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e) | |
| return None | |
| try: | |
| from gradio_client import utils as _gc_utils | |
| _orig_get_type = _gc_utils.get_type | |
| _orig_j2p = _gc_utils._json_schema_to_python_type | |
| _orig_pub = _gc_utils.json_schema_to_python_type | |
| def _get_type_safe(schema): | |
| # Sometimes schema is a bare True/False (JSON Schema boolean form) | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _orig_get_type(schema) | |
| def _j2p_safe(schema, defs=None): | |
| # Accept non-dict schemas (True/False/None) and treat as "any" | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _orig_j2p(schema, defs or schema.get("$defs")) | |
| def _pub_safe(schema): | |
| # Public wrapper used by Gradio; keep it resilient too | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _j2p_safe(schema, schema.get("$defs")) | |
| _gc_utils.get_type = _get_type_safe | |
| _gc_utils._json_schema_to_python_type = _j2p_safe | |
| _gc_utils.json_schema_to_python_type = _pub_safe | |
| except Exception as e: | |
| print("gradio_client hotfix not applied:", e) | |
| # ------------------------------------------------------------------------------- | |
| allow_nsfw_tags = False | |
| def _is_production_runtime() -> bool: | |
| """Best-effort detection for deployed runtime (HF Spaces or explicit env).""" | |
| if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}: | |
| return True | |
| if os.environ.get("SPACE_ID"): | |
| return True | |
| if os.environ.get("HF_SPACE_ID"): | |
| return True | |
| if os.environ.get("SYSTEM") == "spaces": | |
| return True | |
| return False | |
| verbose_retrieval_default = "0" if _is_production_runtime() else "1" | |
| verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"} | |
| verbose_retrieval_all = False | |
| verbose_retrieval_limit = 20 | |
| enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"} | |
| display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10")) | |
| display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7")) | |
| display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7")) | |
| display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14")) | |
| retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300")) | |
| retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10")) | |
| retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1")) | |
| selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip() | |
| selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60")) | |
| selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2")) | |
| selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0")) | |
| stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45")) | |
| stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45")) | |
| stage1_probe_timeout_s = float(os.environ.get("PSQ_TIMEOUT_PROBE_S", "45")) | |
| stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "50")) | |
| stage3_select_retry_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_RETRY_S", "20")) | |
| stage3_fast_retry_count = max(0, int(os.environ.get("PSQ_STAGE3_FAST_RETRY_COUNT", "1"))) | |
| timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl")) | |
| def _startup_preflight_errors() -> List[str]: | |
| errs: List[str] = [] | |
| if not os.getenv("OPENROUTER_API_KEY"): | |
| errs.append("OPENROUTER_API_KEY is missing. Set it in Space Secrets or environment variables.") | |
| return errs | |
| STARTUP_PREFLIGHT_ERRORS = _startup_preflight_errors() | |
| if STARTUP_PREFLIGHT_ERRORS: | |
| for _err in STARTUP_PREFLIGHT_ERRORS: | |
| logging.error("Startup preflight error: %s", _err) | |
| css = """ | |
| .scrollable-content{ | |
| max-height: 420px; | |
| overflow-y: scroll; /* always show scrollbar */ | |
| overflow-x: hidden; | |
| padding-right: 8px; | |
| padding-bottom: 14px; /* <— add this */ | |
| scrollbar-gutter: stable; /* prevent layout shift as it fills */ | |
| /* Firefox */ | |
| scrollbar-width: auto; | |
| scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15); | |
| } | |
| /* WebKit/Chromium (Chrome/Edge/Safari) */ | |
| .scrollable-content::-webkit-scrollbar{ width: 10px; } | |
| .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; } | |
| .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); } | |
| /* (Optional) make both scroll panes taller so they fill more of the column */ | |
| .pane-left .scrollable-content, | |
| .pane-right .scrollable-content { | |
| max-height: 610px; /* was 420px; tweak to taste */ | |
| } | |
| .lego-tags .gr-checkboxgroup, | |
| .lego-tags .wrap { | |
| display: flex !important; | |
| flex-wrap: wrap !important; | |
| gap: 10px !important; | |
| } | |
| .lego-tags label { | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| position: relative !important; | |
| } | |
| /* Hide native checkbox visuals completely */ | |
| .lego-tags input[type="checkbox"] { | |
| appearance: none !important; | |
| -webkit-appearance: none !important; | |
| -moz-appearance: none !important; | |
| position: absolute !important; | |
| width: 1px !important; | |
| height: 1px !important; | |
| opacity: 0 !important; | |
| pointer-events: none !important; | |
| display: none !important; | |
| } | |
| /* Brick button skin (works for both +span and ~span structures) */ | |
| .lego-tags input[type="checkbox"] + span, | |
| .lego-tags input[type="checkbox"] ~ span { | |
| --on-bg1: #ffd166; | |
| --on-bg2: #f39c4a; | |
| --on-border: #b86e21; | |
| --on-text: #2e1706; | |
| position: relative !important; | |
| display: inline-flex !important; | |
| align-items: center !important; | |
| min-height: 40px !important; | |
| padding: 10px 15px 9px 22px !important; | |
| border: 1px solid #9aa6b8 !important; | |
| border-radius: 10px !important; | |
| background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important; | |
| color: #364254 !important; | |
| font-size: 0.97rem !important; | |
| font-weight: 800 !important; | |
| line-height: 1.15 !important; | |
| cursor: pointer !important; | |
| user-select: none !important; | |
| letter-spacing: 0.01em !important; | |
| box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important; | |
| transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important; | |
| } | |
| .lego-tags input[type="checkbox"] + span::before, | |
| .lego-tags input[type="checkbox"] ~ span::before { | |
| content: "" !important; | |
| position: absolute !important; | |
| top: 5px !important; | |
| left: 8px !important; | |
| width: 8px !important; | |
| height: 8px !important; | |
| border-radius: 50% !important; | |
| background: rgba(255,255,255,0.58) !important; | |
| box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important; | |
| pointer-events: none !important; | |
| } | |
| /* Unselected cue: show "+" on the left. */ | |
| .lego-tags input[type="checkbox"] + span::after, | |
| .lego-tags input[type="checkbox"] ~ span::after { | |
| content: "+" !important; | |
| position: absolute !important; | |
| left: 6px !important; | |
| top: 50% !important; | |
| transform: translateY(-52%) !important; | |
| font-size: 1rem !important; | |
| font-weight: 900 !important; | |
| color: #4b5563 !important; | |
| opacity: 0.95 !important; | |
| pointer-events: none !important; | |
| } | |
| /* Bright color cycle used only when selected */ | |
| .lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; } | |
| .lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; } | |
| .lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; } | |
| .lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; } | |
| .lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; } | |
| .lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; } | |
| .lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; } | |
| .lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; } | |
| /* Source-driven selected colors (applies when tags are preselected by the pipeline). */ | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span { | |
| --on-bg1: #77f0d7; | |
| --on-bg2: #26b9a3; | |
| --on-border: #187869; | |
| --on-text: #062923; | |
| } | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span { | |
| --on-bg1: #ffd98a; | |
| --on-bg2: #f0a93c; | |
| --on-border: #a66f1f; | |
| --on-text: #382206; | |
| } | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span { | |
| --on-bg1: #d8b4ff; | |
| --on-bg2: #9a6cff; | |
| --on-border: #6745b0; | |
| --on-text: #24143b; | |
| } | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span { | |
| --on-bg1: #a6f79a; | |
| --on-bg2: #53c368; | |
| --on-border: #2f8442; | |
| --on-text: #102d17; | |
| } | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span { | |
| --on-bg1: #d7dde8; | |
| --on-bg2: #a8b3c4; | |
| --on-border: #6f7e95; | |
| --on-text: #1d2633; | |
| } | |
| /* User-selected tags (not initially selected by the pipeline). */ | |
| .lego-tags label[data-psq-preselected="0"] span { | |
| --on-bg1: #9ec5ff; | |
| --on-bg2: #4f86ff; | |
| --on-border: #2f5fbf; | |
| --on-text: #0b1f42; | |
| } | |
| .lego-tags label:hover span { | |
| filter: brightness(1.02) !important; | |
| transform: translateY(1px) !important; | |
| } | |
| /* ON state: brighter + visibly recessed */ | |
| .lego-tags input[type="checkbox"]:checked + span, | |
| .lego-tags input[type="checkbox"]:checked ~ span, | |
| .lego-tags label:has(input[type="checkbox"]:checked) span { | |
| background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important; | |
| color: var(--on-text) !important; | |
| border-color: var(--on-border) !important; | |
| filter: saturate(1.2) brightness(1.12) !important; | |
| transform: translateY(-2px) !important; | |
| box-shadow: | |
| inset 0 3px 6px rgba(0,0,0,0.20), | |
| inset 0 -1px 0 rgba(255,255,255,0.36), | |
| 0 6px 0 rgba(0,0,0,0.32) !important; | |
| } | |
| .lego-tags input[type="checkbox"]:checked + span::after, | |
| .lego-tags input[type="checkbox"]:checked ~ span::after, | |
| .lego-tags label:has(input[type="checkbox"]:checked) span::after { | |
| content: "" !important; | |
| } | |
| .source-legend { | |
| display: flex; | |
| flex-wrap: wrap; | |
| align-items: center; | |
| gap: 8px; | |
| margin: 4px 0 10px 0; | |
| } | |
| .source-legend .legend-title { | |
| font-size: 0.92rem; | |
| font-weight: 900; | |
| color: #334155; | |
| margin-right: 4px; | |
| } | |
| .source-legend .chip { | |
| display: inline-flex; | |
| align-items: center; | |
| border-radius: 10px; | |
| border: 1px solid #6c7788; | |
| padding: 6px 12px; | |
| font-size: 0.85rem; | |
| font-weight: 800; | |
| color: #111827; | |
| background: #f3f6fb; | |
| } | |
| .source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; } | |
| .source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; } | |
| .source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; } | |
| .source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; } | |
| .source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; } | |
| .source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; } | |
| .source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; } | |
| .row-heading p { | |
| margin: 8px 0 0 0 !important; | |
| font-size: 1.18rem !important; | |
| font-weight: 850 !important; | |
| line-height: 1.2 !important; | |
| } | |
| .row-instruction { | |
| text-align: center; | |
| margin: 8px 0 12px 0; | |
| } | |
| .row-instruction p { | |
| margin: 0 !important; | |
| font-size: 1.02rem !important; | |
| font-style: italic !important; | |
| font-weight: 800 !important; | |
| color: #1d4ed8 !important; | |
| } | |
| .about-docs { | |
| margin-top: 4px; | |
| } | |
| .about-docs > p { | |
| line-height: 1.42 !important; | |
| } | |
| .about-docs img { | |
| max-width: 100% !important; | |
| height: auto !important; | |
| border: 1px solid #d2d7e0; | |
| border-radius: 10px; | |
| background: #ffffff; | |
| } | |
| .arch-diagram-wrap { | |
| margin: 6px 0 10px 0; | |
| } | |
| .arch-diagram-wrap h2 { | |
| margin: 0 0 8px 0 !important; | |
| } | |
| .top-instruction { | |
| text-align: center; | |
| margin: 2px 0 6px 0; | |
| } | |
| .top-instruction p { | |
| margin: 0 !important; | |
| font-size: 1.02rem !important; | |
| font-style: italic !important; | |
| font-weight: 800 !important; | |
| color: #1d4ed8 !important; | |
| } | |
| .run-hint { | |
| margin-top: 6px; | |
| text-align: center; | |
| } | |
| .run-hint p { | |
| margin: 0 !important; | |
| font-size: 0.9rem !important; | |
| font-style: italic !important; | |
| color: #475569 !important; | |
| } | |
| .prompt-card { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| padding: 0 !important; | |
| } | |
| .suggested-prompt-box { | |
| margin-top: 2px !important; | |
| } | |
| .suggested-prompt-card { | |
| margin-top: 10px !important; | |
| } | |
| .psq-hidden { | |
| display: none !important; | |
| } | |
| """ | |
| client_js = """ | |
| () => { | |
| const readTooltipMap = () => { | |
| const el = document.querySelector("#psq-tooltip-map textarea, #psq-tooltip-map input"); | |
| if (!el) return { rows: [], tips: {} }; | |
| const raw = (el.value || "").trim(); | |
| if (!raw) return { rows: [], tips: {} }; | |
| try { | |
| const obj = JSON.parse(raw); | |
| if (!obj || typeof obj !== "object") return { rows: [], tips: {} }; | |
| const rows = Array.isArray(obj.rows) ? obj.rows : []; | |
| const tips = (obj.tips && typeof obj.tips === "object") ? obj.tips : {}; | |
| return { rows, tips }; | |
| } catch (_) { | |
| return { rows: [], tips: {} }; | |
| } | |
| }; | |
| const applyTooltips = () => { | |
| const payload = readTooltipMap(); | |
| const rowTags = Array.isArray(payload.rows) ? payload.rows : []; | |
| const tipMap = (payload.tips && typeof payload.tips === "object") ? payload.tips : {}; | |
| const rowEls = document.querySelectorAll(".lego-tags"); | |
| rowEls.forEach((rowEl, rowIdx) => { | |
| const tags = Array.isArray(rowTags[rowIdx]) ? rowTags[rowIdx] : []; | |
| const labels = rowEl.querySelectorAll("label"); | |
| labels.forEach((label, tagIdx) => { | |
| const span = label.querySelector("span"); | |
| const tag = (tagIdx < tags.length) ? tags[tagIdx] : ""; | |
| const tip = tag && Object.prototype.hasOwnProperty.call(tipMap, tag) ? (tipMap[tag] || "") : ""; | |
| if (tip) { | |
| label.title = tip; | |
| if (span) span.title = tip; | |
| } else { | |
| label.removeAttribute("title"); | |
| if (span) span.removeAttribute("title"); | |
| } | |
| }); | |
| }); | |
| }; | |
| let scheduled = false; | |
| const scheduleApply = () => { | |
| if (scheduled) return; | |
| scheduled = true; | |
| requestAnimationFrame(() => { | |
| scheduled = false; | |
| applyTooltips(); | |
| }); | |
| }; | |
| scheduleApply(); | |
| const observer = new MutationObserver(() => scheduleApply()); | |
| observer.observe(document.body, { childList: true, subtree: true }); | |
| } | |
| """ | |
| def rag_pipeline_ui( | |
| user_prompt: str, | |
| display_top_groups: float, | |
| display_top_tags_per_group: float, | |
| display_rank_top_k: float, | |
| ): | |
| logs = [] | |
| def log(s): logs.append(s) | |
| try: | |
| stage_timings = {} | |
| def _record_timing(stage: str, dt_s: float): | |
| stage_timings[stage] = float(dt_s) | |
| def _emit_timing_summary(total_s: float): | |
| summary_order = [ | |
| "preprocess", | |
| "rewrite", | |
| "structural", | |
| "probe", | |
| "retrieval", | |
| "selection", | |
| "implication_expansion", | |
| "prompt_composition", | |
| "group_display", | |
| ] | |
| lines = [] | |
| for k in summary_order: | |
| if k in stage_timings: | |
| lines.append(f"{k}={stage_timings[k]:.2f}s") | |
| slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a" | |
| log("Timing Summary: " + ", ".join(lines)) | |
| log(f"Timing Slowest Stage: {slowest}") | |
| log(f"Timing Total: {total_s:.2f}s") | |
| def _append_timing_jsonl(total_s: float): | |
| try: | |
| timing_log_path.parent.mkdir(parents=True, exist_ok=True) | |
| rec = { | |
| "timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z", | |
| "stages_s": stage_timings, | |
| "total_s": float(total_s), | |
| "config": { | |
| "timeout_rewrite_s": stage1_rewrite_timeout_s, | |
| "timeout_struct_s": stage1_struct_timeout_s, | |
| "timeout_probe_s": stage1_probe_timeout_s, | |
| "timeout_select_s": stage3_select_timeout_s, | |
| }, | |
| } | |
| with timing_log_path.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(rec, ensure_ascii=True) + "\n") | |
| log(f"Timing Log: wrote {timing_log_path}") | |
| except Exception as e: | |
| log(f"Timing Log: failed ({type(e).__name__}: {e})") | |
| def _future_with_timeout( | |
| fut, | |
| timeout_s: float, | |
| stage_name: str, | |
| fallback, | |
| *, | |
| strict: bool = False, | |
| ): | |
| t0 = time.perf_counter() | |
| try: | |
| out = fut.result(timeout=max(1.0, float(timeout_s))) | |
| dt = time.perf_counter() - t0 | |
| log(f"{stage_name}: {dt:.2f}s") | |
| stage_key = { | |
| "Rewrite": "rewrite", | |
| "Structural inference": "structural", | |
| "Probe inference": "probe", | |
| "Index selection": "selection", | |
| }.get(stage_name) | |
| if stage_key: | |
| _record_timing(stage_key, dt) | |
| return out | |
| except FutureTimeoutError: | |
| fut.cancel() | |
| msg = f"{stage_name}: timed out after {timeout_s:.0f}s" | |
| if strict: | |
| raise RuntimeError(msg) | |
| log(f"{msg}; using fallback") | |
| return fallback | |
| except Exception as e: | |
| msg = f"{stage_name}: failed ({type(e).__name__}: {e})" | |
| if strict: | |
| raise RuntimeError(msg) | |
| log(f"{msg}; using fallback") | |
| return fallback | |
| t_total0 = time.perf_counter() | |
| log("Start: received prompt") | |
| if STARTUP_PREFLIGHT_ERRORS: | |
| log("Startup preflight failed:") | |
| for e in STARTUP_PREFLIGHT_ERRORS: | |
| log(f"- {e}") | |
| return _build_ui_payload( | |
| console_text="\n".join(logs), | |
| row_defs=[], | |
| selected_tags=[], | |
| suggested_prompt_text="Error: startup preflight failed. Check console details.", | |
| ) | |
| prompt_in = (user_prompt or "").strip() | |
| if not prompt_in: | |
| return _build_ui_payload( | |
| console_text="Error: empty prompt", | |
| row_defs=[], | |
| selected_tags=[], | |
| suggested_prompt_text='Enter a prompt and click "Run".', | |
| ) | |
| log("Input:") | |
| log(prompt_in) | |
| log("") | |
| log( | |
| "Runtime config: " | |
| f"retrieval_global_k={retrieval_global_k} " | |
| f"retrieval_per_phrase_k={retrieval_per_phrase_k} " | |
| f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} " | |
| f"selection_mode={selection_mode} " | |
| f"selection_chunk_size={selection_chunk_size} " | |
| f"selection_per_phrase_k={selection_per_phrase_k} " | |
| f"min_tag_count={_get_min_tag_count()} " | |
| f"select_timeout_s={stage3_select_timeout_s:.0f} " | |
| f"select_retry_timeout_s={stage3_select_retry_timeout_s:.0f} " | |
| f"select_fast_retries={stage3_fast_retry_count}" | |
| ) | |
| log("") | |
| t0 = time.perf_counter() | |
| min_tag_count = _get_min_tag_count() | |
| user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in) | |
| user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count) | |
| user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags) | |
| dt = time.perf_counter()-t0 | |
| _record_timing("preprocess", dt) | |
| log(f"Preprocess (user tag extraction): {dt:.2f}s") | |
| log("Heuristically extracted user tags:") | |
| if user_tags: | |
| log(", ".join(user_tags)) | |
| else: | |
| log("(none)") | |
| if removed_user_low: | |
| log( | |
| f"Filtered {len(removed_user_low)} low-frequency user tags " | |
| f"(<{min_tag_count}): {', '.join(removed_user_low)}" | |
| ) | |
| if removed_user_excluded: | |
| log( | |
| f"Filtered {len(removed_user_excluded)} excluded user tags: " | |
| f"{', '.join(removed_user_excluded)}" | |
| ) | |
| log("") | |
| log("Step 1: LLM rewrite + structural inference + probe (concurrent)") | |
| max_workers = 3 if enable_probe_tags else 2 | |
| ex = ThreadPoolExecutor(max_workers=max_workers) | |
| try: | |
| fut_rewrite = ex.submit(llm_rewrite_prompt, prompt_in, log) | |
| fut_struct = ex.submit(llm_infer_structural_tags, prompt_in, log=log) | |
| fut_probe = ex.submit(llm_infer_probe_tags, prompt_in, log=log) if enable_probe_tags else None | |
| rewritten = _future_with_timeout( | |
| fut_rewrite, | |
| stage1_rewrite_timeout_s, | |
| "Rewrite", | |
| "", | |
| strict=True, | |
| ) | |
| structural_tags = _future_with_timeout( | |
| fut_struct, stage1_struct_timeout_s, "Structural inference", [] | |
| ) | |
| probe_tags = ( | |
| _future_with_timeout(fut_probe, stage1_probe_timeout_s, "Probe inference", []) | |
| if fut_probe else [] | |
| ) | |
| finally: | |
| ex.shutdown(wait=False, cancel_futures=True) | |
| structural_tags, removed_struct_low = _filter_min_count_tags(structural_tags, min_tag_count) | |
| probe_tags, removed_probe_low = _filter_min_count_tags(probe_tags, min_tag_count) | |
| structural_tags, removed_struct_excluded = _filter_excluded_recommendation_tags(structural_tags) | |
| probe_tags, removed_probe_excluded = _filter_excluded_recommendation_tags(probe_tags) | |
| if removed_struct_low: | |
| log( | |
| f"Filtered {len(removed_struct_low)} low-frequency structural tags " | |
| f"(<{min_tag_count}): {', '.join(removed_struct_low)}" | |
| ) | |
| if removed_probe_low: | |
| log( | |
| f"Filtered {len(removed_probe_low)} low-frequency probe tags " | |
| f"(<{min_tag_count}): {', '.join(removed_probe_low)}" | |
| ) | |
| if removed_struct_excluded: | |
| log( | |
| f"Filtered {len(removed_struct_excluded)} excluded structural tags: " | |
| f"{', '.join(removed_struct_excluded)}" | |
| ) | |
| if removed_probe_excluded: | |
| log( | |
| f"Filtered {len(removed_probe_excluded)} excluded probe tags: " | |
| f"{', '.join(removed_probe_excluded)}" | |
| ) | |
| if not rewritten: | |
| raise RuntimeError("Rewrite: empty output") | |
| log("Rewrite:") | |
| log(rewritten if rewritten else "(empty)") | |
| log("") | |
| rewrite_for_retrieval = rewritten | |
| if user_tags: | |
| # keep them separate in logs, but allow them to help retrieval | |
| rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip() | |
| log("Step 2: Prompt Squirrel retrieval (hidden)") | |
| try: | |
| t0 = time.perf_counter() | |
| retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or []))) | |
| rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()] | |
| retrieval_result = psq_candidates_from_rewrite_phrases( | |
| rewrite_phrases=rewrite_phrases, | |
| allow_nsfw_tags=allow_nsfw_tags, | |
| context_tags=retrieval_context_tags, | |
| global_k=max(1, retrieval_global_k), | |
| per_phrase_k=max(1, retrieval_per_phrase_k), | |
| per_phrase_final_k=max(1, retrieval_per_phrase_final_k), | |
| min_tag_count=max(0, min_tag_count), | |
| verbose=verbose_retrieval, | |
| ) | |
| if isinstance(retrieval_result, tuple): | |
| candidates, phrase_reports = retrieval_result | |
| else: | |
| candidates, phrase_reports = retrieval_result, [] | |
| candidates, removed_candidate_excluded = _filter_excluded_candidates(candidates) | |
| if removed_candidate_excluded: | |
| log( | |
| f"Filtered {len(removed_candidate_excluded)} excluded retrieved tags: " | |
| f"{', '.join(removed_candidate_excluded[:25])}" | |
| + (" ..." if len(removed_candidate_excluded) > 25 else "") | |
| ) | |
| if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap: | |
| candidates = candidates[:selection_candidate_cap] | |
| log(f"Selection candidate cap applied: {selection_candidate_cap}") | |
| dt = time.perf_counter()-t0 | |
| _record_timing("retrieval", dt) | |
| log(f"Retrieval: {dt:.2f}s") | |
| log(f"Retrieved {len(candidates)} candidate tags") | |
| if verbose_retrieval: | |
| log(f"Total unique candidates: {len(candidates)}") | |
| limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit)) | |
| for report in phrase_reports: | |
| phrase = report.get("normalized") or report.get("phrase") or "" | |
| lookup = report.get("lookup") or "" | |
| tfidf_vocab = report.get("tfidf_vocab") | |
| log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}") | |
| rows = report.get("candidates", []) | |
| shown = rows if limit is None else rows[:limit] | |
| for row in shown: | |
| tag = row.get("tag") | |
| alias_token = row.get("alias_token") | |
| score_fasttext = row.get("score_fasttext") | |
| score_context = row.get("score_context") | |
| score_combined = row.get("score_combined") | |
| count = row.get("count") | |
| alias_part = "" | |
| if alias_token and alias_token != tag: | |
| alias_part = f" [alias_token={alias_token}]" | |
| fasttext_str = ( | |
| f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext | |
| ) | |
| if score_context is None: | |
| context_str = "None" | |
| else: | |
| context_str = ( | |
| f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context | |
| ) | |
| combined_str = ( | |
| f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined | |
| ) | |
| log( | |
| f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} " | |
| f"combined={combined_str} count={count}" | |
| ) | |
| if limit is not None and len(rows) > limit: | |
| log(f" ... ({len(rows) - limit} more)") | |
| except Exception as e: | |
| log(f"Retrieval fallback: {type(e).__name__}: {e}") | |
| candidates = [] | |
| retrieved_candidate_tags = list( | |
| dict.fromkeys( | |
| _norm_tag_for_lookup(c.tag) | |
| for c in (candidates or []) | |
| if getattr(c, "tag", None) | |
| ) | |
| ) | |
| log("Step 3: LLM index selection (uses rewrite + structural/probe context)") | |
| selection_query = _build_selection_query( | |
| prompt_in=prompt_in, | |
| rewritten=rewritten, | |
| structural_tags=structural_tags, | |
| probe_tags=probe_tags, | |
| ) | |
| picked_indices = None | |
| last_stage3_error: Exception | None = None | |
| stage3_attempts = 1 + int(stage3_fast_retry_count) | |
| for attempt_i in range(stage3_attempts): | |
| timeout_s = stage3_select_timeout_s if attempt_i == 0 else stage3_select_retry_timeout_s | |
| if attempt_i > 0: | |
| log( | |
| f"Index selection: fast retry {attempt_i}/{stage3_fast_retry_count} " | |
| f"(timeout={timeout_s:.0f}s)" | |
| ) | |
| ex = ThreadPoolExecutor(max_workers=1) | |
| try: | |
| fut_sel = ex.submit( | |
| llm_select_indices, | |
| query_text=selection_query, | |
| candidates=candidates, | |
| max_pick=0, | |
| log=log, | |
| mode=selection_mode, | |
| chunk_size=max(1, selection_chunk_size), | |
| per_phrase_k=max(1, selection_per_phrase_k), | |
| ) | |
| picked_indices = _future_with_timeout( | |
| fut_sel, | |
| timeout_s, | |
| "Index selection", | |
| [], | |
| strict=True, | |
| ) | |
| last_stage3_error = None | |
| break | |
| except Exception as e: | |
| last_stage3_error = e | |
| log(f"Index selection attempt {attempt_i + 1} failed: {e}") | |
| finally: | |
| ex.shutdown(wait=False, cancel_futures=True) | |
| if picked_indices is None: | |
| raise RuntimeError( | |
| f"Index selection failed after {stage3_attempts} attempt(s): {last_stage3_error}" | |
| ) | |
| selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else [] | |
| selection_selected_tags, removed_stage3_low = _filter_min_count_tags(selection_selected_tags, min_tag_count) | |
| if removed_stage3_low: | |
| log( | |
| f" Filtered {len(removed_stage3_low)} low-frequency stage3 tags " | |
| f"(<{min_tag_count}): {', '.join(removed_stage3_low)}" | |
| ) | |
| selected_tags = list(selection_selected_tags) | |
| if structural_tags: | |
| # Add structural tags that aren't already selected | |
| existing = {t for t in selected_tags} | |
| new_structural = [t for t in structural_tags if t not in existing] | |
| selected_tags.extend(new_structural) | |
| log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}") | |
| else: | |
| log(" No structural tags inferred") | |
| if probe_tags: | |
| existing = {t for t in selected_tags} | |
| new_probe = [t for t in probe_tags if t not in existing] | |
| selected_tags.extend(new_probe) | |
| log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}") | |
| elif enable_probe_tags: | |
| log(" No probe tags inferred") | |
| selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags) | |
| if removed_excluded_direct: | |
| log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}") | |
| direct_selected_tags = list(dict.fromkeys(selected_tags)) | |
| log("Step 3c: Expand via tag implications") | |
| t0 = time.perf_counter() | |
| tag_set = set(selected_tags) | |
| expanded, implied_only = expand_tags_via_implications(tag_set) | |
| dt = time.perf_counter()-t0 | |
| _record_timing("implication_expansion", dt) | |
| log(f"Implication expansion: {dt:.2f}s") | |
| implied_selected_tags = sorted(implied_only) if implied_only else [] | |
| if implied_only: | |
| implied_added = sorted(implied_only) | |
| implied_added, removed_implied_low = _filter_min_count_tags(implied_added, min_tag_count) | |
| implied_selected_tags = list(implied_added) | |
| if implied_added: | |
| selected_tags.extend(implied_added) | |
| log(f" Added {len(implied_added)} implied tags: {', '.join(implied_added)}") | |
| if removed_implied_low: | |
| log( | |
| f" Filtered {len(removed_implied_low)} low-frequency implied tags " | |
| f"(<{min_tag_count}): {', '.join(removed_implied_low)}" | |
| ) | |
| else: | |
| log(" No additional implied tags") | |
| selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags) | |
| implied_selected_tags = [ | |
| t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t) | |
| ] | |
| if removed_excluded_implied: | |
| log( | |
| f" Removed {len(removed_excluded_implied)} excluded tags after implications: " | |
| f"{', '.join(removed_excluded_implied)}" | |
| ) | |
| log("Step 4: Compose final prompt") | |
| t0 = time.perf_counter() | |
| final_prompt = compose_final_prompt(rewritten, selected_tags) | |
| dt = time.perf_counter()-t0 | |
| _record_timing("prompt_composition", dt) | |
| log(f"Prompt composition: {dt:.2f}s") | |
| log("Step 5: Build ranked group/category display") | |
| t0 = time.perf_counter() | |
| seed_terms = [] | |
| seed_terms.extend(user_tags) | |
| seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()]) | |
| seed_terms.extend(structural_tags or []) | |
| seed_terms.extend(probe_tags or []) | |
| seed_terms.extend(selected_tags) | |
| seed_terms = list(dict.fromkeys(seed_terms)) | |
| active_selected_tags = list(dict.fromkeys(selected_tags)) | |
| structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t} | |
| probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t} | |
| implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t} | |
| rewrite_set = { | |
| _norm_tag_for_lookup(t) | |
| for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()]) | |
| if t | |
| } | |
| selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t} | |
| tag_selection_origins: Dict[str, str] = {} | |
| for tag in active_selected_tags: | |
| tag_norm = _norm_tag_for_lookup(tag) | |
| if tag_norm in structural_set: | |
| origin = "structural" | |
| elif tag_norm in probe_set: | |
| origin = "probe" | |
| elif tag_norm in rewrite_set: | |
| origin = "rewrite" | |
| elif tag_norm in selection_set: | |
| origin = "selection" | |
| elif tag_norm in implied_set: | |
| origin = "implied" | |
| else: | |
| # Unknown/fallback tags use selection color. | |
| origin = "selection" | |
| tag_selection_origins[tag] = origin | |
| if tag_norm and tag_norm != tag: | |
| tag_selection_origins[tag_norm] = origin | |
| direct_tags_for_implied = list( | |
| dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t) | |
| ) | |
| direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)} | |
| direct_tags_for_implied.sort( | |
| key=lambda t: ( | |
| _selection_source_rank(tag_selection_origins.get(t, "selection")), | |
| direct_tags_for_implied_idx.get(t, 10**9), | |
| ) | |
| ) | |
| implied_parent_map = _build_implied_parent_map( | |
| direct_tags_ordered=direct_tags_for_implied, | |
| implied_tags=implied_selected_tags, | |
| ) | |
| toggle_rows = _build_toggle_rows( | |
| seed_terms=seed_terms, | |
| selected_tags=active_selected_tags, | |
| retrieved_candidate_tags=retrieved_candidate_tags, | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| top_groups=max(1, int(display_top_groups)), | |
| top_tags_per_group=max(1, int(display_top_tags_per_group)), | |
| group_rank_top_k=max(1, int(display_rank_top_k)), | |
| ) | |
| dt = time.perf_counter()-t0 | |
| _record_timing("group_display", dt) | |
| log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)") | |
| log( | |
| _build_display_audit_line( | |
| toggle_rows, | |
| active_selected_tags=active_selected_tags, | |
| direct_selected_tags=direct_selected_tags, | |
| implied_selected_tags=implied_selected_tags, | |
| ) | |
| ) | |
| for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]): | |
| tags_preview = ", ".join(row.get("tags", [])) | |
| log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}") | |
| total_dt = time.perf_counter()-t_total0 | |
| _emit_timing_summary(total_dt) | |
| _append_timing_jsonl(total_dt) | |
| log("Done: final prompt ready") | |
| return _build_ui_payload( | |
| console_text="\n".join(logs), | |
| row_defs=toggle_rows, | |
| selected_tags=active_selected_tags, | |
| ) | |
| except Exception as e: | |
| log(f"Error: {type(e).__name__}: {e}") | |
| return _build_ui_payload( | |
| console_text="\n".join(logs), | |
| row_defs=[], | |
| selected_tags=[], | |
| suggested_prompt_text=_format_user_facing_error(e), | |
| ) | |
| with gr.Blocks(css=css, js=client_js) as app: | |
| with gr.Row(): | |
| with gr.Column(scale=3, elem_classes=["prompt-col"]): | |
| gr.Markdown( | |
| 'Describe your image under "Enter Prompt" and click "Run". ' | |
| 'Prompt Squirrel will translate it into image board tags.', | |
| elem_classes=["top-instruction"], | |
| ) | |
| with gr.Group(elem_classes=["prompt-card"]): | |
| image_tags = gr.Textbox( | |
| label="Enter Prompt", | |
| placeholder="e.g. fox, outside, detailed background, .", | |
| lines=1, | |
| elem_classes=["enter-prompt-box"], | |
| ) | |
| with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]): | |
| suggested_prompt = gr.Textbox( | |
| label="Suggested Prompt (Read-only)", | |
| lines=2, | |
| interactive=False, | |
| show_copy_button=True, | |
| placeholder='Suggested prompt will appear here after you click "Run".', | |
| elem_classes=["suggested-prompt-box"], | |
| ) | |
| with gr.Column(scale=1): | |
| _mascot_pil = _load_mascot_image() | |
| if _mascot_pil is not None: | |
| mascot_img = gr.Image( | |
| value=_mascot_pil, | |
| show_label=False, | |
| interactive=False, | |
| height=240, | |
| elem_id="mascot" | |
| ) | |
| else: | |
| mascot_img = gr.Markdown("`(mascot image unavailable)`") | |
| submit_button = gr.Button("Run", variant="primary", visible=False, interactive=False) | |
| gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"]) | |
| last_run_prompt_state = gr.State("") | |
| selected_tags_state = gr.State([]) | |
| rows_dirty_state = gr.State(False) | |
| row_defs_state = gr.State([]) | |
| row_values_state = gr.State([]) | |
| toggle_instruction = gr.Markdown( | |
| "Click tag buttons to add or remove tags from the suggested prompt.", | |
| elem_classes=["row-instruction"], | |
| visible=False, | |
| ) | |
| row_headers: List[gr.Markdown] = [] | |
| row_checkboxes: List[gr.CheckboxGroup] = [] | |
| for _ in range(display_max_rows_default): | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=170): | |
| row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"])) | |
| with gr.Column(scale=10): | |
| row_checkboxes.append( | |
| gr.CheckboxGroup( | |
| choices=[], | |
| value=[], | |
| visible=False, | |
| interactive=True, | |
| container=False, | |
| elem_classes=["lego-tags"], | |
| ) | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| gr.HTML( | |
| """ | |
| <div class="source-legend"> | |
| <span class="legend-title">Legend:</span> | |
| <span class="chip rewrite">Rewrite phrase</span> | |
| <span class="chip selection">General selection</span> | |
| <span class="chip probe">Probe query</span> | |
| <span class="chip structural">Structural query</span> | |
| <span class="chip implied">Implied</span> | |
| <span class="chip user">User-toggled</span> | |
| <span class="chip unselected">Unselected</span> | |
| </div> | |
| """ | |
| ) | |
| with gr.Column(scale=2, min_width=180): | |
| rebuild_rows_button = gr.Button( | |
| "Rebuild Rows", | |
| variant="primary", | |
| visible=False, | |
| interactive=False, | |
| ) | |
| with gr.Accordion("Display Settings", open=False): | |
| with gr.Row(): | |
| display_top_groups = gr.Number( | |
| value=display_top_groups_default, | |
| precision=0, | |
| label="Rows (Top Groups/Categories)", | |
| minimum=1, | |
| ) | |
| display_top_tags_per_group = gr.Number( | |
| value=display_top_tags_per_group_default, | |
| precision=0, | |
| label="Top Tags Shown Per Row", | |
| minimum=1, | |
| ) | |
| display_rank_top_k = gr.Number( | |
| value=display_rank_top_k_default, | |
| precision=0, | |
| label="Top Tags Used for Row Ranking", | |
| minimum=1, | |
| ) | |
| with gr.Accordion("Console", open=False): | |
| console = gr.Textbox( | |
| label="Console", | |
| lines=10, | |
| interactive=False, | |
| placeholder="Progress logs will appear here." | |
| ) | |
| with gr.Accordion("How Prompt Squirrel Works", open=False): | |
| _about_md = _load_about_docs_markdown() | |
| _about_before, _about_after, _has_arch_slot = _split_about_docs_for_diagram(_about_md) | |
| if _has_arch_slot: | |
| if _about_before: | |
| gr.Markdown( | |
| _about_before, | |
| elem_id="about-docs", | |
| elem_classes=["about-docs"], | |
| ) | |
| gr.HTML( | |
| _build_arch_diagram_html(), | |
| elem_classes=["about-docs"], | |
| ) | |
| if _about_after: | |
| gr.Markdown( | |
| _about_after, | |
| elem_classes=["about-docs"], | |
| ) | |
| else: | |
| gr.Markdown( | |
| _about_md, | |
| elem_id="about-docs", | |
| elem_classes=["about-docs"], | |
| ) | |
| tooltip_map_payload = gr.Textbox( | |
| value="{}", | |
| visible=True, | |
| interactive=False, | |
| container=False, | |
| elem_id="psq-tooltip-map", | |
| elem_classes=["psq-hidden"], | |
| ) | |
| run_outputs = [ | |
| console, | |
| toggle_instruction, | |
| tooltip_map_payload, | |
| suggested_prompt, | |
| selected_tags_state, | |
| rows_dirty_state, | |
| rebuild_rows_button, | |
| row_defs_state, | |
| row_values_state, | |
| *row_headers, | |
| *row_checkboxes, | |
| ] | |
| image_tags.change( | |
| _update_run_button_visibility, | |
| inputs=[image_tags, last_run_prompt_state], | |
| outputs=[submit_button], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| submit_button.click( | |
| _mark_run_triggered, | |
| inputs=[image_tags], | |
| outputs=[submit_button, last_run_prompt_state], | |
| queue=False, | |
| show_progress="hidden", | |
| ).then( | |
| _prepare_run_ui, | |
| inputs=[], | |
| outputs=run_outputs, | |
| queue=False, | |
| show_progress="hidden", | |
| ).then( | |
| rag_pipeline_ui, | |
| inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k], | |
| outputs=run_outputs, | |
| ) | |
| image_tags.submit( | |
| _mark_run_triggered, | |
| inputs=[image_tags], | |
| outputs=[submit_button, last_run_prompt_state], | |
| queue=False, | |
| show_progress="hidden", | |
| ).then( | |
| _prepare_run_ui, | |
| inputs=[], | |
| outputs=run_outputs, | |
| queue=False, | |
| show_progress="hidden", | |
| ).then( | |
| rag_pipeline_ui, | |
| inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k], | |
| outputs=run_outputs, | |
| ) | |
| for idx, row_cb in enumerate(row_checkboxes): | |
| row_cb.change( | |
| fn=lambda changed_values, selected_state, rows_dirty, row_defs, row_values, i=idx: _on_toggle_row( | |
| i, | |
| changed_values, | |
| selected_state, | |
| rows_dirty, | |
| row_defs, | |
| row_values, | |
| display_max_rows_default, | |
| ), | |
| inputs=[row_cb, selected_tags_state, rows_dirty_state, row_defs_state, row_values_state], | |
| outputs=[selected_tags_state, rows_dirty_state, rebuild_rows_button, row_values_state, suggested_prompt, *row_checkboxes], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| rebuild_rows_button.click( | |
| _rebuild_rows_from_selected, | |
| inputs=[selected_tags_state, row_defs_state, row_values_state, display_top_groups, display_top_tags_per_group, display_rank_top_k], | |
| outputs=[ | |
| toggle_instruction, | |
| tooltip_map_payload, | |
| suggested_prompt, | |
| selected_tags_state, | |
| rows_dirty_state, | |
| rebuild_rows_button, | |
| row_defs_state, | |
| row_values_state, | |
| *row_headers, | |
| *row_checkboxes, | |
| ], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| if __name__ == "__main__": | |
| app.queue().launch(allowed_paths=[str(MASCOT_DIR), str(DOCS_DIR)]) | |