Spaces:
Running
Running
File size: 5,686 Bytes
c6be992 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | from pathlib import Path
import sys
import traceback
import os
# Add repo root (parent of /scripts) to sys.path
repo_root = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(repo_root))
os.chdir(repo_root)
def main():
from psq_rag.llm.rewrite import llm_rewrite_prompt
from psq_rag.retrieval.psq_retrieval import (
psq_candidates_from_rewrite_phrases,
)
from psq_rag.retrieval.state import (
get_artist_set,
get_nsfw_tags,
)
def log(x=""):
print(x)
def assert_true(condition, message):
if not condition:
raise AssertionError(message)
def print_failure(message, exc):
log(f"FAIL: {message}")
if exc is not None:
for line in traceback.format_exception_only(type(exc), exc):
log(line.rstrip())
def import_sanity():
try:
__import__("psq_rag.retrieval.state")
__import__("psq_rag.retrieval.psq_retrieval")
__import__("psq_rag.parsing.prompt_grammar")
__import__("psq_rag.llm.rewrite")
import app
log("import sanity: ok")
except Exception as e:
log(f"import sanity: {type(e).__name__}: {e}")
import_sanity()
stage2_only = "--stage2-only" in sys.argv
if not stage2_only:
prompt = "ape, raised arms, looking at viewer"
rewrite = llm_rewrite_prompt(prompt, log)
if rewrite:
print("rewrite:", rewrite)
else:
log("LLM rewrite: no result (continuing)")
def run_stage2_test_a():
phrases = ["big shirt", "grey shirt"]
cands, per_phrase = psq_candidates_from_rewrite_phrases(
rewrite_phrases=phrases,
allow_nsfw_tags=True,
verbose=True,
global_k=300,
per_phrase_k=50,
per_phrase_final_k=10,
)
print("cands:", len(cands))
assert_true(isinstance(per_phrase, list), "per_phrase must be a list")
phrase_set = {report.get("phrase") for report in per_phrase}
assert_true("big shirt" in phrase_set, "per_phrase missing entry for 'big shirt'")
assert_true("grey shirt" in phrase_set, "per_phrase missing entry for 'grey shirt'")
assert_true("shirt" in phrase_set, "per_phrase missing head-noun expansion for 'shirt'")
required_report_keys = {"phrase", "normalized", "lookup", "tfidf_vocab", "oov_terms", "candidates"}
required_row_keys = {
"tag",
"alias_token",
"score_fasttext",
"score_context",
"score_combined",
"context_imputed",
"count",
}
for report in per_phrase:
assert_true(required_report_keys.issubset(report.keys()), "per_phrase missing required keys")
rows = report.get("candidates", [])
assert_true(isinstance(rows, list), "per_phrase candidates must be a list")
for row in rows:
assert_true(required_row_keys.issubset(row.keys()), "candidate row missing required keys")
big_report = None
for report in per_phrase:
if report.get("phrase") == "big shirt":
big_report = report
break
assert_true(big_report is not None, "no per_phrase report found for 'big shirt'")
big_tags = {row.get("tag") for row in big_report.get("candidates", [])}
assert_true("big_shirt" in big_tags, "big_shirt missing from per_phrase_final_k for 'big shirt'")
log("stage2-only test A: PASS")
def run_stage2_test_b():
phrases = ["anuss"]
result_unfiltered = psq_candidates_from_rewrite_phrases(
rewrite_phrases=phrases,
allow_nsfw_tags=True,
verbose=False,
global_k=300,
per_phrase_k=50,
per_phrase_final_k=10,
)
result_filtered = psq_candidates_from_rewrite_phrases(
rewrite_phrases=phrases,
allow_nsfw_tags=False,
verbose=False,
global_k=300,
per_phrase_k=50,
per_phrase_final_k=10,
)
cands_unfiltered = result_unfiltered[0] if isinstance(result_unfiltered, tuple) else result_unfiltered
cands_filtered = result_filtered[0] if isinstance(result_filtered, tuple) else result_filtered
def extract_tag(row):
if hasattr(row, "get"):
return row.get("tag")
return getattr(row, "tag", None)
unfiltered_tags = {extract_tag(row) for row in cands_unfiltered}
filtered_tags = {extract_tag(row) for row in cands_filtered}
assert_true("anus" in unfiltered_tags, "anus missing from unfiltered candidates")
assert_true("anus" not in filtered_tags, "anus unexpectedly present in filtered candidates")
log(f"stage2-only test B: PASS (anus in unfiltered={ 'anus' in unfiltered_tags }, in filtered={ 'anus' in filtered_tags })")
if stage2_only:
try:
run_stage2_test_a()
run_stage2_test_b()
except AssertionError as exc:
print_failure("stage2 contract assertion failed", exc)
sys.exit(1)
return
# Artist set check (optional in RAG mode)
try:
artists = get_artist_set()
log(f"artist set size: {len(artists)}")
except Exception as e:
log(f"artist set: {type(e).__name__}: {e}")
try:
nsfw_tags = get_nsfw_tags()
log(f"nsfw tag count: {len(nsfw_tags)}")
except Exception as e:
log(f"nsfw tags: {type(e).__name__}: {e}")
if __name__ == "__main__":
main()
|