from pathlib import Path import sys import traceback import os # Add repo root (parent of /scripts) to sys.path repo_root = Path(__file__).resolve().parents[1] sys.path.insert(0, str(repo_root)) os.chdir(repo_root) def main(): from psq_rag.llm.rewrite import llm_rewrite_prompt from psq_rag.retrieval.psq_retrieval import ( psq_candidates_from_rewrite_phrases, ) from psq_rag.retrieval.state import ( get_artist_set, get_nsfw_tags, ) def log(x=""): print(x) def assert_true(condition, message): if not condition: raise AssertionError(message) def print_failure(message, exc): log(f"FAIL: {message}") if exc is not None: for line in traceback.format_exception_only(type(exc), exc): log(line.rstrip()) def import_sanity(): try: __import__("psq_rag.retrieval.state") __import__("psq_rag.retrieval.psq_retrieval") __import__("psq_rag.parsing.prompt_grammar") __import__("psq_rag.llm.rewrite") import app log("import sanity: ok") except Exception as e: log(f"import sanity: {type(e).__name__}: {e}") import_sanity() stage2_only = "--stage2-only" in sys.argv if not stage2_only: prompt = "ape, raised arms, looking at viewer" rewrite = llm_rewrite_prompt(prompt, log) if rewrite: print("rewrite:", rewrite) else: log("LLM rewrite: no result (continuing)") def run_stage2_test_a(): phrases = ["big shirt", "grey shirt"] cands, per_phrase = psq_candidates_from_rewrite_phrases( rewrite_phrases=phrases, allow_nsfw_tags=True, verbose=True, global_k=300, per_phrase_k=50, per_phrase_final_k=10, ) print("cands:", len(cands)) assert_true(isinstance(per_phrase, list), "per_phrase must be a list") phrase_set = {report.get("phrase") for report in per_phrase} assert_true("big shirt" in phrase_set, "per_phrase missing entry for 'big shirt'") assert_true("grey shirt" in phrase_set, "per_phrase missing entry for 'grey shirt'") assert_true("shirt" in phrase_set, "per_phrase missing head-noun expansion for 'shirt'") required_report_keys = {"phrase", "normalized", "lookup", "tfidf_vocab", "oov_terms", "candidates"} required_row_keys = { "tag", "alias_token", "score_fasttext", "score_context", "score_combined", "context_imputed", "count", } for report in per_phrase: assert_true(required_report_keys.issubset(report.keys()), "per_phrase missing required keys") rows = report.get("candidates", []) assert_true(isinstance(rows, list), "per_phrase candidates must be a list") for row in rows: assert_true(required_row_keys.issubset(row.keys()), "candidate row missing required keys") big_report = None for report in per_phrase: if report.get("phrase") == "big shirt": big_report = report break assert_true(big_report is not None, "no per_phrase report found for 'big shirt'") big_tags = {row.get("tag") for row in big_report.get("candidates", [])} assert_true("big_shirt" in big_tags, "big_shirt missing from per_phrase_final_k for 'big shirt'") log("stage2-only test A: PASS") def run_stage2_test_b(): phrases = ["anuss"] result_unfiltered = psq_candidates_from_rewrite_phrases( rewrite_phrases=phrases, allow_nsfw_tags=True, verbose=False, global_k=300, per_phrase_k=50, per_phrase_final_k=10, ) result_filtered = psq_candidates_from_rewrite_phrases( rewrite_phrases=phrases, allow_nsfw_tags=False, verbose=False, global_k=300, per_phrase_k=50, per_phrase_final_k=10, ) cands_unfiltered = result_unfiltered[0] if isinstance(result_unfiltered, tuple) else result_unfiltered cands_filtered = result_filtered[0] if isinstance(result_filtered, tuple) else result_filtered def extract_tag(row): if hasattr(row, "get"): return row.get("tag") return getattr(row, "tag", None) unfiltered_tags = {extract_tag(row) for row in cands_unfiltered} filtered_tags = {extract_tag(row) for row in cands_filtered} assert_true("anus" in unfiltered_tags, "anus missing from unfiltered candidates") assert_true("anus" not in filtered_tags, "anus unexpectedly present in filtered candidates") log(f"stage2-only test B: PASS (anus in unfiltered={ 'anus' in unfiltered_tags }, in filtered={ 'anus' in filtered_tags })") if stage2_only: try: run_stage2_test_a() run_stage2_test_b() except AssertionError as exc: print_failure("stage2 contract assertion failed", exc) sys.exit(1) return # Artist set check (optional in RAG mode) try: artists = get_artist_set() log(f"artist set size: {len(artists)}") except Exception as e: log(f"artist set: {type(e).__name__}: {e}") try: nsfw_tags = get_nsfw_tags() log(f"nsfw tag count: {len(nsfw_tags)}") except Exception as e: log(f"nsfw tags: {type(e).__name__}: {e}") if __name__ == "__main__": main()