Spaces:
Running
Running
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| import requests | |
| CAPTION_FIELDS = ["caption_llm_4", "caption_llm_6", "caption_cogvlm"] | |
| # Start with something minimal. You will iterate this. | |
| REWRITE_SYSTEM = """Rewrite the input into a concise, comma-separated list of short phrases | |
| that resemble image tags. | |
| Use short, literal phrases that reflect how visual concepts are commonly | |
| written in image tag vocabularies. | |
| Multi-word phrases are appropriate when they represent one coherent | |
| visual idea. | |
| Examples of tag-shaped phrases: | |
| - wolf, angry | |
| - blue jacket, striped tail | |
| - long hair, raised ears | |
| - holding object, hand on shoulder | |
| - looking at viewer, looking down | |
| - simple background, outdoor scene | |
| - wooden table, plant | |
| - running, sleeping | |
| - smiling, angry expression | |
| - bedroom, forest | |
| - sonic the hedgehog, princess peach | |
| Do not invent details or guess identities. | |
| Do not infer demographic attributes (e.g., gender/age) unless explicitly stated. | |
| Output ONLY the rewritten list.""" | |
| def load_jsonl(path: Path): | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| yield json.loads(line) | |
| def openrouter_chat(model: str, system: str, user: str, temperature: float = 0.0, max_tokens: int = 200) -> str: | |
| api_key = os.environ.get("OPENROUTER_API_KEY") | |
| if not api_key: | |
| raise RuntimeError("Set OPENROUTER_API_KEY in your environment.") | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": model, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "messages": [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ], | |
| } | |
| r = requests.post(url, headers=headers, json=payload, timeout=60) | |
| r.raise_for_status() | |
| data = r.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| def main() -> None: | |
| ap = argparse.ArgumentParser(description="Interactive prompt **query rewriting** playground.") | |
| ap.add_argument("--sample", type=str, required=True, help="Path to the trimmed JSONL sample.") | |
| ap.add_argument("--field", type=str, default="caption_llm_6", choices=CAPTION_FIELDS) | |
| ap.add_argument("--model", type=str, default="meta-llama/llama-3.1-8b-instruct") | |
| ap.add_argument("--temperature", type=float, default=0.0) | |
| ap.add_argument("--max-tokens", type=int, default=200) | |
| ap.add_argument("--start", type=int, default=0, help="Index to start from within the loaded examples.") | |
| args = ap.parse_args() | |
| rows = [] | |
| for row in load_jsonl(Path(args.sample)): | |
| text = (row.get(args.field) or "").strip() | |
| if text: | |
| gt = row.get("tags_ground_truth_categorized") | |
| rows.append((str(row["id"]), text, gt)) | |
| if not rows: | |
| raise RuntimeError(f"No non-empty rows found for field={args.field}") | |
| print(f"Loaded {len(rows)} examples from {args.sample} using {args.field}.") | |
| print("Commands: [Enter]=next | r=rerun current (same input) | q=quit\n") | |
| if args.start < 0 or args.start >= len(rows): | |
| raise ValueError(f"--start must be in [0, {len(rows)-1}] but got {args.start}") | |
| idx = args.start | |
| while True: | |
| row_id, prompt, gt = rows[idx] | |
| print("=" * 80) | |
| print(f"row_id: {row_id}") | |
| print(f"ORIGINAL:\n{prompt}\n") | |
| rewritten = openrouter_chat( | |
| model=args.model, | |
| system=REWRITE_SYSTEM, | |
| user=prompt, | |
| temperature=args.temperature, | |
| max_tokens=args.max_tokens, | |
| ) | |
| print(f"REWRITE:\n{rewritten}\n") | |
| if gt: | |
| gt_dict = json.loads(gt) | |
| flat_gt = sorted({tag for tags in gt_dict.values() for tag in tags}) | |
| print(f"GROUND TRUTH TAGS:\n{', '.join(flat_gt)}\n") | |
| cmd = input("> ").strip().lower() | |
| if cmd == "q": | |
| break | |
| if cmd == "r": | |
| continue | |
| idx += 1 | |
| if idx >= len(rows): | |
| print("End of samples.") | |
| break | |
| if __name__ == "__main__": | |
| main() | |