File size: 5,686 Bytes
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from pathlib import Path
import sys
import traceback
import os

# Add repo root (parent of /scripts) to sys.path
repo_root = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(repo_root))
os.chdir(repo_root)


def main():
    from psq_rag.llm.rewrite import llm_rewrite_prompt
    from psq_rag.retrieval.psq_retrieval import (
        psq_candidates_from_rewrite_phrases,
    )
    from psq_rag.retrieval.state import (
        get_artist_set,
        get_nsfw_tags,
    )

    def log(x=""):
        print(x)

    def assert_true(condition, message):
        if not condition:
            raise AssertionError(message)

    def print_failure(message, exc):
        log(f"FAIL: {message}")
        if exc is not None:
            for line in traceback.format_exception_only(type(exc), exc):
                log(line.rstrip())

    def import_sanity():
        try:
            __import__("psq_rag.retrieval.state")
            __import__("psq_rag.retrieval.psq_retrieval")
            __import__("psq_rag.parsing.prompt_grammar")
            __import__("psq_rag.llm.rewrite")
            import app
            log("import sanity: ok")
        except Exception as e:
            log(f"import sanity: {type(e).__name__}: {e}")

    import_sanity()

    stage2_only = "--stage2-only" in sys.argv

    if not stage2_only:
        prompt = "ape, raised arms, looking at viewer"
        rewrite = llm_rewrite_prompt(prompt, log)
        if rewrite:
            print("rewrite:", rewrite)
        else:
            log("LLM rewrite: no result (continuing)")

    def run_stage2_test_a():
        phrases = ["big shirt", "grey shirt"]
        cands, per_phrase = psq_candidates_from_rewrite_phrases(
            rewrite_phrases=phrases,
            allow_nsfw_tags=True,
            verbose=True,
            global_k=300,
            per_phrase_k=50,
            per_phrase_final_k=10,
        )
        print("cands:", len(cands))

        assert_true(isinstance(per_phrase, list), "per_phrase must be a list")
        phrase_set = {report.get("phrase") for report in per_phrase}
        assert_true("big shirt" in phrase_set, "per_phrase missing entry for 'big shirt'")
        assert_true("grey shirt" in phrase_set, "per_phrase missing entry for 'grey shirt'")
        assert_true("shirt" in phrase_set, "per_phrase missing head-noun expansion for 'shirt'")

        required_report_keys = {"phrase", "normalized", "lookup", "tfidf_vocab", "oov_terms", "candidates"}
        required_row_keys = {
            "tag",
            "alias_token",
            "score_fasttext",
            "score_context",
            "score_combined",
            "context_imputed",
            "count",
        }
        for report in per_phrase:
            assert_true(required_report_keys.issubset(report.keys()), "per_phrase missing required keys")
            rows = report.get("candidates", [])
            assert_true(isinstance(rows, list), "per_phrase candidates must be a list")
            for row in rows:
                assert_true(required_row_keys.issubset(row.keys()), "candidate row missing required keys")

        big_report = None
        for report in per_phrase:
            if report.get("phrase") == "big shirt":
                big_report = report
                break
        assert_true(big_report is not None, "no per_phrase report found for 'big shirt'")
        big_tags = {row.get("tag") for row in big_report.get("candidates", [])}
        assert_true("big_shirt" in big_tags, "big_shirt missing from per_phrase_final_k for 'big shirt'")

        log("stage2-only test A: PASS")

    def run_stage2_test_b():
        phrases = ["anuss"]
        result_unfiltered = psq_candidates_from_rewrite_phrases(
            rewrite_phrases=phrases,
            allow_nsfw_tags=True,
            verbose=False,
            global_k=300,
            per_phrase_k=50,
            per_phrase_final_k=10,
        )
        result_filtered = psq_candidates_from_rewrite_phrases(
            rewrite_phrases=phrases,
            allow_nsfw_tags=False,
            verbose=False,
            global_k=300,
            per_phrase_k=50,
            per_phrase_final_k=10,
        )
        cands_unfiltered = result_unfiltered[0] if isinstance(result_unfiltered, tuple) else result_unfiltered
        cands_filtered = result_filtered[0] if isinstance(result_filtered, tuple) else result_filtered

        def extract_tag(row):
            if hasattr(row, "get"):
                return row.get("tag")
            return getattr(row, "tag", None)

        unfiltered_tags = {extract_tag(row) for row in cands_unfiltered}
        filtered_tags = {extract_tag(row) for row in cands_filtered}
        assert_true("anus" in unfiltered_tags, "anus missing from unfiltered candidates")
        assert_true("anus" not in filtered_tags, "anus unexpectedly present in filtered candidates")
        log(f"stage2-only test B: PASS (anus in unfiltered={ 'anus' in unfiltered_tags }, in filtered={ 'anus' in filtered_tags })")

    if stage2_only:
        try:
            run_stage2_test_a()
            run_stage2_test_b()
        except AssertionError as exc:
            print_failure("stage2 contract assertion failed", exc)
            sys.exit(1)
        return

    # Artist set check (optional in RAG mode)
    try:
        artists = get_artist_set()
        log(f"artist set size: {len(artists)}")
    except Exception as e:
        log(f"artist set: {type(e).__name__}: {e}")

    try:
        nsfw_tags = get_nsfw_tags()
        log(f"nsfw tag count: {len(nsfw_tags)}")
    except Exception as e:
        log(f"nsfw tags: {type(e).__name__}: {e}")

if __name__ == "__main__":
    main()