File size: 4,191 Bytes
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import json
import os
from pathlib import Path

import requests

CAPTION_FIELDS = ["caption_llm_4", "caption_llm_6", "caption_cogvlm"]

# Start with something minimal. You will iterate this.
REWRITE_SYSTEM = """Rewrite the input into a concise, comma-separated list of short phrases
that resemble image tags.

Use short, literal phrases that reflect how visual concepts are commonly
written in image tag vocabularies.

Multi-word phrases are appropriate when they represent one coherent
visual idea.

Examples of tag-shaped phrases:
- wolf, angry
- blue jacket, striped tail
- long hair, raised ears
- holding object, hand on shoulder
- looking at viewer, looking down
- simple background, outdoor scene
- wooden table, plant
- running, sleeping
- smiling, angry expression
- bedroom, forest
- sonic the hedgehog, princess peach

Do not invent details or guess identities.
Do not infer demographic attributes (e.g., gender/age) unless explicitly stated.

Output ONLY the rewritten list."""


def load_jsonl(path: Path):
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            yield json.loads(line)


def openrouter_chat(model: str, system: str, user: str, temperature: float = 0.0, max_tokens: int = 200) -> str:
    api_key = os.environ.get("OPENROUTER_API_KEY")
    if not api_key:
        raise RuntimeError("Set OPENROUTER_API_KEY in your environment.")

    url = "https://openrouter.ai/api/v1/chat/completions"
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }
    payload = {
        "model": model,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "messages": [
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
    }

    r = requests.post(url, headers=headers, json=payload, timeout=60)
    r.raise_for_status()
    data = r.json()
    return data["choices"][0]["message"]["content"].strip()


def main() -> None:
    ap = argparse.ArgumentParser(description="Interactive prompt **query rewriting** playground.")
    ap.add_argument("--sample", type=str, required=True, help="Path to the trimmed JSONL sample.")
    ap.add_argument("--field", type=str, default="caption_llm_6", choices=CAPTION_FIELDS)
    ap.add_argument("--model", type=str, default="meta-llama/llama-3.1-8b-instruct")
    ap.add_argument("--temperature", type=float, default=0.0)
    ap.add_argument("--max-tokens", type=int, default=200)
    ap.add_argument("--start", type=int, default=0, help="Index to start from within the loaded examples.")
    args = ap.parse_args()

    rows = []
    for row in load_jsonl(Path(args.sample)):
        text = (row.get(args.field) or "").strip()
        if text:
            gt = row.get("tags_ground_truth_categorized")
            rows.append((str(row["id"]), text, gt))

    if not rows:
        raise RuntimeError(f"No non-empty rows found for field={args.field}")

    print(f"Loaded {len(rows)} examples from {args.sample} using {args.field}.")
    print("Commands: [Enter]=next  |  r=rerun current (same input)  |  q=quit\n")

    if args.start < 0 or args.start >= len(rows):
        raise ValueError(f"--start must be in [0, {len(rows)-1}] but got {args.start}")
    idx = args.start
    while True:
        row_id, prompt, gt = rows[idx]
        print("=" * 80)
        print(f"row_id: {row_id}")
        print(f"ORIGINAL:\n{prompt}\n")

        rewritten = openrouter_chat(
            model=args.model,
            system=REWRITE_SYSTEM,
            user=prompt,
            temperature=args.temperature,
            max_tokens=args.max_tokens,
        )
        print(f"REWRITE:\n{rewritten}\n")

        if gt:
            gt_dict = json.loads(gt)
            flat_gt = sorted({tag for tags in gt_dict.values() for tag in tags})
            print(f"GROUND TRUTH TAGS:\n{', '.join(flat_gt)}\n")


        cmd = input("> ").strip().lower()
        if cmd == "q":
            break
        if cmd == "r":
            continue
        idx += 1
        if idx >= len(rows):
            print("End of samples.")
            break


if __name__ == "__main__":
    main()