import gradio as gr import os import logging import time import json import csv import re import base64 from datetime import datetime from functools import lru_cache from PIL import Image from pathlib import Path from typing import Any, Dict, List, Set, Tuple from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words from psq_rag.llm.rewrite import llm_rewrite_prompt from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags from psq_rag.retrieval.state import ( expand_tags_via_implications, get_tag_type_name, get_tag_implications, get_tag_counts, ) from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups APP_DIR = Path(__file__).parent DOCS_DIR = APP_DIR / "docs" ARCH_DIAGRAM_FILE = DOCS_DIR / "assets" / "architecture_overview.png" ARCH_DIAGRAM_MARKER = "{{ARCHITECTURE_DIAGRAM}}" ARCH_DIAGRAM_INSERT_BEFORE_HEADING = "## What Each Step Does" _CORPORATE_HARDBLOCK_PATTERNS = [ # Rating-like explicitness markers. re.compile(r"(^|_)(nsfw|explicit|questionable)(_|$)", re.IGNORECASE), # Unambiguous sexual anatomy. re.compile( r"(^|_)(breast|breasts|boob|boobs|nipple|nipples|penis|vagina|pussy|clit|testicle|scrotum|genital|crotch|anus|anal|areola)(_|$)", re.IGNORECASE, ), # Unambiguous sexual activity. re.compile( r"(^|_)(sex|sexual|fucking|fuck|blowjob|handjob|masturbat|penetrat|thrust|orgasm|cum|ejaculat|creampie|nude|naked|topless|bottomless|moan|sexy)(_|$)", re.IGNORECASE, ), # Common kink/fetish markers. re.compile(r"(^|_)(fetish|bdsm|bondage|dominatrix|submission|vore|inflation|watersports)(_|$)", re.IGNORECASE), ] def _split_prompt_commas(s: str) -> List[str]: return [p.strip() for p in (s or "").split(",") if p.strip()] def _norm_for_dedupe(tag: str) -> str: # your canonical form for lookup/dedupe return _norm_tag_for_lookup(tag.lower()) def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str: parts = _split_prompt_commas(rewritten_prompt) parts.extend(selected_tags) seen = set() out = [] for p in parts: key = _norm_for_dedupe(p) if key in seen: continue seen.add(key) out.append(p) return ", ".join(out) def _display_tag_text(tag: str) -> str: return tag.replace("_", " ") def _display_row_label(name: str) -> str: n = (name or "").strip() if not n: return "" if n == "selected_other": return "Selected (Other)" return n.replace("_", " ").title() def _normalize_selection_origin(origin: str) -> str: o = (origin or "").strip().lower() if o in {"rewrite", "selection", "probe", "structural", "user", "candidate"}: return o return "selection" def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str: # Keep labels plain to avoid frontend text/value desynchronization. return _display_tag_text(tag) @lru_cache(maxsize=1) def _load_tag_wiki_defs() -> Dict[str, str]: p = Path("data/tag_wiki_defs.json") if not p.exists(): return {} try: with p.open("r", encoding="utf-8") as f: data = json.load(f) out: Dict[str, str] = {} if isinstance(data, dict): for k, v in data.items(): tag = _norm_tag_for_lookup(str(k)) text = " ".join(str(v or "").split()) if tag and text: out[tag] = text return out except Exception: return {} @lru_cache(maxsize=1) def _load_about_docs_markdown() -> str: candidates = [ DOCS_DIR / "space_overview.md", APP_DIR / "PROJECT_SUMMARY.md", ] for p in candidates: if not p.exists(): continue try: raw = p.read_text(encoding="utf-8") except Exception: continue text = raw.strip() if not text: continue # Strip YAML front matter if present. if text.startswith("---"): parts = text.split("---", 2) if len(parts) >= 3: text = parts[2].strip() if text: return text return ( "Documentation is unavailable.\n\n" "Expected file: `docs/space_overview.md`" ) def _tooltip_text_for_tag(tag: str) -> str: t = _norm_tag_for_lookup(tag) parts: List[str] = [] try: count = get_tag_counts().get(t) except Exception: count = None if isinstance(count, int): parts.append(f"Count: {count:,}") d = _load_tag_wiki_defs().get(t, "") if d: parts.append(d) return "\n".join(parts).strip() @lru_cache(maxsize=1) def _load_arch_diagram_data_uri() -> str: if not ARCH_DIAGRAM_FILE.exists(): return "" try: raw = ARCH_DIAGRAM_FILE.read_bytes() except Exception: return "" if not raw: return "" b64 = base64.b64encode(raw).decode("ascii") return f"data:image/png;base64,{b64}" def _split_about_docs_for_diagram(md: str) -> Tuple[str, str, bool]: text = (md or "").strip() if ARCH_DIAGRAM_MARKER in text: before, after = text.rsplit(ARCH_DIAGRAM_MARKER, 1) return before.strip(), after.strip(), True # Backward compatibility if an explicit architecture heading exists in docs. m_arch = re.search(r"(?m)^##\s+Architecture At A Glance\s*$", text) if m_arch: before = text[: m_arch.start()].strip() after = text[m_arch.end() :].strip() return before, after, True # Preferred insertion point: inject diagram right before "What Each Step Does". m_steps = re.search(r"(?m)^##\s+What Each Step Does\s*$", text) if m_steps: before = text[: m_steps.start()].strip() after = text[m_steps.start() :].strip() return before, after, True return text, "", False def _build_arch_diagram_html() -> str: uri = _load_arch_diagram_data_uri() if not uri: return "

(architecture diagram unavailable)

" return f"""

Architecture At A Glance

Architecture diagram
""" def _selection_source_rank(origin: str) -> int: o = _normalize_selection_origin(origin) if o == "structural": return 0 if o == "probe": return 1 # Keep rewrite/user in the same priority band as general selection for row ordering. return 2 def _build_implied_parent_map( direct_tags_ordered: List[str], implied_tags: List[str], ) -> Dict[str, str]: implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t} if not implied_set or not direct_tags_ordered: return {} impl = get_tag_implications() parent_by_implied: Dict[str, str] = {} for direct in direct_tags_ordered: d = _norm_tag_for_lookup(direct) if not d: continue queue = [d] seen = {d} while queue: t = queue.pop() for parent in impl.get(t, ()): p = _norm_tag_for_lookup(parent) if not p or p in seen: continue seen.add(p) if p in implied_set and p not in parent_by_implied: parent_by_implied[p] = d queue.append(p) return parent_by_implied def _order_selected_tags_for_row( *, row_selected_tags: List[str], selected_index: Dict[str, int], tag_selection_origins: Dict[str, str], implied_parent_map: Dict[str, str], ) -> List[str]: row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t] implied_in_row = {t for t in row_selected_norm if t in implied_parent_map} base_tags = [t for t in row_selected_norm if t not in implied_in_row] base_tags.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), selected_index.get(t, 10**9), t, ) ) children_by_parent: Dict[str, List[str]] = {} for implied in implied_in_row: parent = implied_parent_map.get(implied) if parent: children_by_parent.setdefault(parent, []).append(implied) for parent, children in children_by_parent.items(): children.sort(key=lambda t: (selected_index.get(t, 10**9), t)) ordered: List[str] = [] emitted: Set[str] = set() for tag in base_tags: if tag in emitted: continue ordered.append(tag) emitted.add(tag) for child in children_by_parent.get(tag, []): if child not in emitted: ordered.append(child) emitted.add(child) remaining_implied = [t for t in row_selected_norm if t not in emitted] remaining_implied.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")), selected_index.get(implied_parent_map.get(t, ""), 10**9), selected_index.get(t, 10**9), t, ) ) for t in remaining_implied: if t not in emitted: ordered.append(t) emitted.add(t) return ordered def _escape_prompt_tag(tag: str) -> str: return ( tag.replace("_", " ") .replace("(", "\\(") .replace(")", "\\)") ) def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]: out: List[str] = [] seen: Set[str] = set() for row in row_defs: for tag in row.get("tags", []): if tag in selected and tag not in seen: out.append(tag) seen.add(tag) return out def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str: selected = {t for t in (selected_tags or []) if t} ordered = _ordered_selected_for_prompt(selected, row_defs or []) return ", ".join(_escape_prompt_tag(t) for t in ordered) def _is_artist_tag(tag: str) -> bool: t = _norm_tag_for_lookup(str(tag)) if not t: return False # Keep a resilient fallback for malformed/missing tag typing metadata. return get_tag_type_name(t) == "artist" or t.startswith("by_") @lru_cache(maxsize=1) def _load_excluded_recommendation_tags() -> Set[str]: out: Set[str] = set() # Existing category-registry driven exclusions. csv_path = Path("data/category_registry.csv") if not csv_path.exists(): csv_path = Path("data/analysis/category_registry.csv") if csv_path.exists(): try: with csv_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: tag = _norm_tag_for_lookup(str(row.get("tag") or "")) if not tag: continue status = str(row.get("category_status") or "").strip().lower() if status == "excluded": out.add(tag) except Exception: pass # Corporate-safety exclusions (editable runtime list). corp_path = Path("data/corporate_excluded_tags.csv") if corp_path.exists(): try: with corp_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: tag = _norm_tag_for_lookup(str(row.get("tag") or "")) if not tag: continue enabled_raw = str(row.get("enabled", "1")).strip().lower() enabled = enabled_raw not in {"0", "false", "no", "off"} if enabled: out.add(tag) except Exception: pass return out def _is_hardblocked_corporate_tag(tag: str) -> bool: t = _norm_tag_for_lookup(str(tag)) if not t: return False return any(rx.search(t) for rx in _CORPORATE_HARDBLOCK_PATTERNS) def _is_excluded_recommendation_tag(tag: str) -> bool: t = _norm_tag_for_lookup(str(tag)) if not t: return False if _is_hardblocked_corporate_tag(t): return True return t in _load_excluded_recommendation_tags() def _get_min_tag_count() -> int: try: return max(0, int(os.environ.get("PSQ_MIN_TAG_COUNT", "100"))) except Exception: return 100 def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str], List[str]]: if min_count <= 0: return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] tag_counts = get_tag_counts() keep: List[str] = [] removed: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue c = int(tag_counts.get(t, 0) or 0) if c < min_count: removed.append(t) continue if t in seen: continue seen.add(t) keep.append(t) return keep, sorted(set(removed)) def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]: excluded = _load_excluded_recommendation_tags() if not excluded: return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] keep: List[str] = [] removed: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue if t in excluded: removed.append(t) continue if t in seen: continue seen.add(t) keep.append(t) return keep, sorted(set(removed)) def _filter_excluded_candidates(candidates: List[Any]) -> Tuple[List[Any], List[str]]: excluded = _load_excluded_recommendation_tags() if not excluded: return list(candidates or []), [] keep: List[Any] = [] removed: List[str] = [] for c in (candidates or []): tag = _norm_tag_for_lookup(str(getattr(c, "tag", "") or "")) if tag and tag in excluded: removed.append(tag) continue keep.append(c) return keep, sorted(set(removed)) def _dedupe_norm_tags(tags: List[str]) -> List[str]: out: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t or t in seen: continue seen.add(t) out.append(t) return out def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]: out: Set[str] = set() for row in (row_defs or []): for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []): out.add(t) return out def _collect_selected_from_state( selected_tags_state: List[str], row_defs: List[Dict[str, Any]], ) -> List[str]: visible_tags = _collect_visible_tags(row_defs) if not visible_tags: return [] selected: List[str] = [] seen: Set[str] = set() visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags} for raw in (selected_tags_state or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue mapped = t if t in visible_tags else visible_by_norm.get(t) if not mapped or mapped in seen: continue seen.add(mapped) selected.append(mapped) return selected def _collect_selected_from_row_values( row_defs: List[Dict[str, Any]], row_values_state: List[List[str]], ) -> List[str]: selected: List[str] = [] seen: Set[str] = set() values = list(row_values_state or []) for idx, row in enumerate(row_defs or []): row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []) if not row_tags: continue row_tag_set = set(row_tags) row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} raw_vals = values[idx] if 0 <= idx < len(values) else [] for raw in (raw_vals or []): if raw in row_tag_set: if raw not in seen: seen.add(raw) selected.append(raw) continue raw_norm = _norm_tag_for_lookup(str(raw)) mapped = row_tag_by_norm.get(raw_norm) if mapped and mapped not in seen: seen.add(mapped) selected.append(mapped) return selected def _build_toggle_rows( *, seed_terms: List[str], selected_tags: List[str], retrieved_candidate_tags: List[str], tag_selection_origins: Dict[str, str], implied_parent_map: Dict[str, str], top_groups: int, top_tags_per_group: int, group_rank_top_k: int, ) -> List[Dict[str, Any]]: ranked_rows = rank_groups_from_tfidf( seed_terms=seed_terms, top_groups=max(1, int(top_groups)), top_tags_per_group=max(1, int(top_tags_per_group)), group_rank_top_k=max(1, int(group_rank_top_k)), ) groups_map = _load_enabled_groups() selected_active = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in selected_tags if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ) ) selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)} row_defs: List[Dict[str, Any]] = [] enabled_group_tag_sets: Dict[str, Set[str]] = { name: {t for t in tags if not _is_artist_tag(t)} for name, tags in groups_map.items() } tags_in_any_enabled_group: Set[str] = set() for tag_set in enabled_group_tag_sets.values(): tags_in_any_enabled_group.update(tag_set) displayed_group_names = [r.group_name for r in ranked_rows] displayed_group_tag_sets: Dict[str, Set[str]] = { name: enabled_group_tag_sets.get(name, set()) for name in displayed_group_names } tags_in_any_displayed_group: Set[str] = set() for tag_set in displayed_group_tag_sets.values(): tags_in_any_displayed_group.update(tag_set) retrieved_uncategorized_ranked = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in (retrieved_candidate_tags or []) if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group ) ) retrieved_other_row: Dict[str, Any] | None = None if retrieved_uncategorized_ranked: retrieved_uncategorized_set = set(retrieved_uncategorized_ranked) selected_in_retrieved_other_raw = [ t for t in selected_active if t in retrieved_uncategorized_set ] selected_in_retrieved_other = _order_selected_tags_for_row( row_selected_tags=selected_in_retrieved_other_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) merged_retrieved_other = selected_in_retrieved_other + [ t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other ] merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other) keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other)) merged_retrieved_other = merged_retrieved_other[:keep_n] retrieved_other_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": t in selected_active, } for t in merged_retrieved_other } retrieved_other_row = { "name": "other_retrieved", "label": "Other (Retrieved)", "tags": merged_retrieved_other, "tag_meta": retrieved_other_meta, } # "Selected (Other)" should contain selected tags not already shown in any displayed row. # Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows. tags_in_displayed_rows = set(tags_in_any_displayed_group) if retrieved_other_row: tags_in_displayed_rows.update(retrieved_other_row.get("tags", [])) selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows] selected_other = _order_selected_tags_for_row( row_selected_tags=selected_other_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) selected_other = _dedupe_norm_tags(selected_other) selected_other_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": True, } for t in selected_other } row_defs.append( { "name": "selected_other", "label": _display_row_label("selected_other"), "tags": selected_other, "tag_meta": selected_other_meta, } ) for row in ranked_rows: group_name = row.group_name group_tag_set = displayed_group_tag_sets.get(group_name, set()) selected_in_group_raw = [t for t in selected_active if t in group_tag_set] selected_in_group = _order_selected_tags_for_row( row_selected_tags=selected_in_group_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) ranked_tags = [ _norm_tag_for_lookup(t) for t, _ in row.tags if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ] ranked_tags = _dedupe_norm_tags(ranked_tags) merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group] merged = _dedupe_norm_tags(merged) keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group)) merged = merged[:keep_n] tag_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": t in selected_active, } for t in merged } row_defs.append( { "name": group_name, "label": _display_row_label(group_name), "tags": merged, "tag_meta": tag_meta, } ) # Keep this row at the bottom so category/group rows remain contiguous. if retrieved_other_row: row_defs.append(retrieved_other_row) return row_defs def _build_display_audit_line( row_defs: List[Dict[str, Any]], *, active_selected_tags: List[str], direct_selected_tags: List[str], implied_selected_tags: List[str], ) -> str: active_set = { _norm_tag_for_lookup(t) for t in (active_selected_tags or []) if t and not _is_artist_tag(t) } direct_set = { _norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t and not _is_artist_tag(t) } implied_set = { _norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t and not _is_artist_tag(t) } info_by_tag: Dict[str, Dict[str, Any]] = {} for row in row_defs or []: row_name = row.get("name", "") row_label = row.get("label", row_name) for tag in row.get("tags", []): rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()}) rec["rows"].append(row_label) if row_name == "selected_other": rec["sources"].add("selected_other_row") elif row_name == "other_retrieved": rec["sources"].add("other_retrieved_row") else: rec["sources"].add("ranked_group_row") if tag in active_set: rec["sources"].add("selected_active") if tag in direct_set: rec["sources"].add("selected_direct") if tag in implied_set: rec["sources"].add("selected_implied") payload = { "n_tags": len(info_by_tag), "tags": [ { "tag": tag, "rows": rec["rows"], "sources": sorted(rec["sources"]), } for tag, rec in sorted(info_by_tag.items()) ], } return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True) def _build_tooltip_payload(row_defs: List[Dict[str, Any]], max_rows: int) -> str: row_defs_ui = (row_defs or [])[: max(0, int(max_rows))] tips: Dict[str, str] = {} rows: List[List[str]] = [] for row in row_defs_ui: tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []) rows.append(tags) for t in tags: if t not in tips: tips[t] = _tooltip_text_for_tag(t) return json.dumps({"rows": rows, "tips": tips}, ensure_ascii=True) def _build_row_component_updates( row_defs: List[Dict[str, Any]], selected_tags: List[str], max_rows: int, ): selected = {t for t in (selected_tags or []) if t} row_defs_ui = (row_defs or [])[: max(0, int(max_rows))] row_values_state: List[List[str]] = [] header_updates = [] checkbox_updates = [] for idx in range(max_rows): if idx < len(row_defs_ui): row = row_defs_ui[idx] tags = _dedupe_norm_tags(row.get("tags", [])) values = [t for t in tags if t in selected] row_values_state.append(values) visible = bool(tags) header_updates.append(gr.update(value=row.get("label", ""), visible=visible)) tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {} choices = [] for t in tags: meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {} origin = _normalize_selection_origin(str(meta.get("origin", "selection"))) preselected = bool(meta.get("preselected", False)) choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t)) checkbox_updates.append( gr.update( choices=choices, value=values, visible=visible, ) ) else: header_updates.append(gr.update(value="", visible=False)) checkbox_updates.append(gr.update(choices=[], value=[], visible=False)) prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui) return prompt_text, row_values_state, header_updates, checkbox_updates def _on_toggle_row( row_idx: int, changed_values: List[str], selected_tags_state: List[str], rows_dirty_state: bool, row_defs_state: List[Dict[str, Any]], row_values_state: List[List[str]], max_rows: int, ): row_defs = row_defs_state or [] row_defs_ui = row_defs[: max(0, int(max_rows))] prev_values = list(row_values_state or []) selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui) selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values) # Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback. selected: Set[str] = set(selected_from_rows or selected_from_state) row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {} row_tags = _dedupe_norm_tags(row.get("tags", [])) row_label = str(row.get("label", "")) row_tag_set = set(row_tags) row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} # Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants, # and occasional single-string payloads from frontend events. if changed_values is None: changed_iter: List[Any] = [] elif isinstance(changed_values, str): changed_iter = [changed_values] elif isinstance(changed_values, (list, tuple, set)): changed_iter = list(changed_values) else: changed_iter = [changed_values] # Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants. new_set: Set[str] = set() for raw in changed_iter: if raw in row_tag_set: new_set.add(raw) continue raw_norm = _norm_tag_for_lookup(str(raw)) mapped = row_tag_by_norm.get(raw_norm) if mapped: new_set.add(mapped) prev_row_selected = {t for t in row_tags if t in selected} # Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically. if new_set == prev_row_selected: prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) checkbox_updates = [gr.skip() for _ in range(max_rows)] return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates] selected.difference_update(row_tag_set) selected.update(new_set) toggled_tags = prev_row_selected ^ new_set new_row_values_state: List[List[str]] = [] affected_rows: Set[int] = {row_idx} for idx, row_item in enumerate(row_defs_ui): tags = _dedupe_norm_tags(row_item.get("tags", [])) values = [t for t in tags if t in selected] new_row_values_state.append(values) if toggled_tags and any(t in toggled_tags for t in tags): affected_rows.add(idx) checkbox_updates = [] for idx in range(max_rows): if idx >= len(row_defs_ui): checkbox_updates.append(gr.skip()) continue if idx in affected_rows: checkbox_updates.append(gr.update(value=new_row_values_state[idx])) else: checkbox_updates.append(gr.skip()) prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) return [ sorted(selected), True, gr.update(visible=True, interactive=True), new_row_values_state, prompt_text, *checkbox_updates, ] def _build_ui_payload( *, console_text: str, row_defs: List[Dict[str, Any]], selected_tags: List[str], suggested_prompt_text: str | None = None, ): prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( row_defs=row_defs, selected_tags=selected_tags, max_rows=display_max_rows_default, ) if suggested_prompt_text is not None: prompt_text = str(suggested_prompt_text) selected_ui: List[str] = [] selected_ui_seen: Set[str] = set() for vals in row_values_state: for t in vals: if t in selected_ui_seen: continue selected_ui_seen.add(t) selected_ui.append(t) tooltip_payload = _build_tooltip_payload(row_defs, display_max_rows_default) return [ console_text, gr.update(visible=bool(row_defs)), tooltip_payload, prompt_text, selected_ui, False, gr.update(visible=False, interactive=False), row_defs, row_values_state, *header_updates, *checkbox_updates, ] def _format_user_facing_error(exc: Exception) -> str: msg = str(exc or "").strip() msg_l = msg.lower() if "rewrite: empty output" in msg_l: return ( "Could not rewrite that prompt. Try simpler, neutral wording and remove sensitive phrasing, " "then click Run again." ) if "openrouter_api_key" in msg_l: return "Service configuration is missing. Please contact the app owner." if "timed out" in msg_l: return "The model request timed out. Please try again with a shorter or simpler prompt." if "index selection failed" in msg_l: return "Tag selection failed for this request. Please try again." if "startup preflight failed" in msg_l: return "App startup checks failed. Please contact the app owner." return "Something went wrong while processing the prompt. Please try again." def _prepare_run_ui() -> List[Any]: header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)] checkbox_updates = [ gr.update(choices=[], value=[], visible=False) for _ in range(display_max_rows_default) ] return [ "Running...", gr.skip(), "{}", "Running... usually completes in about 20 seconds.", [], False, gr.update(visible=False, interactive=False), [], [], *header_updates, *checkbox_updates, ] def _update_run_button_visibility(prompt_text: str, last_run_prompt: str): curr = (prompt_text or "").strip() last = (last_run_prompt or "").strip() can_run = bool(curr) and curr != last return gr.update(visible=can_run, interactive=can_run) def _mark_run_triggered(prompt_text: str): curr = (prompt_text or "").strip() return gr.update(visible=False, interactive=False), curr def _rebuild_rows_from_selected( selected_tags_state: List[str], row_defs_state: List[Dict[str, Any]], row_values_state: List[List[str]], display_top_groups: float, display_top_tags_per_group: float, display_rank_top_k: float, ): existing_rows = row_defs_state or [] existing_values = list(row_values_state or []) selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows) selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values) # Rebuild source-of-truth is current row checkbox values; fall back only when unavailable. selected_seed = selected_from_rows if existing_values else selected_from_state selected_active = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in selected_seed if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ) ) retrieved_candidate_tags: List[str] = [] tag_selection_origins: Dict[str, str] = {} for row in existing_rows: row_tags = row.get("tags", []) if isinstance(row, dict) else [] row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {} if not isinstance(row_meta, dict): row_meta = {} for t in row_tags: tn = _norm_tag_for_lookup(t) if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn): continue retrieved_candidate_tags.append(tn) if tn not in tag_selection_origins: meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {} tag_selection_origins[tn] = _normalize_selection_origin(str(meta.get("origin", "selection"))) for t in selected_active: tag_selection_origins.setdefault(t, "user") retrieved_candidate_tags.append(t) implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"] implied_set = set(implied_selected_tags) direct_selected_tags = [t for t in selected_active if t not in implied_set] direct_idx = {t: i for i, t in enumerate(direct_selected_tags)} direct_selected_tags.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), direct_idx.get(t, 10**9), ) ) implied_parent_map = _build_implied_parent_map( direct_tags_ordered=direct_selected_tags, implied_tags=implied_selected_tags, ) toggle_rows = _build_toggle_rows( seed_terms=list(selected_active), selected_tags=selected_active, retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)), tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, top_groups=max(1, int(display_top_groups)), top_tags_per_group=max(1, int(display_top_tags_per_group)), group_rank_top_k=max(1, int(display_rank_top_k)), ) prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( row_defs=toggle_rows, selected_tags=selected_active, max_rows=display_max_rows_default, ) tooltip_payload = _build_tooltip_payload(toggle_rows, display_max_rows_default) return [ gr.update(visible=bool(toggle_rows)), tooltip_payload, prompt_text, sorted(selected_active), False, gr.update(visible=False, interactive=False), toggle_rows, row_values_state, *header_updates, *checkbox_updates, ] def _build_selection_query( prompt_in: str, rewritten: str, structural_tags: List[str], probe_tags: List[str], ) -> str: lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"] if rewritten and rewritten.strip(): lines.append(f"REWRITE PHRASES: {rewritten.strip()}") hint_tags = [] if structural_tags: hint_tags.extend(structural_tags) if probe_tags: hint_tags.extend(probe_tags) if hint_tags: # Keep hints as context only; selection still must choose by candidate indices. lines.append( "INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags))) ) return "\n".join(lines) # Set up logging # Minimal prod logging: warnings+ to stderr, no file by default import os, logging LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL, logging.WARNING), format="%(asctime)s %(levelname)s:%(message)s", handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces ) # Quiet down common noisy libs (optional) for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"): logging.getLogger(_name).setLevel(logging.ERROR) # Turn off Gradio analytics phone-home to avoid those background thread errors (optional) os.environ["GRADIO_ANALYTICS_ENABLED"] = "0" MASCOT_DIR = Path(__file__).parent / "mascotimages" MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png" def _load_mascot_image(): """Load mascot image if available; return None when missing/unreadable.""" if not MASCOT_FILE.exists(): logging.warning("Mascot image missing: %s", MASCOT_FILE) return None try: return Image.open(MASCOT_FILE).convert("RGBA") except Exception as e: logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e) return None try: from gradio_client import utils as _gc_utils _orig_get_type = _gc_utils.get_type _orig_j2p = _gc_utils._json_schema_to_python_type _orig_pub = _gc_utils.json_schema_to_python_type def _get_type_safe(schema): # Sometimes schema is a bare True/False (JSON Schema boolean form) if not isinstance(schema, dict): return "any" return _orig_get_type(schema) def _j2p_safe(schema, defs=None): # Accept non-dict schemas (True/False/None) and treat as "any" if not isinstance(schema, dict): return "any" return _orig_j2p(schema, defs or schema.get("$defs")) def _pub_safe(schema): # Public wrapper used by Gradio; keep it resilient too if not isinstance(schema, dict): return "any" return _j2p_safe(schema, schema.get("$defs")) _gc_utils.get_type = _get_type_safe _gc_utils._json_schema_to_python_type = _j2p_safe _gc_utils.json_schema_to_python_type = _pub_safe except Exception as e: print("gradio_client hotfix not applied:", e) # ------------------------------------------------------------------------------- allow_nsfw_tags = False def _is_production_runtime() -> bool: """Best-effort detection for deployed runtime (HF Spaces or explicit env).""" if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}: return True if os.environ.get("SPACE_ID"): return True if os.environ.get("HF_SPACE_ID"): return True if os.environ.get("SYSTEM") == "spaces": return True return False verbose_retrieval_default = "0" if _is_production_runtime() else "1" verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"} verbose_retrieval_all = False verbose_retrieval_limit = 20 enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"} display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10")) display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7")) display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7")) display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14")) retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300")) retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10")) retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1")) selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip() selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60")) selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2")) selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0")) stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45")) stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45")) stage1_probe_timeout_s = float(os.environ.get("PSQ_TIMEOUT_PROBE_S", "45")) stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "50")) stage3_select_retry_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_RETRY_S", "20")) stage3_fast_retry_count = max(0, int(os.environ.get("PSQ_STAGE3_FAST_RETRY_COUNT", "1"))) timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl")) def _startup_preflight_errors() -> List[str]: errs: List[str] = [] if not os.getenv("OPENROUTER_API_KEY"): errs.append("OPENROUTER_API_KEY is missing. Set it in Space Secrets or environment variables.") return errs STARTUP_PREFLIGHT_ERRORS = _startup_preflight_errors() if STARTUP_PREFLIGHT_ERRORS: for _err in STARTUP_PREFLIGHT_ERRORS: logging.error("Startup preflight error: %s", _err) css = """ .scrollable-content{ max-height: 420px; overflow-y: scroll; /* always show scrollbar */ overflow-x: hidden; padding-right: 8px; padding-bottom: 14px; /* <— add this */ scrollbar-gutter: stable; /* prevent layout shift as it fills */ /* Firefox */ scrollbar-width: auto; scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15); } /* WebKit/Chromium (Chrome/Edge/Safari) */ .scrollable-content::-webkit-scrollbar{ width: 10px; } .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; } .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); } /* (Optional) make both scroll panes taller so they fill more of the column */ .pane-left .scrollable-content, .pane-right .scrollable-content { max-height: 610px; /* was 420px; tweak to taste */ } .lego-tags .gr-checkboxgroup, .lego-tags .wrap { display: flex !important; flex-wrap: wrap !important; gap: 10px !important; } .lego-tags label { margin: 0 !important; padding: 0 !important; position: relative !important; } /* Hide native checkbox visuals completely */ .lego-tags input[type="checkbox"] { appearance: none !important; -webkit-appearance: none !important; -moz-appearance: none !important; position: absolute !important; width: 1px !important; height: 1px !important; opacity: 0 !important; pointer-events: none !important; display: none !important; } /* Brick button skin (works for both +span and ~span structures) */ .lego-tags input[type="checkbox"] + span, .lego-tags input[type="checkbox"] ~ span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; position: relative !important; display: inline-flex !important; align-items: center !important; min-height: 40px !important; padding: 10px 15px 9px 22px !important; border: 1px solid #9aa6b8 !important; border-radius: 10px !important; background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important; color: #364254 !important; font-size: 0.97rem !important; font-weight: 800 !important; line-height: 1.15 !important; cursor: pointer !important; user-select: none !important; letter-spacing: 0.01em !important; box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important; transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important; } .lego-tags input[type="checkbox"] + span::before, .lego-tags input[type="checkbox"] ~ span::before { content: "" !important; position: absolute !important; top: 5px !important; left: 8px !important; width: 8px !important; height: 8px !important; border-radius: 50% !important; background: rgba(255,255,255,0.58) !important; box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important; pointer-events: none !important; } /* Unselected cue: show "+" on the left. */ .lego-tags input[type="checkbox"] + span::after, .lego-tags input[type="checkbox"] ~ span::after { content: "+" !important; position: absolute !important; left: 6px !important; top: 50% !important; transform: translateY(-52%) !important; font-size: 1rem !important; font-weight: 900 !important; color: #4b5563 !important; opacity: 0.95 !important; pointer-events: none !important; } /* Bright color cycle used only when selected */ .lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; } .lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; } .lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; } .lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; } .lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; } .lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; } .lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; } .lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; } /* Source-driven selected colors (applies when tags are preselected by the pipeline). */ .lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span { --on-bg1: #77f0d7; --on-bg2: #26b9a3; --on-border: #187869; --on-text: #062923; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span { --on-bg1: #ffd98a; --on-bg2: #f0a93c; --on-border: #a66f1f; --on-text: #382206; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span { --on-bg1: #d8b4ff; --on-bg2: #9a6cff; --on-border: #6745b0; --on-text: #24143b; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span { --on-bg1: #a6f79a; --on-bg2: #53c368; --on-border: #2f8442; --on-text: #102d17; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span { --on-bg1: #d7dde8; --on-bg2: #a8b3c4; --on-border: #6f7e95; --on-text: #1d2633; } /* User-selected tags (not initially selected by the pipeline). */ .lego-tags label[data-psq-preselected="0"] span { --on-bg1: #9ec5ff; --on-bg2: #4f86ff; --on-border: #2f5fbf; --on-text: #0b1f42; } .lego-tags label:hover span { filter: brightness(1.02) !important; transform: translateY(1px) !important; } /* ON state: brighter + visibly recessed */ .lego-tags input[type="checkbox"]:checked + span, .lego-tags input[type="checkbox"]:checked ~ span, .lego-tags label:has(input[type="checkbox"]:checked) span { background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important; color: var(--on-text) !important; border-color: var(--on-border) !important; filter: saturate(1.2) brightness(1.12) !important; transform: translateY(-2px) !important; box-shadow: inset 0 3px 6px rgba(0,0,0,0.20), inset 0 -1px 0 rgba(255,255,255,0.36), 0 6px 0 rgba(0,0,0,0.32) !important; } .lego-tags input[type="checkbox"]:checked + span::after, .lego-tags input[type="checkbox"]:checked ~ span::after, .lego-tags label:has(input[type="checkbox"]:checked) span::after { content: "" !important; } .source-legend { display: flex; flex-wrap: wrap; align-items: center; gap: 8px; margin: 4px 0 10px 0; } .source-legend .legend-title { font-size: 0.92rem; font-weight: 900; color: #334155; margin-right: 4px; } .source-legend .chip { display: inline-flex; align-items: center; border-radius: 10px; border: 1px solid #6c7788; padding: 6px 12px; font-size: 0.85rem; font-weight: 800; color: #111827; background: #f3f6fb; } .source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; } .source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; } .source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; } .source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; } .source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; } .source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; } .source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; } .row-heading p { margin: 8px 0 0 0 !important; font-size: 1.18rem !important; font-weight: 850 !important; line-height: 1.2 !important; } .row-instruction { text-align: center; margin: 8px 0 12px 0; } .row-instruction p { margin: 0 !important; font-size: 1.02rem !important; font-style: italic !important; font-weight: 800 !important; color: #1d4ed8 !important; } .about-docs { margin-top: 4px; } .about-docs > p { line-height: 1.42 !important; } .about-docs img { max-width: 100% !important; height: auto !important; border: 1px solid #d2d7e0; border-radius: 10px; background: #ffffff; } .arch-diagram-wrap { margin: 6px 0 10px 0; } .arch-diagram-wrap h2 { margin: 0 0 8px 0 !important; } .top-instruction { text-align: center; margin: 2px 0 6px 0; } .top-instruction p { margin: 0 !important; font-size: 1.02rem !important; font-style: italic !important; font-weight: 800 !important; color: #1d4ed8 !important; } .run-hint { margin-top: 6px; text-align: center; } .run-hint p { margin: 0 !important; font-size: 0.9rem !important; font-style: italic !important; color: #475569 !important; } .prompt-card { background: transparent !important; border: none !important; box-shadow: none !important; padding: 0 !important; } .suggested-prompt-box { margin-top: 2px !important; } .suggested-prompt-card { margin-top: 10px !important; } .psq-hidden { display: none !important; } """ client_js = """ () => { const readTooltipMap = () => { const el = document.querySelector("#psq-tooltip-map textarea, #psq-tooltip-map input"); if (!el) return { rows: [], tips: {} }; const raw = (el.value || "").trim(); if (!raw) return { rows: [], tips: {} }; try { const obj = JSON.parse(raw); if (!obj || typeof obj !== "object") return { rows: [], tips: {} }; const rows = Array.isArray(obj.rows) ? obj.rows : []; const tips = (obj.tips && typeof obj.tips === "object") ? obj.tips : {}; return { rows, tips }; } catch (_) { return { rows: [], tips: {} }; } }; const applyTooltips = () => { const payload = readTooltipMap(); const rowTags = Array.isArray(payload.rows) ? payload.rows : []; const tipMap = (payload.tips && typeof payload.tips === "object") ? payload.tips : {}; const rowEls = document.querySelectorAll(".lego-tags"); rowEls.forEach((rowEl, rowIdx) => { const tags = Array.isArray(rowTags[rowIdx]) ? rowTags[rowIdx] : []; const labels = rowEl.querySelectorAll("label"); labels.forEach((label, tagIdx) => { const span = label.querySelector("span"); const tag = (tagIdx < tags.length) ? tags[tagIdx] : ""; const tip = tag && Object.prototype.hasOwnProperty.call(tipMap, tag) ? (tipMap[tag] || "") : ""; if (tip) { label.title = tip; if (span) span.title = tip; } else { label.removeAttribute("title"); if (span) span.removeAttribute("title"); } }); }); }; let scheduled = false; const scheduleApply = () => { if (scheduled) return; scheduled = true; requestAnimationFrame(() => { scheduled = false; applyTooltips(); }); }; scheduleApply(); const observer = new MutationObserver(() => scheduleApply()); observer.observe(document.body, { childList: true, subtree: true }); } """ def rag_pipeline_ui( user_prompt: str, display_top_groups: float, display_top_tags_per_group: float, display_rank_top_k: float, ): logs = [] def log(s): logs.append(s) try: stage_timings = {} def _record_timing(stage: str, dt_s: float): stage_timings[stage] = float(dt_s) def _emit_timing_summary(total_s: float): summary_order = [ "preprocess", "rewrite", "structural", "probe", "retrieval", "selection", "implication_expansion", "prompt_composition", "group_display", ] lines = [] for k in summary_order: if k in stage_timings: lines.append(f"{k}={stage_timings[k]:.2f}s") slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a" log("Timing Summary: " + ", ".join(lines)) log(f"Timing Slowest Stage: {slowest}") log(f"Timing Total: {total_s:.2f}s") def _append_timing_jsonl(total_s: float): try: timing_log_path.parent.mkdir(parents=True, exist_ok=True) rec = { "timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z", "stages_s": stage_timings, "total_s": float(total_s), "config": { "timeout_rewrite_s": stage1_rewrite_timeout_s, "timeout_struct_s": stage1_struct_timeout_s, "timeout_probe_s": stage1_probe_timeout_s, "timeout_select_s": stage3_select_timeout_s, }, } with timing_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(rec, ensure_ascii=True) + "\n") log(f"Timing Log: wrote {timing_log_path}") except Exception as e: log(f"Timing Log: failed ({type(e).__name__}: {e})") def _future_with_timeout( fut, timeout_s: float, stage_name: str, fallback, *, strict: bool = False, ): t0 = time.perf_counter() try: out = fut.result(timeout=max(1.0, float(timeout_s))) dt = time.perf_counter() - t0 log(f"{stage_name}: {dt:.2f}s") stage_key = { "Rewrite": "rewrite", "Structural inference": "structural", "Probe inference": "probe", "Index selection": "selection", }.get(stage_name) if stage_key: _record_timing(stage_key, dt) return out except FutureTimeoutError: fut.cancel() msg = f"{stage_name}: timed out after {timeout_s:.0f}s" if strict: raise RuntimeError(msg) log(f"{msg}; using fallback") return fallback except Exception as e: msg = f"{stage_name}: failed ({type(e).__name__}: {e})" if strict: raise RuntimeError(msg) log(f"{msg}; using fallback") return fallback t_total0 = time.perf_counter() log("Start: received prompt") if STARTUP_PREFLIGHT_ERRORS: log("Startup preflight failed:") for e in STARTUP_PREFLIGHT_ERRORS: log(f"- {e}") return _build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], suggested_prompt_text="Error: startup preflight failed. Check console details.", ) prompt_in = (user_prompt or "").strip() if not prompt_in: return _build_ui_payload( console_text="Error: empty prompt", row_defs=[], selected_tags=[], suggested_prompt_text='Enter a prompt and click "Run".', ) log("Input:") log(prompt_in) log("") log( "Runtime config: " f"retrieval_global_k={retrieval_global_k} " f"retrieval_per_phrase_k={retrieval_per_phrase_k} " f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} " f"selection_mode={selection_mode} " f"selection_chunk_size={selection_chunk_size} " f"selection_per_phrase_k={selection_per_phrase_k} " f"min_tag_count={_get_min_tag_count()} " f"select_timeout_s={stage3_select_timeout_s:.0f} " f"select_retry_timeout_s={stage3_select_retry_timeout_s:.0f} " f"select_fast_retries={stage3_fast_retry_count}" ) log("") t0 = time.perf_counter() min_tag_count = _get_min_tag_count() user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in) user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count) user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags) dt = time.perf_counter()-t0 _record_timing("preprocess", dt) log(f"Preprocess (user tag extraction): {dt:.2f}s") log("Heuristically extracted user tags:") if user_tags: log(", ".join(user_tags)) else: log("(none)") if removed_user_low: log( f"Filtered {len(removed_user_low)} low-frequency user tags " f"(<{min_tag_count}): {', '.join(removed_user_low)}" ) if removed_user_excluded: log( f"Filtered {len(removed_user_excluded)} excluded user tags: " f"{', '.join(removed_user_excluded)}" ) log("") log("Step 1: LLM rewrite + structural inference + probe (concurrent)") max_workers = 3 if enable_probe_tags else 2 ex = ThreadPoolExecutor(max_workers=max_workers) try: fut_rewrite = ex.submit(llm_rewrite_prompt, prompt_in, log) fut_struct = ex.submit(llm_infer_structural_tags, prompt_in, log=log) fut_probe = ex.submit(llm_infer_probe_tags, prompt_in, log=log) if enable_probe_tags else None rewritten = _future_with_timeout( fut_rewrite, stage1_rewrite_timeout_s, "Rewrite", "", strict=True, ) structural_tags = _future_with_timeout( fut_struct, stage1_struct_timeout_s, "Structural inference", [] ) probe_tags = ( _future_with_timeout(fut_probe, stage1_probe_timeout_s, "Probe inference", []) if fut_probe else [] ) finally: ex.shutdown(wait=False, cancel_futures=True) structural_tags, removed_struct_low = _filter_min_count_tags(structural_tags, min_tag_count) probe_tags, removed_probe_low = _filter_min_count_tags(probe_tags, min_tag_count) structural_tags, removed_struct_excluded = _filter_excluded_recommendation_tags(structural_tags) probe_tags, removed_probe_excluded = _filter_excluded_recommendation_tags(probe_tags) if removed_struct_low: log( f"Filtered {len(removed_struct_low)} low-frequency structural tags " f"(<{min_tag_count}): {', '.join(removed_struct_low)}" ) if removed_probe_low: log( f"Filtered {len(removed_probe_low)} low-frequency probe tags " f"(<{min_tag_count}): {', '.join(removed_probe_low)}" ) if removed_struct_excluded: log( f"Filtered {len(removed_struct_excluded)} excluded structural tags: " f"{', '.join(removed_struct_excluded)}" ) if removed_probe_excluded: log( f"Filtered {len(removed_probe_excluded)} excluded probe tags: " f"{', '.join(removed_probe_excluded)}" ) if not rewritten: raise RuntimeError("Rewrite: empty output") log("Rewrite:") log(rewritten if rewritten else "(empty)") log("") rewrite_for_retrieval = rewritten if user_tags: # keep them separate in logs, but allow them to help retrieval rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip() log("Step 2: Prompt Squirrel retrieval (hidden)") try: t0 = time.perf_counter() retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or []))) rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()] retrieval_result = psq_candidates_from_rewrite_phrases( rewrite_phrases=rewrite_phrases, allow_nsfw_tags=allow_nsfw_tags, context_tags=retrieval_context_tags, global_k=max(1, retrieval_global_k), per_phrase_k=max(1, retrieval_per_phrase_k), per_phrase_final_k=max(1, retrieval_per_phrase_final_k), min_tag_count=max(0, min_tag_count), verbose=verbose_retrieval, ) if isinstance(retrieval_result, tuple): candidates, phrase_reports = retrieval_result else: candidates, phrase_reports = retrieval_result, [] candidates, removed_candidate_excluded = _filter_excluded_candidates(candidates) if removed_candidate_excluded: log( f"Filtered {len(removed_candidate_excluded)} excluded retrieved tags: " f"{', '.join(removed_candidate_excluded[:25])}" + (" ..." if len(removed_candidate_excluded) > 25 else "") ) if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap: candidates = candidates[:selection_candidate_cap] log(f"Selection candidate cap applied: {selection_candidate_cap}") dt = time.perf_counter()-t0 _record_timing("retrieval", dt) log(f"Retrieval: {dt:.2f}s") log(f"Retrieved {len(candidates)} candidate tags") if verbose_retrieval: log(f"Total unique candidates: {len(candidates)}") limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit)) for report in phrase_reports: phrase = report.get("normalized") or report.get("phrase") or "" lookup = report.get("lookup") or "" tfidf_vocab = report.get("tfidf_vocab") log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}") rows = report.get("candidates", []) shown = rows if limit is None else rows[:limit] for row in shown: tag = row.get("tag") alias_token = row.get("alias_token") score_fasttext = row.get("score_fasttext") score_context = row.get("score_context") score_combined = row.get("score_combined") count = row.get("count") alias_part = "" if alias_token and alias_token != tag: alias_part = f" [alias_token={alias_token}]" fasttext_str = ( f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext ) if score_context is None: context_str = "None" else: context_str = ( f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context ) combined_str = ( f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined ) log( f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} " f"combined={combined_str} count={count}" ) if limit is not None and len(rows) > limit: log(f" ... ({len(rows) - limit} more)") except Exception as e: log(f"Retrieval fallback: {type(e).__name__}: {e}") candidates = [] retrieved_candidate_tags = list( dict.fromkeys( _norm_tag_for_lookup(c.tag) for c in (candidates or []) if getattr(c, "tag", None) ) ) log("Step 3: LLM index selection (uses rewrite + structural/probe context)") selection_query = _build_selection_query( prompt_in=prompt_in, rewritten=rewritten, structural_tags=structural_tags, probe_tags=probe_tags, ) picked_indices = None last_stage3_error: Exception | None = None stage3_attempts = 1 + int(stage3_fast_retry_count) for attempt_i in range(stage3_attempts): timeout_s = stage3_select_timeout_s if attempt_i == 0 else stage3_select_retry_timeout_s if attempt_i > 0: log( f"Index selection: fast retry {attempt_i}/{stage3_fast_retry_count} " f"(timeout={timeout_s:.0f}s)" ) ex = ThreadPoolExecutor(max_workers=1) try: fut_sel = ex.submit( llm_select_indices, query_text=selection_query, candidates=candidates, max_pick=0, log=log, mode=selection_mode, chunk_size=max(1, selection_chunk_size), per_phrase_k=max(1, selection_per_phrase_k), ) picked_indices = _future_with_timeout( fut_sel, timeout_s, "Index selection", [], strict=True, ) last_stage3_error = None break except Exception as e: last_stage3_error = e log(f"Index selection attempt {attempt_i + 1} failed: {e}") finally: ex.shutdown(wait=False, cancel_futures=True) if picked_indices is None: raise RuntimeError( f"Index selection failed after {stage3_attempts} attempt(s): {last_stage3_error}" ) selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else [] selection_selected_tags, removed_stage3_low = _filter_min_count_tags(selection_selected_tags, min_tag_count) if removed_stage3_low: log( f" Filtered {len(removed_stage3_low)} low-frequency stage3 tags " f"(<{min_tag_count}): {', '.join(removed_stage3_low)}" ) selected_tags = list(selection_selected_tags) if structural_tags: # Add structural tags that aren't already selected existing = {t for t in selected_tags} new_structural = [t for t in structural_tags if t not in existing] selected_tags.extend(new_structural) log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}") else: log(" No structural tags inferred") if probe_tags: existing = {t for t in selected_tags} new_probe = [t for t in probe_tags if t not in existing] selected_tags.extend(new_probe) log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}") elif enable_probe_tags: log(" No probe tags inferred") selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags) if removed_excluded_direct: log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}") direct_selected_tags = list(dict.fromkeys(selected_tags)) log("Step 3c: Expand via tag implications") t0 = time.perf_counter() tag_set = set(selected_tags) expanded, implied_only = expand_tags_via_implications(tag_set) dt = time.perf_counter()-t0 _record_timing("implication_expansion", dt) log(f"Implication expansion: {dt:.2f}s") implied_selected_tags = sorted(implied_only) if implied_only else [] if implied_only: implied_added = sorted(implied_only) implied_added, removed_implied_low = _filter_min_count_tags(implied_added, min_tag_count) implied_selected_tags = list(implied_added) if implied_added: selected_tags.extend(implied_added) log(f" Added {len(implied_added)} implied tags: {', '.join(implied_added)}") if removed_implied_low: log( f" Filtered {len(removed_implied_low)} low-frequency implied tags " f"(<{min_tag_count}): {', '.join(removed_implied_low)}" ) else: log(" No additional implied tags") selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags) implied_selected_tags = [ t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t) ] if removed_excluded_implied: log( f" Removed {len(removed_excluded_implied)} excluded tags after implications: " f"{', '.join(removed_excluded_implied)}" ) log("Step 4: Compose final prompt") t0 = time.perf_counter() final_prompt = compose_final_prompt(rewritten, selected_tags) dt = time.perf_counter()-t0 _record_timing("prompt_composition", dt) log(f"Prompt composition: {dt:.2f}s") log("Step 5: Build ranked group/category display") t0 = time.perf_counter() seed_terms = [] seed_terms.extend(user_tags) seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()]) seed_terms.extend(structural_tags or []) seed_terms.extend(probe_tags or []) seed_terms.extend(selected_tags) seed_terms = list(dict.fromkeys(seed_terms)) active_selected_tags = list(dict.fromkeys(selected_tags)) structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t} probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t} implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t} rewrite_set = { _norm_tag_for_lookup(t) for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()]) if t } selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t} tag_selection_origins: Dict[str, str] = {} for tag in active_selected_tags: tag_norm = _norm_tag_for_lookup(tag) if tag_norm in structural_set: origin = "structural" elif tag_norm in probe_set: origin = "probe" elif tag_norm in rewrite_set: origin = "rewrite" elif tag_norm in selection_set: origin = "selection" elif tag_norm in implied_set: origin = "implied" else: # Unknown/fallback tags use selection color. origin = "selection" tag_selection_origins[tag] = origin if tag_norm and tag_norm != tag: tag_selection_origins[tag_norm] = origin direct_tags_for_implied = list( dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t) ) direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)} direct_tags_for_implied.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), direct_tags_for_implied_idx.get(t, 10**9), ) ) implied_parent_map = _build_implied_parent_map( direct_tags_ordered=direct_tags_for_implied, implied_tags=implied_selected_tags, ) toggle_rows = _build_toggle_rows( seed_terms=seed_terms, selected_tags=active_selected_tags, retrieved_candidate_tags=retrieved_candidate_tags, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, top_groups=max(1, int(display_top_groups)), top_tags_per_group=max(1, int(display_top_tags_per_group)), group_rank_top_k=max(1, int(display_rank_top_k)), ) dt = time.perf_counter()-t0 _record_timing("group_display", dt) log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)") log( _build_display_audit_line( toggle_rows, active_selected_tags=active_selected_tags, direct_selected_tags=direct_selected_tags, implied_selected_tags=implied_selected_tags, ) ) for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]): tags_preview = ", ".join(row.get("tags", [])) log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}") total_dt = time.perf_counter()-t_total0 _emit_timing_summary(total_dt) _append_timing_jsonl(total_dt) log("Done: final prompt ready") return _build_ui_payload( console_text="\n".join(logs), row_defs=toggle_rows, selected_tags=active_selected_tags, ) except Exception as e: log(f"Error: {type(e).__name__}: {e}") return _build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], suggested_prompt_text=_format_user_facing_error(e), ) with gr.Blocks(css=css, js=client_js) as app: with gr.Row(): with gr.Column(scale=3, elem_classes=["prompt-col"]): gr.Markdown( 'Describe your image under "Enter Prompt" and click "Run". ' 'Prompt Squirrel will translate it into image board tags.', elem_classes=["top-instruction"], ) with gr.Group(elem_classes=["prompt-card"]): image_tags = gr.Textbox( label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, .", lines=1, elem_classes=["enter-prompt-box"], ) with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]): suggested_prompt = gr.Textbox( label="Suggested Prompt (Read-only)", lines=2, interactive=False, show_copy_button=True, placeholder='Suggested prompt will appear here after you click "Run".', elem_classes=["suggested-prompt-box"], ) with gr.Column(scale=1): _mascot_pil = _load_mascot_image() if _mascot_pil is not None: mascot_img = gr.Image( value=_mascot_pil, show_label=False, interactive=False, height=240, elem_id="mascot" ) else: mascot_img = gr.Markdown("`(mascot image unavailable)`") submit_button = gr.Button("Run", variant="primary", visible=False, interactive=False) gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"]) last_run_prompt_state = gr.State("") selected_tags_state = gr.State([]) rows_dirty_state = gr.State(False) row_defs_state = gr.State([]) row_values_state = gr.State([]) toggle_instruction = gr.Markdown( "Click tag buttons to add or remove tags from the suggested prompt.", elem_classes=["row-instruction"], visible=False, ) row_headers: List[gr.Markdown] = [] row_checkboxes: List[gr.CheckboxGroup] = [] for _ in range(display_max_rows_default): with gr.Row(): with gr.Column(scale=2, min_width=170): row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"])) with gr.Column(scale=10): row_checkboxes.append( gr.CheckboxGroup( choices=[], value=[], visible=False, interactive=True, container=False, elem_classes=["lego-tags"], ) ) with gr.Row(): with gr.Column(scale=10): gr.HTML( """
Legend: Rewrite phrase General selection Probe query Structural query Implied User-toggled Unselected
""" ) with gr.Column(scale=2, min_width=180): rebuild_rows_button = gr.Button( "Rebuild Rows", variant="primary", visible=False, interactive=False, ) with gr.Accordion("Display Settings", open=False): with gr.Row(): display_top_groups = gr.Number( value=display_top_groups_default, precision=0, label="Rows (Top Groups/Categories)", minimum=1, ) display_top_tags_per_group = gr.Number( value=display_top_tags_per_group_default, precision=0, label="Top Tags Shown Per Row", minimum=1, ) display_rank_top_k = gr.Number( value=display_rank_top_k_default, precision=0, label="Top Tags Used for Row Ranking", minimum=1, ) with gr.Accordion("Console", open=False): console = gr.Textbox( label="Console", lines=10, interactive=False, placeholder="Progress logs will appear here." ) with gr.Accordion("How Prompt Squirrel Works", open=False): _about_md = _load_about_docs_markdown() _about_before, _about_after, _has_arch_slot = _split_about_docs_for_diagram(_about_md) if _has_arch_slot: if _about_before: gr.Markdown( _about_before, elem_id="about-docs", elem_classes=["about-docs"], ) gr.HTML( _build_arch_diagram_html(), elem_classes=["about-docs"], ) if _about_after: gr.Markdown( _about_after, elem_classes=["about-docs"], ) else: gr.Markdown( _about_md, elem_id="about-docs", elem_classes=["about-docs"], ) tooltip_map_payload = gr.Textbox( value="{}", visible=True, interactive=False, container=False, elem_id="psq-tooltip-map", elem_classes=["psq-hidden"], ) run_outputs = [ console, toggle_instruction, tooltip_map_payload, suggested_prompt, selected_tags_state, rows_dirty_state, rebuild_rows_button, row_defs_state, row_values_state, *row_headers, *row_checkboxes, ] image_tags.change( _update_run_button_visibility, inputs=[image_tags, last_run_prompt_state], outputs=[submit_button], queue=False, show_progress="hidden", ) submit_button.click( _mark_run_triggered, inputs=[image_tags], outputs=[submit_button, last_run_prompt_state], queue=False, show_progress="hidden", ).then( _prepare_run_ui, inputs=[], outputs=run_outputs, queue=False, show_progress="hidden", ).then( rag_pipeline_ui, inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k], outputs=run_outputs, ) image_tags.submit( _mark_run_triggered, inputs=[image_tags], outputs=[submit_button, last_run_prompt_state], queue=False, show_progress="hidden", ).then( _prepare_run_ui, inputs=[], outputs=run_outputs, queue=False, show_progress="hidden", ).then( rag_pipeline_ui, inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k], outputs=run_outputs, ) for idx, row_cb in enumerate(row_checkboxes): row_cb.change( fn=lambda changed_values, selected_state, rows_dirty, row_defs, row_values, i=idx: _on_toggle_row( i, changed_values, selected_state, rows_dirty, row_defs, row_values, display_max_rows_default, ), inputs=[row_cb, selected_tags_state, rows_dirty_state, row_defs_state, row_values_state], outputs=[selected_tags_state, rows_dirty_state, rebuild_rows_button, row_values_state, suggested_prompt, *row_checkboxes], queue=False, show_progress="hidden", ) rebuild_rows_button.click( _rebuild_rows_from_selected, inputs=[selected_tags_state, row_defs_state, row_values_state, display_top_groups, display_top_tags_per_group, display_rank_top_k], outputs=[ toggle_instruction, tooltip_map_payload, suggested_prompt, selected_tags_state, rows_dirty_state, rebuild_rows_button, row_defs_state, row_values_state, *row_headers, *row_checkboxes, ], queue=False, show_progress="hidden", ) if __name__ == "__main__": app.queue().launch(allowed_paths=[str(MASCOT_DIR), str(DOCS_DIR)])