finephrase / app /scripts /extract_finephrase_samples.py
joelniklaus's picture
joelniklaus HF Staff
added finephrase section and moved progress monitoring section there from infra
77f7fc5
#!/usr/bin/env python3
"""Extract aligned sample rows from the FinePhrase dataset for the blog data explorer widget.
Streams through HuggingFaceFW/finephrase and collects aligned samples where the
same source document has outputs for all four prompt configs (faq, math, table, tutorial).
Stops once 1000 aligned samples are found.
Usage:
python app/scripts/extract_finephrase_samples.py
"""
import json
import logging
from pathlib import Path
from datasets import load_dataset
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
REPO_ID = "HuggingFaceFW/finephrase"
PROMPTS = ["faq", "math", "table", "tutorial"]
TARGET_SAMPLES = 1000
OUTPUT_PATH = Path(__file__).parent.parent / "public" / "data" / "finephrase-samples.jsonl"
def extract_samples() -> list[dict]:
"""Stream all prompts and collect aligned samples until we reach TARGET_SAMPLES."""
# Accumulate rows by document id for each prompt
prompt_data: dict[str, dict[str, dict]] = {p: {} for p in PROMPTS}
# Track which doc ids are complete (present in all 4 prompts)
complete_ids: list[str] = []
complete_set: set[str] = set()
# Stream all prompts in parallel via iterators
iterators = {}
for prompt in PROMPTS:
logger.info(f"Opening stream for {prompt}...")
ds = load_dataset(REPO_ID, name=prompt, split="train", streaming=True)
iterators[prompt] = iter(ds)
batch_size = 500
while len(complete_ids) < TARGET_SAMPLES:
# Fetch a batch from each prompt
any_progress = False
for prompt in PROMPTS:
it = iterators[prompt]
fetched = 0
for _ in range(batch_size):
try:
row = next(it)
except StopIteration:
break
fetched += 1
doc_id = row["id"]
prompt_data[prompt][doc_id] = row
# Check if this doc_id is now complete across all prompts
if doc_id not in complete_set and all(
doc_id in prompt_data[p] for p in PROMPTS
):
complete_set.add(doc_id)
complete_ids.append(doc_id)
if len(complete_ids) >= TARGET_SAMPLES:
break
if fetched > 0:
any_progress = True
logger.info(
f" {prompt}: {len(prompt_data[prompt])} rows loaded, "
f"{len(complete_ids)} aligned so far"
)
if len(complete_ids) >= TARGET_SAMPLES:
break
if not any_progress:
logger.warning("All streams exhausted before reaching target")
break
logger.info(f"Found {len(complete_ids)} aligned documents")
samples: list[dict] = []
for doc_id in complete_ids[:TARGET_SAMPLES]:
ref_row = prompt_data[PROMPTS[0]][doc_id]
entry: dict = {
"id": doc_id,
"url": ref_row.get("url", ""),
"file_path": ref_row.get("file_path", ""),
"source": str(ref_row.get("text", "")),
}
for prompt in PROMPTS:
rollout = prompt_data[prompt][doc_id].get("rollout_results", [])
entry[prompt] = str(rollout[0]["text"]) if rollout else ""
samples.append(entry)
logger.info(f"Built {len(samples)} samples")
return samples
def main() -> None:
samples = extract_samples()
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_PATH, "w") as f:
for sample in samples:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
logger.info(f"Saved {len(samples)} samples to {OUTPUT_PATH}")
if __name__ == "__main__":
main()