Spaces:
Running
Running
File size: 4,191 Bytes
c6be992 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import argparse
import json
import os
from pathlib import Path
import requests
CAPTION_FIELDS = ["caption_llm_4", "caption_llm_6", "caption_cogvlm"]
# Start with something minimal. You will iterate this.
REWRITE_SYSTEM = """Rewrite the input into a concise, comma-separated list of short phrases
that resemble image tags.
Use short, literal phrases that reflect how visual concepts are commonly
written in image tag vocabularies.
Multi-word phrases are appropriate when they represent one coherent
visual idea.
Examples of tag-shaped phrases:
- wolf, angry
- blue jacket, striped tail
- long hair, raised ears
- holding object, hand on shoulder
- looking at viewer, looking down
- simple background, outdoor scene
- wooden table, plant
- running, sleeping
- smiling, angry expression
- bedroom, forest
- sonic the hedgehog, princess peach
Do not invent details or guess identities.
Do not infer demographic attributes (e.g., gender/age) unless explicitly stated.
Output ONLY the rewritten list."""
def load_jsonl(path: Path):
with path.open("r", encoding="utf-8") as f:
for line in f:
yield json.loads(line)
def openrouter_chat(model: str, system: str, user: str, temperature: float = 0.0, max_tokens: int = 200) -> str:
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError("Set OPENROUTER_API_KEY in your environment.")
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
}
r = requests.post(url, headers=headers, json=payload, timeout=60)
r.raise_for_status()
data = r.json()
return data["choices"][0]["message"]["content"].strip()
def main() -> None:
ap = argparse.ArgumentParser(description="Interactive prompt **query rewriting** playground.")
ap.add_argument("--sample", type=str, required=True, help="Path to the trimmed JSONL sample.")
ap.add_argument("--field", type=str, default="caption_llm_6", choices=CAPTION_FIELDS)
ap.add_argument("--model", type=str, default="meta-llama/llama-3.1-8b-instruct")
ap.add_argument("--temperature", type=float, default=0.0)
ap.add_argument("--max-tokens", type=int, default=200)
ap.add_argument("--start", type=int, default=0, help="Index to start from within the loaded examples.")
args = ap.parse_args()
rows = []
for row in load_jsonl(Path(args.sample)):
text = (row.get(args.field) or "").strip()
if text:
gt = row.get("tags_ground_truth_categorized")
rows.append((str(row["id"]), text, gt))
if not rows:
raise RuntimeError(f"No non-empty rows found for field={args.field}")
print(f"Loaded {len(rows)} examples from {args.sample} using {args.field}.")
print("Commands: [Enter]=next | r=rerun current (same input) | q=quit\n")
if args.start < 0 or args.start >= len(rows):
raise ValueError(f"--start must be in [0, {len(rows)-1}] but got {args.start}")
idx = args.start
while True:
row_id, prompt, gt = rows[idx]
print("=" * 80)
print(f"row_id: {row_id}")
print(f"ORIGINAL:\n{prompt}\n")
rewritten = openrouter_chat(
model=args.model,
system=REWRITE_SYSTEM,
user=prompt,
temperature=args.temperature,
max_tokens=args.max_tokens,
)
print(f"REWRITE:\n{rewritten}\n")
if gt:
gt_dict = json.loads(gt)
flat_gt = sorted({tag for tags in gt_dict.values() for tag in tags})
print(f"GROUND TRUTH TAGS:\n{', '.join(flat_gt)}\n")
cmd = input("> ").strip().lower()
if cmd == "q":
break
if cmd == "r":
continue
idx += 1
if idx >= len(rows):
print("End of samples.")
break
if __name__ == "__main__":
main()
|