Spaces:
Running
Running
| from pathlib import Path | |
| import sys | |
| import traceback | |
| import os | |
| # Add repo root (parent of /scripts) to sys.path | |
| repo_root = Path(__file__).resolve().parents[1] | |
| sys.path.insert(0, str(repo_root)) | |
| os.chdir(repo_root) | |
| def main(): | |
| from psq_rag.llm.rewrite import llm_rewrite_prompt | |
| from psq_rag.retrieval.psq_retrieval import ( | |
| psq_candidates_from_rewrite_phrases, | |
| ) | |
| from psq_rag.retrieval.state import ( | |
| get_artist_set, | |
| get_nsfw_tags, | |
| ) | |
| def log(x=""): | |
| print(x) | |
| def assert_true(condition, message): | |
| if not condition: | |
| raise AssertionError(message) | |
| def print_failure(message, exc): | |
| log(f"FAIL: {message}") | |
| if exc is not None: | |
| for line in traceback.format_exception_only(type(exc), exc): | |
| log(line.rstrip()) | |
| def import_sanity(): | |
| try: | |
| __import__("psq_rag.retrieval.state") | |
| __import__("psq_rag.retrieval.psq_retrieval") | |
| __import__("psq_rag.parsing.prompt_grammar") | |
| __import__("psq_rag.llm.rewrite") | |
| import app | |
| log("import sanity: ok") | |
| except Exception as e: | |
| log(f"import sanity: {type(e).__name__}: {e}") | |
| import_sanity() | |
| stage2_only = "--stage2-only" in sys.argv | |
| if not stage2_only: | |
| prompt = "ape, raised arms, looking at viewer" | |
| rewrite = llm_rewrite_prompt(prompt, log) | |
| if rewrite: | |
| print("rewrite:", rewrite) | |
| else: | |
| log("LLM rewrite: no result (continuing)") | |
| def run_stage2_test_a(): | |
| phrases = ["big shirt", "grey shirt"] | |
| cands, per_phrase = psq_candidates_from_rewrite_phrases( | |
| rewrite_phrases=phrases, | |
| allow_nsfw_tags=True, | |
| verbose=True, | |
| global_k=300, | |
| per_phrase_k=50, | |
| per_phrase_final_k=10, | |
| ) | |
| print("cands:", len(cands)) | |
| assert_true(isinstance(per_phrase, list), "per_phrase must be a list") | |
| phrase_set = {report.get("phrase") for report in per_phrase} | |
| assert_true("big shirt" in phrase_set, "per_phrase missing entry for 'big shirt'") | |
| assert_true("grey shirt" in phrase_set, "per_phrase missing entry for 'grey shirt'") | |
| assert_true("shirt" in phrase_set, "per_phrase missing head-noun expansion for 'shirt'") | |
| required_report_keys = {"phrase", "normalized", "lookup", "tfidf_vocab", "oov_terms", "candidates"} | |
| required_row_keys = { | |
| "tag", | |
| "alias_token", | |
| "score_fasttext", | |
| "score_context", | |
| "score_combined", | |
| "context_imputed", | |
| "count", | |
| } | |
| for report in per_phrase: | |
| assert_true(required_report_keys.issubset(report.keys()), "per_phrase missing required keys") | |
| rows = report.get("candidates", []) | |
| assert_true(isinstance(rows, list), "per_phrase candidates must be a list") | |
| for row in rows: | |
| assert_true(required_row_keys.issubset(row.keys()), "candidate row missing required keys") | |
| big_report = None | |
| for report in per_phrase: | |
| if report.get("phrase") == "big shirt": | |
| big_report = report | |
| break | |
| assert_true(big_report is not None, "no per_phrase report found for 'big shirt'") | |
| big_tags = {row.get("tag") for row in big_report.get("candidates", [])} | |
| assert_true("big_shirt" in big_tags, "big_shirt missing from per_phrase_final_k for 'big shirt'") | |
| log("stage2-only test A: PASS") | |
| def run_stage2_test_b(): | |
| phrases = ["anuss"] | |
| result_unfiltered = psq_candidates_from_rewrite_phrases( | |
| rewrite_phrases=phrases, | |
| allow_nsfw_tags=True, | |
| verbose=False, | |
| global_k=300, | |
| per_phrase_k=50, | |
| per_phrase_final_k=10, | |
| ) | |
| result_filtered = psq_candidates_from_rewrite_phrases( | |
| rewrite_phrases=phrases, | |
| allow_nsfw_tags=False, | |
| verbose=False, | |
| global_k=300, | |
| per_phrase_k=50, | |
| per_phrase_final_k=10, | |
| ) | |
| cands_unfiltered = result_unfiltered[0] if isinstance(result_unfiltered, tuple) else result_unfiltered | |
| cands_filtered = result_filtered[0] if isinstance(result_filtered, tuple) else result_filtered | |
| def extract_tag(row): | |
| if hasattr(row, "get"): | |
| return row.get("tag") | |
| return getattr(row, "tag", None) | |
| unfiltered_tags = {extract_tag(row) for row in cands_unfiltered} | |
| filtered_tags = {extract_tag(row) for row in cands_filtered} | |
| assert_true("anus" in unfiltered_tags, "anus missing from unfiltered candidates") | |
| assert_true("anus" not in filtered_tags, "anus unexpectedly present in filtered candidates") | |
| log(f"stage2-only test B: PASS (anus in unfiltered={ 'anus' in unfiltered_tags }, in filtered={ 'anus' in filtered_tags })") | |
| if stage2_only: | |
| try: | |
| run_stage2_test_a() | |
| run_stage2_test_b() | |
| except AssertionError as exc: | |
| print_failure("stage2 contract assertion failed", exc) | |
| sys.exit(1) | |
| return | |
| # Artist set check (optional in RAG mode) | |
| try: | |
| artists = get_artist_set() | |
| log(f"artist set size: {len(artists)}") | |
| except Exception as e: | |
| log(f"artist set: {type(e).__name__}: {e}") | |
| try: | |
| nsfw_tags = get_nsfw_tags() | |
| log(f"nsfw tag count: {len(nsfw_tags)}") | |
| except Exception as e: | |
| log(f"nsfw tags: {type(e).__name__}: {e}") | |
| if __name__ == "__main__": | |
| main() | |