Prompt_Squirrel_RAG / scripts /rewrite_playground.py
Food Desert
Add alias-based character tag filtering for Stage 3
c6be992
import argparse
import json
import os
from pathlib import Path
import requests
CAPTION_FIELDS = ["caption_llm_4", "caption_llm_6", "caption_cogvlm"]
# Start with something minimal. You will iterate this.
REWRITE_SYSTEM = """Rewrite the input into a concise, comma-separated list of short phrases
that resemble image tags.
Use short, literal phrases that reflect how visual concepts are commonly
written in image tag vocabularies.
Multi-word phrases are appropriate when they represent one coherent
visual idea.
Examples of tag-shaped phrases:
- wolf, angry
- blue jacket, striped tail
- long hair, raised ears
- holding object, hand on shoulder
- looking at viewer, looking down
- simple background, outdoor scene
- wooden table, plant
- running, sleeping
- smiling, angry expression
- bedroom, forest
- sonic the hedgehog, princess peach
Do not invent details or guess identities.
Do not infer demographic attributes (e.g., gender/age) unless explicitly stated.
Output ONLY the rewritten list."""
def load_jsonl(path: Path):
with path.open("r", encoding="utf-8") as f:
for line in f:
yield json.loads(line)
def openrouter_chat(model: str, system: str, user: str, temperature: float = 0.0, max_tokens: int = 200) -> str:
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError("Set OPENROUTER_API_KEY in your environment.")
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
}
r = requests.post(url, headers=headers, json=payload, timeout=60)
r.raise_for_status()
data = r.json()
return data["choices"][0]["message"]["content"].strip()
def main() -> None:
ap = argparse.ArgumentParser(description="Interactive prompt **query rewriting** playground.")
ap.add_argument("--sample", type=str, required=True, help="Path to the trimmed JSONL sample.")
ap.add_argument("--field", type=str, default="caption_llm_6", choices=CAPTION_FIELDS)
ap.add_argument("--model", type=str, default="meta-llama/llama-3.1-8b-instruct")
ap.add_argument("--temperature", type=float, default=0.0)
ap.add_argument("--max-tokens", type=int, default=200)
ap.add_argument("--start", type=int, default=0, help="Index to start from within the loaded examples.")
args = ap.parse_args()
rows = []
for row in load_jsonl(Path(args.sample)):
text = (row.get(args.field) or "").strip()
if text:
gt = row.get("tags_ground_truth_categorized")
rows.append((str(row["id"]), text, gt))
if not rows:
raise RuntimeError(f"No non-empty rows found for field={args.field}")
print(f"Loaded {len(rows)} examples from {args.sample} using {args.field}.")
print("Commands: [Enter]=next | r=rerun current (same input) | q=quit\n")
if args.start < 0 or args.start >= len(rows):
raise ValueError(f"--start must be in [0, {len(rows)-1}] but got {args.start}")
idx = args.start
while True:
row_id, prompt, gt = rows[idx]
print("=" * 80)
print(f"row_id: {row_id}")
print(f"ORIGINAL:\n{prompt}\n")
rewritten = openrouter_chat(
model=args.model,
system=REWRITE_SYSTEM,
user=prompt,
temperature=args.temperature,
max_tokens=args.max_tokens,
)
print(f"REWRITE:\n{rewritten}\n")
if gt:
gt_dict = json.loads(gt)
flat_gt = sorted({tag for tags in gt_dict.values() for tag in tags})
print(f"GROUND TRUTH TAGS:\n{', '.join(flat_gt)}\n")
cmd = input("> ").strip().lower()
if cmd == "q":
break
if cmd == "r":
continue
idx += 1
if idx >= len(rows):
print("End of samples.")
break
if __name__ == "__main__":
main()