#!/usr/bin/env python3 """Extract aligned sample rows from the FinePhrase dataset for the blog data explorer widget. Streams through HuggingFaceFW/finephrase and collects aligned samples where the same source document has outputs for all four prompt configs (faq, math, table, tutorial). Stops once 1000 aligned samples are found. Usage: python app/scripts/extract_finephrase_samples.py """ import json import logging from pathlib import Path from datasets import load_dataset logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") REPO_ID = "HuggingFaceFW/finephrase" PROMPTS = ["faq", "math", "table", "tutorial"] TARGET_SAMPLES = 1000 OUTPUT_PATH = Path(__file__).parent.parent / "public" / "data" / "finephrase-samples.jsonl" def extract_samples() -> list[dict]: """Stream all prompts and collect aligned samples until we reach TARGET_SAMPLES.""" # Accumulate rows by document id for each prompt prompt_data: dict[str, dict[str, dict]] = {p: {} for p in PROMPTS} # Track which doc ids are complete (present in all 4 prompts) complete_ids: list[str] = [] complete_set: set[str] = set() # Stream all prompts in parallel via iterators iterators = {} for prompt in PROMPTS: logger.info(f"Opening stream for {prompt}...") ds = load_dataset(REPO_ID, name=prompt, split="train", streaming=True) iterators[prompt] = iter(ds) batch_size = 500 while len(complete_ids) < TARGET_SAMPLES: # Fetch a batch from each prompt any_progress = False for prompt in PROMPTS: it = iterators[prompt] fetched = 0 for _ in range(batch_size): try: row = next(it) except StopIteration: break fetched += 1 doc_id = row["id"] prompt_data[prompt][doc_id] = row # Check if this doc_id is now complete across all prompts if doc_id not in complete_set and all( doc_id in prompt_data[p] for p in PROMPTS ): complete_set.add(doc_id) complete_ids.append(doc_id) if len(complete_ids) >= TARGET_SAMPLES: break if fetched > 0: any_progress = True logger.info( f" {prompt}: {len(prompt_data[prompt])} rows loaded, " f"{len(complete_ids)} aligned so far" ) if len(complete_ids) >= TARGET_SAMPLES: break if not any_progress: logger.warning("All streams exhausted before reaching target") break logger.info(f"Found {len(complete_ids)} aligned documents") samples: list[dict] = [] for doc_id in complete_ids[:TARGET_SAMPLES]: ref_row = prompt_data[PROMPTS[0]][doc_id] entry: dict = { "id": doc_id, "url": ref_row.get("url", ""), "file_path": ref_row.get("file_path", ""), "source": str(ref_row.get("text", "")), } for prompt in PROMPTS: rollout = prompt_data[prompt][doc_id].get("rollout_results", []) entry[prompt] = str(rollout[0]["text"]) if rollout else "" samples.append(entry) logger.info(f"Built {len(samples)} samples") return samples def main() -> None: samples = extract_samples() OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) with open(OUTPUT_PATH, "w") as f: for sample in samples: f.write(json.dumps(sample, ensure_ascii=False) + "\n") logger.info(f"Saved {len(samples)} samples to {OUTPUT_PATH}") if __name__ == "__main__": main()