Food Desert
Inject architecture section only in accordion, keep docs file clean
eef9f1d
import gradio as gr
import os
import logging
import time
import json
import csv
import re
import base64
from datetime import datetime
from functools import lru_cache
from PIL import Image
from pathlib import Path
from typing import Any, Dict, List, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
from psq_rag.llm.rewrite import llm_rewrite_prompt
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags
from psq_rag.retrieval.state import (
expand_tags_via_implications,
get_tag_type_name,
get_tag_implications,
get_tag_counts,
)
from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
APP_DIR = Path(__file__).parent
DOCS_DIR = APP_DIR / "docs"
ARCH_DIAGRAM_FILE = DOCS_DIR / "assets" / "architecture_overview.png"
ARCH_DIAGRAM_MARKER = "{{ARCHITECTURE_DIAGRAM}}"
ARCH_DIAGRAM_INSERT_BEFORE_HEADING = "## What Each Step Does"
_CORPORATE_HARDBLOCK_PATTERNS = [
# Rating-like explicitness markers.
re.compile(r"(^|_)(nsfw|explicit|questionable)(_|$)", re.IGNORECASE),
# Unambiguous sexual anatomy.
re.compile(
r"(^|_)(breast|breasts|boob|boobs|nipple|nipples|penis|vagina|pussy|clit|testicle|scrotum|genital|crotch|anus|anal|areola)(_|$)",
re.IGNORECASE,
),
# Unambiguous sexual activity.
re.compile(
r"(^|_)(sex|sexual|fucking|fuck|blowjob|handjob|masturbat|penetrat|thrust|orgasm|cum|ejaculat|creampie|nude|naked|topless|bottomless|moan|sexy)(_|$)",
re.IGNORECASE,
),
# Common kink/fetish markers.
re.compile(r"(^|_)(fetish|bdsm|bondage|dominatrix|submission|vore|inflation|watersports)(_|$)", re.IGNORECASE),
]
def _split_prompt_commas(s: str) -> List[str]:
return [p.strip() for p in (s or "").split(",") if p.strip()]
def _norm_for_dedupe(tag: str) -> str:
# your canonical form for lookup/dedupe
return _norm_tag_for_lookup(tag.lower())
def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str:
parts = _split_prompt_commas(rewritten_prompt)
parts.extend(selected_tags)
seen = set()
out = []
for p in parts:
key = _norm_for_dedupe(p)
if key in seen:
continue
seen.add(key)
out.append(p)
return ", ".join(out)
def _display_tag_text(tag: str) -> str:
return tag.replace("_", " ")
def _display_row_label(name: str) -> str:
n = (name or "").strip()
if not n:
return ""
if n == "selected_other":
return "Selected (Other)"
return n.replace("_", " ").title()
def _normalize_selection_origin(origin: str) -> str:
o = (origin or "").strip().lower()
if o in {"rewrite", "selection", "probe", "structural", "user", "candidate"}:
return o
return "selection"
def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str:
# Keep labels plain to avoid frontend text/value desynchronization.
return _display_tag_text(tag)
@lru_cache(maxsize=1)
def _load_tag_wiki_defs() -> Dict[str, str]:
p = Path("data/tag_wiki_defs.json")
if not p.exists():
return {}
try:
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
out: Dict[str, str] = {}
if isinstance(data, dict):
for k, v in data.items():
tag = _norm_tag_for_lookup(str(k))
text = " ".join(str(v or "").split())
if tag and text:
out[tag] = text
return out
except Exception:
return {}
@lru_cache(maxsize=1)
def _load_about_docs_markdown() -> str:
candidates = [
DOCS_DIR / "space_overview.md",
APP_DIR / "PROJECT_SUMMARY.md",
]
for p in candidates:
if not p.exists():
continue
try:
raw = p.read_text(encoding="utf-8")
except Exception:
continue
text = raw.strip()
if not text:
continue
# Strip YAML front matter if present.
if text.startswith("---"):
parts = text.split("---", 2)
if len(parts) >= 3:
text = parts[2].strip()
if text:
return text
return (
"Documentation is unavailable.\n\n"
"Expected file: `docs/space_overview.md`"
)
def _tooltip_text_for_tag(tag: str) -> str:
t = _norm_tag_for_lookup(tag)
parts: List[str] = []
try:
count = get_tag_counts().get(t)
except Exception:
count = None
if isinstance(count, int):
parts.append(f"Count: {count:,}")
d = _load_tag_wiki_defs().get(t, "")
if d:
parts.append(d)
return "\n".join(parts).strip()
@lru_cache(maxsize=1)
def _load_arch_diagram_data_uri() -> str:
if not ARCH_DIAGRAM_FILE.exists():
return ""
try:
raw = ARCH_DIAGRAM_FILE.read_bytes()
except Exception:
return ""
if not raw:
return ""
b64 = base64.b64encode(raw).decode("ascii")
return f"data:image/png;base64,{b64}"
def _split_about_docs_for_diagram(md: str) -> Tuple[str, str, bool]:
text = (md or "").strip()
if ARCH_DIAGRAM_MARKER in text:
before, after = text.rsplit(ARCH_DIAGRAM_MARKER, 1)
return before.strip(), after.strip(), True
# Backward compatibility if an explicit architecture heading exists in docs.
m_arch = re.search(r"(?m)^##\s+Architecture At A Glance\s*$", text)
if m_arch:
before = text[: m_arch.start()].strip()
after = text[m_arch.end() :].strip()
return before, after, True
# Preferred insertion point: inject diagram right before "What Each Step Does".
m_steps = re.search(r"(?m)^##\s+What Each Step Does\s*$", text)
if m_steps:
before = text[: m_steps.start()].strip()
after = text[m_steps.start() :].strip()
return before, after, True
return text, "", False
def _build_arch_diagram_html() -> str:
uri = _load_arch_diagram_data_uri()
if not uri:
return "<p><code>(architecture diagram unavailable)</code></p>"
return f"""
<div class="arch-diagram-wrap">
<h2>Architecture At A Glance</h2>
<img src="{uri}" alt="Architecture diagram" />
</div>
"""
def _selection_source_rank(origin: str) -> int:
o = _normalize_selection_origin(origin)
if o == "structural":
return 0
if o == "probe":
return 1
# Keep rewrite/user in the same priority band as general selection for row ordering.
return 2
def _build_implied_parent_map(
direct_tags_ordered: List[str],
implied_tags: List[str],
) -> Dict[str, str]:
implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t}
if not implied_set or not direct_tags_ordered:
return {}
impl = get_tag_implications()
parent_by_implied: Dict[str, str] = {}
for direct in direct_tags_ordered:
d = _norm_tag_for_lookup(direct)
if not d:
continue
queue = [d]
seen = {d}
while queue:
t = queue.pop()
for parent in impl.get(t, ()):
p = _norm_tag_for_lookup(parent)
if not p or p in seen:
continue
seen.add(p)
if p in implied_set and p not in parent_by_implied:
parent_by_implied[p] = d
queue.append(p)
return parent_by_implied
def _order_selected_tags_for_row(
*,
row_selected_tags: List[str],
selected_index: Dict[str, int],
tag_selection_origins: Dict[str, str],
implied_parent_map: Dict[str, str],
) -> List[str]:
row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t]
implied_in_row = {t for t in row_selected_norm if t in implied_parent_map}
base_tags = [t for t in row_selected_norm if t not in implied_in_row]
base_tags.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
selected_index.get(t, 10**9),
t,
)
)
children_by_parent: Dict[str, List[str]] = {}
for implied in implied_in_row:
parent = implied_parent_map.get(implied)
if parent:
children_by_parent.setdefault(parent, []).append(implied)
for parent, children in children_by_parent.items():
children.sort(key=lambda t: (selected_index.get(t, 10**9), t))
ordered: List[str] = []
emitted: Set[str] = set()
for tag in base_tags:
if tag in emitted:
continue
ordered.append(tag)
emitted.add(tag)
for child in children_by_parent.get(tag, []):
if child not in emitted:
ordered.append(child)
emitted.add(child)
remaining_implied = [t for t in row_selected_norm if t not in emitted]
remaining_implied.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")),
selected_index.get(implied_parent_map.get(t, ""), 10**9),
selected_index.get(t, 10**9),
t,
)
)
for t in remaining_implied:
if t not in emitted:
ordered.append(t)
emitted.add(t)
return ordered
def _escape_prompt_tag(tag: str) -> str:
return (
tag.replace("_", " ")
.replace("(", "\\(")
.replace(")", "\\)")
)
def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
out: List[str] = []
seen: Set[str] = set()
for row in row_defs:
for tag in row.get("tags", []):
if tag in selected and tag not in seen:
out.append(tag)
seen.add(tag)
return out
def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
selected = {t for t in (selected_tags or []) if t}
ordered = _ordered_selected_for_prompt(selected, row_defs or [])
return ", ".join(_escape_prompt_tag(t) for t in ordered)
def _is_artist_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
# Keep a resilient fallback for malformed/missing tag typing metadata.
return get_tag_type_name(t) == "artist" or t.startswith("by_")
@lru_cache(maxsize=1)
def _load_excluded_recommendation_tags() -> Set[str]:
out: Set[str] = set()
# Existing category-registry driven exclusions.
csv_path = Path("data/category_registry.csv")
if not csv_path.exists():
csv_path = Path("data/analysis/category_registry.csv")
if csv_path.exists():
try:
with csv_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = _norm_tag_for_lookup(str(row.get("tag") or ""))
if not tag:
continue
status = str(row.get("category_status") or "").strip().lower()
if status == "excluded":
out.add(tag)
except Exception:
pass
# Corporate-safety exclusions (editable runtime list).
corp_path = Path("data/corporate_excluded_tags.csv")
if corp_path.exists():
try:
with corp_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = _norm_tag_for_lookup(str(row.get("tag") or ""))
if not tag:
continue
enabled_raw = str(row.get("enabled", "1")).strip().lower()
enabled = enabled_raw not in {"0", "false", "no", "off"}
if enabled:
out.add(tag)
except Exception:
pass
return out
def _is_hardblocked_corporate_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
return any(rx.search(t) for rx in _CORPORATE_HARDBLOCK_PATTERNS)
def _is_excluded_recommendation_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
if _is_hardblocked_corporate_tag(t):
return True
return t in _load_excluded_recommendation_tags()
def _get_min_tag_count() -> int:
try:
return max(0, int(os.environ.get("PSQ_MIN_TAG_COUNT", "100")))
except Exception:
return 100
def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str], List[str]]:
if min_count <= 0:
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
tag_counts = get_tag_counts()
keep: List[str] = []
removed: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
c = int(tag_counts.get(t, 0) or 0)
if c < min_count:
removed.append(t)
continue
if t in seen:
continue
seen.add(t)
keep.append(t)
return keep, sorted(set(removed))
def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]:
excluded = _load_excluded_recommendation_tags()
if not excluded:
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
keep: List[str] = []
removed: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
if t in excluded:
removed.append(t)
continue
if t in seen:
continue
seen.add(t)
keep.append(t)
return keep, sorted(set(removed))
def _filter_excluded_candidates(candidates: List[Any]) -> Tuple[List[Any], List[str]]:
excluded = _load_excluded_recommendation_tags()
if not excluded:
return list(candidates or []), []
keep: List[Any] = []
removed: List[str] = []
for c in (candidates or []):
tag = _norm_tag_for_lookup(str(getattr(c, "tag", "") or ""))
if tag and tag in excluded:
removed.append(tag)
continue
keep.append(c)
return keep, sorted(set(removed))
def _dedupe_norm_tags(tags: List[str]) -> List[str]:
out: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t or t in seen:
continue
seen.add(t)
out.append(t)
return out
def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]:
out: Set[str] = set()
for row in (row_defs or []):
for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []):
out.add(t)
return out
def _collect_selected_from_state(
selected_tags_state: List[str],
row_defs: List[Dict[str, Any]],
) -> List[str]:
visible_tags = _collect_visible_tags(row_defs)
if not visible_tags:
return []
selected: List[str] = []
seen: Set[str] = set()
visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags}
for raw in (selected_tags_state or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
mapped = t if t in visible_tags else visible_by_norm.get(t)
if not mapped or mapped in seen:
continue
seen.add(mapped)
selected.append(mapped)
return selected
def _collect_selected_from_row_values(
row_defs: List[Dict[str, Any]],
row_values_state: List[List[str]],
) -> List[str]:
selected: List[str] = []
seen: Set[str] = set()
values = list(row_values_state or [])
for idx, row in enumerate(row_defs or []):
row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else [])
if not row_tags:
continue
row_tag_set = set(row_tags)
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
raw_vals = values[idx] if 0 <= idx < len(values) else []
for raw in (raw_vals or []):
if raw in row_tag_set:
if raw not in seen:
seen.add(raw)
selected.append(raw)
continue
raw_norm = _norm_tag_for_lookup(str(raw))
mapped = row_tag_by_norm.get(raw_norm)
if mapped and mapped not in seen:
seen.add(mapped)
selected.append(mapped)
return selected
def _build_toggle_rows(
*,
seed_terms: List[str],
selected_tags: List[str],
retrieved_candidate_tags: List[str],
tag_selection_origins: Dict[str, str],
implied_parent_map: Dict[str, str],
top_groups: int,
top_tags_per_group: int,
group_rank_top_k: int,
) -> List[Dict[str, Any]]:
ranked_rows = rank_groups_from_tfidf(
seed_terms=seed_terms,
top_groups=max(1, int(top_groups)),
top_tags_per_group=max(1, int(top_tags_per_group)),
group_rank_top_k=max(1, int(group_rank_top_k)),
)
groups_map = _load_enabled_groups()
selected_active = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in selected_tags
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
)
)
selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)}
row_defs: List[Dict[str, Any]] = []
enabled_group_tag_sets: Dict[str, Set[str]] = {
name: {t for t in tags if not _is_artist_tag(t)}
for name, tags in groups_map.items()
}
tags_in_any_enabled_group: Set[str] = set()
for tag_set in enabled_group_tag_sets.values():
tags_in_any_enabled_group.update(tag_set)
displayed_group_names = [r.group_name for r in ranked_rows]
displayed_group_tag_sets: Dict[str, Set[str]] = {
name: enabled_group_tag_sets.get(name, set())
for name in displayed_group_names
}
tags_in_any_displayed_group: Set[str] = set()
for tag_set in displayed_group_tag_sets.values():
tags_in_any_displayed_group.update(tag_set)
retrieved_uncategorized_ranked = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in (retrieved_candidate_tags or [])
if t
and not _is_artist_tag(t)
and not _is_excluded_recommendation_tag(t)
and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group
)
)
retrieved_other_row: Dict[str, Any] | None = None
if retrieved_uncategorized_ranked:
retrieved_uncategorized_set = set(retrieved_uncategorized_ranked)
selected_in_retrieved_other_raw = [
t for t in selected_active if t in retrieved_uncategorized_set
]
selected_in_retrieved_other = _order_selected_tags_for_row(
row_selected_tags=selected_in_retrieved_other_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
merged_retrieved_other = selected_in_retrieved_other + [
t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other
]
merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other)
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other))
merged_retrieved_other = merged_retrieved_other[:keep_n]
retrieved_other_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": t in selected_active,
}
for t in merged_retrieved_other
}
retrieved_other_row = {
"name": "other_retrieved",
"label": "Other (Retrieved)",
"tags": merged_retrieved_other,
"tag_meta": retrieved_other_meta,
}
# "Selected (Other)" should contain selected tags not already shown in any displayed row.
# Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows.
tags_in_displayed_rows = set(tags_in_any_displayed_group)
if retrieved_other_row:
tags_in_displayed_rows.update(retrieved_other_row.get("tags", []))
selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows]
selected_other = _order_selected_tags_for_row(
row_selected_tags=selected_other_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
selected_other = _dedupe_norm_tags(selected_other)
selected_other_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": True,
}
for t in selected_other
}
row_defs.append(
{
"name": "selected_other",
"label": _display_row_label("selected_other"),
"tags": selected_other,
"tag_meta": selected_other_meta,
}
)
for row in ranked_rows:
group_name = row.group_name
group_tag_set = displayed_group_tag_sets.get(group_name, set())
selected_in_group_raw = [t for t in selected_active if t in group_tag_set]
selected_in_group = _order_selected_tags_for_row(
row_selected_tags=selected_in_group_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
ranked_tags = [
_norm_tag_for_lookup(t)
for t, _ in row.tags
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
]
ranked_tags = _dedupe_norm_tags(ranked_tags)
merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
merged = _dedupe_norm_tags(merged)
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
merged = merged[:keep_n]
tag_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": t in selected_active,
}
for t in merged
}
row_defs.append(
{
"name": group_name,
"label": _display_row_label(group_name),
"tags": merged,
"tag_meta": tag_meta,
}
)
# Keep this row at the bottom so category/group rows remain contiguous.
if retrieved_other_row:
row_defs.append(retrieved_other_row)
return row_defs
def _build_display_audit_line(
row_defs: List[Dict[str, Any]],
*,
active_selected_tags: List[str],
direct_selected_tags: List[str],
implied_selected_tags: List[str],
) -> str:
active_set = {
_norm_tag_for_lookup(t)
for t in (active_selected_tags or [])
if t and not _is_artist_tag(t)
}
direct_set = {
_norm_tag_for_lookup(t)
for t in (direct_selected_tags or [])
if t and not _is_artist_tag(t)
}
implied_set = {
_norm_tag_for_lookup(t)
for t in (implied_selected_tags or [])
if t and not _is_artist_tag(t)
}
info_by_tag: Dict[str, Dict[str, Any]] = {}
for row in row_defs or []:
row_name = row.get("name", "")
row_label = row.get("label", row_name)
for tag in row.get("tags", []):
rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()})
rec["rows"].append(row_label)
if row_name == "selected_other":
rec["sources"].add("selected_other_row")
elif row_name == "other_retrieved":
rec["sources"].add("other_retrieved_row")
else:
rec["sources"].add("ranked_group_row")
if tag in active_set:
rec["sources"].add("selected_active")
if tag in direct_set:
rec["sources"].add("selected_direct")
if tag in implied_set:
rec["sources"].add("selected_implied")
payload = {
"n_tags": len(info_by_tag),
"tags": [
{
"tag": tag,
"rows": rec["rows"],
"sources": sorted(rec["sources"]),
}
for tag, rec in sorted(info_by_tag.items())
],
}
return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True)
def _build_tooltip_payload(row_defs: List[Dict[str, Any]], max_rows: int) -> str:
row_defs_ui = (row_defs or [])[: max(0, int(max_rows))]
tips: Dict[str, str] = {}
rows: List[List[str]] = []
for row in row_defs_ui:
tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else [])
rows.append(tags)
for t in tags:
if t not in tips:
tips[t] = _tooltip_text_for_tag(t)
return json.dumps({"rows": rows, "tips": tips}, ensure_ascii=True)
def _build_row_component_updates(
row_defs: List[Dict[str, Any]],
selected_tags: List[str],
max_rows: int,
):
selected = {t for t in (selected_tags or []) if t}
row_defs_ui = (row_defs or [])[: max(0, int(max_rows))]
row_values_state: List[List[str]] = []
header_updates = []
checkbox_updates = []
for idx in range(max_rows):
if idx < len(row_defs_ui):
row = row_defs_ui[idx]
tags = _dedupe_norm_tags(row.get("tags", []))
values = [t for t in tags if t in selected]
row_values_state.append(values)
visible = bool(tags)
header_updates.append(gr.update(value=row.get("label", ""), visible=visible))
tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {}
choices = []
for t in tags:
meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {}
origin = _normalize_selection_origin(str(meta.get("origin", "selection")))
preselected = bool(meta.get("preselected", False))
choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t))
checkbox_updates.append(
gr.update(
choices=choices,
value=values,
visible=visible,
)
)
else:
header_updates.append(gr.update(value="", visible=False))
checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui)
return prompt_text, row_values_state, header_updates, checkbox_updates
def _on_toggle_row(
row_idx: int,
changed_values: List[str],
selected_tags_state: List[str],
rows_dirty_state: bool,
row_defs_state: List[Dict[str, Any]],
row_values_state: List[List[str]],
max_rows: int,
):
row_defs = row_defs_state or []
row_defs_ui = row_defs[: max(0, int(max_rows))]
prev_values = list(row_values_state or [])
selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui)
selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values)
# Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback.
selected: Set[str] = set(selected_from_rows or selected_from_state)
row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {}
row_tags = _dedupe_norm_tags(row.get("tags", []))
row_label = str(row.get("label", ""))
row_tag_set = set(row_tags)
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
# Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants,
# and occasional single-string payloads from frontend events.
if changed_values is None:
changed_iter: List[Any] = []
elif isinstance(changed_values, str):
changed_iter = [changed_values]
elif isinstance(changed_values, (list, tuple, set)):
changed_iter = list(changed_values)
else:
changed_iter = [changed_values]
# Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants.
new_set: Set[str] = set()
for raw in changed_iter:
if raw in row_tag_set:
new_set.add(raw)
continue
raw_norm = _norm_tag_for_lookup(str(raw))
mapped = row_tag_by_norm.get(raw_norm)
if mapped:
new_set.add(mapped)
prev_row_selected = {t for t in row_tags if t in selected}
# Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically.
if new_set == prev_row_selected:
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
checkbox_updates = [gr.skip() for _ in range(max_rows)]
return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates]
selected.difference_update(row_tag_set)
selected.update(new_set)
toggled_tags = prev_row_selected ^ new_set
new_row_values_state: List[List[str]] = []
affected_rows: Set[int] = {row_idx}
for idx, row_item in enumerate(row_defs_ui):
tags = _dedupe_norm_tags(row_item.get("tags", []))
values = [t for t in tags if t in selected]
new_row_values_state.append(values)
if toggled_tags and any(t in toggled_tags for t in tags):
affected_rows.add(idx)
checkbox_updates = []
for idx in range(max_rows):
if idx >= len(row_defs_ui):
checkbox_updates.append(gr.skip())
continue
if idx in affected_rows:
checkbox_updates.append(gr.update(value=new_row_values_state[idx]))
else:
checkbox_updates.append(gr.skip())
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
return [
sorted(selected),
True,
gr.update(visible=True, interactive=True),
new_row_values_state,
prompt_text,
*checkbox_updates,
]
def _build_ui_payload(
*,
console_text: str,
row_defs: List[Dict[str, Any]],
selected_tags: List[str],
suggested_prompt_text: str | None = None,
):
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
row_defs=row_defs,
selected_tags=selected_tags,
max_rows=display_max_rows_default,
)
if suggested_prompt_text is not None:
prompt_text = str(suggested_prompt_text)
selected_ui: List[str] = []
selected_ui_seen: Set[str] = set()
for vals in row_values_state:
for t in vals:
if t in selected_ui_seen:
continue
selected_ui_seen.add(t)
selected_ui.append(t)
tooltip_payload = _build_tooltip_payload(row_defs, display_max_rows_default)
return [
console_text,
gr.update(visible=bool(row_defs)),
tooltip_payload,
prompt_text,
selected_ui,
False,
gr.update(visible=False, interactive=False),
row_defs,
row_values_state,
*header_updates,
*checkbox_updates,
]
def _format_user_facing_error(exc: Exception) -> str:
msg = str(exc or "").strip()
msg_l = msg.lower()
if "rewrite: empty output" in msg_l:
return (
"Could not rewrite that prompt. Try simpler, neutral wording and remove sensitive phrasing, "
"then click Run again."
)
if "openrouter_api_key" in msg_l:
return "Service configuration is missing. Please contact the app owner."
if "timed out" in msg_l:
return "The model request timed out. Please try again with a shorter or simpler prompt."
if "index selection failed" in msg_l:
return "Tag selection failed for this request. Please try again."
if "startup preflight failed" in msg_l:
return "App startup checks failed. Please contact the app owner."
return "Something went wrong while processing the prompt. Please try again."
def _prepare_run_ui() -> List[Any]:
header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)]
checkbox_updates = [
gr.update(choices=[], value=[], visible=False)
for _ in range(display_max_rows_default)
]
return [
"Running...",
gr.skip(),
"{}",
"Running... usually completes in about 20 seconds.",
[],
False,
gr.update(visible=False, interactive=False),
[],
[],
*header_updates,
*checkbox_updates,
]
def _update_run_button_visibility(prompt_text: str, last_run_prompt: str):
curr = (prompt_text or "").strip()
last = (last_run_prompt or "").strip()
can_run = bool(curr) and curr != last
return gr.update(visible=can_run, interactive=can_run)
def _mark_run_triggered(prompt_text: str):
curr = (prompt_text or "").strip()
return gr.update(visible=False, interactive=False), curr
def _rebuild_rows_from_selected(
selected_tags_state: List[str],
row_defs_state: List[Dict[str, Any]],
row_values_state: List[List[str]],
display_top_groups: float,
display_top_tags_per_group: float,
display_rank_top_k: float,
):
existing_rows = row_defs_state or []
existing_values = list(row_values_state or [])
selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows)
selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values)
# Rebuild source-of-truth is current row checkbox values; fall back only when unavailable.
selected_seed = selected_from_rows if existing_values else selected_from_state
selected_active = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in selected_seed
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
)
)
retrieved_candidate_tags: List[str] = []
tag_selection_origins: Dict[str, str] = {}
for row in existing_rows:
row_tags = row.get("tags", []) if isinstance(row, dict) else []
row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {}
if not isinstance(row_meta, dict):
row_meta = {}
for t in row_tags:
tn = _norm_tag_for_lookup(t)
if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn):
continue
retrieved_candidate_tags.append(tn)
if tn not in tag_selection_origins:
meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {}
tag_selection_origins[tn] = _normalize_selection_origin(str(meta.get("origin", "selection")))
for t in selected_active:
tag_selection_origins.setdefault(t, "user")
retrieved_candidate_tags.append(t)
implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"]
implied_set = set(implied_selected_tags)
direct_selected_tags = [t for t in selected_active if t not in implied_set]
direct_idx = {t: i for i, t in enumerate(direct_selected_tags)}
direct_selected_tags.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
direct_idx.get(t, 10**9),
)
)
implied_parent_map = _build_implied_parent_map(
direct_tags_ordered=direct_selected_tags,
implied_tags=implied_selected_tags,
)
toggle_rows = _build_toggle_rows(
seed_terms=list(selected_active),
selected_tags=selected_active,
retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)),
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
top_groups=max(1, int(display_top_groups)),
top_tags_per_group=max(1, int(display_top_tags_per_group)),
group_rank_top_k=max(1, int(display_rank_top_k)),
)
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
row_defs=toggle_rows,
selected_tags=selected_active,
max_rows=display_max_rows_default,
)
tooltip_payload = _build_tooltip_payload(toggle_rows, display_max_rows_default)
return [
gr.update(visible=bool(toggle_rows)),
tooltip_payload,
prompt_text,
sorted(selected_active),
False,
gr.update(visible=False, interactive=False),
toggle_rows,
row_values_state,
*header_updates,
*checkbox_updates,
]
def _build_selection_query(
prompt_in: str,
rewritten: str,
structural_tags: List[str],
probe_tags: List[str],
) -> str:
lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"]
if rewritten and rewritten.strip():
lines.append(f"REWRITE PHRASES: {rewritten.strip()}")
hint_tags = []
if structural_tags:
hint_tags.extend(structural_tags)
if probe_tags:
hint_tags.extend(probe_tags)
if hint_tags:
# Keep hints as context only; selection still must choose by candidate indices.
lines.append(
"INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags)))
)
return "\n".join(lines)
# Set up logging
# Minimal prod logging: warnings+ to stderr, no file by default
import os, logging
LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.WARNING),
format="%(asctime)s %(levelname)s:%(message)s",
handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces
)
# Quiet down common noisy libs (optional)
for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"):
logging.getLogger(_name).setLevel(logging.ERROR)
# Turn off Gradio analytics phone-home to avoid those background thread errors (optional)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "0"
MASCOT_DIR = Path(__file__).parent / "mascotimages"
MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png"
def _load_mascot_image():
"""Load mascot image if available; return None when missing/unreadable."""
if not MASCOT_FILE.exists():
logging.warning("Mascot image missing: %s", MASCOT_FILE)
return None
try:
return Image.open(MASCOT_FILE).convert("RGBA")
except Exception as e:
logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e)
return None
try:
from gradio_client import utils as _gc_utils
_orig_get_type = _gc_utils.get_type
_orig_j2p = _gc_utils._json_schema_to_python_type
_orig_pub = _gc_utils.json_schema_to_python_type
def _get_type_safe(schema):
# Sometimes schema is a bare True/False (JSON Schema boolean form)
if not isinstance(schema, dict):
return "any"
return _orig_get_type(schema)
def _j2p_safe(schema, defs=None):
# Accept non-dict schemas (True/False/None) and treat as "any"
if not isinstance(schema, dict):
return "any"
return _orig_j2p(schema, defs or schema.get("$defs"))
def _pub_safe(schema):
# Public wrapper used by Gradio; keep it resilient too
if not isinstance(schema, dict):
return "any"
return _j2p_safe(schema, schema.get("$defs"))
_gc_utils.get_type = _get_type_safe
_gc_utils._json_schema_to_python_type = _j2p_safe
_gc_utils.json_schema_to_python_type = _pub_safe
except Exception as e:
print("gradio_client hotfix not applied:", e)
# -------------------------------------------------------------------------------
allow_nsfw_tags = False
def _is_production_runtime() -> bool:
"""Best-effort detection for deployed runtime (HF Spaces or explicit env)."""
if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}:
return True
if os.environ.get("SPACE_ID"):
return True
if os.environ.get("HF_SPACE_ID"):
return True
if os.environ.get("SYSTEM") == "spaces":
return True
return False
verbose_retrieval_default = "0" if _is_production_runtime() else "1"
verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"}
verbose_retrieval_all = False
verbose_retrieval_limit = 20
enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"}
display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7"))
display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7"))
display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip()
selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60"))
selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2"))
selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0"))
stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45"))
stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45"))
stage1_probe_timeout_s = float(os.environ.get("PSQ_TIMEOUT_PROBE_S", "45"))
stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "50"))
stage3_select_retry_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_RETRY_S", "20"))
stage3_fast_retry_count = max(0, int(os.environ.get("PSQ_STAGE3_FAST_RETRY_COUNT", "1")))
timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl"))
def _startup_preflight_errors() -> List[str]:
errs: List[str] = []
if not os.getenv("OPENROUTER_API_KEY"):
errs.append("OPENROUTER_API_KEY is missing. Set it in Space Secrets or environment variables.")
return errs
STARTUP_PREFLIGHT_ERRORS = _startup_preflight_errors()
if STARTUP_PREFLIGHT_ERRORS:
for _err in STARTUP_PREFLIGHT_ERRORS:
logging.error("Startup preflight error: %s", _err)
css = """
.scrollable-content{
max-height: 420px;
overflow-y: scroll; /* always show scrollbar */
overflow-x: hidden;
padding-right: 8px;
padding-bottom: 14px; /* <— add this */
scrollbar-gutter: stable; /* prevent layout shift as it fills */
/* Firefox */
scrollbar-width: auto;
scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15);
}
/* WebKit/Chromium (Chrome/Edge/Safari) */
.scrollable-content::-webkit-scrollbar{ width: 10px; }
.scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; }
.scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); }
/* (Optional) make both scroll panes taller so they fill more of the column */
.pane-left .scrollable-content,
.pane-right .scrollable-content {
max-height: 610px; /* was 420px; tweak to taste */
}
.lego-tags .gr-checkboxgroup,
.lego-tags .wrap {
display: flex !important;
flex-wrap: wrap !important;
gap: 10px !important;
}
.lego-tags label {
margin: 0 !important;
padding: 0 !important;
position: relative !important;
}
/* Hide native checkbox visuals completely */
.lego-tags input[type="checkbox"] {
appearance: none !important;
-webkit-appearance: none !important;
-moz-appearance: none !important;
position: absolute !important;
width: 1px !important;
height: 1px !important;
opacity: 0 !important;
pointer-events: none !important;
display: none !important;
}
/* Brick button skin (works for both +span and ~span structures) */
.lego-tags input[type="checkbox"] + span,
.lego-tags input[type="checkbox"] ~ span {
--on-bg1: #ffd166;
--on-bg2: #f39c4a;
--on-border: #b86e21;
--on-text: #2e1706;
position: relative !important;
display: inline-flex !important;
align-items: center !important;
min-height: 40px !important;
padding: 10px 15px 9px 22px !important;
border: 1px solid #9aa6b8 !important;
border-radius: 10px !important;
background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important;
color: #364254 !important;
font-size: 0.97rem !important;
font-weight: 800 !important;
line-height: 1.15 !important;
cursor: pointer !important;
user-select: none !important;
letter-spacing: 0.01em !important;
box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important;
transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important;
}
.lego-tags input[type="checkbox"] + span::before,
.lego-tags input[type="checkbox"] ~ span::before {
content: "" !important;
position: absolute !important;
top: 5px !important;
left: 8px !important;
width: 8px !important;
height: 8px !important;
border-radius: 50% !important;
background: rgba(255,255,255,0.58) !important;
box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important;
pointer-events: none !important;
}
/* Unselected cue: show "+" on the left. */
.lego-tags input[type="checkbox"] + span::after,
.lego-tags input[type="checkbox"] ~ span::after {
content: "+" !important;
position: absolute !important;
left: 6px !important;
top: 50% !important;
transform: translateY(-52%) !important;
font-size: 1rem !important;
font-weight: 900 !important;
color: #4b5563 !important;
opacity: 0.95 !important;
pointer-events: none !important;
}
/* Bright color cycle used only when selected */
.lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; }
.lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; }
.lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; }
.lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; }
.lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; }
.lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; }
.lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; }
.lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; }
/* Source-driven selected colors (applies when tags are preselected by the pipeline). */
.lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span {
--on-bg1: #77f0d7;
--on-bg2: #26b9a3;
--on-border: #187869;
--on-text: #062923;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span {
--on-bg1: #ffd98a;
--on-bg2: #f0a93c;
--on-border: #a66f1f;
--on-text: #382206;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span {
--on-bg1: #d8b4ff;
--on-bg2: #9a6cff;
--on-border: #6745b0;
--on-text: #24143b;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span {
--on-bg1: #a6f79a;
--on-bg2: #53c368;
--on-border: #2f8442;
--on-text: #102d17;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span {
--on-bg1: #d7dde8;
--on-bg2: #a8b3c4;
--on-border: #6f7e95;
--on-text: #1d2633;
}
/* User-selected tags (not initially selected by the pipeline). */
.lego-tags label[data-psq-preselected="0"] span {
--on-bg1: #9ec5ff;
--on-bg2: #4f86ff;
--on-border: #2f5fbf;
--on-text: #0b1f42;
}
.lego-tags label:hover span {
filter: brightness(1.02) !important;
transform: translateY(1px) !important;
}
/* ON state: brighter + visibly recessed */
.lego-tags input[type="checkbox"]:checked + span,
.lego-tags input[type="checkbox"]:checked ~ span,
.lego-tags label:has(input[type="checkbox"]:checked) span {
background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important;
color: var(--on-text) !important;
border-color: var(--on-border) !important;
filter: saturate(1.2) brightness(1.12) !important;
transform: translateY(-2px) !important;
box-shadow:
inset 0 3px 6px rgba(0,0,0,0.20),
inset 0 -1px 0 rgba(255,255,255,0.36),
0 6px 0 rgba(0,0,0,0.32) !important;
}
.lego-tags input[type="checkbox"]:checked + span::after,
.lego-tags input[type="checkbox"]:checked ~ span::after,
.lego-tags label:has(input[type="checkbox"]:checked) span::after {
content: "" !important;
}
.source-legend {
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 8px;
margin: 4px 0 10px 0;
}
.source-legend .legend-title {
font-size: 0.92rem;
font-weight: 900;
color: #334155;
margin-right: 4px;
}
.source-legend .chip {
display: inline-flex;
align-items: center;
border-radius: 10px;
border: 1px solid #6c7788;
padding: 6px 12px;
font-size: 0.85rem;
font-weight: 800;
color: #111827;
background: #f3f6fb;
}
.source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; }
.source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; }
.source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; }
.source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; }
.source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; }
.source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; }
.source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; }
.row-heading p {
margin: 8px 0 0 0 !important;
font-size: 1.18rem !important;
font-weight: 850 !important;
line-height: 1.2 !important;
}
.row-instruction {
text-align: center;
margin: 8px 0 12px 0;
}
.row-instruction p {
margin: 0 !important;
font-size: 1.02rem !important;
font-style: italic !important;
font-weight: 800 !important;
color: #1d4ed8 !important;
}
.about-docs {
margin-top: 4px;
}
.about-docs > p {
line-height: 1.42 !important;
}
.about-docs img {
max-width: 100% !important;
height: auto !important;
border: 1px solid #d2d7e0;
border-radius: 10px;
background: #ffffff;
}
.arch-diagram-wrap {
margin: 6px 0 10px 0;
}
.arch-diagram-wrap h2 {
margin: 0 0 8px 0 !important;
}
.top-instruction {
text-align: center;
margin: 2px 0 6px 0;
}
.top-instruction p {
margin: 0 !important;
font-size: 1.02rem !important;
font-style: italic !important;
font-weight: 800 !important;
color: #1d4ed8 !important;
}
.run-hint {
margin-top: 6px;
text-align: center;
}
.run-hint p {
margin: 0 !important;
font-size: 0.9rem !important;
font-style: italic !important;
color: #475569 !important;
}
.prompt-card {
background: transparent !important;
border: none !important;
box-shadow: none !important;
padding: 0 !important;
}
.suggested-prompt-box {
margin-top: 2px !important;
}
.suggested-prompt-card {
margin-top: 10px !important;
}
.psq-hidden {
display: none !important;
}
"""
client_js = """
() => {
const readTooltipMap = () => {
const el = document.querySelector("#psq-tooltip-map textarea, #psq-tooltip-map input");
if (!el) return { rows: [], tips: {} };
const raw = (el.value || "").trim();
if (!raw) return { rows: [], tips: {} };
try {
const obj = JSON.parse(raw);
if (!obj || typeof obj !== "object") return { rows: [], tips: {} };
const rows = Array.isArray(obj.rows) ? obj.rows : [];
const tips = (obj.tips && typeof obj.tips === "object") ? obj.tips : {};
return { rows, tips };
} catch (_) {
return { rows: [], tips: {} };
}
};
const applyTooltips = () => {
const payload = readTooltipMap();
const rowTags = Array.isArray(payload.rows) ? payload.rows : [];
const tipMap = (payload.tips && typeof payload.tips === "object") ? payload.tips : {};
const rowEls = document.querySelectorAll(".lego-tags");
rowEls.forEach((rowEl, rowIdx) => {
const tags = Array.isArray(rowTags[rowIdx]) ? rowTags[rowIdx] : [];
const labels = rowEl.querySelectorAll("label");
labels.forEach((label, tagIdx) => {
const span = label.querySelector("span");
const tag = (tagIdx < tags.length) ? tags[tagIdx] : "";
const tip = tag && Object.prototype.hasOwnProperty.call(tipMap, tag) ? (tipMap[tag] || "") : "";
if (tip) {
label.title = tip;
if (span) span.title = tip;
} else {
label.removeAttribute("title");
if (span) span.removeAttribute("title");
}
});
});
};
let scheduled = false;
const scheduleApply = () => {
if (scheduled) return;
scheduled = true;
requestAnimationFrame(() => {
scheduled = false;
applyTooltips();
});
};
scheduleApply();
const observer = new MutationObserver(() => scheduleApply());
observer.observe(document.body, { childList: true, subtree: true });
}
"""
def rag_pipeline_ui(
user_prompt: str,
display_top_groups: float,
display_top_tags_per_group: float,
display_rank_top_k: float,
):
logs = []
def log(s): logs.append(s)
try:
stage_timings = {}
def _record_timing(stage: str, dt_s: float):
stage_timings[stage] = float(dt_s)
def _emit_timing_summary(total_s: float):
summary_order = [
"preprocess",
"rewrite",
"structural",
"probe",
"retrieval",
"selection",
"implication_expansion",
"prompt_composition",
"group_display",
]
lines = []
for k in summary_order:
if k in stage_timings:
lines.append(f"{k}={stage_timings[k]:.2f}s")
slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a"
log("Timing Summary: " + ", ".join(lines))
log(f"Timing Slowest Stage: {slowest}")
log(f"Timing Total: {total_s:.2f}s")
def _append_timing_jsonl(total_s: float):
try:
timing_log_path.parent.mkdir(parents=True, exist_ok=True)
rec = {
"timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z",
"stages_s": stage_timings,
"total_s": float(total_s),
"config": {
"timeout_rewrite_s": stage1_rewrite_timeout_s,
"timeout_struct_s": stage1_struct_timeout_s,
"timeout_probe_s": stage1_probe_timeout_s,
"timeout_select_s": stage3_select_timeout_s,
},
}
with timing_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=True) + "\n")
log(f"Timing Log: wrote {timing_log_path}")
except Exception as e:
log(f"Timing Log: failed ({type(e).__name__}: {e})")
def _future_with_timeout(
fut,
timeout_s: float,
stage_name: str,
fallback,
*,
strict: bool = False,
):
t0 = time.perf_counter()
try:
out = fut.result(timeout=max(1.0, float(timeout_s)))
dt = time.perf_counter() - t0
log(f"{stage_name}: {dt:.2f}s")
stage_key = {
"Rewrite": "rewrite",
"Structural inference": "structural",
"Probe inference": "probe",
"Index selection": "selection",
}.get(stage_name)
if stage_key:
_record_timing(stage_key, dt)
return out
except FutureTimeoutError:
fut.cancel()
msg = f"{stage_name}: timed out after {timeout_s:.0f}s"
if strict:
raise RuntimeError(msg)
log(f"{msg}; using fallback")
return fallback
except Exception as e:
msg = f"{stage_name}: failed ({type(e).__name__}: {e})"
if strict:
raise RuntimeError(msg)
log(f"{msg}; using fallback")
return fallback
t_total0 = time.perf_counter()
log("Start: received prompt")
if STARTUP_PREFLIGHT_ERRORS:
log("Startup preflight failed:")
for e in STARTUP_PREFLIGHT_ERRORS:
log(f"- {e}")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
suggested_prompt_text="Error: startup preflight failed. Check console details.",
)
prompt_in = (user_prompt or "").strip()
if not prompt_in:
return _build_ui_payload(
console_text="Error: empty prompt",
row_defs=[],
selected_tags=[],
suggested_prompt_text='Enter a prompt and click "Run".',
)
log("Input:")
log(prompt_in)
log("")
log(
"Runtime config: "
f"retrieval_global_k={retrieval_global_k} "
f"retrieval_per_phrase_k={retrieval_per_phrase_k} "
f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} "
f"selection_mode={selection_mode} "
f"selection_chunk_size={selection_chunk_size} "
f"selection_per_phrase_k={selection_per_phrase_k} "
f"min_tag_count={_get_min_tag_count()} "
f"select_timeout_s={stage3_select_timeout_s:.0f} "
f"select_retry_timeout_s={stage3_select_retry_timeout_s:.0f} "
f"select_fast_retries={stage3_fast_retry_count}"
)
log("")
t0 = time.perf_counter()
min_tag_count = _get_min_tag_count()
user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in)
user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count)
user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags)
dt = time.perf_counter()-t0
_record_timing("preprocess", dt)
log(f"Preprocess (user tag extraction): {dt:.2f}s")
log("Heuristically extracted user tags:")
if user_tags:
log(", ".join(user_tags))
else:
log("(none)")
if removed_user_low:
log(
f"Filtered {len(removed_user_low)} low-frequency user tags "
f"(<{min_tag_count}): {', '.join(removed_user_low)}"
)
if removed_user_excluded:
log(
f"Filtered {len(removed_user_excluded)} excluded user tags: "
f"{', '.join(removed_user_excluded)}"
)
log("")
log("Step 1: LLM rewrite + structural inference + probe (concurrent)")
max_workers = 3 if enable_probe_tags else 2
ex = ThreadPoolExecutor(max_workers=max_workers)
try:
fut_rewrite = ex.submit(llm_rewrite_prompt, prompt_in, log)
fut_struct = ex.submit(llm_infer_structural_tags, prompt_in, log=log)
fut_probe = ex.submit(llm_infer_probe_tags, prompt_in, log=log) if enable_probe_tags else None
rewritten = _future_with_timeout(
fut_rewrite,
stage1_rewrite_timeout_s,
"Rewrite",
"",
strict=True,
)
structural_tags = _future_with_timeout(
fut_struct, stage1_struct_timeout_s, "Structural inference", []
)
probe_tags = (
_future_with_timeout(fut_probe, stage1_probe_timeout_s, "Probe inference", [])
if fut_probe else []
)
finally:
ex.shutdown(wait=False, cancel_futures=True)
structural_tags, removed_struct_low = _filter_min_count_tags(structural_tags, min_tag_count)
probe_tags, removed_probe_low = _filter_min_count_tags(probe_tags, min_tag_count)
structural_tags, removed_struct_excluded = _filter_excluded_recommendation_tags(structural_tags)
probe_tags, removed_probe_excluded = _filter_excluded_recommendation_tags(probe_tags)
if removed_struct_low:
log(
f"Filtered {len(removed_struct_low)} low-frequency structural tags "
f"(<{min_tag_count}): {', '.join(removed_struct_low)}"
)
if removed_probe_low:
log(
f"Filtered {len(removed_probe_low)} low-frequency probe tags "
f"(<{min_tag_count}): {', '.join(removed_probe_low)}"
)
if removed_struct_excluded:
log(
f"Filtered {len(removed_struct_excluded)} excluded structural tags: "
f"{', '.join(removed_struct_excluded)}"
)
if removed_probe_excluded:
log(
f"Filtered {len(removed_probe_excluded)} excluded probe tags: "
f"{', '.join(removed_probe_excluded)}"
)
if not rewritten:
raise RuntimeError("Rewrite: empty output")
log("Rewrite:")
log(rewritten if rewritten else "(empty)")
log("")
rewrite_for_retrieval = rewritten
if user_tags:
# keep them separate in logs, but allow them to help retrieval
rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip()
log("Step 2: Prompt Squirrel retrieval (hidden)")
try:
t0 = time.perf_counter()
retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or [])))
rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()]
retrieval_result = psq_candidates_from_rewrite_phrases(
rewrite_phrases=rewrite_phrases,
allow_nsfw_tags=allow_nsfw_tags,
context_tags=retrieval_context_tags,
global_k=max(1, retrieval_global_k),
per_phrase_k=max(1, retrieval_per_phrase_k),
per_phrase_final_k=max(1, retrieval_per_phrase_final_k),
min_tag_count=max(0, min_tag_count),
verbose=verbose_retrieval,
)
if isinstance(retrieval_result, tuple):
candidates, phrase_reports = retrieval_result
else:
candidates, phrase_reports = retrieval_result, []
candidates, removed_candidate_excluded = _filter_excluded_candidates(candidates)
if removed_candidate_excluded:
log(
f"Filtered {len(removed_candidate_excluded)} excluded retrieved tags: "
f"{', '.join(removed_candidate_excluded[:25])}"
+ (" ..." if len(removed_candidate_excluded) > 25 else "")
)
if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap:
candidates = candidates[:selection_candidate_cap]
log(f"Selection candidate cap applied: {selection_candidate_cap}")
dt = time.perf_counter()-t0
_record_timing("retrieval", dt)
log(f"Retrieval: {dt:.2f}s")
log(f"Retrieved {len(candidates)} candidate tags")
if verbose_retrieval:
log(f"Total unique candidates: {len(candidates)}")
limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit))
for report in phrase_reports:
phrase = report.get("normalized") or report.get("phrase") or ""
lookup = report.get("lookup") or ""
tfidf_vocab = report.get("tfidf_vocab")
log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}")
rows = report.get("candidates", [])
shown = rows if limit is None else rows[:limit]
for row in shown:
tag = row.get("tag")
alias_token = row.get("alias_token")
score_fasttext = row.get("score_fasttext")
score_context = row.get("score_context")
score_combined = row.get("score_combined")
count = row.get("count")
alias_part = ""
if alias_token and alias_token != tag:
alias_part = f" [alias_token={alias_token}]"
fasttext_str = (
f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext
)
if score_context is None:
context_str = "None"
else:
context_str = (
f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context
)
combined_str = (
f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined
)
log(
f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} "
f"combined={combined_str} count={count}"
)
if limit is not None and len(rows) > limit:
log(f" ... ({len(rows) - limit} more)")
except Exception as e:
log(f"Retrieval fallback: {type(e).__name__}: {e}")
candidates = []
retrieved_candidate_tags = list(
dict.fromkeys(
_norm_tag_for_lookup(c.tag)
for c in (candidates or [])
if getattr(c, "tag", None)
)
)
log("Step 3: LLM index selection (uses rewrite + structural/probe context)")
selection_query = _build_selection_query(
prompt_in=prompt_in,
rewritten=rewritten,
structural_tags=structural_tags,
probe_tags=probe_tags,
)
picked_indices = None
last_stage3_error: Exception | None = None
stage3_attempts = 1 + int(stage3_fast_retry_count)
for attempt_i in range(stage3_attempts):
timeout_s = stage3_select_timeout_s if attempt_i == 0 else stage3_select_retry_timeout_s
if attempt_i > 0:
log(
f"Index selection: fast retry {attempt_i}/{stage3_fast_retry_count} "
f"(timeout={timeout_s:.0f}s)"
)
ex = ThreadPoolExecutor(max_workers=1)
try:
fut_sel = ex.submit(
llm_select_indices,
query_text=selection_query,
candidates=candidates,
max_pick=0,
log=log,
mode=selection_mode,
chunk_size=max(1, selection_chunk_size),
per_phrase_k=max(1, selection_per_phrase_k),
)
picked_indices = _future_with_timeout(
fut_sel,
timeout_s,
"Index selection",
[],
strict=True,
)
last_stage3_error = None
break
except Exception as e:
last_stage3_error = e
log(f"Index selection attempt {attempt_i + 1} failed: {e}")
finally:
ex.shutdown(wait=False, cancel_futures=True)
if picked_indices is None:
raise RuntimeError(
f"Index selection failed after {stage3_attempts} attempt(s): {last_stage3_error}"
)
selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
selection_selected_tags, removed_stage3_low = _filter_min_count_tags(selection_selected_tags, min_tag_count)
if removed_stage3_low:
log(
f" Filtered {len(removed_stage3_low)} low-frequency stage3 tags "
f"(<{min_tag_count}): {', '.join(removed_stage3_low)}"
)
selected_tags = list(selection_selected_tags)
if structural_tags:
# Add structural tags that aren't already selected
existing = {t for t in selected_tags}
new_structural = [t for t in structural_tags if t not in existing]
selected_tags.extend(new_structural)
log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}")
else:
log(" No structural tags inferred")
if probe_tags:
existing = {t for t in selected_tags}
new_probe = [t for t in probe_tags if t not in existing]
selected_tags.extend(new_probe)
log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}")
elif enable_probe_tags:
log(" No probe tags inferred")
selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags)
if removed_excluded_direct:
log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}")
direct_selected_tags = list(dict.fromkeys(selected_tags))
log("Step 3c: Expand via tag implications")
t0 = time.perf_counter()
tag_set = set(selected_tags)
expanded, implied_only = expand_tags_via_implications(tag_set)
dt = time.perf_counter()-t0
_record_timing("implication_expansion", dt)
log(f"Implication expansion: {dt:.2f}s")
implied_selected_tags = sorted(implied_only) if implied_only else []
if implied_only:
implied_added = sorted(implied_only)
implied_added, removed_implied_low = _filter_min_count_tags(implied_added, min_tag_count)
implied_selected_tags = list(implied_added)
if implied_added:
selected_tags.extend(implied_added)
log(f" Added {len(implied_added)} implied tags: {', '.join(implied_added)}")
if removed_implied_low:
log(
f" Filtered {len(removed_implied_low)} low-frequency implied tags "
f"(<{min_tag_count}): {', '.join(removed_implied_low)}"
)
else:
log(" No additional implied tags")
selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags)
implied_selected_tags = [
t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t)
]
if removed_excluded_implied:
log(
f" Removed {len(removed_excluded_implied)} excluded tags after implications: "
f"{', '.join(removed_excluded_implied)}"
)
log("Step 4: Compose final prompt")
t0 = time.perf_counter()
final_prompt = compose_final_prompt(rewritten, selected_tags)
dt = time.perf_counter()-t0
_record_timing("prompt_composition", dt)
log(f"Prompt composition: {dt:.2f}s")
log("Step 5: Build ranked group/category display")
t0 = time.perf_counter()
seed_terms = []
seed_terms.extend(user_tags)
seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()])
seed_terms.extend(structural_tags or [])
seed_terms.extend(probe_tags or [])
seed_terms.extend(selected_tags)
seed_terms = list(dict.fromkeys(seed_terms))
active_selected_tags = list(dict.fromkeys(selected_tags))
structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t}
probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t}
implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t}
rewrite_set = {
_norm_tag_for_lookup(t)
for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()])
if t
}
selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t}
tag_selection_origins: Dict[str, str] = {}
for tag in active_selected_tags:
tag_norm = _norm_tag_for_lookup(tag)
if tag_norm in structural_set:
origin = "structural"
elif tag_norm in probe_set:
origin = "probe"
elif tag_norm in rewrite_set:
origin = "rewrite"
elif tag_norm in selection_set:
origin = "selection"
elif tag_norm in implied_set:
origin = "implied"
else:
# Unknown/fallback tags use selection color.
origin = "selection"
tag_selection_origins[tag] = origin
if tag_norm and tag_norm != tag:
tag_selection_origins[tag_norm] = origin
direct_tags_for_implied = list(
dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t)
)
direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)}
direct_tags_for_implied.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
direct_tags_for_implied_idx.get(t, 10**9),
)
)
implied_parent_map = _build_implied_parent_map(
direct_tags_ordered=direct_tags_for_implied,
implied_tags=implied_selected_tags,
)
toggle_rows = _build_toggle_rows(
seed_terms=seed_terms,
selected_tags=active_selected_tags,
retrieved_candidate_tags=retrieved_candidate_tags,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
top_groups=max(1, int(display_top_groups)),
top_tags_per_group=max(1, int(display_top_tags_per_group)),
group_rank_top_k=max(1, int(display_rank_top_k)),
)
dt = time.perf_counter()-t0
_record_timing("group_display", dt)
log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
log(
_build_display_audit_line(
toggle_rows,
active_selected_tags=active_selected_tags,
direct_selected_tags=direct_selected_tags,
implied_selected_tags=implied_selected_tags,
)
)
for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]):
tags_preview = ", ".join(row.get("tags", []))
log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}")
total_dt = time.perf_counter()-t_total0
_emit_timing_summary(total_dt)
_append_timing_jsonl(total_dt)
log("Done: final prompt ready")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=toggle_rows,
selected_tags=active_selected_tags,
)
except Exception as e:
log(f"Error: {type(e).__name__}: {e}")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
suggested_prompt_text=_format_user_facing_error(e),
)
with gr.Blocks(css=css, js=client_js) as app:
with gr.Row():
with gr.Column(scale=3, elem_classes=["prompt-col"]):
gr.Markdown(
'Describe your image under "Enter Prompt" and click "Run". '
'Prompt Squirrel will translate it into image board tags.',
elem_classes=["top-instruction"],
)
with gr.Group(elem_classes=["prompt-card"]):
image_tags = gr.Textbox(
label="Enter Prompt",
placeholder="e.g. fox, outside, detailed background, .",
lines=1,
elem_classes=["enter-prompt-box"],
)
with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]):
suggested_prompt = gr.Textbox(
label="Suggested Prompt (Read-only)",
lines=2,
interactive=False,
show_copy_button=True,
placeholder='Suggested prompt will appear here after you click "Run".',
elem_classes=["suggested-prompt-box"],
)
with gr.Column(scale=1):
_mascot_pil = _load_mascot_image()
if _mascot_pil is not None:
mascot_img = gr.Image(
value=_mascot_pil,
show_label=False,
interactive=False,
height=240,
elem_id="mascot"
)
else:
mascot_img = gr.Markdown("`(mascot image unavailable)`")
submit_button = gr.Button("Run", variant="primary", visible=False, interactive=False)
gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"])
last_run_prompt_state = gr.State("")
selected_tags_state = gr.State([])
rows_dirty_state = gr.State(False)
row_defs_state = gr.State([])
row_values_state = gr.State([])
toggle_instruction = gr.Markdown(
"Click tag buttons to add or remove tags from the suggested prompt.",
elem_classes=["row-instruction"],
visible=False,
)
row_headers: List[gr.Markdown] = []
row_checkboxes: List[gr.CheckboxGroup] = []
for _ in range(display_max_rows_default):
with gr.Row():
with gr.Column(scale=2, min_width=170):
row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"]))
with gr.Column(scale=10):
row_checkboxes.append(
gr.CheckboxGroup(
choices=[],
value=[],
visible=False,
interactive=True,
container=False,
elem_classes=["lego-tags"],
)
)
with gr.Row():
with gr.Column(scale=10):
gr.HTML(
"""
<div class="source-legend">
<span class="legend-title">Legend:</span>
<span class="chip rewrite">Rewrite phrase</span>
<span class="chip selection">General selection</span>
<span class="chip probe">Probe query</span>
<span class="chip structural">Structural query</span>
<span class="chip implied">Implied</span>
<span class="chip user">User-toggled</span>
<span class="chip unselected">Unselected</span>
</div>
"""
)
with gr.Column(scale=2, min_width=180):
rebuild_rows_button = gr.Button(
"Rebuild Rows",
variant="primary",
visible=False,
interactive=False,
)
with gr.Accordion("Display Settings", open=False):
with gr.Row():
display_top_groups = gr.Number(
value=display_top_groups_default,
precision=0,
label="Rows (Top Groups/Categories)",
minimum=1,
)
display_top_tags_per_group = gr.Number(
value=display_top_tags_per_group_default,
precision=0,
label="Top Tags Shown Per Row",
minimum=1,
)
display_rank_top_k = gr.Number(
value=display_rank_top_k_default,
precision=0,
label="Top Tags Used for Row Ranking",
minimum=1,
)
with gr.Accordion("Console", open=False):
console = gr.Textbox(
label="Console",
lines=10,
interactive=False,
placeholder="Progress logs will appear here."
)
with gr.Accordion("How Prompt Squirrel Works", open=False):
_about_md = _load_about_docs_markdown()
_about_before, _about_after, _has_arch_slot = _split_about_docs_for_diagram(_about_md)
if _has_arch_slot:
if _about_before:
gr.Markdown(
_about_before,
elem_id="about-docs",
elem_classes=["about-docs"],
)
gr.HTML(
_build_arch_diagram_html(),
elem_classes=["about-docs"],
)
if _about_after:
gr.Markdown(
_about_after,
elem_classes=["about-docs"],
)
else:
gr.Markdown(
_about_md,
elem_id="about-docs",
elem_classes=["about-docs"],
)
tooltip_map_payload = gr.Textbox(
value="{}",
visible=True,
interactive=False,
container=False,
elem_id="psq-tooltip-map",
elem_classes=["psq-hidden"],
)
run_outputs = [
console,
toggle_instruction,
tooltip_map_payload,
suggested_prompt,
selected_tags_state,
rows_dirty_state,
rebuild_rows_button,
row_defs_state,
row_values_state,
*row_headers,
*row_checkboxes,
]
image_tags.change(
_update_run_button_visibility,
inputs=[image_tags, last_run_prompt_state],
outputs=[submit_button],
queue=False,
show_progress="hidden",
)
submit_button.click(
_mark_run_triggered,
inputs=[image_tags],
outputs=[submit_button, last_run_prompt_state],
queue=False,
show_progress="hidden",
).then(
_prepare_run_ui,
inputs=[],
outputs=run_outputs,
queue=False,
show_progress="hidden",
).then(
rag_pipeline_ui,
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=run_outputs,
)
image_tags.submit(
_mark_run_triggered,
inputs=[image_tags],
outputs=[submit_button, last_run_prompt_state],
queue=False,
show_progress="hidden",
).then(
_prepare_run_ui,
inputs=[],
outputs=run_outputs,
queue=False,
show_progress="hidden",
).then(
rag_pipeline_ui,
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=run_outputs,
)
for idx, row_cb in enumerate(row_checkboxes):
row_cb.change(
fn=lambda changed_values, selected_state, rows_dirty, row_defs, row_values, i=idx: _on_toggle_row(
i,
changed_values,
selected_state,
rows_dirty,
row_defs,
row_values,
display_max_rows_default,
),
inputs=[row_cb, selected_tags_state, rows_dirty_state, row_defs_state, row_values_state],
outputs=[selected_tags_state, rows_dirty_state, rebuild_rows_button, row_values_state, suggested_prompt, *row_checkboxes],
queue=False,
show_progress="hidden",
)
rebuild_rows_button.click(
_rebuild_rows_from_selected,
inputs=[selected_tags_state, row_defs_state, row_values_state, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=[
toggle_instruction,
tooltip_map_payload,
suggested_prompt,
selected_tags_state,
rows_dirty_state,
rebuild_rows_button,
row_defs_state,
row_values_state,
*row_headers,
*row_checkboxes,
],
queue=False,
show_progress="hidden",
)
if __name__ == "__main__":
app.queue().launch(allowed_paths=[str(MASCOT_DIR), str(DOCS_DIR)])