Prompt_Squirrel_RAG / scripts /smoke_test.py
Food Desert
Add alias-based character tag filtering for Stage 3
c6be992
from pathlib import Path
import sys
import traceback
import os
# Add repo root (parent of /scripts) to sys.path
repo_root = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(repo_root))
os.chdir(repo_root)
def main():
from psq_rag.llm.rewrite import llm_rewrite_prompt
from psq_rag.retrieval.psq_retrieval import (
psq_candidates_from_rewrite_phrases,
)
from psq_rag.retrieval.state import (
get_artist_set,
get_nsfw_tags,
)
def log(x=""):
print(x)
def assert_true(condition, message):
if not condition:
raise AssertionError(message)
def print_failure(message, exc):
log(f"FAIL: {message}")
if exc is not None:
for line in traceback.format_exception_only(type(exc), exc):
log(line.rstrip())
def import_sanity():
try:
__import__("psq_rag.retrieval.state")
__import__("psq_rag.retrieval.psq_retrieval")
__import__("psq_rag.parsing.prompt_grammar")
__import__("psq_rag.llm.rewrite")
import app
log("import sanity: ok")
except Exception as e:
log(f"import sanity: {type(e).__name__}: {e}")
import_sanity()
stage2_only = "--stage2-only" in sys.argv
if not stage2_only:
prompt = "ape, raised arms, looking at viewer"
rewrite = llm_rewrite_prompt(prompt, log)
if rewrite:
print("rewrite:", rewrite)
else:
log("LLM rewrite: no result (continuing)")
def run_stage2_test_a():
phrases = ["big shirt", "grey shirt"]
cands, per_phrase = psq_candidates_from_rewrite_phrases(
rewrite_phrases=phrases,
allow_nsfw_tags=True,
verbose=True,
global_k=300,
per_phrase_k=50,
per_phrase_final_k=10,
)
print("cands:", len(cands))
assert_true(isinstance(per_phrase, list), "per_phrase must be a list")
phrase_set = {report.get("phrase") for report in per_phrase}
assert_true("big shirt" in phrase_set, "per_phrase missing entry for 'big shirt'")
assert_true("grey shirt" in phrase_set, "per_phrase missing entry for 'grey shirt'")
assert_true("shirt" in phrase_set, "per_phrase missing head-noun expansion for 'shirt'")
required_report_keys = {"phrase", "normalized", "lookup", "tfidf_vocab", "oov_terms", "candidates"}
required_row_keys = {
"tag",
"alias_token",
"score_fasttext",
"score_context",
"score_combined",
"context_imputed",
"count",
}
for report in per_phrase:
assert_true(required_report_keys.issubset(report.keys()), "per_phrase missing required keys")
rows = report.get("candidates", [])
assert_true(isinstance(rows, list), "per_phrase candidates must be a list")
for row in rows:
assert_true(required_row_keys.issubset(row.keys()), "candidate row missing required keys")
big_report = None
for report in per_phrase:
if report.get("phrase") == "big shirt":
big_report = report
break
assert_true(big_report is not None, "no per_phrase report found for 'big shirt'")
big_tags = {row.get("tag") for row in big_report.get("candidates", [])}
assert_true("big_shirt" in big_tags, "big_shirt missing from per_phrase_final_k for 'big shirt'")
log("stage2-only test A: PASS")
def run_stage2_test_b():
phrases = ["anuss"]
result_unfiltered = psq_candidates_from_rewrite_phrases(
rewrite_phrases=phrases,
allow_nsfw_tags=True,
verbose=False,
global_k=300,
per_phrase_k=50,
per_phrase_final_k=10,
)
result_filtered = psq_candidates_from_rewrite_phrases(
rewrite_phrases=phrases,
allow_nsfw_tags=False,
verbose=False,
global_k=300,
per_phrase_k=50,
per_phrase_final_k=10,
)
cands_unfiltered = result_unfiltered[0] if isinstance(result_unfiltered, tuple) else result_unfiltered
cands_filtered = result_filtered[0] if isinstance(result_filtered, tuple) else result_filtered
def extract_tag(row):
if hasattr(row, "get"):
return row.get("tag")
return getattr(row, "tag", None)
unfiltered_tags = {extract_tag(row) for row in cands_unfiltered}
filtered_tags = {extract_tag(row) for row in cands_filtered}
assert_true("anus" in unfiltered_tags, "anus missing from unfiltered candidates")
assert_true("anus" not in filtered_tags, "anus unexpectedly present in filtered candidates")
log(f"stage2-only test B: PASS (anus in unfiltered={ 'anus' in unfiltered_tags }, in filtered={ 'anus' in filtered_tags })")
if stage2_only:
try:
run_stage2_test_a()
run_stage2_test_b()
except AssertionError as exc:
print_failure("stage2 contract assertion failed", exc)
sys.exit(1)
return
# Artist set check (optional in RAG mode)
try:
artists = get_artist_set()
log(f"artist set size: {len(artists)}")
except Exception as e:
log(f"artist set: {type(e).__name__}: {e}")
try:
nsfw_tags = get_nsfw_tags()
log(f"nsfw tag count: {len(nsfw_tags)}")
except Exception as e:
log(f"nsfw tags: {type(e).__name__}: {e}")
if __name__ == "__main__":
main()