import gradio as gr
import os
import logging
import time
import json
import csv
import re
import base64
from datetime import datetime
from functools import lru_cache
from PIL import Image
from pathlib import Path
from typing import Any, Dict, List, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
from psq_rag.llm.rewrite import llm_rewrite_prompt
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags
from psq_rag.retrieval.state import (
expand_tags_via_implications,
get_tag_type_name,
get_tag_implications,
get_tag_counts,
)
from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
APP_DIR = Path(__file__).parent
DOCS_DIR = APP_DIR / "docs"
ARCH_DIAGRAM_FILE = DOCS_DIR / "assets" / "architecture_overview.png"
ARCH_DIAGRAM_MARKER = "{{ARCHITECTURE_DIAGRAM}}"
ARCH_DIAGRAM_INSERT_BEFORE_HEADING = "## What Each Step Does"
_CORPORATE_HARDBLOCK_PATTERNS = [
# Rating-like explicitness markers.
re.compile(r"(^|_)(nsfw|explicit|questionable)(_|$)", re.IGNORECASE),
# Unambiguous sexual anatomy.
re.compile(
r"(^|_)(breast|breasts|boob|boobs|nipple|nipples|penis|vagina|pussy|clit|testicle|scrotum|genital|crotch|anus|anal|areola)(_|$)",
re.IGNORECASE,
),
# Unambiguous sexual activity.
re.compile(
r"(^|_)(sex|sexual|fucking|fuck|blowjob|handjob|masturbat|penetrat|thrust|orgasm|cum|ejaculat|creampie|nude|naked|topless|bottomless|moan|sexy)(_|$)",
re.IGNORECASE,
),
# Common kink/fetish markers.
re.compile(r"(^|_)(fetish|bdsm|bondage|dominatrix|submission|vore|inflation|watersports)(_|$)", re.IGNORECASE),
]
def _split_prompt_commas(s: str) -> List[str]:
return [p.strip() for p in (s or "").split(",") if p.strip()]
def _norm_for_dedupe(tag: str) -> str:
# your canonical form for lookup/dedupe
return _norm_tag_for_lookup(tag.lower())
def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str:
parts = _split_prompt_commas(rewritten_prompt)
parts.extend(selected_tags)
seen = set()
out = []
for p in parts:
key = _norm_for_dedupe(p)
if key in seen:
continue
seen.add(key)
out.append(p)
return ", ".join(out)
def _display_tag_text(tag: str) -> str:
return tag.replace("_", " ")
def _display_row_label(name: str) -> str:
n = (name or "").strip()
if not n:
return ""
if n == "selected_other":
return "Selected (Other)"
return n.replace("_", " ").title()
def _normalize_selection_origin(origin: str) -> str:
o = (origin or "").strip().lower()
if o in {"rewrite", "selection", "probe", "structural", "user", "candidate"}:
return o
return "selection"
def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str:
# Keep labels plain to avoid frontend text/value desynchronization.
return _display_tag_text(tag)
@lru_cache(maxsize=1)
def _load_tag_wiki_defs() -> Dict[str, str]:
p = Path("data/tag_wiki_defs.json")
if not p.exists():
return {}
try:
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
out: Dict[str, str] = {}
if isinstance(data, dict):
for k, v in data.items():
tag = _norm_tag_for_lookup(str(k))
text = " ".join(str(v or "").split())
if tag and text:
out[tag] = text
return out
except Exception:
return {}
@lru_cache(maxsize=1)
def _load_about_docs_markdown() -> str:
candidates = [
DOCS_DIR / "space_overview.md",
APP_DIR / "PROJECT_SUMMARY.md",
]
for p in candidates:
if not p.exists():
continue
try:
raw = p.read_text(encoding="utf-8")
except Exception:
continue
text = raw.strip()
if not text:
continue
# Strip YAML front matter if present.
if text.startswith("---"):
parts = text.split("---", 2)
if len(parts) >= 3:
text = parts[2].strip()
if text:
return text
return (
"Documentation is unavailable.\n\n"
"Expected file: `docs/space_overview.md`"
)
def _tooltip_text_for_tag(tag: str) -> str:
t = _norm_tag_for_lookup(tag)
parts: List[str] = []
try:
count = get_tag_counts().get(t)
except Exception:
count = None
if isinstance(count, int):
parts.append(f"Count: {count:,}")
d = _load_tag_wiki_defs().get(t, "")
if d:
parts.append(d)
return "\n".join(parts).strip()
@lru_cache(maxsize=1)
def _load_arch_diagram_data_uri() -> str:
if not ARCH_DIAGRAM_FILE.exists():
return ""
try:
raw = ARCH_DIAGRAM_FILE.read_bytes()
except Exception:
return ""
if not raw:
return ""
b64 = base64.b64encode(raw).decode("ascii")
return f"data:image/png;base64,{b64}"
def _split_about_docs_for_diagram(md: str) -> Tuple[str, str, bool]:
text = (md or "").strip()
if ARCH_DIAGRAM_MARKER in text:
before, after = text.rsplit(ARCH_DIAGRAM_MARKER, 1)
return before.strip(), after.strip(), True
# Backward compatibility if an explicit architecture heading exists in docs.
m_arch = re.search(r"(?m)^##\s+Architecture At A Glance\s*$", text)
if m_arch:
before = text[: m_arch.start()].strip()
after = text[m_arch.end() :].strip()
return before, after, True
# Preferred insertion point: inject diagram right before "What Each Step Does".
m_steps = re.search(r"(?m)^##\s+What Each Step Does\s*$", text)
if m_steps:
before = text[: m_steps.start()].strip()
after = text[m_steps.start() :].strip()
return before, after, True
return text, "", False
def _build_arch_diagram_html() -> str:
uri = _load_arch_diagram_data_uri()
if not uri:
return "
(architecture diagram unavailable)
"
return f"""
Architecture At A Glance
"""
def _selection_source_rank(origin: str) -> int:
o = _normalize_selection_origin(origin)
if o == "structural":
return 0
if o == "probe":
return 1
# Keep rewrite/user in the same priority band as general selection for row ordering.
return 2
def _build_implied_parent_map(
direct_tags_ordered: List[str],
implied_tags: List[str],
) -> Dict[str, str]:
implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t}
if not implied_set or not direct_tags_ordered:
return {}
impl = get_tag_implications()
parent_by_implied: Dict[str, str] = {}
for direct in direct_tags_ordered:
d = _norm_tag_for_lookup(direct)
if not d:
continue
queue = [d]
seen = {d}
while queue:
t = queue.pop()
for parent in impl.get(t, ()):
p = _norm_tag_for_lookup(parent)
if not p or p in seen:
continue
seen.add(p)
if p in implied_set and p not in parent_by_implied:
parent_by_implied[p] = d
queue.append(p)
return parent_by_implied
def _order_selected_tags_for_row(
*,
row_selected_tags: List[str],
selected_index: Dict[str, int],
tag_selection_origins: Dict[str, str],
implied_parent_map: Dict[str, str],
) -> List[str]:
row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t]
implied_in_row = {t for t in row_selected_norm if t in implied_parent_map}
base_tags = [t for t in row_selected_norm if t not in implied_in_row]
base_tags.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
selected_index.get(t, 10**9),
t,
)
)
children_by_parent: Dict[str, List[str]] = {}
for implied in implied_in_row:
parent = implied_parent_map.get(implied)
if parent:
children_by_parent.setdefault(parent, []).append(implied)
for parent, children in children_by_parent.items():
children.sort(key=lambda t: (selected_index.get(t, 10**9), t))
ordered: List[str] = []
emitted: Set[str] = set()
for tag in base_tags:
if tag in emitted:
continue
ordered.append(tag)
emitted.add(tag)
for child in children_by_parent.get(tag, []):
if child not in emitted:
ordered.append(child)
emitted.add(child)
remaining_implied = [t for t in row_selected_norm if t not in emitted]
remaining_implied.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")),
selected_index.get(implied_parent_map.get(t, ""), 10**9),
selected_index.get(t, 10**9),
t,
)
)
for t in remaining_implied:
if t not in emitted:
ordered.append(t)
emitted.add(t)
return ordered
def _escape_prompt_tag(tag: str) -> str:
return (
tag.replace("_", " ")
.replace("(", "\\(")
.replace(")", "\\)")
)
def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
out: List[str] = []
seen: Set[str] = set()
for row in row_defs:
for tag in row.get("tags", []):
if tag in selected and tag not in seen:
out.append(tag)
seen.add(tag)
return out
def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
selected = {t for t in (selected_tags or []) if t}
ordered = _ordered_selected_for_prompt(selected, row_defs or [])
return ", ".join(_escape_prompt_tag(t) for t in ordered)
def _is_artist_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
# Keep a resilient fallback for malformed/missing tag typing metadata.
return get_tag_type_name(t) == "artist" or t.startswith("by_")
@lru_cache(maxsize=1)
def _load_excluded_recommendation_tags() -> Set[str]:
out: Set[str] = set()
# Existing category-registry driven exclusions.
csv_path = Path("data/category_registry.csv")
if not csv_path.exists():
csv_path = Path("data/analysis/category_registry.csv")
if csv_path.exists():
try:
with csv_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = _norm_tag_for_lookup(str(row.get("tag") or ""))
if not tag:
continue
status = str(row.get("category_status") or "").strip().lower()
if status == "excluded":
out.add(tag)
except Exception:
pass
# Corporate-safety exclusions (editable runtime list).
corp_path = Path("data/corporate_excluded_tags.csv")
if corp_path.exists():
try:
with corp_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = _norm_tag_for_lookup(str(row.get("tag") or ""))
if not tag:
continue
enabled_raw = str(row.get("enabled", "1")).strip().lower()
enabled = enabled_raw not in {"0", "false", "no", "off"}
if enabled:
out.add(tag)
except Exception:
pass
return out
def _is_hardblocked_corporate_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
return any(rx.search(t) for rx in _CORPORATE_HARDBLOCK_PATTERNS)
def _is_excluded_recommendation_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
if _is_hardblocked_corporate_tag(t):
return True
return t in _load_excluded_recommendation_tags()
def _get_min_tag_count() -> int:
try:
return max(0, int(os.environ.get("PSQ_MIN_TAG_COUNT", "100")))
except Exception:
return 100
def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str], List[str]]:
if min_count <= 0:
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
tag_counts = get_tag_counts()
keep: List[str] = []
removed: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
c = int(tag_counts.get(t, 0) or 0)
if c < min_count:
removed.append(t)
continue
if t in seen:
continue
seen.add(t)
keep.append(t)
return keep, sorted(set(removed))
def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]:
excluded = _load_excluded_recommendation_tags()
if not excluded:
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
keep: List[str] = []
removed: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
if t in excluded:
removed.append(t)
continue
if t in seen:
continue
seen.add(t)
keep.append(t)
return keep, sorted(set(removed))
def _filter_excluded_candidates(candidates: List[Any]) -> Tuple[List[Any], List[str]]:
excluded = _load_excluded_recommendation_tags()
if not excluded:
return list(candidates or []), []
keep: List[Any] = []
removed: List[str] = []
for c in (candidates or []):
tag = _norm_tag_for_lookup(str(getattr(c, "tag", "") or ""))
if tag and tag in excluded:
removed.append(tag)
continue
keep.append(c)
return keep, sorted(set(removed))
def _dedupe_norm_tags(tags: List[str]) -> List[str]:
out: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t or t in seen:
continue
seen.add(t)
out.append(t)
return out
def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]:
out: Set[str] = set()
for row in (row_defs or []):
for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []):
out.add(t)
return out
def _collect_selected_from_state(
selected_tags_state: List[str],
row_defs: List[Dict[str, Any]],
) -> List[str]:
visible_tags = _collect_visible_tags(row_defs)
if not visible_tags:
return []
selected: List[str] = []
seen: Set[str] = set()
visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags}
for raw in (selected_tags_state or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
mapped = t if t in visible_tags else visible_by_norm.get(t)
if not mapped or mapped in seen:
continue
seen.add(mapped)
selected.append(mapped)
return selected
def _collect_selected_from_row_values(
row_defs: List[Dict[str, Any]],
row_values_state: List[List[str]],
) -> List[str]:
selected: List[str] = []
seen: Set[str] = set()
values = list(row_values_state or [])
for idx, row in enumerate(row_defs or []):
row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else [])
if not row_tags:
continue
row_tag_set = set(row_tags)
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
raw_vals = values[idx] if 0 <= idx < len(values) else []
for raw in (raw_vals or []):
if raw in row_tag_set:
if raw not in seen:
seen.add(raw)
selected.append(raw)
continue
raw_norm = _norm_tag_for_lookup(str(raw))
mapped = row_tag_by_norm.get(raw_norm)
if mapped and mapped not in seen:
seen.add(mapped)
selected.append(mapped)
return selected
def _build_toggle_rows(
*,
seed_terms: List[str],
selected_tags: List[str],
retrieved_candidate_tags: List[str],
tag_selection_origins: Dict[str, str],
implied_parent_map: Dict[str, str],
top_groups: int,
top_tags_per_group: int,
group_rank_top_k: int,
) -> List[Dict[str, Any]]:
ranked_rows = rank_groups_from_tfidf(
seed_terms=seed_terms,
top_groups=max(1, int(top_groups)),
top_tags_per_group=max(1, int(top_tags_per_group)),
group_rank_top_k=max(1, int(group_rank_top_k)),
)
groups_map = _load_enabled_groups()
selected_active = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in selected_tags
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
)
)
selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)}
row_defs: List[Dict[str, Any]] = []
enabled_group_tag_sets: Dict[str, Set[str]] = {
name: {t for t in tags if not _is_artist_tag(t)}
for name, tags in groups_map.items()
}
tags_in_any_enabled_group: Set[str] = set()
for tag_set in enabled_group_tag_sets.values():
tags_in_any_enabled_group.update(tag_set)
displayed_group_names = [r.group_name for r in ranked_rows]
displayed_group_tag_sets: Dict[str, Set[str]] = {
name: enabled_group_tag_sets.get(name, set())
for name in displayed_group_names
}
tags_in_any_displayed_group: Set[str] = set()
for tag_set in displayed_group_tag_sets.values():
tags_in_any_displayed_group.update(tag_set)
retrieved_uncategorized_ranked = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in (retrieved_candidate_tags or [])
if t
and not _is_artist_tag(t)
and not _is_excluded_recommendation_tag(t)
and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group
)
)
retrieved_other_row: Dict[str, Any] | None = None
if retrieved_uncategorized_ranked:
retrieved_uncategorized_set = set(retrieved_uncategorized_ranked)
selected_in_retrieved_other_raw = [
t for t in selected_active if t in retrieved_uncategorized_set
]
selected_in_retrieved_other = _order_selected_tags_for_row(
row_selected_tags=selected_in_retrieved_other_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
merged_retrieved_other = selected_in_retrieved_other + [
t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other
]
merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other)
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other))
merged_retrieved_other = merged_retrieved_other[:keep_n]
retrieved_other_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": t in selected_active,
}
for t in merged_retrieved_other
}
retrieved_other_row = {
"name": "other_retrieved",
"label": "Other (Retrieved)",
"tags": merged_retrieved_other,
"tag_meta": retrieved_other_meta,
}
# "Selected (Other)" should contain selected tags not already shown in any displayed row.
# Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows.
tags_in_displayed_rows = set(tags_in_any_displayed_group)
if retrieved_other_row:
tags_in_displayed_rows.update(retrieved_other_row.get("tags", []))
selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows]
selected_other = _order_selected_tags_for_row(
row_selected_tags=selected_other_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
selected_other = _dedupe_norm_tags(selected_other)
selected_other_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": True,
}
for t in selected_other
}
row_defs.append(
{
"name": "selected_other",
"label": _display_row_label("selected_other"),
"tags": selected_other,
"tag_meta": selected_other_meta,
}
)
for row in ranked_rows:
group_name = row.group_name
group_tag_set = displayed_group_tag_sets.get(group_name, set())
selected_in_group_raw = [t for t in selected_active if t in group_tag_set]
selected_in_group = _order_selected_tags_for_row(
row_selected_tags=selected_in_group_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
ranked_tags = [
_norm_tag_for_lookup(t)
for t, _ in row.tags
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
]
ranked_tags = _dedupe_norm_tags(ranked_tags)
merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
merged = _dedupe_norm_tags(merged)
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
merged = merged[:keep_n]
tag_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": t in selected_active,
}
for t in merged
}
row_defs.append(
{
"name": group_name,
"label": _display_row_label(group_name),
"tags": merged,
"tag_meta": tag_meta,
}
)
# Keep this row at the bottom so category/group rows remain contiguous.
if retrieved_other_row:
row_defs.append(retrieved_other_row)
return row_defs
def _build_display_audit_line(
row_defs: List[Dict[str, Any]],
*,
active_selected_tags: List[str],
direct_selected_tags: List[str],
implied_selected_tags: List[str],
) -> str:
active_set = {
_norm_tag_for_lookup(t)
for t in (active_selected_tags or [])
if t and not _is_artist_tag(t)
}
direct_set = {
_norm_tag_for_lookup(t)
for t in (direct_selected_tags or [])
if t and not _is_artist_tag(t)
}
implied_set = {
_norm_tag_for_lookup(t)
for t in (implied_selected_tags or [])
if t and not _is_artist_tag(t)
}
info_by_tag: Dict[str, Dict[str, Any]] = {}
for row in row_defs or []:
row_name = row.get("name", "")
row_label = row.get("label", row_name)
for tag in row.get("tags", []):
rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()})
rec["rows"].append(row_label)
if row_name == "selected_other":
rec["sources"].add("selected_other_row")
elif row_name == "other_retrieved":
rec["sources"].add("other_retrieved_row")
else:
rec["sources"].add("ranked_group_row")
if tag in active_set:
rec["sources"].add("selected_active")
if tag in direct_set:
rec["sources"].add("selected_direct")
if tag in implied_set:
rec["sources"].add("selected_implied")
payload = {
"n_tags": len(info_by_tag),
"tags": [
{
"tag": tag,
"rows": rec["rows"],
"sources": sorted(rec["sources"]),
}
for tag, rec in sorted(info_by_tag.items())
],
}
return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True)
def _build_tooltip_payload(row_defs: List[Dict[str, Any]], max_rows: int) -> str:
row_defs_ui = (row_defs or [])[: max(0, int(max_rows))]
tips: Dict[str, str] = {}
rows: List[List[str]] = []
for row in row_defs_ui:
tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else [])
rows.append(tags)
for t in tags:
if t not in tips:
tips[t] = _tooltip_text_for_tag(t)
return json.dumps({"rows": rows, "tips": tips}, ensure_ascii=True)
def _build_row_component_updates(
row_defs: List[Dict[str, Any]],
selected_tags: List[str],
max_rows: int,
):
selected = {t for t in (selected_tags or []) if t}
row_defs_ui = (row_defs or [])[: max(0, int(max_rows))]
row_values_state: List[List[str]] = []
header_updates = []
checkbox_updates = []
for idx in range(max_rows):
if idx < len(row_defs_ui):
row = row_defs_ui[idx]
tags = _dedupe_norm_tags(row.get("tags", []))
values = [t for t in tags if t in selected]
row_values_state.append(values)
visible = bool(tags)
header_updates.append(gr.update(value=row.get("label", ""), visible=visible))
tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {}
choices = []
for t in tags:
meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {}
origin = _normalize_selection_origin(str(meta.get("origin", "selection")))
preselected = bool(meta.get("preselected", False))
choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t))
checkbox_updates.append(
gr.update(
choices=choices,
value=values,
visible=visible,
)
)
else:
header_updates.append(gr.update(value="", visible=False))
checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui)
return prompt_text, row_values_state, header_updates, checkbox_updates
def _on_toggle_row(
row_idx: int,
changed_values: List[str],
selected_tags_state: List[str],
rows_dirty_state: bool,
row_defs_state: List[Dict[str, Any]],
row_values_state: List[List[str]],
max_rows: int,
):
row_defs = row_defs_state or []
row_defs_ui = row_defs[: max(0, int(max_rows))]
prev_values = list(row_values_state or [])
selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui)
selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values)
# Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback.
selected: Set[str] = set(selected_from_rows or selected_from_state)
row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {}
row_tags = _dedupe_norm_tags(row.get("tags", []))
row_label = str(row.get("label", ""))
row_tag_set = set(row_tags)
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
# Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants,
# and occasional single-string payloads from frontend events.
if changed_values is None:
changed_iter: List[Any] = []
elif isinstance(changed_values, str):
changed_iter = [changed_values]
elif isinstance(changed_values, (list, tuple, set)):
changed_iter = list(changed_values)
else:
changed_iter = [changed_values]
# Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants.
new_set: Set[str] = set()
for raw in changed_iter:
if raw in row_tag_set:
new_set.add(raw)
continue
raw_norm = _norm_tag_for_lookup(str(raw))
mapped = row_tag_by_norm.get(raw_norm)
if mapped:
new_set.add(mapped)
prev_row_selected = {t for t in row_tags if t in selected}
# Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically.
if new_set == prev_row_selected:
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
checkbox_updates = [gr.skip() for _ in range(max_rows)]
return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates]
selected.difference_update(row_tag_set)
selected.update(new_set)
toggled_tags = prev_row_selected ^ new_set
new_row_values_state: List[List[str]] = []
affected_rows: Set[int] = {row_idx}
for idx, row_item in enumerate(row_defs_ui):
tags = _dedupe_norm_tags(row_item.get("tags", []))
values = [t for t in tags if t in selected]
new_row_values_state.append(values)
if toggled_tags and any(t in toggled_tags for t in tags):
affected_rows.add(idx)
checkbox_updates = []
for idx in range(max_rows):
if idx >= len(row_defs_ui):
checkbox_updates.append(gr.skip())
continue
if idx in affected_rows:
checkbox_updates.append(gr.update(value=new_row_values_state[idx]))
else:
checkbox_updates.append(gr.skip())
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
return [
sorted(selected),
True,
gr.update(visible=True, interactive=True),
new_row_values_state,
prompt_text,
*checkbox_updates,
]
def _build_ui_payload(
*,
console_text: str,
row_defs: List[Dict[str, Any]],
selected_tags: List[str],
suggested_prompt_text: str | None = None,
):
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
row_defs=row_defs,
selected_tags=selected_tags,
max_rows=display_max_rows_default,
)
if suggested_prompt_text is not None:
prompt_text = str(suggested_prompt_text)
selected_ui: List[str] = []
selected_ui_seen: Set[str] = set()
for vals in row_values_state:
for t in vals:
if t in selected_ui_seen:
continue
selected_ui_seen.add(t)
selected_ui.append(t)
tooltip_payload = _build_tooltip_payload(row_defs, display_max_rows_default)
return [
console_text,
gr.update(visible=bool(row_defs)),
tooltip_payload,
prompt_text,
selected_ui,
False,
gr.update(visible=False, interactive=False),
row_defs,
row_values_state,
*header_updates,
*checkbox_updates,
]
def _format_user_facing_error(exc: Exception) -> str:
msg = str(exc or "").strip()
msg_l = msg.lower()
if "rewrite: empty output" in msg_l:
return (
"Could not rewrite that prompt. Try simpler, neutral wording and remove sensitive phrasing, "
"then click Run again."
)
if "openrouter_api_key" in msg_l:
return "Service configuration is missing. Please contact the app owner."
if "timed out" in msg_l:
return "The model request timed out. Please try again with a shorter or simpler prompt."
if "index selection failed" in msg_l:
return "Tag selection failed for this request. Please try again."
if "startup preflight failed" in msg_l:
return "App startup checks failed. Please contact the app owner."
return "Something went wrong while processing the prompt. Please try again."
def _prepare_run_ui() -> List[Any]:
header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)]
checkbox_updates = [
gr.update(choices=[], value=[], visible=False)
for _ in range(display_max_rows_default)
]
return [
"Running...",
gr.skip(),
"{}",
"Running... usually completes in about 20 seconds.",
[],
False,
gr.update(visible=False, interactive=False),
[],
[],
*header_updates,
*checkbox_updates,
]
def _update_run_button_visibility(prompt_text: str, last_run_prompt: str):
curr = (prompt_text or "").strip()
last = (last_run_prompt or "").strip()
can_run = bool(curr) and curr != last
return gr.update(visible=can_run, interactive=can_run)
def _mark_run_triggered(prompt_text: str):
curr = (prompt_text or "").strip()
return gr.update(visible=False, interactive=False), curr
def _rebuild_rows_from_selected(
selected_tags_state: List[str],
row_defs_state: List[Dict[str, Any]],
row_values_state: List[List[str]],
display_top_groups: float,
display_top_tags_per_group: float,
display_rank_top_k: float,
):
existing_rows = row_defs_state or []
existing_values = list(row_values_state or [])
selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows)
selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values)
# Rebuild source-of-truth is current row checkbox values; fall back only when unavailable.
selected_seed = selected_from_rows if existing_values else selected_from_state
selected_active = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in selected_seed
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
)
)
retrieved_candidate_tags: List[str] = []
tag_selection_origins: Dict[str, str] = {}
for row in existing_rows:
row_tags = row.get("tags", []) if isinstance(row, dict) else []
row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {}
if not isinstance(row_meta, dict):
row_meta = {}
for t in row_tags:
tn = _norm_tag_for_lookup(t)
if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn):
continue
retrieved_candidate_tags.append(tn)
if tn not in tag_selection_origins:
meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {}
tag_selection_origins[tn] = _normalize_selection_origin(str(meta.get("origin", "selection")))
for t in selected_active:
tag_selection_origins.setdefault(t, "user")
retrieved_candidate_tags.append(t)
implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"]
implied_set = set(implied_selected_tags)
direct_selected_tags = [t for t in selected_active if t not in implied_set]
direct_idx = {t: i for i, t in enumerate(direct_selected_tags)}
direct_selected_tags.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
direct_idx.get(t, 10**9),
)
)
implied_parent_map = _build_implied_parent_map(
direct_tags_ordered=direct_selected_tags,
implied_tags=implied_selected_tags,
)
toggle_rows = _build_toggle_rows(
seed_terms=list(selected_active),
selected_tags=selected_active,
retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)),
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
top_groups=max(1, int(display_top_groups)),
top_tags_per_group=max(1, int(display_top_tags_per_group)),
group_rank_top_k=max(1, int(display_rank_top_k)),
)
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
row_defs=toggle_rows,
selected_tags=selected_active,
max_rows=display_max_rows_default,
)
tooltip_payload = _build_tooltip_payload(toggle_rows, display_max_rows_default)
return [
gr.update(visible=bool(toggle_rows)),
tooltip_payload,
prompt_text,
sorted(selected_active),
False,
gr.update(visible=False, interactive=False),
toggle_rows,
row_values_state,
*header_updates,
*checkbox_updates,
]
def _build_selection_query(
prompt_in: str,
rewritten: str,
structural_tags: List[str],
probe_tags: List[str],
) -> str:
lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"]
if rewritten and rewritten.strip():
lines.append(f"REWRITE PHRASES: {rewritten.strip()}")
hint_tags = []
if structural_tags:
hint_tags.extend(structural_tags)
if probe_tags:
hint_tags.extend(probe_tags)
if hint_tags:
# Keep hints as context only; selection still must choose by candidate indices.
lines.append(
"INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags)))
)
return "\n".join(lines)
# Set up logging
# Minimal prod logging: warnings+ to stderr, no file by default
import os, logging
LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.WARNING),
format="%(asctime)s %(levelname)s:%(message)s",
handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces
)
# Quiet down common noisy libs (optional)
for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"):
logging.getLogger(_name).setLevel(logging.ERROR)
# Turn off Gradio analytics phone-home to avoid those background thread errors (optional)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "0"
MASCOT_DIR = Path(__file__).parent / "mascotimages"
MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png"
def _load_mascot_image():
"""Load mascot image if available; return None when missing/unreadable."""
if not MASCOT_FILE.exists():
logging.warning("Mascot image missing: %s", MASCOT_FILE)
return None
try:
return Image.open(MASCOT_FILE).convert("RGBA")
except Exception as e:
logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e)
return None
try:
from gradio_client import utils as _gc_utils
_orig_get_type = _gc_utils.get_type
_orig_j2p = _gc_utils._json_schema_to_python_type
_orig_pub = _gc_utils.json_schema_to_python_type
def _get_type_safe(schema):
# Sometimes schema is a bare True/False (JSON Schema boolean form)
if not isinstance(schema, dict):
return "any"
return _orig_get_type(schema)
def _j2p_safe(schema, defs=None):
# Accept non-dict schemas (True/False/None) and treat as "any"
if not isinstance(schema, dict):
return "any"
return _orig_j2p(schema, defs or schema.get("$defs"))
def _pub_safe(schema):
# Public wrapper used by Gradio; keep it resilient too
if not isinstance(schema, dict):
return "any"
return _j2p_safe(schema, schema.get("$defs"))
_gc_utils.get_type = _get_type_safe
_gc_utils._json_schema_to_python_type = _j2p_safe
_gc_utils.json_schema_to_python_type = _pub_safe
except Exception as e:
print("gradio_client hotfix not applied:", e)
# -------------------------------------------------------------------------------
allow_nsfw_tags = False
def _is_production_runtime() -> bool:
"""Best-effort detection for deployed runtime (HF Spaces or explicit env)."""
if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}:
return True
if os.environ.get("SPACE_ID"):
return True
if os.environ.get("HF_SPACE_ID"):
return True
if os.environ.get("SYSTEM") == "spaces":
return True
return False
verbose_retrieval_default = "0" if _is_production_runtime() else "1"
verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"}
verbose_retrieval_all = False
verbose_retrieval_limit = 20
enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"}
display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7"))
display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7"))
display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip()
selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60"))
selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2"))
selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0"))
stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45"))
stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45"))
stage1_probe_timeout_s = float(os.environ.get("PSQ_TIMEOUT_PROBE_S", "45"))
stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "50"))
stage3_select_retry_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_RETRY_S", "20"))
stage3_fast_retry_count = max(0, int(os.environ.get("PSQ_STAGE3_FAST_RETRY_COUNT", "1")))
timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl"))
def _startup_preflight_errors() -> List[str]:
errs: List[str] = []
if not os.getenv("OPENROUTER_API_KEY"):
errs.append("OPENROUTER_API_KEY is missing. Set it in Space Secrets or environment variables.")
return errs
STARTUP_PREFLIGHT_ERRORS = _startup_preflight_errors()
if STARTUP_PREFLIGHT_ERRORS:
for _err in STARTUP_PREFLIGHT_ERRORS:
logging.error("Startup preflight error: %s", _err)
css = """
.scrollable-content{
max-height: 420px;
overflow-y: scroll; /* always show scrollbar */
overflow-x: hidden;
padding-right: 8px;
padding-bottom: 14px; /* <— add this */
scrollbar-gutter: stable; /* prevent layout shift as it fills */
/* Firefox */
scrollbar-width: auto;
scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15);
}
/* WebKit/Chromium (Chrome/Edge/Safari) */
.scrollable-content::-webkit-scrollbar{ width: 10px; }
.scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; }
.scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); }
/* (Optional) make both scroll panes taller so they fill more of the column */
.pane-left .scrollable-content,
.pane-right .scrollable-content {
max-height: 610px; /* was 420px; tweak to taste */
}
.lego-tags .gr-checkboxgroup,
.lego-tags .wrap {
display: flex !important;
flex-wrap: wrap !important;
gap: 10px !important;
}
.lego-tags label {
margin: 0 !important;
padding: 0 !important;
position: relative !important;
}
/* Hide native checkbox visuals completely */
.lego-tags input[type="checkbox"] {
appearance: none !important;
-webkit-appearance: none !important;
-moz-appearance: none !important;
position: absolute !important;
width: 1px !important;
height: 1px !important;
opacity: 0 !important;
pointer-events: none !important;
display: none !important;
}
/* Brick button skin (works for both +span and ~span structures) */
.lego-tags input[type="checkbox"] + span,
.lego-tags input[type="checkbox"] ~ span {
--on-bg1: #ffd166;
--on-bg2: #f39c4a;
--on-border: #b86e21;
--on-text: #2e1706;
position: relative !important;
display: inline-flex !important;
align-items: center !important;
min-height: 40px !important;
padding: 10px 15px 9px 22px !important;
border: 1px solid #9aa6b8 !important;
border-radius: 10px !important;
background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important;
color: #364254 !important;
font-size: 0.97rem !important;
font-weight: 800 !important;
line-height: 1.15 !important;
cursor: pointer !important;
user-select: none !important;
letter-spacing: 0.01em !important;
box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important;
transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important;
}
.lego-tags input[type="checkbox"] + span::before,
.lego-tags input[type="checkbox"] ~ span::before {
content: "" !important;
position: absolute !important;
top: 5px !important;
left: 8px !important;
width: 8px !important;
height: 8px !important;
border-radius: 50% !important;
background: rgba(255,255,255,0.58) !important;
box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important;
pointer-events: none !important;
}
/* Unselected cue: show "+" on the left. */
.lego-tags input[type="checkbox"] + span::after,
.lego-tags input[type="checkbox"] ~ span::after {
content: "+" !important;
position: absolute !important;
left: 6px !important;
top: 50% !important;
transform: translateY(-52%) !important;
font-size: 1rem !important;
font-weight: 900 !important;
color: #4b5563 !important;
opacity: 0.95 !important;
pointer-events: none !important;
}
/* Bright color cycle used only when selected */
.lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; }
.lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; }
.lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; }
.lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; }
.lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; }
.lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; }
.lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; }
.lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; }
/* Source-driven selected colors (applies when tags are preselected by the pipeline). */
.lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span {
--on-bg1: #77f0d7;
--on-bg2: #26b9a3;
--on-border: #187869;
--on-text: #062923;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span {
--on-bg1: #ffd98a;
--on-bg2: #f0a93c;
--on-border: #a66f1f;
--on-text: #382206;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span {
--on-bg1: #d8b4ff;
--on-bg2: #9a6cff;
--on-border: #6745b0;
--on-text: #24143b;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span {
--on-bg1: #a6f79a;
--on-bg2: #53c368;
--on-border: #2f8442;
--on-text: #102d17;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span {
--on-bg1: #d7dde8;
--on-bg2: #a8b3c4;
--on-border: #6f7e95;
--on-text: #1d2633;
}
/* User-selected tags (not initially selected by the pipeline). */
.lego-tags label[data-psq-preselected="0"] span {
--on-bg1: #9ec5ff;
--on-bg2: #4f86ff;
--on-border: #2f5fbf;
--on-text: #0b1f42;
}
.lego-tags label:hover span {
filter: brightness(1.02) !important;
transform: translateY(1px) !important;
}
/* ON state: brighter + visibly recessed */
.lego-tags input[type="checkbox"]:checked + span,
.lego-tags input[type="checkbox"]:checked ~ span,
.lego-tags label:has(input[type="checkbox"]:checked) span {
background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important;
color: var(--on-text) !important;
border-color: var(--on-border) !important;
filter: saturate(1.2) brightness(1.12) !important;
transform: translateY(-2px) !important;
box-shadow:
inset 0 3px 6px rgba(0,0,0,0.20),
inset 0 -1px 0 rgba(255,255,255,0.36),
0 6px 0 rgba(0,0,0,0.32) !important;
}
.lego-tags input[type="checkbox"]:checked + span::after,
.lego-tags input[type="checkbox"]:checked ~ span::after,
.lego-tags label:has(input[type="checkbox"]:checked) span::after {
content: "" !important;
}
.source-legend {
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 8px;
margin: 4px 0 10px 0;
}
.source-legend .legend-title {
font-size: 0.92rem;
font-weight: 900;
color: #334155;
margin-right: 4px;
}
.source-legend .chip {
display: inline-flex;
align-items: center;
border-radius: 10px;
border: 1px solid #6c7788;
padding: 6px 12px;
font-size: 0.85rem;
font-weight: 800;
color: #111827;
background: #f3f6fb;
}
.source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; }
.source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; }
.source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; }
.source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; }
.source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; }
.source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; }
.source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; }
.row-heading p {
margin: 8px 0 0 0 !important;
font-size: 1.18rem !important;
font-weight: 850 !important;
line-height: 1.2 !important;
}
.row-instruction {
text-align: center;
margin: 8px 0 12px 0;
}
.row-instruction p {
margin: 0 !important;
font-size: 1.02rem !important;
font-style: italic !important;
font-weight: 800 !important;
color: #1d4ed8 !important;
}
.about-docs {
margin-top: 4px;
}
.about-docs > p {
line-height: 1.42 !important;
}
.about-docs img {
max-width: 100% !important;
height: auto !important;
border: 1px solid #d2d7e0;
border-radius: 10px;
background: #ffffff;
}
.arch-diagram-wrap {
margin: 6px 0 10px 0;
}
.arch-diagram-wrap h2 {
margin: 0 0 8px 0 !important;
}
.top-instruction {
text-align: center;
margin: 2px 0 6px 0;
}
.top-instruction p {
margin: 0 !important;
font-size: 1.02rem !important;
font-style: italic !important;
font-weight: 800 !important;
color: #1d4ed8 !important;
}
.run-hint {
margin-top: 6px;
text-align: center;
}
.run-hint p {
margin: 0 !important;
font-size: 0.9rem !important;
font-style: italic !important;
color: #475569 !important;
}
.prompt-card {
background: transparent !important;
border: none !important;
box-shadow: none !important;
padding: 0 !important;
}
.suggested-prompt-box {
margin-top: 2px !important;
}
.suggested-prompt-card {
margin-top: 10px !important;
}
.psq-hidden {
display: none !important;
}
"""
client_js = """
() => {
const readTooltipMap = () => {
const el = document.querySelector("#psq-tooltip-map textarea, #psq-tooltip-map input");
if (!el) return { rows: [], tips: {} };
const raw = (el.value || "").trim();
if (!raw) return { rows: [], tips: {} };
try {
const obj = JSON.parse(raw);
if (!obj || typeof obj !== "object") return { rows: [], tips: {} };
const rows = Array.isArray(obj.rows) ? obj.rows : [];
const tips = (obj.tips && typeof obj.tips === "object") ? obj.tips : {};
return { rows, tips };
} catch (_) {
return { rows: [], tips: {} };
}
};
const applyTooltips = () => {
const payload = readTooltipMap();
const rowTags = Array.isArray(payload.rows) ? payload.rows : [];
const tipMap = (payload.tips && typeof payload.tips === "object") ? payload.tips : {};
const rowEls = document.querySelectorAll(".lego-tags");
rowEls.forEach((rowEl, rowIdx) => {
const tags = Array.isArray(rowTags[rowIdx]) ? rowTags[rowIdx] : [];
const labels = rowEl.querySelectorAll("label");
labels.forEach((label, tagIdx) => {
const span = label.querySelector("span");
const tag = (tagIdx < tags.length) ? tags[tagIdx] : "";
const tip = tag && Object.prototype.hasOwnProperty.call(tipMap, tag) ? (tipMap[tag] || "") : "";
if (tip) {
label.title = tip;
if (span) span.title = tip;
} else {
label.removeAttribute("title");
if (span) span.removeAttribute("title");
}
});
});
};
let scheduled = false;
const scheduleApply = () => {
if (scheduled) return;
scheduled = true;
requestAnimationFrame(() => {
scheduled = false;
applyTooltips();
});
};
scheduleApply();
const observer = new MutationObserver(() => scheduleApply());
observer.observe(document.body, { childList: true, subtree: true });
}
"""
def rag_pipeline_ui(
user_prompt: str,
display_top_groups: float,
display_top_tags_per_group: float,
display_rank_top_k: float,
):
logs = []
def log(s): logs.append(s)
try:
stage_timings = {}
def _record_timing(stage: str, dt_s: float):
stage_timings[stage] = float(dt_s)
def _emit_timing_summary(total_s: float):
summary_order = [
"preprocess",
"rewrite",
"structural",
"probe",
"retrieval",
"selection",
"implication_expansion",
"prompt_composition",
"group_display",
]
lines = []
for k in summary_order:
if k in stage_timings:
lines.append(f"{k}={stage_timings[k]:.2f}s")
slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a"
log("Timing Summary: " + ", ".join(lines))
log(f"Timing Slowest Stage: {slowest}")
log(f"Timing Total: {total_s:.2f}s")
def _append_timing_jsonl(total_s: float):
try:
timing_log_path.parent.mkdir(parents=True, exist_ok=True)
rec = {
"timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z",
"stages_s": stage_timings,
"total_s": float(total_s),
"config": {
"timeout_rewrite_s": stage1_rewrite_timeout_s,
"timeout_struct_s": stage1_struct_timeout_s,
"timeout_probe_s": stage1_probe_timeout_s,
"timeout_select_s": stage3_select_timeout_s,
},
}
with timing_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=True) + "\n")
log(f"Timing Log: wrote {timing_log_path}")
except Exception as e:
log(f"Timing Log: failed ({type(e).__name__}: {e})")
def _future_with_timeout(
fut,
timeout_s: float,
stage_name: str,
fallback,
*,
strict: bool = False,
):
t0 = time.perf_counter()
try:
out = fut.result(timeout=max(1.0, float(timeout_s)))
dt = time.perf_counter() - t0
log(f"{stage_name}: {dt:.2f}s")
stage_key = {
"Rewrite": "rewrite",
"Structural inference": "structural",
"Probe inference": "probe",
"Index selection": "selection",
}.get(stage_name)
if stage_key:
_record_timing(stage_key, dt)
return out
except FutureTimeoutError:
fut.cancel()
msg = f"{stage_name}: timed out after {timeout_s:.0f}s"
if strict:
raise RuntimeError(msg)
log(f"{msg}; using fallback")
return fallback
except Exception as e:
msg = f"{stage_name}: failed ({type(e).__name__}: {e})"
if strict:
raise RuntimeError(msg)
log(f"{msg}; using fallback")
return fallback
t_total0 = time.perf_counter()
log("Start: received prompt")
if STARTUP_PREFLIGHT_ERRORS:
log("Startup preflight failed:")
for e in STARTUP_PREFLIGHT_ERRORS:
log(f"- {e}")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
suggested_prompt_text="Error: startup preflight failed. Check console details.",
)
prompt_in = (user_prompt or "").strip()
if not prompt_in:
return _build_ui_payload(
console_text="Error: empty prompt",
row_defs=[],
selected_tags=[],
suggested_prompt_text='Enter a prompt and click "Run".',
)
log("Input:")
log(prompt_in)
log("")
log(
"Runtime config: "
f"retrieval_global_k={retrieval_global_k} "
f"retrieval_per_phrase_k={retrieval_per_phrase_k} "
f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} "
f"selection_mode={selection_mode} "
f"selection_chunk_size={selection_chunk_size} "
f"selection_per_phrase_k={selection_per_phrase_k} "
f"min_tag_count={_get_min_tag_count()} "
f"select_timeout_s={stage3_select_timeout_s:.0f} "
f"select_retry_timeout_s={stage3_select_retry_timeout_s:.0f} "
f"select_fast_retries={stage3_fast_retry_count}"
)
log("")
t0 = time.perf_counter()
min_tag_count = _get_min_tag_count()
user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in)
user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count)
user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags)
dt = time.perf_counter()-t0
_record_timing("preprocess", dt)
log(f"Preprocess (user tag extraction): {dt:.2f}s")
log("Heuristically extracted user tags:")
if user_tags:
log(", ".join(user_tags))
else:
log("(none)")
if removed_user_low:
log(
f"Filtered {len(removed_user_low)} low-frequency user tags "
f"(<{min_tag_count}): {', '.join(removed_user_low)}"
)
if removed_user_excluded:
log(
f"Filtered {len(removed_user_excluded)} excluded user tags: "
f"{', '.join(removed_user_excluded)}"
)
log("")
log("Step 1: LLM rewrite + structural inference + probe (concurrent)")
max_workers = 3 if enable_probe_tags else 2
ex = ThreadPoolExecutor(max_workers=max_workers)
try:
fut_rewrite = ex.submit(llm_rewrite_prompt, prompt_in, log)
fut_struct = ex.submit(llm_infer_structural_tags, prompt_in, log=log)
fut_probe = ex.submit(llm_infer_probe_tags, prompt_in, log=log) if enable_probe_tags else None
rewritten = _future_with_timeout(
fut_rewrite,
stage1_rewrite_timeout_s,
"Rewrite",
"",
strict=True,
)
structural_tags = _future_with_timeout(
fut_struct, stage1_struct_timeout_s, "Structural inference", []
)
probe_tags = (
_future_with_timeout(fut_probe, stage1_probe_timeout_s, "Probe inference", [])
if fut_probe else []
)
finally:
ex.shutdown(wait=False, cancel_futures=True)
structural_tags, removed_struct_low = _filter_min_count_tags(structural_tags, min_tag_count)
probe_tags, removed_probe_low = _filter_min_count_tags(probe_tags, min_tag_count)
structural_tags, removed_struct_excluded = _filter_excluded_recommendation_tags(structural_tags)
probe_tags, removed_probe_excluded = _filter_excluded_recommendation_tags(probe_tags)
if removed_struct_low:
log(
f"Filtered {len(removed_struct_low)} low-frequency structural tags "
f"(<{min_tag_count}): {', '.join(removed_struct_low)}"
)
if removed_probe_low:
log(
f"Filtered {len(removed_probe_low)} low-frequency probe tags "
f"(<{min_tag_count}): {', '.join(removed_probe_low)}"
)
if removed_struct_excluded:
log(
f"Filtered {len(removed_struct_excluded)} excluded structural tags: "
f"{', '.join(removed_struct_excluded)}"
)
if removed_probe_excluded:
log(
f"Filtered {len(removed_probe_excluded)} excluded probe tags: "
f"{', '.join(removed_probe_excluded)}"
)
if not rewritten:
raise RuntimeError("Rewrite: empty output")
log("Rewrite:")
log(rewritten if rewritten else "(empty)")
log("")
rewrite_for_retrieval = rewritten
if user_tags:
# keep them separate in logs, but allow them to help retrieval
rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip()
log("Step 2: Prompt Squirrel retrieval (hidden)")
try:
t0 = time.perf_counter()
retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or [])))
rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()]
retrieval_result = psq_candidates_from_rewrite_phrases(
rewrite_phrases=rewrite_phrases,
allow_nsfw_tags=allow_nsfw_tags,
context_tags=retrieval_context_tags,
global_k=max(1, retrieval_global_k),
per_phrase_k=max(1, retrieval_per_phrase_k),
per_phrase_final_k=max(1, retrieval_per_phrase_final_k),
min_tag_count=max(0, min_tag_count),
verbose=verbose_retrieval,
)
if isinstance(retrieval_result, tuple):
candidates, phrase_reports = retrieval_result
else:
candidates, phrase_reports = retrieval_result, []
candidates, removed_candidate_excluded = _filter_excluded_candidates(candidates)
if removed_candidate_excluded:
log(
f"Filtered {len(removed_candidate_excluded)} excluded retrieved tags: "
f"{', '.join(removed_candidate_excluded[:25])}"
+ (" ..." if len(removed_candidate_excluded) > 25 else "")
)
if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap:
candidates = candidates[:selection_candidate_cap]
log(f"Selection candidate cap applied: {selection_candidate_cap}")
dt = time.perf_counter()-t0
_record_timing("retrieval", dt)
log(f"Retrieval: {dt:.2f}s")
log(f"Retrieved {len(candidates)} candidate tags")
if verbose_retrieval:
log(f"Total unique candidates: {len(candidates)}")
limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit))
for report in phrase_reports:
phrase = report.get("normalized") or report.get("phrase") or ""
lookup = report.get("lookup") or ""
tfidf_vocab = report.get("tfidf_vocab")
log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}")
rows = report.get("candidates", [])
shown = rows if limit is None else rows[:limit]
for row in shown:
tag = row.get("tag")
alias_token = row.get("alias_token")
score_fasttext = row.get("score_fasttext")
score_context = row.get("score_context")
score_combined = row.get("score_combined")
count = row.get("count")
alias_part = ""
if alias_token and alias_token != tag:
alias_part = f" [alias_token={alias_token}]"
fasttext_str = (
f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext
)
if score_context is None:
context_str = "None"
else:
context_str = (
f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context
)
combined_str = (
f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined
)
log(
f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} "
f"combined={combined_str} count={count}"
)
if limit is not None and len(rows) > limit:
log(f" ... ({len(rows) - limit} more)")
except Exception as e:
log(f"Retrieval fallback: {type(e).__name__}: {e}")
candidates = []
retrieved_candidate_tags = list(
dict.fromkeys(
_norm_tag_for_lookup(c.tag)
for c in (candidates or [])
if getattr(c, "tag", None)
)
)
log("Step 3: LLM index selection (uses rewrite + structural/probe context)")
selection_query = _build_selection_query(
prompt_in=prompt_in,
rewritten=rewritten,
structural_tags=structural_tags,
probe_tags=probe_tags,
)
picked_indices = None
last_stage3_error: Exception | None = None
stage3_attempts = 1 + int(stage3_fast_retry_count)
for attempt_i in range(stage3_attempts):
timeout_s = stage3_select_timeout_s if attempt_i == 0 else stage3_select_retry_timeout_s
if attempt_i > 0:
log(
f"Index selection: fast retry {attempt_i}/{stage3_fast_retry_count} "
f"(timeout={timeout_s:.0f}s)"
)
ex = ThreadPoolExecutor(max_workers=1)
try:
fut_sel = ex.submit(
llm_select_indices,
query_text=selection_query,
candidates=candidates,
max_pick=0,
log=log,
mode=selection_mode,
chunk_size=max(1, selection_chunk_size),
per_phrase_k=max(1, selection_per_phrase_k),
)
picked_indices = _future_with_timeout(
fut_sel,
timeout_s,
"Index selection",
[],
strict=True,
)
last_stage3_error = None
break
except Exception as e:
last_stage3_error = e
log(f"Index selection attempt {attempt_i + 1} failed: {e}")
finally:
ex.shutdown(wait=False, cancel_futures=True)
if picked_indices is None:
raise RuntimeError(
f"Index selection failed after {stage3_attempts} attempt(s): {last_stage3_error}"
)
selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
selection_selected_tags, removed_stage3_low = _filter_min_count_tags(selection_selected_tags, min_tag_count)
if removed_stage3_low:
log(
f" Filtered {len(removed_stage3_low)} low-frequency stage3 tags "
f"(<{min_tag_count}): {', '.join(removed_stage3_low)}"
)
selected_tags = list(selection_selected_tags)
if structural_tags:
# Add structural tags that aren't already selected
existing = {t for t in selected_tags}
new_structural = [t for t in structural_tags if t not in existing]
selected_tags.extend(new_structural)
log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}")
else:
log(" No structural tags inferred")
if probe_tags:
existing = {t for t in selected_tags}
new_probe = [t for t in probe_tags if t not in existing]
selected_tags.extend(new_probe)
log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}")
elif enable_probe_tags:
log(" No probe tags inferred")
selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags)
if removed_excluded_direct:
log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}")
direct_selected_tags = list(dict.fromkeys(selected_tags))
log("Step 3c: Expand via tag implications")
t0 = time.perf_counter()
tag_set = set(selected_tags)
expanded, implied_only = expand_tags_via_implications(tag_set)
dt = time.perf_counter()-t0
_record_timing("implication_expansion", dt)
log(f"Implication expansion: {dt:.2f}s")
implied_selected_tags = sorted(implied_only) if implied_only else []
if implied_only:
implied_added = sorted(implied_only)
implied_added, removed_implied_low = _filter_min_count_tags(implied_added, min_tag_count)
implied_selected_tags = list(implied_added)
if implied_added:
selected_tags.extend(implied_added)
log(f" Added {len(implied_added)} implied tags: {', '.join(implied_added)}")
if removed_implied_low:
log(
f" Filtered {len(removed_implied_low)} low-frequency implied tags "
f"(<{min_tag_count}): {', '.join(removed_implied_low)}"
)
else:
log(" No additional implied tags")
selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags)
implied_selected_tags = [
t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t)
]
if removed_excluded_implied:
log(
f" Removed {len(removed_excluded_implied)} excluded tags after implications: "
f"{', '.join(removed_excluded_implied)}"
)
log("Step 4: Compose final prompt")
t0 = time.perf_counter()
final_prompt = compose_final_prompt(rewritten, selected_tags)
dt = time.perf_counter()-t0
_record_timing("prompt_composition", dt)
log(f"Prompt composition: {dt:.2f}s")
log("Step 5: Build ranked group/category display")
t0 = time.perf_counter()
seed_terms = []
seed_terms.extend(user_tags)
seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()])
seed_terms.extend(structural_tags or [])
seed_terms.extend(probe_tags or [])
seed_terms.extend(selected_tags)
seed_terms = list(dict.fromkeys(seed_terms))
active_selected_tags = list(dict.fromkeys(selected_tags))
structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t}
probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t}
implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t}
rewrite_set = {
_norm_tag_for_lookup(t)
for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()])
if t
}
selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t}
tag_selection_origins: Dict[str, str] = {}
for tag in active_selected_tags:
tag_norm = _norm_tag_for_lookup(tag)
if tag_norm in structural_set:
origin = "structural"
elif tag_norm in probe_set:
origin = "probe"
elif tag_norm in rewrite_set:
origin = "rewrite"
elif tag_norm in selection_set:
origin = "selection"
elif tag_norm in implied_set:
origin = "implied"
else:
# Unknown/fallback tags use selection color.
origin = "selection"
tag_selection_origins[tag] = origin
if tag_norm and tag_norm != tag:
tag_selection_origins[tag_norm] = origin
direct_tags_for_implied = list(
dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t)
)
direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)}
direct_tags_for_implied.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
direct_tags_for_implied_idx.get(t, 10**9),
)
)
implied_parent_map = _build_implied_parent_map(
direct_tags_ordered=direct_tags_for_implied,
implied_tags=implied_selected_tags,
)
toggle_rows = _build_toggle_rows(
seed_terms=seed_terms,
selected_tags=active_selected_tags,
retrieved_candidate_tags=retrieved_candidate_tags,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
top_groups=max(1, int(display_top_groups)),
top_tags_per_group=max(1, int(display_top_tags_per_group)),
group_rank_top_k=max(1, int(display_rank_top_k)),
)
dt = time.perf_counter()-t0
_record_timing("group_display", dt)
log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
log(
_build_display_audit_line(
toggle_rows,
active_selected_tags=active_selected_tags,
direct_selected_tags=direct_selected_tags,
implied_selected_tags=implied_selected_tags,
)
)
for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]):
tags_preview = ", ".join(row.get("tags", []))
log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}")
total_dt = time.perf_counter()-t_total0
_emit_timing_summary(total_dt)
_append_timing_jsonl(total_dt)
log("Done: final prompt ready")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=toggle_rows,
selected_tags=active_selected_tags,
)
except Exception as e:
log(f"Error: {type(e).__name__}: {e}")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
suggested_prompt_text=_format_user_facing_error(e),
)
with gr.Blocks(css=css, js=client_js) as app:
with gr.Row():
with gr.Column(scale=3, elem_classes=["prompt-col"]):
gr.Markdown(
'Describe your image under "Enter Prompt" and click "Run". '
'Prompt Squirrel will translate it into image board tags.',
elem_classes=["top-instruction"],
)
with gr.Group(elem_classes=["prompt-card"]):
image_tags = gr.Textbox(
label="Enter Prompt",
placeholder="e.g. fox, outside, detailed background, .",
lines=1,
elem_classes=["enter-prompt-box"],
)
with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]):
suggested_prompt = gr.Textbox(
label="Suggested Prompt (Read-only)",
lines=2,
interactive=False,
show_copy_button=True,
placeholder='Suggested prompt will appear here after you click "Run".',
elem_classes=["suggested-prompt-box"],
)
with gr.Column(scale=1):
_mascot_pil = _load_mascot_image()
if _mascot_pil is not None:
mascot_img = gr.Image(
value=_mascot_pil,
show_label=False,
interactive=False,
height=240,
elem_id="mascot"
)
else:
mascot_img = gr.Markdown("`(mascot image unavailable)`")
submit_button = gr.Button("Run", variant="primary", visible=False, interactive=False)
gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"])
last_run_prompt_state = gr.State("")
selected_tags_state = gr.State([])
rows_dirty_state = gr.State(False)
row_defs_state = gr.State([])
row_values_state = gr.State([])
toggle_instruction = gr.Markdown(
"Click tag buttons to add or remove tags from the suggested prompt.",
elem_classes=["row-instruction"],
visible=False,
)
row_headers: List[gr.Markdown] = []
row_checkboxes: List[gr.CheckboxGroup] = []
for _ in range(display_max_rows_default):
with gr.Row():
with gr.Column(scale=2, min_width=170):
row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"]))
with gr.Column(scale=10):
row_checkboxes.append(
gr.CheckboxGroup(
choices=[],
value=[],
visible=False,
interactive=True,
container=False,
elem_classes=["lego-tags"],
)
)
with gr.Row():
with gr.Column(scale=10):
gr.HTML(
"""
Legend:
Rewrite phrase
General selection
Probe query
Structural query
Implied
User-toggled
Unselected
"""
)
with gr.Column(scale=2, min_width=180):
rebuild_rows_button = gr.Button(
"Rebuild Rows",
variant="primary",
visible=False,
interactive=False,
)
with gr.Accordion("Display Settings", open=False):
with gr.Row():
display_top_groups = gr.Number(
value=display_top_groups_default,
precision=0,
label="Rows (Top Groups/Categories)",
minimum=1,
)
display_top_tags_per_group = gr.Number(
value=display_top_tags_per_group_default,
precision=0,
label="Top Tags Shown Per Row",
minimum=1,
)
display_rank_top_k = gr.Number(
value=display_rank_top_k_default,
precision=0,
label="Top Tags Used for Row Ranking",
minimum=1,
)
with gr.Accordion("Console", open=False):
console = gr.Textbox(
label="Console",
lines=10,
interactive=False,
placeholder="Progress logs will appear here."
)
with gr.Accordion("How Prompt Squirrel Works", open=False):
_about_md = _load_about_docs_markdown()
_about_before, _about_after, _has_arch_slot = _split_about_docs_for_diagram(_about_md)
if _has_arch_slot:
if _about_before:
gr.Markdown(
_about_before,
elem_id="about-docs",
elem_classes=["about-docs"],
)
gr.HTML(
_build_arch_diagram_html(),
elem_classes=["about-docs"],
)
if _about_after:
gr.Markdown(
_about_after,
elem_classes=["about-docs"],
)
else:
gr.Markdown(
_about_md,
elem_id="about-docs",
elem_classes=["about-docs"],
)
tooltip_map_payload = gr.Textbox(
value="{}",
visible=True,
interactive=False,
container=False,
elem_id="psq-tooltip-map",
elem_classes=["psq-hidden"],
)
run_outputs = [
console,
toggle_instruction,
tooltip_map_payload,
suggested_prompt,
selected_tags_state,
rows_dirty_state,
rebuild_rows_button,
row_defs_state,
row_values_state,
*row_headers,
*row_checkboxes,
]
image_tags.change(
_update_run_button_visibility,
inputs=[image_tags, last_run_prompt_state],
outputs=[submit_button],
queue=False,
show_progress="hidden",
)
submit_button.click(
_mark_run_triggered,
inputs=[image_tags],
outputs=[submit_button, last_run_prompt_state],
queue=False,
show_progress="hidden",
).then(
_prepare_run_ui,
inputs=[],
outputs=run_outputs,
queue=False,
show_progress="hidden",
).then(
rag_pipeline_ui,
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=run_outputs,
)
image_tags.submit(
_mark_run_triggered,
inputs=[image_tags],
outputs=[submit_button, last_run_prompt_state],
queue=False,
show_progress="hidden",
).then(
_prepare_run_ui,
inputs=[],
outputs=run_outputs,
queue=False,
show_progress="hidden",
).then(
rag_pipeline_ui,
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=run_outputs,
)
for idx, row_cb in enumerate(row_checkboxes):
row_cb.change(
fn=lambda changed_values, selected_state, rows_dirty, row_defs, row_values, i=idx: _on_toggle_row(
i,
changed_values,
selected_state,
rows_dirty,
row_defs,
row_values,
display_max_rows_default,
),
inputs=[row_cb, selected_tags_state, rows_dirty_state, row_defs_state, row_values_state],
outputs=[selected_tags_state, rows_dirty_state, rebuild_rows_button, row_values_state, suggested_prompt, *row_checkboxes],
queue=False,
show_progress="hidden",
)
rebuild_rows_button.click(
_rebuild_rows_from_selected,
inputs=[selected_tags_state, row_defs_state, row_values_state, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=[
toggle_instruction,
tooltip_map_payload,
suggested_prompt,
selected_tags_state,
rows_dirty_state,
rebuild_rows_button,
row_defs_state,
row_values_state,
*row_headers,
*row_checkboxes,
],
queue=False,
show_progress="hidden",
)
if __name__ == "__main__":
app.queue().launch(allowed_paths=[str(MASCOT_DIR), str(DOCS_DIR)])