import argparse import json import os from pathlib import Path import requests CAPTION_FIELDS = ["caption_llm_4", "caption_llm_6", "caption_cogvlm"] # Start with something minimal. You will iterate this. REWRITE_SYSTEM = """Rewrite the input into a concise, comma-separated list of short phrases that resemble image tags. Use short, literal phrases that reflect how visual concepts are commonly written in image tag vocabularies. Multi-word phrases are appropriate when they represent one coherent visual idea. Examples of tag-shaped phrases: - wolf, angry - blue jacket, striped tail - long hair, raised ears - holding object, hand on shoulder - looking at viewer, looking down - simple background, outdoor scene - wooden table, plant - running, sleeping - smiling, angry expression - bedroom, forest - sonic the hedgehog, princess peach Do not invent details or guess identities. Do not infer demographic attributes (e.g., gender/age) unless explicitly stated. Output ONLY the rewritten list.""" def load_jsonl(path: Path): with path.open("r", encoding="utf-8") as f: for line in f: yield json.loads(line) def openrouter_chat(model: str, system: str, user: str, temperature: float = 0.0, max_tokens: int = 200) -> str: api_key = os.environ.get("OPENROUTER_API_KEY") if not api_key: raise RuntimeError("Set OPENROUTER_API_KEY in your environment.") url = "https://openrouter.ai/api/v1/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } payload = { "model": model, "temperature": temperature, "max_tokens": max_tokens, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": user}, ], } r = requests.post(url, headers=headers, json=payload, timeout=60) r.raise_for_status() data = r.json() return data["choices"][0]["message"]["content"].strip() def main() -> None: ap = argparse.ArgumentParser(description="Interactive prompt **query rewriting** playground.") ap.add_argument("--sample", type=str, required=True, help="Path to the trimmed JSONL sample.") ap.add_argument("--field", type=str, default="caption_llm_6", choices=CAPTION_FIELDS) ap.add_argument("--model", type=str, default="meta-llama/llama-3.1-8b-instruct") ap.add_argument("--temperature", type=float, default=0.0) ap.add_argument("--max-tokens", type=int, default=200) ap.add_argument("--start", type=int, default=0, help="Index to start from within the loaded examples.") args = ap.parse_args() rows = [] for row in load_jsonl(Path(args.sample)): text = (row.get(args.field) or "").strip() if text: gt = row.get("tags_ground_truth_categorized") rows.append((str(row["id"]), text, gt)) if not rows: raise RuntimeError(f"No non-empty rows found for field={args.field}") print(f"Loaded {len(rows)} examples from {args.sample} using {args.field}.") print("Commands: [Enter]=next | r=rerun current (same input) | q=quit\n") if args.start < 0 or args.start >= len(rows): raise ValueError(f"--start must be in [0, {len(rows)-1}] but got {args.start}") idx = args.start while True: row_id, prompt, gt = rows[idx] print("=" * 80) print(f"row_id: {row_id}") print(f"ORIGINAL:\n{prompt}\n") rewritten = openrouter_chat( model=args.model, system=REWRITE_SYSTEM, user=prompt, temperature=args.temperature, max_tokens=args.max_tokens, ) print(f"REWRITE:\n{rewritten}\n") if gt: gt_dict = json.loads(gt) flat_gt = sorted({tag for tags in gt_dict.values() for tag in tags}) print(f"GROUND TRUTH TAGS:\n{', '.join(flat_gt)}\n") cmd = input("> ").strip().lower() if cmd == "q": break if cmd == "r": continue idx += 1 if idx >= len(rows): print("End of samples.") break if __name__ == "__main__": main()